Skip to content

Commit

Permalink
expose internal.handleHTTP as a standard http middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
zevdg committed Mar 31, 2021
1 parent b48684e commit 32e7680
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 81 deletions.
3 changes: 3 additions & 0 deletions appengine.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ func Main() {
internal.Main()
}

// Middleware wraps an http handler so that it can make GAE API calls
var Middleware func(http.Handler) http.Handler = internal.Middleware

// IsDevAppServer reports whether the App Engine app is running in the
// development App Server.
func IsDevAppServer() bool {
Expand Down
154 changes: 82 additions & 72 deletions internal/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,88 +87,98 @@ func apiURL() *url.URL {
}
}

func handleHTTP(w http.ResponseWriter, r *http.Request) {
c := &context{
req: r,
outHeader: w.Header(),
apiURL: apiURL(),
}
r = r.WithContext(withContext(r.Context(), c))
c.req = r

stopFlushing := make(chan int)
// Middleware wraps an http handler so that it can make GAE API calls
func Middleware(next http.Handler) http.Handler {
return handleHTTPMiddleware(executeRequestSafelyMiddleware(next))
}

// Patch up RemoteAddr so it looks reasonable.
if addr := r.Header.Get(userIPHeader); addr != "" {
r.RemoteAddr = addr
} else if addr = r.Header.Get(remoteAddrHeader); addr != "" {
r.RemoteAddr = addr
} else {
// Should not normally reach here, but pick a sensible default anyway.
r.RemoteAddr = "127.0.0.1"
}
// The address in the headers will most likely be of these forms:
// 123.123.123.123
// 2001:db8::1
// net/http.Request.RemoteAddr is specified to be in "IP:port" form.
if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
// Assume the remote address is only a host; add a default port.
r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80")
}
func handleHTTPMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c := &context{
req: r,
outHeader: w.Header(),
apiURL: apiURL(),
}
r = r.WithContext(withContext(r.Context(), c))
c.req = r

stopFlushing := make(chan int)

// Patch up RemoteAddr so it looks reasonable.
if addr := r.Header.Get(userIPHeader); addr != "" {
r.RemoteAddr = addr
} else if addr = r.Header.Get(remoteAddrHeader); addr != "" {
r.RemoteAddr = addr
} else {
// Should not normally reach here, but pick a sensible default anyway.
r.RemoteAddr = "127.0.0.1"
}
// The address in the headers will most likely be of these forms:
// 123.123.123.123
// 2001:db8::1
// net/http.Request.RemoteAddr is specified to be in "IP:port" form.
if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
// Assume the remote address is only a host; add a default port.
r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80")
}

if logToLogservice() {
// Start goroutine responsible for flushing app logs.
// This is done after adding c to ctx.m (and stopped before removing it)
// because flushing logs requires making an API call.
go c.logFlusher(stopFlushing)
}
if logToLogservice() {
// Start goroutine responsible for flushing app logs.
// This is done after adding c to ctx.m (and stopped before removing it)
// because flushing logs requires making an API call.
go c.logFlusher(stopFlushing)
}

executeRequestSafely(c, r)
c.outHeader = nil // make sure header changes aren't respected any more
next.ServeHTTP(c, r)
c.outHeader = nil // make sure header changes aren't respected any more

