diff --git a/hooks.go b/hooks.go index 3ea7a000140..571ac252105 100644 --- a/hooks.go +++ b/hooks.go @@ -7,6 +7,7 @@ type OnGroupHandler = func(Group) error type OnGroupNameHandler = OnGroupHandler type OnListenHandler = func() error type OnShutdownHandler = OnListenHandler +type OnForkHandler = func(int) error type hooks struct { // Embed app @@ -19,6 +20,7 @@ type hooks struct { onGroupName []OnGroupNameHandler onListen []OnListenHandler onShutdown []OnShutdownHandler + onFork []OnForkHandler } func newHooks(app *App) *hooks { @@ -30,6 +32,7 @@ func newHooks(app *App) *hooks { onName: make([]OnNameHandler, 0), onListen: make([]OnListenHandler, 0), onShutdown: make([]OnShutdownHandler, 0), + onFork: make([]OnForkHandler, 0), } } @@ -83,6 +86,13 @@ func (h *hooks) OnShutdown(handler ...OnShutdownHandler) { h.app.mutex.Unlock() } +// OnFork is a hook to execute user function after fork process. +func (h *hooks) OnFork(handler ...OnForkHandler) { + h.app.mutex.Lock() + h.onFork = append(h.onFork, handler...) + h.app.mutex.Unlock() +} + func (h *hooks) executeOnRouteHooks(route Route) error { for _, v := range h.onRoute { if err := v(route); err != nil { @@ -138,3 +148,9 @@ func (h *hooks) executeOnShutdownHooks() { _ = v() } } + +func (h *hooks) executeOnForkHooks(pid int) { + for _, v := range h.onFork { + _ = v(pid) + } +} diff --git a/hooks_test.go b/hooks_test.go index 2339e5306c8..5a800920d2d 100644 --- a/hooks_test.go +++ b/hooks_test.go @@ -176,3 +176,23 @@ func Test_Hook_OnListen(t *testing.T) { utils.AssertEqual(t, "ready", buf.String()) } + +func Test_Hook_OnHook(t *testing.T) { + // Reset test var + testPreforkMaster = true + testOnPrefork = true + + app := New() + + go func() { + time.Sleep(1000 * time.Millisecond) + utils.AssertEqual(t, nil, app.Shutdown()) + }() + + app.Hooks().OnFork(func(pid int) error { + utils.AssertEqual(t, 1, pid) + return nil + }) + + utils.AssertEqual(t, nil, app.prefork(NetworkTCP4, ":3000", nil)) +} diff --git a/prefork.go b/prefork.go index 459e4089306..b3049abc60a 100644 --- a/prefork.go +++ b/prefork.go @@ -20,6 +20,7 @@ const ( ) var testPreforkMaster = false +var testOnPrefork = false // IsChild determines if the current process is a child of Prefork func IsChild() bool { @@ -102,6 +103,15 @@ func (app *App) prefork(network, addr string, tlsConfig *tls.Config) (err error) childs[pid] = cmd pids = append(pids, strconv.Itoa(pid)) + // execute fork hook + if app.hooks != nil { + if testOnPrefork { + app.hooks.executeOnForkHooks(dummyPid) + } else { + app.hooks.executeOnForkHooks(pid) + } + } + // notify master if child crashes go func() { channel <- child{pid, cmd.Wait()} @@ -146,3 +156,5 @@ func dummyCmd() *exec.Cmd { } return exec.Command(dummyChildCmd, "version") } + +var dummyPid = 1