Skip to content

Commit

Permalink
mage: cancel context on SIGINT
Browse files Browse the repository at this point in the history
On receiving an interrupt signal, mage cancels the context allowing the magefile
to perform any cleanup before exiting.

A second interrupt signal will kill the magefile process without delay.

The behaviour for a timeout remains unchanged (context is canclled and the magefile
exits).
  • Loading branch information
pmcatominey committed Aug 11, 2020
1 parent 707b7bd commit 0aa93f8
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 18 deletions.
6 changes: 6 additions & 0 deletions mage/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ import (
"log"
"os"
"os/exec"
"os/signal"
"path/filepath"
"regexp"
"runtime"
"sort"
"strings"
"syscall"
"text/template"
"time"

Expand Down Expand Up @@ -650,6 +652,10 @@ func RunCompiled(inv Invocation, exePath string, errlog *log.Logger) int {
c.Env = append(c.Env, fmt.Sprintf("MAGEFILE_TIMEOUT=%s", inv.Timeout.String()))
}
debug.Print("running magefile with mage vars:\n", strings.Join(filter(c.Env, "MAGEFILE"), "\n"))
// catch SIGINT to allow magefile to handle them
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT)
defer signal.Stop(sigCh)
err := c.Run()
if !sh.CmdRan(err) {
errlog.Printf("failed to run compiled magefile: %v", err)
Expand Down
91 changes: 89 additions & 2 deletions mage/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"runtime"
"strconv"
"strings"
"syscall"
"testing"
"time"

Expand Down Expand Up @@ -1146,7 +1147,7 @@ func TestCompiledFlags(t *testing.T) {
if err == nil {
t.Fatalf("expected an error because of timeout")
}
got = stdout.String()
got = stderr.String()
want = "context deadline exceeded"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
Expand Down Expand Up @@ -1235,7 +1236,7 @@ func TestCompiledEnvironmentVars(t *testing.T) {
if err == nil {
t.Fatalf("expected an error because of timeout")
}
got = stdout.String()
got = stderr.String()
want = "context deadline exceeded"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
Expand Down Expand Up @@ -1305,6 +1306,92 @@ func TestCompiledVerboseFlag(t *testing.T) {
}
}

func TestSignals(t *testing.T) {
stderr := &bytes.Buffer{}
stdout := &bytes.Buffer{}
dir := "./testdata/signals"
compileDir, err := ioutil.TempDir(dir, "")
if err != nil {
t.Fatal(err)
}
name := filepath.Join(compileDir, "mage_out")
// The CompileOut directory is relative to the
// invocation directory, so chop off the invocation dir.
outName := "./" + name[len(dir)-1:]
defer os.RemoveAll(compileDir)
inv := Invocation{
Dir: dir,
Stdout: stdout,
Stderr: stderr,
CompileOut: outName,
}
code := Invoke(inv)
if code != 0 {
t.Errorf("expected to exit with code 0, but got %v, stderr: %s", code, stderr)
}

run := func(stdout, stderr *bytes.Buffer, filename string, target string, signals ...syscall.Signal) error {
stderr.Reset()
stdout.Reset()
cmd := exec.Command(filename, target)
cmd.Stderr = stderr
cmd.Stdout = stdout
if err := cmd.Start(); err != nil {
return fmt.Errorf("running '%s %s' failed with: %v\nstdout: %s\nstderr: %s",
filename, target, err, stdout, stderr)
}
pid := cmd.Process.Pid
go func() {
time.Sleep(time.Millisecond * 500)
for _, s := range signals {
syscall.Kill(pid, s)
time.Sleep(time.Millisecond * 50)
}
}()
if err := cmd.Wait(); err != nil {
return fmt.Errorf("running '%s %s' failed with: %v\nstdout: %s\nstderr: %s",
filename, target, err, stdout, stderr)
}
return nil
}

if err := run(stdout, stderr, name, "exitsAfterSighup", syscall.SIGHUP); err != nil {
t.Fatal(err)
}
got := stdout.String()
want := "received sighup\n"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
}

if err := run(stdout, stderr, name, "exitsAfterSigint", syscall.SIGINT); err != nil {
t.Fatal(err)
}
got = stdout.String()
want = "exiting...done\n"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
}

if err := run(stdout, stderr, name, "exitsAfterCancel", syscall.SIGINT); err != nil {
t.Fatal(err)
}
got = stdout.String()
want = "exiting...done\n"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
}

