Skip to content

Commit

Permalink
mage: cancel context on SIGINT (#313)
Browse files Browse the repository at this point in the history
* mage: cancel context on SIGINT

On receiving an interrupt signal, mage cancels the context allowing the magefile
to perform any cleanup before exiting.

A second interrupt signal will kill the magefile process without delay.

The behaviour for a timeout remains unchanged (context is cancelled and the magefile
exits).

* mage: add cleanup timeout to cancel

Co-authored-by: Nate Finch <nate.finch@gmail.com>
  • Loading branch information
pmcatominey and natefinch committed Nov 29, 2022
1 parent 300bbc8 commit a920604
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 14 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Expand Up @@ -9,6 +9,7 @@ jobs:
fail-fast: false
matrix:
go-version:
- 1.18.x
- 1.17.x
- 1.16.x
- 1.15.x
Expand Down
6 changes: 6 additions & 0 deletions mage/main.go
Expand Up @@ -12,11 +12,13 @@ import (
"log"
"os"
"os/exec"
"os/signal"
"path/filepath"
"regexp"
"runtime"
"sort"
"strings"
"syscall"
"text/template"
"time"

Expand Down Expand Up @@ -737,6 +739,10 @@ func RunCompiled(inv Invocation, exePath string, errlog *log.Logger) int {
c.Env = append(c.Env, fmt.Sprintf("MAGEFILE_TIMEOUT=%s", inv.Timeout.String()))
}
debug.Print("running magefile with mage vars:\n", strings.Join(filter(c.Env, "MAGEFILE"), "\n"))
// catch SIGINT to allow magefile to handle them
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT)
defer signal.Stop(sigCh)
err := c.Run()
if !sh.CmdRan(err) {
errlog.Printf("failed to run compiled magefile: %v", err)
Expand Down
110 changes: 108 additions & 2 deletions mage/main_test.go
Expand Up @@ -22,6 +22,7 @@ import (
"runtime"
"strconv"
"strings"
"syscall"
"testing"
"time"

Expand Down Expand Up @@ -1292,7 +1293,7 @@ func TestCompiledFlags(t *testing.T) {
if err == nil {
t.Fatalf("expected an error because of timeout")
}
got = stdout.String()
got = stderr.String()
want = "context deadline exceeded"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
Expand Down Expand Up @@ -1384,7 +1385,7 @@ func TestCompiledEnvironmentVars(t *testing.T) {
if err == nil {
t.Fatalf("expected an error because of timeout")
}
got = stdout.String()
got = stderr.String()
want = "context deadline exceeded"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
Expand Down Expand Up @@ -1457,6 +1458,111 @@ func TestCompiledVerboseFlag(t *testing.T) {
}
}

func TestSignals(t *testing.T) {
stderr := &bytes.Buffer{}
stdout := &bytes.Buffer{}
dir := "./testdata/signals"
compileDir, err := ioutil.TempDir(dir, "")
if err != nil {
t.Fatal(err)
}
name := filepath.Join(compileDir, "mage_out")
// The CompileOut directory is relative to the
// invocation directory, so chop off the invocation dir.
outName := "./" + name[len(dir)-1:]
defer os.RemoveAll(compileDir)
inv := Invocation{
Dir: dir,
Stdout: stdout,
Stderr: stderr,
CompileOut: outName,
}
code := Invoke(inv)
if code != 0 {
t.Errorf("expected to exit with code 0, but got %v, stderr: %s", code, stderr)
}

run := func(stdout, stderr *bytes.Buffer, filename string, target string, signals ...syscall.Signal) error {
stderr.Reset()
stdout.Reset()
cmd := exec.Command(filename, target)
cmd.Stderr = stderr
cmd.Stdout = stdout
if err := cmd.Start(); err != nil {
return fmt.Errorf("running '%s %s' failed with: %v\nstdout: %s\nstderr: %s",
filename, target, err, stdout, stderr)
}
pid := cmd.Process.Pid
go func() {
time.Sleep(time.Millisecond * 500)
for _, s := range signals {
syscall.Kill(pid, s)
time.Sleep(time.Millisecond * 50)
}
}()
if err := cmd.Wait(); err != nil {
return fmt.Errorf("running '%s %s' failed with: %v\nstdout: %s\nstderr: %s",
filename, target, err, stdout, stderr)
}
return nil
}

if err := run(stdout, stderr, name, "exitsAfterSighup", syscall.SIGHUP); err != nil {
t.Fatal(err)
}
got := stdout.String()
want := "received sighup\n"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
}

if err := run(stdout, stderr, name, "exitsAfterSigint", syscall.SIGINT); err != nil {
t.Fatal(err)
}
got = stdout.String()
want = "exiting...done\n"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
}
got = stderr.String()
want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\n"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
}

if err := run(stdout, stderr, name, "exitsAfterCancel", syscall.SIGINT); err != nil {
t.Fatal(err)
}
got = stdout.String()
want = "exiting...done\ndeferred cleanup\n"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
}
got = stderr.String()
want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\n"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
}

