Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CustomRecovery builtin middleware #2322

Merged
merged 6 commits into from Jul 9, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 33 additions & 0 deletions README.md
Expand Up @@ -496,6 +496,39 @@ func main() {
}
```

### Custom Recovery behavior
```go
func main() {
// Creates a router without any middleware by default
r := gin.New()

// Global middleware
// Logger middleware will write the logs to gin.DefaultWriter even if you set with GIN_MODE=release.
// By default gin.DefaultWriter = os.Stdout
r.Use(gin.Logger())

// Recovery middleware recovers from any panics and writes a 500 if there was one.
r.Use(gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
if err, ok := recovered.(string); ok {
c.String(http.StatusInternalServerError, fmt.Sprintf("error: %s", err))
}
c.AbortWithStatus(http.StatusInternalServerError)
}))

r.GET("/panic", func(c *gin.Context) {
// panic with a string -- the custom middleware could save this to a database or report it to the user
panic("foo")
})

r.GET("/", func(c *gin.Context) {
c.String(http.StatusOK, "ohai")
})

// Listen and serve on 0.0.0.0:8080
r.Run(":8080")
}
```

### How to write log file
```go
func main() {
Expand Down
22 changes: 19 additions & 3 deletions recovery.go
Expand Up @@ -26,13 +26,26 @@ var (
slash = []byte("/")
)

// RecoveryFunc defines the function passable to CustomRecovery.
type RecoveryFunc func(c *Context, err interface{})

// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
func Recovery() HandlerFunc {
return RecoveryWithWriter(DefaultErrorWriter)
}

//CustomRecovery returns a middleware that recovers from any panics and calls the provided handle func to handle it.
func CustomRecovery(handle RecoveryFunc) HandlerFunc {
return CustomRecoveryWithWriter(DefaultErrorWriter, handle)
}

// RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one.
func RecoveryWithWriter(out io.Writer) HandlerFunc {
theonlyjohnny marked this conversation as resolved.
Show resolved Hide resolved
return CustomRecoveryWithWriter(out, defaultHandleRecovery)
}

// CustomRecoveryWithWriter returns a middleware for a given writer that recovers from any panics and calls the provided handle func to handle it.
func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc {
var logger *log.Logger
if out != nil {
logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags)
Expand Down Expand Up @@ -70,20 +83,23 @@ func RecoveryWithWriter(out io.Writer) HandlerFunc {
timeFormat(time.Now()), err, stack, reset)
}
}

// If the connection is dead, we can't write a status to it.
if brokenPipe {
// If the connection is dead, we can't write a status to it.
c.Error(err.(error)) // nolint: errcheck
c.Abort()
} else {
c.AbortWithStatus(http.StatusInternalServerError)
handle(c, err)
}
}
}()
c.Next()
}
}

func defaultHandleRecovery(c *Context, err interface{}) {
c.AbortWithStatus(http.StatusInternalServerError)
}

// stack returns a nicely formatted stack frame, skipping skip frames.
func stack(skip int) []byte {
buf := new(bytes.Buffer) // the returned data
Expand Down
69 changes: 69 additions & 0 deletions recovery_test.go
Expand Up @@ -144,3 +144,72 @@ func TestPanicWithBrokenPipe(t *testing.T) {
})
}
}

func TestCustomRecoveryWithWriter(t *testing.T) {
errBuffer := new(bytes.Buffer)
buffer := new(bytes.Buffer)
router := New()
handleRecovery := func(c *Context, err interface{}) {
errBuffer.WriteString(err.(string))
c.AbortWithStatus(http.StatusBadRequest)
}
router.Use(CustomRecoveryWithWriter(buffer, handleRecovery))
router.GET("/recovery", func(_ *Context) {
panic("Oupps, Houston, we have a problem")
})
// RUN
w := performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "panic recovered")
assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
assert.Contains(t, buffer.String(), "TestCustomRecoveryWithWriter")
assert.NotContains(t, buffer.String(), "GET /recovery")

// Debug mode prints the request
SetMode(DebugMode)
// RUN
w = performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "GET /recovery")

assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String())

SetMode(TestMode)
}

func TestCustomRecovery(t *testing.T) {
errBuffer := new(bytes.Buffer)
buffer := new(bytes.Buffer)
router := New()
DefaultErrorWriter = buffer
handleRecovery := func(c *Context, err interface{}) {
errBuffer.WriteString(err.(string))
c.AbortWithStatus(http.StatusBadRequest)
}
router.Use(CustomRecovery(handleRecovery))
router.GET("/recovery", func(_ *Context) {
panic("Oupps, Houston, we have a problem")
})
// RUN
w := performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "panic recovered")
assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
assert.Contains(t, buffer.String(), "TestCustomRecovery")
assert.NotContains(t, buffer.String(), "GET /recovery")

// Debug mode prints the request
SetMode(DebugMode)
// RUN
w = performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "GET /recovery")

assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String())

SetMode(TestMode)
}