if err := run(stdout, stderr, name, "ignoresSignals", syscall.SIGINT, syscall.SIGINT); err == nil {
t.Fatalf("expected an error because of force kill")
}
got = stderr.String()
want = "Error: target killed\n"
if strings.Contains(got, want) == false {
t.Errorf("got %q, does not contain %q", got, want)
}
}

func TestClean(t *testing.T) {
if err := os.RemoveAll(mg.CacheDir()); err != nil {
t.Error("error removing cache dir:", err)
Expand Down
46 changes: 30 additions & 16 deletions mage/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ import (
"io/ioutil"
"log"
"os"
"os/signal"
"path/filepath"
"sort"
"strconv"
"strings"
"syscall"
"text/tabwriter"
"time"
{{range .Imports}}{{.UniqueName}} "{{.Path}}"
Expand Down Expand Up @@ -260,17 +262,19 @@ Options:
var ctxCancel func()
getContext := func() (context.Context, func()) {
if ctx != nil {
return ctx, ctxCancel
if ctx == nil {
ctx, ctxCancel = context.WithCancel(context.Background())
}
return ctx, ctxCancel
}
getTimeout := func() <-chan time.Time {
if args.Timeout != 0 {
ctx, ctxCancel = context.WithTimeout(context.Background(), args.Timeout)
} else {
ctx = context.Background()
ctxCancel = func() {}
return time.After(args.Timeout)
}
return ctx, ctxCancel
return make(chan time.Time)
}
runTarget := func(fn func(context.Context) error) interface{} {
Expand All @@ -285,15 +289,25 @@ Options:
err := fn(ctx)
d <- err
}()
select {
case <-ctx.Done():
cancel()
e := ctx.Err()
fmt.Printf("ctx err: %v\n", e)
return e
case err = <-d:
cancel()
return err
timeoutCh := getTimeout()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT)
for {
select {
case <-sigCh:
select {
case <-ctx.Done():
return fmt.Errorf("target killed")
default:
cancel()
}
case <-timeoutCh:
cancel()
return fmt.Errorf("context deadline exceeded")
case err = <-d:
cancel()
return err
}
}
}
// This is necessary in case there aren't any targets, to avoid an unused
Expand Down
47 changes: 47 additions & 0 deletions mage/testdata/signals/signals.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//+build mage

package main

import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
"time"
)

// Exits after receiving SIGHUP
func ExitsAfterSighup(ctx context.Context) {
sigC := make(chan os.Signal, 1)
signal.Notify(sigC, syscall.SIGHUP)
<-sigC
fmt.Println("received sighup")
}

// Exits after SIGINT and wait
func ExitsAfterSigint(ctx context.Context) {
sigC := make(chan os.Signal, 1)
signal.Notify(sigC, syscall.SIGINT)
<-sigC
fmt.Printf("exiting...")
time.Sleep(200 * time.Millisecond)
fmt.Println("done")
}

// Exits after ctx cancel and wait
func ExitsAfterCancel(ctx context.Context) {
<-ctx.Done()
fmt.Printf("exiting...")
time.Sleep(200 * time.Millisecond)
fmt.Println("done")
}

// Ignores all signals, requires killing
func IgnoresSignals(ctx context.Context) {
sigC := make(chan os.Signal, 1)
signal.Notify(sigC, syscall.SIGINT)
for {
<-sigC
}
}

0 comments on commit 0aa93f8

Please sign in to comment.