Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WithEnableSIGTERM option #457

Merged
merged 13 commits into from
Jul 28, 2022
6 changes: 5 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@ jobs:

- run: go version

- name: install lambda runtime interface emulator
run: curl -L -o /usr/local/bin/aws-lambda-rie https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest/download/aws-lambda-rie-x86_64
- run: chmod +x /usr/local/bin/aws-lambda-rie

- name: Check out code into the Go module directory
uses: actions/checkout@v2

- name: go test
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
run: go test -v -race -coverprofile=coverage.txt -covermode=atomic ./...

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
Expand Down
90 changes: 90 additions & 0 deletions lambda/extensions_api_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package lambda

import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
)

const (
headerExtensionName = "Lambda-Extension-Name"
headerExtensionIdentifier = "Lambda-Extension-Identifier"
extensionAPIVersion = "2020-01-01"
)

type extensionAPIEventType string

const (
extensionInvokeEvent extensionAPIEventType = "INVOKE" //nolint:deadcode,unused,varcheck
extensionShutdownEvent extensionAPIEventType = "SHUTDOWN" //nolint:deadcode,unused,varcheck
)

type extensionAPIClient struct {
baseURL string
httpClient *http.Client
}

func newExtensionAPIClient(address string) *extensionAPIClient {
client := &http.Client{
Timeout: 0, // connections to the extensions API are never expected to time out
}
endpoint := "http://" + address + "/" + extensionAPIVersion + "/extension/"
return &extensionAPIClient{
baseURL: endpoint,
httpClient: client,
}
}

func (c *extensionAPIClient) register(name string, events ...extensionAPIEventType) (string, error) {
url := c.baseURL + "register"
body, _ := json.Marshal(struct {
Events []extensionAPIEventType `json:"events"`
}{
Events: events,
})

req, _ := http.NewRequest(http.MethodPost, url, bytes.NewReader(body))
req.Header.Add(headerExtensionName, name)
res, err := c.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to register extension: %v", err)
}
defer res.Body.Close()
_, _ = io.Copy(ioutil.Discard, res.Body)

if res.StatusCode != http.StatusOK {
return "", fmt.Errorf("failed to register extension, got response status: %d %s", res.StatusCode, http.StatusText(res.StatusCode))
}

return res.Header.Get(headerExtensionIdentifier), nil
}

type extensionEventResponse struct {
EventType extensionAPIEventType
// ... the rest not implemented
}

func (c *extensionAPIClient) next(id string) (response extensionEventResponse, err error) {
url := c.baseURL + "event/next"

req, _ := http.NewRequest(http.MethodGet, url, nil)
req.Header.Add(headerExtensionIdentifier, id)
res, err := c.httpClient.Do(req)
if err != nil {
err = fmt.Errorf("failed to get extension event: %v", err)
return
}
defer res.Body.Close()
_, _ = io.Copy(ioutil.Discard, res.Body)

if res.StatusCode != http.StatusOK {
err = fmt.Errorf("failed to register extension, got response status: %d %s", res.StatusCode, http.StatusText(res.StatusCode))
return
}

err = json.NewDecoder(res.Body).Decode(&response)
return
}
25 changes: 25 additions & 0 deletions lambda/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ type handlerOptions struct {
jsonResponseEscapeHTML bool
jsonResponseIndentPrefix string
jsonResponseIndentValue string
enableSIGTERM bool
sigtermCallbacks []func()
carlzogh marked this conversation as resolved.
Show resolved Hide resolved
}

type Option func(*handlerOptions)
Expand Down Expand Up @@ -73,6 +75,26 @@ func WithSetIndent(prefix, indent string) Option {
})
}

// WithEnableSIGTERM enables SIGTERM behavior within the Lambda platform on container spindown.
// SIGKILL will occur ~500ms after SIGTERM.
// Optionally, an array of callback functions to run on SIGTERM may be provided.
//
// Usage:
// lambda.StartWithOptions(
// func (event any) (any error) {
// return event, nil
// },
// lambda.WithEnableSIGTERM(func() {
// log.Print("function container shutting down...")
// })
// )
func WithEnableSIGTERM(callbacks ...func()) Option {
return Option(func(h *handlerOptions) {
h.sigtermCallbacks = append(h.sigtermCallbacks, callbacks...)
h.enableSIGTERM = true
})
}