if err := run(stdout, stderr, name, "ignoresSignals", syscall.SIGINT, syscall.SIGINT); err == nil {
t.Fatalf("expected an error because of force kill")
}
got = stderr.String()
want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\nexiting mage\nError: exit forced\n"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
}

if err := run(stdout, stderr, name, "ignoresSignals", syscall.SIGINT); err == nil {
t.Fatalf("expected an error because of force kill")
}
got = stderr.String()
want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\nError: cleanup timeout exceeded\n"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
}
}

func TestCompiledDeterministic(t *testing.T) {
dir := "./testdata/compiled"
compileDir, err := ioutil.TempDir(dir, "")
Expand Down
48 changes: 37 additions & 11 deletions mage/template.go
Expand Up @@ -14,10 +14,12 @@ import (
_ioutil "io/ioutil"
_log "log"
"os"
"os/signal"
_filepath "path/filepath"
_sort "sort"
"strconv"
_strings "strings"
"syscall"
_tabwriter "text/tabwriter"
"time"
{{range .Imports}}{{.UniqueName}} "{{.Path}}"
Expand Down Expand Up @@ -256,23 +258,27 @@ Options:
}
var ctx context.Context
var ctxCancel func()
ctxCancel := func(){}
// by deferring in a closure, we let the cancel function get replaced
// by the getContext function.
defer func() {
ctxCancel()
}()
getContext := func() (context.Context, func()) {
if ctx != nil {
return ctx, ctxCancel
if ctx == nil {
if args.Timeout != 0 {
ctx, ctxCancel = context.WithTimeout(context.Background(), args.Timeout)
} else {
ctx, ctxCancel = context.WithCancel(context.Background())
}
}
if args.Timeout != 0 {
ctx, ctxCancel = context.WithTimeout(context.Background(), args.Timeout)
} else {
ctx = context.Background()
ctxCancel = func() {}
}
return ctx, ctxCancel
}
runTarget := func(fn func(context.Context) error) interface{} {
runTarget := func(logger *_log.Logger, fn func(context.Context) error) interface{} {
var err interface{}
ctx, cancel := getContext()
d := make(chan interface{})
Expand All @@ -284,14 +290,34 @@ Options:
err := fn(ctx)
d <- err
}()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT)
select {
case <-sigCh:
logger.Println("cancelling mage targets, waiting up to 5 seconds for cleanup...")
cancel()
cleanupCh := time.After(5 * time.Second)
select {
// target exited by itself
case err = <-d:
return err
// cleanup timeout exceeded
case <-cleanupCh:
return _fmt.Errorf("cleanup timeout exceeded")
// second SIGINT received
case <-sigCh:
logger.Println("exiting mage")
return _fmt.Errorf("exit forced")
}
case <-ctx.Done():
cancel()
e := ctx.Err()
_fmt.Printf("ctx err: %v\n", e)
return e
case err = <-d:
cancel()
// we intentionally don't cancel the context here, because
// the next target will need to run with the same context.
return err
}
}
Expand Down
50 changes: 50 additions & 0 deletions mage/testdata/signals/signals.go
@@ -0,0 +1,50 @@
//+build mage

package main

import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
"time"
)

// Exits after receiving SIGHUP
func ExitsAfterSighup(ctx context.Context) {
sigC := make(chan os.Signal, 1)
signal.Notify(sigC, syscall.SIGHUP)
<-sigC
fmt.Println("received sighup")
}

// Exits after SIGINT and wait
func ExitsAfterSigint(ctx context.Context) {
sigC := make(chan os.Signal, 1)
signal.Notify(sigC, syscall.SIGINT)
<-sigC
fmt.Printf("exiting...")
time.Sleep(200 * time.Millisecond)
fmt.Println("done")
}

// Exits after ctx cancel and wait
func ExitsAfterCancel(ctx context.Context) {
defer func() {
fmt.Println("deferred cleanup")
}()
<-ctx.Done()
fmt.Printf("exiting...")
time.Sleep(200 * time.Millisecond)
fmt.Println("done")
}

// Ignores all signals, requires killing via timeout or second SIGINT
func IgnoresSignals(ctx context.Context) {
sigC := make(chan os.Signal, 1)
signal.Notify(sigC, syscall.SIGINT)
for {
<-sigC
}
}
2 changes: 1 addition & 1 deletion parse/parse.go
Expand Up @@ -169,7 +169,7 @@ func (f Function) ExecCode() string {
}
out += `
}
ret := runTarget(wrapFn)`
ret := runTarget(logger, wrapFn)`
return out
}

Expand Down

0 comments on commit a920604

Please sign in to comment.