flushed := make(chan struct{})
if logToLogservice() {
stopFlushing <- 1 // any logging beyond this point will be dropped
flushed := make(chan struct{})
if logToLogservice() {
stopFlushing <- 1 // any logging beyond this point will be dropped

// Flush any pending logs asynchronously.
c.pendingLogs.Lock()
flushes := c.pendingLogs.flushes
if len(c.pendingLogs.lines) > 0 {
flushes++
// Flush any pending logs asynchronously.
c.pendingLogs.Lock()
flushes := c.pendingLogs.flushes
if len(c.pendingLogs.lines) > 0 {
flushes++
}
c.pendingLogs.Unlock()
go func() {
defer close(flushed)
// Force a log flush, because with very short requests we
// may not ever flush logs.
c.flushLog(true)
}()
w.Header().Set(logFlushHeader, strconv.Itoa(flushes))
}
c.pendingLogs.Unlock()
go func() {
defer close(flushed)
// Force a log flush, because with very short requests we
// may not ever flush logs.
c.flushLog(true)
}()
w.Header().Set(logFlushHeader, strconv.Itoa(flushes))
}

// Avoid nil Write call if c.Write is never called.
if c.outCode != 0 {
w.WriteHeader(c.outCode)
}
if c.outBody != nil {
w.Write(c.outBody)
}
if logToLogservice() {
// Wait for the last flush to complete before returning,
// otherwise the security ticket will not be valid.
<-flushed
}
// Avoid nil Write call if c.Write is never called.
if c.outCode != 0 {
w.WriteHeader(c.outCode)
}
if c.outBody != nil {
w.Write(c.outBody)
}
if logToLogservice() {
// Wait for the last flush to complete before returning,
// otherwise the security ticket will not be valid.
<-flushed
}
})
}

func executeRequestSafely(c *context, r *http.Request) {
defer func() {
if x := recover(); x != nil {
logf(c, 4, "%s", renderPanic(x)) // 4 == critical
c.outCode = 500
}
}()
func executeRequestSafelyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if x := recover(); x != nil {
c := w.(*context)
logf(c, 4, "%s", renderPanic(x)) // 4 == critical
c.outCode = 500
}
}()

http.DefaultServeMux.ServeHTTP(c, r)
next.ServeHTTP(w, r)
})
}

func renderPanic(x interface{}) string {
Expand Down
4 changes: 0 additions & 4 deletions internal/api_classic.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,6 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message)
return err
}

func handleHTTP(w http.ResponseWriter, r *http.Request) {
panic("handleHTTP called; this should be impossible")
}

func logf(c appengine.Context, level int64, format string, args ...interface{}) {
var fn func(format string, args ...interface{})
switch level {
Expand Down
8 changes: 4 additions & 4 deletions internal/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ func TestDelayedLogFlushing(t *testing.T) {
handled := make(chan struct{})
go func() {
defer close(handled)
handleHTTP(w, r)
Middleware(http.DefaultServeMux).ServeHTTP(w, r)
}()
// Check that the log flush eventually comes in.
time.Sleep(1200 * time.Millisecond)
Expand Down Expand Up @@ -360,7 +360,7 @@ func TestLogFlushing(t *testing.T) {
}
w := httptest.NewRecorder()

handleHTTP(w, r)
Middleware(http.DefaultServeMux).ServeHTTP(w, r)
const hdr = "X-AppEngine-Log-Flush-Count"
if got := w.HeaderMap.Get(hdr); got != tc.wantHeader {
t.Errorf("%s header = %q, want %q", hdr, got, tc.wantHeader)
Expand Down Expand Up @@ -403,7 +403,7 @@ func TestRemoteAddr(t *testing.T) {
Header: tc.headers,
Body: ioutil.NopCloser(bytes.NewReader(nil)),
}
handleHTTP(httptest.NewRecorder(), r)
Middleware(http.DefaultServeMux).ServeHTTP(httptest.NewRecorder(), r)
if addr != tc.addr {
t.Errorf("Header %v, got %q, want %q", tc.headers, addr, tc.addr)
}
Expand All @@ -420,7 +420,7 @@ func TestPanickingHandler(t *testing.T) {
Body: ioutil.NopCloser(bytes.NewReader(nil)),
}
rec := httptest.NewRecorder()
handleHTTP(rec, r)
Middleware(http.DefaultServeMux).ServeHTTP(rec, r)
if rec.Code != 500 {
t.Errorf("Panicking handler returned HTTP %d, want HTTP %d", rec.Code, 500)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/main_vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func Main() {
if IsDevAppServer() {
host = "127.0.0.1"
}
if err := http.ListenAndServe(host+":"+port, http.HandlerFunc(handleHTTP)); err != nil {
if err := http.ListenAndServe(host+":"+port, Middleware(http.DefaultServeMux)); err != nil {
log.Fatalf("http.ListenAndServe: %v", err)
}
}
Expand Down

0 comments on commit 32e7680

Please sign in to comment.