Skip to content

Commit

Permalink
Fix issue #377 (#475) - Adds validation to reject handlers that would…
Browse files Browse the repository at this point in the history
… result in a panic when constructing the context

* add test case for #377

* fix panicking

* use interface{} instead of any

* add a comment for argumentType.NumMethod() == 0
  • Loading branch information
shogo82148 committed Dec 22, 2022
1 parent 65f8ccd commit ad74310
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 12 deletions.
2 changes: 1 addition & 1 deletion lambda/entry_generic_test.go
Expand Up @@ -27,7 +27,7 @@ func TestStartHandlerFunc(t *testing.T) {

handlerType := reflect.TypeOf(f)

handlerTakesContext, err := validateArguments(handlerType)
handlerTakesContext, err := handlerTakesContext(handlerType)
assert.NoError(t, err)
assert.True(t, handlerTakesContext)

Expand Down
36 changes: 26 additions & 10 deletions lambda/handler.go
Expand Up @@ -99,20 +99,36 @@ func WithEnableSIGTERM(callbacks ...func()) Option {
})
}

func validateArguments(handler reflect.Type) (bool, error) {
handlerTakesContext := false
if handler.NumIn() > 2 {
return false, fmt.Errorf("handlers may not take more than two arguments, but handler takes %d", handler.NumIn())
} else if handler.NumIn() > 0 {
// handlerTakesContext returns whether the handler takes a context.Context as its first argument.
func handlerTakesContext(handler reflect.Type) (bool, error) {
switch handler.NumIn() {
case 0:
return false, nil
case 1:
contextType := reflect.TypeOf((*context.Context)(nil)).Elem()
argumentType := handler.In(0)
handlerTakesContext = argumentType.Implements(contextType)
if handler.NumIn() > 1 && !handlerTakesContext {
if argumentType.Kind() != reflect.Interface {
return false, nil
}

// handlers like func(event any) are valid.
if argumentType.NumMethod() == 0 {
return false, nil
}

if !contextType.Implements(argumentType) || !argumentType.Implements(contextType) {
return false, fmt.Errorf("handler takes an interface, but it is not context.Context: %q", argumentType.Name())
}
return true, nil
case 2:
contextType := reflect.TypeOf((*context.Context)(nil)).Elem()
argumentType := handler.In(0)
if argumentType.Kind() != reflect.Interface || !contextType.Implements(argumentType) || !argumentType.Implements(contextType) {
return false, fmt.Errorf("handler takes two arguments, but the first is not Context. got %s", argumentType.Kind())
}
return true, nil
}

return handlerTakesContext, nil
return false, fmt.Errorf("handlers may not take more than two arguments, but handler takes %d", handler.NumIn())
}

func validateReturns(handler reflect.Type) error {
Expand Down Expand Up @@ -198,7 +214,7 @@ func reflectHandler(handlerFunc interface{}, h *handlerOptions) Handler {
return errorHandler(fmt.Errorf("handler kind %s is not %s", handlerType.Kind(), reflect.Func))
}

takesContext, err := validateArguments(handlerType)
takesContext, err := handlerTakesContext(handlerType)
if err != nil {
return errorHandler(err)
}
Expand Down
64 changes: 63 additions & 1 deletion lambda/handler_test.go
Expand Up @@ -7,13 +7,29 @@ import (
"errors"
"fmt"
"testing"
"time"

"github.com/aws/aws-lambda-go/lambda/handlertrace"
"github.com/aws/aws-lambda-go/lambda/messages"
"github.com/stretchr/testify/assert"
)

func TestInvalidHandlers(t *testing.T) {
type valuer interface {
Value(key interface{}) interface{}
}

type customContext interface {
context.Context
MyCustomMethod()
}

type myContext interface {
Deadline() (deadline time.Time, ok bool)
Done() <-chan struct{}
Err() error
Value(key interface{}) interface{}
}

testCases := []struct {
name string
Expand Down Expand Up @@ -72,12 +88,58 @@ func TestInvalidHandlers(t *testing.T) {
handler: func() {
},
},
{
name: "the handler takes the empty interface",
expected: nil,
handler: func(v interface{}) error {
if _, ok := v.(context.Context); ok {
return errors.New("v should not be a Context")
}
return nil
},
},
{
name: "the handler takes a subset of context.Context",
expected: errors.New("handler takes an interface, but it is not context.Context: \"valuer\""),
handler: func(ctx valuer) {
},
},
{
name: "the handler takes a same interface with context.Context",
expected: nil,
handler: func(ctx myContext) {
},
},
{
name: "the handler takes a superset of context.Context",
expected: errors.New("handler takes an interface, but it is not context.Context: \"customContext\""),
handler: func(ctx customContext) {
},
},
{
name: "the handler takes two arguments and first argument is a subset of context.Context",
expected: errors.New("handler takes two arguments, but the first is not Context. got interface"),
handler: func(ctx valuer, v interface{}) {
},
},
{
name: "the handler takes two arguments and first argument is a same interface with context.Context",
expected: nil,
handler: func(ctx myContext, v interface{}) {
},
},
{
name: "the handler takes two arguments and first argument is a superset of context.Context",
expected: errors.New("handler takes two arguments, but the first is not Context. got interface"),
handler: func(ctx customContext, v interface{}) {
},
},
}
for i, testCase := range testCases {
testCase := testCase
t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func(t *testing.T) {
lambdaHandler := NewHandler(testCase.handler)
_, err := lambdaHandler.Invoke(context.TODO(), make([]byte, 0))
_, err := lambdaHandler.Invoke(context.TODO(), []byte("{}"))
assert.Equal(t, testCase.expected, err)
})
}
Expand Down

0 comments on commit ad74310

Please sign in to comment.