Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce AccessHandlerWithData #562

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,16 @@ c = c.Append(hlog.AccessHandler(func(r *http.Request, status, size int, duration
Dur("duration", duration).
Msg("")
}))
c = c.Append(hlog.AccessHandlerWithData(func(data hlog.AccessHandlerData) {
hlog.FromRequest(data.Request).Info().
Str("method", data.Request.Method).
Stringer("url", data.Request.URL).
Int("status", data.Status).
Int("sizeWritten", data.BytesWritten).
Int64("sizeRead", data.BytesRead).
Dur("duration", data.Duration).
Msg("")
}))
c = c.Append(hlog.RemoteAddrHandler("ip"))
c = c.Append(hlog.UserAgentHandler("user_agent"))
c = c.Append(hlog.RefererHandler("referer"))
Expand Down
30 changes: 30 additions & 0 deletions hlog/hlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,36 @@ func AccessHandler(f func(r *http.Request, status, size int, duration time.Durat
}
}

type AccessHandlerData struct {
Request *http.Request
Duration time.Duration
Status int
BytesWritten int
BytesRead int64
}

// AccessHandlerWithData returns a handler that call f after each request.
func AccessHandlerWithData(f func(data AccessHandlerData)) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
ww := mutil.WrapWriter(w)
body := mutil.NewByteCountReadCloser(r.Body)
r.Body = body
defer func() {
f(AccessHandlerData{
Request: r,
Duration: time.Since(start),
Status: ww.Status(),
BytesWritten: ww.BytesWritten(),
BytesRead: body.BytesRead(),
})
}()
next.ServeHTTP(ww, r)
})
}
}

// HostHandler adds the request's host as a field to the context's logger
// using fieldKey as field key. If trimPort is set to true, then port is
// removed from the host.
Expand Down
42 changes: 42 additions & 0 deletions hlog/hlog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"

"github.com/rs/xid"
Expand Down Expand Up @@ -432,3 +434,43 @@ func TestGetHost(t *testing.T) {
})
}
}

func TestAccessHandlerWithData(t *testing.T) {
bodyValue := "hello, world!"
req := httptest.NewRequest(http.MethodGet, "/", strings.NewReader(bodyValue))

handler := AccessHandlerWithData(func(data AccessHandlerData) {
expectedBytes := int64(len(bodyValue))
if data.BytesRead != expectedBytes {
t.Errorf("unexpected bytes read, got: %d, want: %d", data.BytesRead, expectedBytes)
}
if data.BytesWritten != int(expectedBytes) {
t.Errorf("unexpected bytes read, got: %d, want: %d", data.BytesWritten, expectedBytes)
}
if data.Status != http.StatusOK {
t.Errorf("unexpected status, got: %d, want: %d", data.Status, http.StatusOK)
}
if data.Request != req {
t.Error("unexpected request object")
}
})

rr := httptest.NewRecorder()

handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.Copy(w, r.Body)
})).ServeHTTP(rr, req)

if rr.Result().StatusCode != http.StatusOK {
t.Errorf("unexpected status, got: %d, want: %d", rr.Result().StatusCode, http.StatusOK)
}

b, err := io.ReadAll(rr.Result().Body)
if err != nil {
t.Errorf("unexpected error: %s", err.Error())
}

if bodyValue != string(b) {
t.Errorf("unexpected response body, got: %s, want: %s", string(b), bodyValue)
}
}
42 changes: 42 additions & 0 deletions hlog/internal/mutil/body.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package mutil

import (
"io"
"sync/atomic"
)

type byteCountReadCloser struct {
rc io.ReadCloser
read *int64
}

var _ io.ReadCloser = (*byteCountReadCloser)(nil)
var _ io.WriterTo = (*byteCountReadCloser)(nil)

func NewByteCountReadCloser(rc io.ReadCloser) *byteCountReadCloser {
read := int64(0)
return &byteCountReadCloser{
rc: rc,
read: &read,
}
}

func (b *byteCountReadCloser) Read(p []byte) (int, error) {
n, err := b.rc.Read(p)
atomic.AddInt64(b.read, int64(n))
return n, err
}

func (b *byteCountReadCloser) Close() error {
return b.rc.Close()
}

func (b *byteCountReadCloser) WriteTo(w io.Writer) (int64, error) {
n, err := io.Copy(w, b.rc)
atomic.AddInt64(b.read, n)
return n, err
}

func (b *byteCountReadCloser) BytesRead() int64 {
return atomic.LoadInt64(b.read)
}