From 869e12867570e06a3a0f4a5ad7836533ce96d9f3 Mon Sep 17 00:00:00 2001 From: Johnny Dallas Date: Thu, 16 Apr 2020 09:08:19 +0000 Subject: [PATCH 1/4] Add CustomRecovery and CustomRecoveryWithWriter methods --- recovery.go | 22 +++++++++++++++++++--- recovery_test.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/recovery.go b/recovery.go index bc946c03ae..2eeff85419 100644 --- a/recovery.go +++ b/recovery.go @@ -26,13 +26,26 @@ var ( slash = []byte("/") ) +// RecoveryFunc defines the function passable to CustomRecovery. +type RecoveryFunc func(c *Context, err interface{}) + // Recovery returns a middleware that recovers from any panics and writes a 500 if there was one. func Recovery() HandlerFunc { return RecoveryWithWriter(DefaultErrorWriter) } +//CustomRecovery returns a middleware that recovers from any panics and calls the provided handle func to handle it. +func CustomRecovery(handle RecoveryFunc) HandlerFunc { + return CustomRecoveryWithWriter(DefaultErrorWriter, handle) +} + // RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one. func RecoveryWithWriter(out io.Writer) HandlerFunc { + return CustomRecoveryWithWriter(out, defaultHandleRecovery) +} + +// CustomRecoveryWithWriter returns a middleware for a given writer that recovers from any panics and calls the provided handle func to handle it. +func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc { var logger *log.Logger if out != nil { logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags) @@ -70,13 +83,12 @@ func RecoveryWithWriter(out io.Writer) HandlerFunc { timeFormat(time.Now()), err, stack, reset) } } - - // If the connection is dead, we can't write a status to it. if brokenPipe { + // If the connection is dead, we can't write a status to it. c.Error(err.(error)) // nolint: errcheck c.Abort() } else { - c.AbortWithStatus(http.StatusInternalServerError) + handle(c, err) } } }() @@ -84,6 +96,10 @@ func RecoveryWithWriter(out io.Writer) HandlerFunc { } } +func defaultHandleRecovery(c *Context, err interface{}) { + c.AbortWithStatus(http.StatusInternalServerError) +} + // stack returns a nicely formatted stack frame, skipping skip frames. func stack(skip int) []byte { buf := new(bytes.Buffer) // the returned data diff --git a/recovery_test.go b/recovery_test.go index 21a0a48012..dbe8ebd4cd 100644 --- a/recovery_test.go +++ b/recovery_test.go @@ -144,3 +144,37 @@ func TestPanicWithBrokenPipe(t *testing.T) { }) } } + +func TestCustomRecoveryWithWriter(t *testing.T) { + errBuffer := new(bytes.Buffer) + buffer := new(bytes.Buffer) + router := New() + handleRecovery := func(c *Context, err interface{}) { + errBuffer.WriteString(err.(string)) + c.AbortWithStatus(http.StatusBadRequest) + } + router.Use(CustomRecoveryWithWriter(buffer, handleRecovery)) + router.GET("/recovery", func(_ *Context) { + panic("Oupps, Houston, we have a problem") + }) + // RUN + w := performRequest(router, "GET", "/recovery") + // TEST + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, buffer.String(), "panic recovered") + assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem") + assert.Contains(t, buffer.String(), "TestCustomRecoveryWithWriter") + assert.NotContains(t, buffer.String(), "GET /recovery") + + // Debug mode prints the request + SetMode(DebugMode) + // RUN + w = performRequest(router, "GET", "/recovery") + // TEST + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, buffer.String(), "GET /recovery") + + assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String()) + + SetMode(TestMode) +} From 6427797fc0e71b44691c599a559255bcceec3d83 Mon Sep 17 00:00:00 2001 From: Johnny Dallas Date: Thu, 16 Apr 2020 09:42:50 +0000 Subject: [PATCH 2/4] add CustomRecovery example to README --- README.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/README.md b/README.md index 9dde693d68..467db3fca9 100644 --- a/README.md +++ b/README.md @@ -490,6 +490,39 @@ func main() { } ``` +### Custom Recovery behavior +```go +func main() { + // Creates a router without any middleware by default + r := gin.New() + + // Global middleware + // Logger middleware will write the logs to gin.DefaultWriter even if you set with GIN_MODE=release. + // By default gin.DefaultWriter = os.Stdout + r.Use(gin.Logger()) + + // Recovery middleware recovers from any panics and writes a 500 if there was one. + r.Use(gin.CustomRecovery(func(c *gin.Context, recovered interface{}) { + if err, ok := recovered.(string); ok { + c.String(http.StatusInternalServerError, fmt.Sprintf("error: %s", err)) + } + c.AbortWithStatus(http.StatusInternalServerError) + })) + + r.GET("/panic", func(c *gin.Context) { + // panic with a string -- the custom middleware could save this to a database or report it to the user + panic("foo") + }) + + r.GET("/", func(c *gin.Context) { + c.String(http.StatusOK, "ohai") + }) + + // Listen and serve on 0.0.0.0:8080 + r.Run(":8080") +} +``` + ### How to write log file ```go func main() { From c77034564cf66a324718fb90091f0aad4635eff8 Mon Sep 17 00:00:00 2001 From: Johnny Dallas Date: Thu, 16 Apr 2020 20:20:04 +0000 Subject: [PATCH 3/4] add test for CustomRecovery --- recovery_test.go | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/recovery_test.go b/recovery_test.go index dbe8ebd4cd..ed8f648985 100644 --- a/recovery_test.go +++ b/recovery_test.go @@ -178,3 +178,38 @@ func TestCustomRecoveryWithWriter(t *testing.T) { SetMode(TestMode) } + +func TestCustomRecovery(t *testing.T) { + errBuffer := new(bytes.Buffer) + buffer := new(bytes.Buffer) + router := New() + DefaultErrorWriter = buffer + handleRecovery := func(c *Context, err interface{}) { + errBuffer.WriteString(err.(string)) + c.AbortWithStatus(http.StatusBadRequest) + } + router.Use(CustomRecovery(handleRecovery)) + router.GET("/recovery", func(_ *Context) { + panic("Oupps, Houston, we have a problem") + }) + // RUN + w := performRequest(router, "GET", "/recovery") + // TEST + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, buffer.String(), "panic recovered") + assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem") + assert.Contains(t, buffer.String(), "TestCustomRecovery") + assert.NotContains(t, buffer.String(), "GET /recovery") + + // Debug mode prints the request + SetMode(DebugMode) + // RUN + w = performRequest(router, "GET", "/recovery") + // TEST + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, buffer.String(), "GET /recovery") + + assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String()) + + SetMode(TestMode) +} From 5782c13d25b2ec5d137f6965f34513b71cfe79df Mon Sep 17 00:00:00 2001 From: Johnny Dallas Date: Fri, 12 Jun 2020 10:22:53 +0000 Subject: [PATCH 4/4] support RecoveryWithWriter(io.Writer, ...RecoveryFunc) --- recovery.go | 7 +++++-- recovery_test.go | 41 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/recovery.go b/recovery.go index 2f4caae031..d02b829b97 100644 --- a/recovery.go +++ b/recovery.go @@ -36,11 +36,14 @@ func Recovery() HandlerFunc { //CustomRecovery returns a middleware that recovers from any panics and calls the provided handle func to handle it. func CustomRecovery(handle RecoveryFunc) HandlerFunc { - return CustomRecoveryWithWriter(DefaultErrorWriter, handle) + return RecoveryWithWriter(DefaultErrorWriter, handle) } // RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one. -func RecoveryWithWriter(out io.Writer) HandlerFunc { +func RecoveryWithWriter(out io.Writer, recovery ...RecoveryFunc) HandlerFunc { + if len(recovery) > 0 { + return CustomRecoveryWithWriter(out, recovery[0]) + } return CustomRecoveryWithWriter(out, defaultHandleRecovery) } diff --git a/recovery_test.go b/recovery_test.go index ed8f648985..6cc2a47af4 100644 --- a/recovery_test.go +++ b/recovery_test.go @@ -62,7 +62,7 @@ func TestPanicInHandler(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Contains(t, buffer.String(), "panic recovered") assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem") - assert.Contains(t, buffer.String(), "TestPanicInHandler") + assert.Contains(t, buffer.String(), t.Name()) assert.NotContains(t, buffer.String(), "GET /recovery") // Debug mode prints the request @@ -163,7 +163,7 @@ func TestCustomRecoveryWithWriter(t *testing.T) { assert.Equal(t, http.StatusBadRequest, w.Code) assert.Contains(t, buffer.String(), "panic recovered") assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem") - assert.Contains(t, buffer.String(), "TestCustomRecoveryWithWriter") + assert.Contains(t, buffer.String(), t.Name()) assert.NotContains(t, buffer.String(), "GET /recovery") // Debug mode prints the request @@ -198,7 +198,42 @@ func TestCustomRecovery(t *testing.T) { assert.Equal(t, http.StatusBadRequest, w.Code) assert.Contains(t, buffer.String(), "panic recovered") assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem") - assert.Contains(t, buffer.String(), "TestCustomRecovery") + assert.Contains(t, buffer.String(), t.Name()) + assert.NotContains(t, buffer.String(), "GET /recovery") + + // Debug mode prints the request + SetMode(DebugMode) + // RUN + w = performRequest(router, "GET", "/recovery") + // TEST + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, buffer.String(), "GET /recovery") + + assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String()) + + SetMode(TestMode) +} + +func TestRecoveryWithWriterWithCustomRecovery(t *testing.T) { + errBuffer := new(bytes.Buffer) + buffer := new(bytes.Buffer) + router := New() + DefaultErrorWriter = buffer + handleRecovery := func(c *Context, err interface{}) { + errBuffer.WriteString(err.(string)) + c.AbortWithStatus(http.StatusBadRequest) + } + router.Use(RecoveryWithWriter(DefaultErrorWriter, handleRecovery)) + router.GET("/recovery", func(_ *Context) { + panic("Oupps, Houston, we have a problem") + }) + // RUN + w := performRequest(router, "GET", "/recovery") + // TEST + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, buffer.String(), "panic recovered") + assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem") + assert.Contains(t, buffer.String(), t.Name()) assert.NotContains(t, buffer.String(), "GET /recovery") // Debug mode prints the request