diff --git a/CHANGELOG.md b/CHANGELOG.md index d8fb57d3..f2cba32f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ * Replace logrus with zerolog. (#413, @miry) * Log HTTP requests to API server. (#413, #421, @miry) * Add TimeoutHandler for the HTTP API server. (#420, @miry) +* Set Write and Read timeouts for HTTP API server connections. (@miry) # [2.4.0] - 2022-03-07 diff --git a/api.go b/api.go index 62dbb177..960da575 100644 --- a/api.go +++ b/api.go @@ -46,12 +46,12 @@ func (server *ApiServer) PopulateConfig(filename string) { } } -func StopBrowsersMiddleware(h http.Handler) http.Handler { +func stopBrowsersMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.HasPrefix(r.UserAgent(), "Mozilla/") { http.Error(w, "User agent not allowed", 403) } else { - h.ServeHTTP(w, r) + next.ServeHTTP(w, r) } }) } @@ -74,6 +74,7 @@ func (server *ApiServer) Listen(host string, port string) { Dur("duration", duration). Msg("") })) + r.Use(stopBrowsersMiddleware) r.Use(timeoutMiddleware) r.HandleFunc("/reset", server.ResetState).Methods("POST") @@ -95,8 +96,6 @@ func (server *ApiServer) Listen(host string, port string) { r.Handle("/metrics", server.Metrics.handler()) } - http.Handle("/", StopBrowsersMiddleware(r)) - server.Logger. Info(). Str("host", host). @@ -104,7 +103,14 @@ func (server *ApiServer) Listen(host string, port string) { Str("version", Version). Msgf("Starting HTTP server on endpoint %s:%s", host, port) - err := http.ListenAndServe(net.JoinHostPort(host, port), nil) + srv := &http.Server{ + Handler: r, + Addr: net.JoinHostPort(host, port), + WriteTimeout: 10 * time.Second, + ReadTimeout: 10 * time.Second, + } + + err := srv.ListenAndServe() if err != nil { server.Logger.Fatal().Err(err).Msg("ListenAndServe finished with error") }