Skip to content

Commit

Permalink
Add CustomRecovery builtin middleware (#2322)
Browse files Browse the repository at this point in the history
* Add CustomRecovery and CustomRecoveryWithWriter methods

* add CustomRecovery example to README

* add test for CustomRecovery

* support RecoveryWithWriter(io.Writer, ...RecoveryFunc)
  • Loading branch information
theonlyjohnny committed Jul 9, 2020
1 parent 5e40c1d commit 4cabdd3
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 5 deletions.
33 changes: 33 additions & 0 deletions README.md
Expand Up @@ -496,6 +496,39 @@ func main() {
}
```

### Custom Recovery behavior
```go
func main() {
// Creates a router without any middleware by default
r := gin.New()

// Global middleware
// Logger middleware will write the logs to gin.DefaultWriter even if you set with GIN_MODE=release.
// By default gin.DefaultWriter = os.Stdout
r.Use(gin.Logger())

// Recovery middleware recovers from any panics and writes a 500 if there was one.
r.Use(gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
if err, ok := recovered.(string); ok {
c.String(http.StatusInternalServerError, fmt.Sprintf("error: %s", err))
}
c.AbortWithStatus(http.StatusInternalServerError)
}))

r.GET("/panic", func(c *gin.Context) {
// panic with a string -- the custom middleware could save this to a database or report it to the user
panic("foo")
})

r.GET("/", func(c *gin.Context) {
c.String(http.StatusOK, "ohai")
})

// Listen and serve on 0.0.0.0:8080
r.Run(":8080")
}
```

### How to write log file
```go
func main() {
Expand Down
27 changes: 23 additions & 4 deletions recovery.go
Expand Up @@ -26,13 +26,29 @@ var (
slash = []byte("/")
)

// RecoveryFunc defines the function passable to CustomRecovery.
type RecoveryFunc func(c *Context, err interface{})

// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
func Recovery() HandlerFunc {
return RecoveryWithWriter(DefaultErrorWriter)
}

//CustomRecovery returns a middleware that recovers from any panics and calls the provided handle func to handle it.
func CustomRecovery(handle RecoveryFunc) HandlerFunc {
return RecoveryWithWriter(DefaultErrorWriter, handle)
}

// RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one.
func RecoveryWithWriter(out io.Writer) HandlerFunc {
func RecoveryWithWriter(out io.Writer, recovery ...RecoveryFunc) HandlerFunc {
if len(recovery) > 0 {
return CustomRecoveryWithWriter(out, recovery[0])
}
return CustomRecoveryWithWriter(out, defaultHandleRecovery)
}

// CustomRecoveryWithWriter returns a middleware for a given writer that recovers from any panics and calls the provided handle func to handle it.
func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc {
var logger *log.Logger
if out != nil {
logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags)
Expand Down Expand Up @@ -70,20 +86,23 @@ func RecoveryWithWriter(out io.Writer) HandlerFunc {
timeFormat(time.Now()), err, stack, reset)
}
}

// If the connection is dead, we can't write a status to it.
if brokenPipe {
// If the connection is dead, we can't write a status to it.
c.Error(err.(error)) // nolint: errcheck
c.Abort()
} else {
c.AbortWithStatus(http.StatusInternalServerError)
handle(c, err)
}
}
}()
c.Next()
}
}

func defaultHandleRecovery(c *Context, err interface{}) {
c.AbortWithStatus(http.StatusInternalServerError)
}

// stack returns a nicely formatted stack frame, skipping skip frames.
func stack(skip int) []byte {
buf := new(bytes.Buffer) // the returned data
Expand Down
106 changes: 105 additions & 1 deletion recovery_test.go
Expand Up @@ -62,7 +62,7 @@ func TestPanicInHandler(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Contains(t, buffer.String(), "panic recovered")
assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
assert.Contains(t, buffer.String(), "TestPanicInHandler")
assert.Contains(t, buffer.String(), t.Name())
assert.NotContains(t, buffer.String(), "GET /recovery")

// Debug mode prints the request
Expand Down Expand Up @@ -144,3 +144,107 @@ func TestPanicWithBrokenPipe(t *testing.T) {
})
}
}

func TestCustomRecoveryWithWriter(t *testing.T) {
errBuffer := new(bytes.Buffer)
buffer := new(bytes.Buffer)
router := New()
handleRecovery := func(c *Context, err interface{}) {
errBuffer.WriteString(err.(string))
c.AbortWithStatus(http.StatusBadRequest)
}
router.Use(CustomRecoveryWithWriter(buffer, handleRecovery))
router.GET("/recovery", func(_ *Context) {
panic("Oupps, Houston, we have a problem")
})
// RUN
w := performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "panic recovered")
assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
assert.Contains(t, buffer.String(), t.Name())
assert.NotContains(t, buffer.String(), "GET /recovery")

// Debug mode prints the request
SetMode(DebugMode)
// RUN
w = performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "GET /recovery")

assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String())

SetMode(TestMode)
}

func TestCustomRecovery(t *testing.T) {
errBuffer := new(bytes.Buffer)
buffer := new(bytes.Buffer)
router := New()
DefaultErrorWriter = buffer
handleRecovery := func(c *Context, err interface{}) {
errBuffer.WriteString(err.(string))
c.AbortWithStatus(http.StatusBadRequest)
}
router.Use(CustomRecovery(handleRecovery))
router.GET("/recovery", func(_ *Context) {
panic("Oupps, Houston, we have a problem")
})
// RUN
w := performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "panic recovered")
assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
assert.Contains(t, buffer.String(), t.Name())
assert.NotContains(t, buffer.String(), "GET /recovery")

// Debug mode prints the request
SetMode(DebugMode)
// RUN
w = performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "GET /recovery")

assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String())

SetMode(TestMode)
}

func TestRecoveryWithWriterWithCustomRecovery(t *testing.T) {
errBuffer := new(bytes.Buffer)
buffer := new(bytes.Buffer)
router := New()
DefaultErrorWriter = buffer
handleRecovery := func(c *Context, err interface{}) {
errBuffer.WriteString(err.(string))
c.AbortWithStatus(http.StatusBadRequest)
}
router.Use(RecoveryWithWriter(DefaultErrorWriter, handleRecovery))
router.GET("/recovery", func(_ *Context) {
panic("Oupps, Houston, we have a problem")
})
// RUN
w := performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "panic recovered")
assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
assert.Contains(t, buffer.String(), t.Name())
assert.NotContains(t, buffer.String(), "GET /recovery")

// Debug mode prints the request
SetMode(DebugMode)
// RUN
w = performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "GET /recovery")

assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String())

SetMode(TestMode)
}

0 comments on commit 4cabdd3

Please sign in to comment.