From 3058e5a49e1da46f6497f64497736bb33986edd2 Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Mon, 25 Jul 2022 21:23:08 +0200 Subject: [PATCH] Allow skipping default headers for api clients (#56) --- internal/api/http.go | 28 ++++++++++++++++++++-------- internal/api/http_test.go | 22 ++++++++++++++++++---- pkg/api/client.go | 5 ++++- 3 files changed, 42 insertions(+), 13 deletions(-) diff --git a/internal/api/http.go b/internal/api/http.go index 4104057..9662b5e 100644 --- a/internal/api/http.go +++ b/internal/api/http.go @@ -112,6 +112,12 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client { transport = logger.RoundTripper(transport) } + if opts.Headers == nil { + opts.Headers = map[string]string{} + } + if !opts.SkipDefaultHeaders { + resolveHeaders(opts.Headers) + } transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport) return http.Client{Transport: transport, Timeout: opts.Timeout} @@ -148,10 +154,7 @@ type headerRoundTripper struct { rt http.RoundTripper } -func newHeaderRoundTripper(host string, authToken string, headers map[string]string, rt http.RoundTripper) http.RoundTripper { - if headers == nil { - headers = map[string]string{} - } +func resolveHeaders(headers map[string]string) { if _, ok := headers[contentType]; !ok { headers[contentType] = jsonContentType } @@ -167,11 +170,11 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str } } } - if _, ok := headers[authorization]; !ok && authToken != "" { - headers[authorization] = fmt.Sprintf("token %s", authToken) - } if _, ok := headers[timeZone]; !ok { - headers[timeZone] = currentTimeZone() + tz := currentTimeZone() + if tz != "" { + headers[timeZone] = tz + } } if _, ok := headers[accept]; !ok { // Preview for PullRequest.mergeStateStatus. @@ -180,6 +183,15 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str a += ", application/vnd.github.nebula-preview" headers[accept] = a } +} + +func newHeaderRoundTripper(host string, authToken string, headers map[string]string, rt http.RoundTripper) http.RoundTripper { + if _, ok := headers[authorization]; !ok && authToken != "" { + headers[authorization] = fmt.Sprintf("token %s", authToken) + } + if len(headers) == 0 { + return rt + } return headerRoundTripper{host: host, headers: headers, rt: rt} } diff --git a/internal/api/http_test.go b/internal/api/http_test.go index 4384094..64b92bf 100644 --- a/internal/api/http_test.go +++ b/internal/api/http_test.go @@ -30,6 +30,7 @@ func TestNewHTTPClient(t *testing.T) { log *bytes.Buffer host string headers map[string]string + skipHeaders bool wantHeaders http.Header }{ { @@ -94,6 +95,18 @@ func TestNewHTTPClient(t *testing.T) { host: "TeSt.CoM", wantHeaders: defaultHeaders(), }, + { + name: "skips default headers", + skipHeaders: true, + wantHeaders: func() http.Header { + h := defaultHeaders() + h.Del(accept) + h.Del(contentType) + h.Del(timeZone) + h.Del(userAgent) + return h + }(), + }, } for _, tt := range tests { @@ -102,10 +115,11 @@ func TestNewHTTPClient(t *testing.T) { tt.host = "test.com" } opts := api.ClientOptions{ - Host: tt.host, - AuthToken: "oauth_token", - Headers: tt.headers, - Transport: reflectHTTP, + Host: tt.host, + AuthToken: "oauth_token", + Headers: tt.headers, + SkipDefaultHeaders: tt.skipHeaders, + Transport: reflectHTTP, } if tt.enableLog { opts.Log = tt.log diff --git a/pkg/api/client.go b/pkg/api/client.go index 4a4eccc..a31dbca 100644 --- a/pkg/api/client.go +++ b/pkg/api/client.go @@ -27,7 +27,7 @@ type ClientOptions struct { EnableCache bool // Headers are the headers that will be sent with every API request. - // Default headers set are Accept, Authorization, Content-Type, Time-Zone, and User-Agent. + // Default headers set are Accept, Content-Type, Time-Zone, and User-Agent. // Default headers will be overridden by keys specified in Headers. Headers map[string]string @@ -38,6 +38,9 @@ type ClientOptions struct { // Default is no logging. Log io.Writer + // SkipDefaultHeaders disables setting of the default headers. + SkipDefaultHeaders bool + // Timeout specifies a time limit for each API request. // Default is no timeout. Timeout time.Duration