func validateArguments(handler reflect.Type) (bool, error) {
handlerTakesContext := false
if handler.NumIn() > 2 {
Expand Down Expand Up @@ -139,6 +161,9 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions {
for _, option := range options {
option(h)
}
if h.enableSIGTERM {
enableSIGTERM(h.sigtermCallbacks)
}
h.Handler = reflectHandler(handlerFunc, h)
return h
}
Expand Down
53 changes: 53 additions & 0 deletions lambda/sigterm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
carlzogh marked this conversation as resolved.
Show resolved Hide resolved

package lambda

import (
"log"
"os"
"os/signal"
"syscall"
)

// enableSIGTERM configures an optional list of sigtermHandlers to run on process shutdown.
// This non-default behavior is enabled within Lambda using the extensions API.
func enableSIGTERM(sigtermHandlers []func()) {
// for fun, we'll also optionally register SIGTERM handlers
if len(sigtermHandlers) > 0 {
signaled := make(chan os.Signal, 1)
signal.Notify(signaled, syscall.SIGTERM)
go func() {
<-signaled
for _, f := range sigtermHandlers {
f()
}
}()
}

// detect if we're actually running within Lambda
endpoint := os.Getenv("AWS_LAMBDA_RUNTIME_API")
if endpoint == "" {
log.Print("WARNING! AWS_LAMBDA_RUNTIME_API environment variable not found. Skipping attempt to register internal extension...")
return
}

// Now to do the AWS Lambda specific stuff.
// The default Lambda behavior is for functions to get SIGKILL at the end of lifetime, or after a timeout.
// Any use of the Lambda extension register API enables SIGTERM to be sent to the function process before the SIGKILL.
// We'll register an extension that does not listen for any lifecycle events named "GoLangEnableSIGTERM".
// The API will respond with an ID we need to pass in future requests.
client := newExtensionAPIClient(endpoint)
id, err := client.register("GoLangEnableSIGTERM")
if err != nil {
log.Printf("WARNING! Failed to register internal extension! SIGTERM events may not be enabled! err: %v", err)
return
}

// We didn't actually register for any events, but we need to call /next anyways to let the API know we're done initalizing.
// Because we didn't register for any events, /next will never return, so we'll do this in a go routine that is doomed to stay blocked.
go func() {
_, err := client.next(id)
log.Printf("WARNING! Reached expected unreachable code! Extension /next call expected to block forever! err: %v", err)
}()

}
93 changes: 93 additions & 0 deletions lambda/sigterm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
//go:build go1.15
// +build go1.15

package lambda

import (
"io/ioutil"
"net/http"
"os"
"os/exec"
"path"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const (
rieInvokeAPI = "http://localhost:8080/2015-03-31/functions/function/invocations"
)

func TestEnableSigterm(t *testing.T) {
if _, err := exec.LookPath("aws-lambda-rie"); err != nil {
t.Skipf("%v - install from https://github.com/aws/aws-lambda-runtime-interface-emulator/", err)
}

testDir := t.TempDir()

// compile our handler, it'll always run to timeout ensuring the SIGTERM is triggered by aws-lambda-rie
handlerBuild := exec.Command("go", "build", "-o", path.Join(testDir, "sigterm.handler"), "./testdata/sigterm.go")
handlerBuild.Stderr = os.Stderr
handlerBuild.Stdout = os.Stderr
require.NoError(t, handlerBuild.Run())

for name, opts := range map[string]struct {
envVars []string
assertLogs func(t *testing.T, logs string)
}{
"baseline": {
assertLogs: func(t *testing.T, logs string) {
assert.NotContains(t, logs, "Hello SIGTERM!")
assert.NotContains(t, logs, "I've been TERMINATED!")
},
},
"sigterm enabled": {
envVars: []string{"ENABLE_SIGTERM=please"},
assertLogs: func(t *testing.T, logs string) {
assert.Contains(t, logs, "Hello SIGTERM!")
assert.Contains(t, logs, "I've been TERMINATED!")
},
},
} {
t.Run(name, func(t *testing.T) {
// run the runtime interface emulator, capture the logs for assertion
cmd := exec.Command("aws-lambda-rie", "sigterm.handler")
cmd.Env = append([]string{
"PATH=" + testDir,
"AWS_LAMBDA_FUNCTION_TIMEOUT=2",
}, opts.envVars...)
cmd.Stderr = os.Stderr
stdout, err := cmd.StdoutPipe()
require.NoError(t, err)
var logs string
done := make(chan interface{}) // closed on completion of log flush
go func() {
logBytes, err := ioutil.ReadAll(stdout)
require.NoError(t, err)
logs = string(logBytes)
close(done)
}()
require.NoError(t, cmd.Start())
t.Cleanup(func() { _ = cmd.Process.Kill() })

// give a moment for the port to bind
time.Sleep(500 * time.Millisecond)

client := &http.Client{Timeout: 5 * time.Second} // http client timeout to prevent case from hanging on aws-lambda-rie
resp, err := client.Post(rieInvokeAPI, "application/json", strings.NewReader("{}"))
require.NoError(t, err)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, string(body), "Task timed out after 2.00 seconds")

require.NoError(t, cmd.Process.Kill()) // now ensure the logs are drained
<-done
t.Logf("stdout:\n%s", logs)
opts.assertLogs(t, logs)
})
}
}
42 changes: 42 additions & 0 deletions lambda/testdata/sigterm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package main

import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
"time"

"github.com/aws/aws-lambda-go/lambda"
)

func init() {
// conventional SIGTERM callback
signaled := make(chan os.Signal, 1)
signal.Notify(signaled, syscall.SIGTERM)
go func() {
<-signaled
fmt.Println("I've been TERMINATED!")
}()

}

func main() {
// lambda option to enable sigterm, plus optional extra sigterm callbacks
sigtermOption := lambda.WithEnableSIGTERM(func() {
fmt.Println("Hello SIGTERM!")
})
handlerOptions := []lambda.Option{}
if os.Getenv("ENABLE_SIGTERM") != "" {
handlerOptions = append(handlerOptions, sigtermOption)
}
lambda.StartWithOptions(
func(ctx context.Context) {
deadline, _ := ctx.Deadline()
<-time.After(time.Until(deadline) + time.Second)
panic("unreachable line reached!")
},
handlerOptions...,
)
}