Skip to content

Commit

Permalink
Allow skipping default headers for api clients (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
samcoe committed Jul 25, 2022
1 parent 0ff23cb commit 3058e5a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 13 deletions.
28 changes: 20 additions & 8 deletions internal/api/http.go
Expand Up @@ -112,6 +112,12 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client {
transport = logger.RoundTripper(transport)
}

if opts.Headers == nil {
opts.Headers = map[string]string{}
}
if !opts.SkipDefaultHeaders {
resolveHeaders(opts.Headers)
}
transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport)

return http.Client{Transport: transport, Timeout: opts.Timeout}
Expand Down Expand Up @@ -148,10 +154,7 @@ type headerRoundTripper struct {
rt http.RoundTripper
}

func newHeaderRoundTripper(host string, authToken string, headers map[string]string, rt http.RoundTripper) http.RoundTripper {
if headers == nil {
headers = map[string]string{}
}
func resolveHeaders(headers map[string]string) {
if _, ok := headers[contentType]; !ok {
headers[contentType] = jsonContentType
}
Expand All @@ -167,11 +170,11 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str
}
}
}
if _, ok := headers[authorization]; !ok && authToken != "" {
headers[authorization] = fmt.Sprintf("token %s", authToken)
}
if _, ok := headers[timeZone]; !ok {
headers[timeZone] = currentTimeZone()
tz := currentTimeZone()
if tz != "" {
headers[timeZone] = tz
}
}
if _, ok := headers[accept]; !ok {
// Preview for PullRequest.mergeStateStatus.
Expand All @@ -180,6 +183,15 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str
a += ", application/vnd.github.nebula-preview"
headers[accept] = a
}
}

func newHeaderRoundTripper(host string, authToken string, headers map[string]string, rt http.RoundTripper) http.RoundTripper {
if _, ok := headers[authorization]; !ok && authToken != "" {
headers[authorization] = fmt.Sprintf("token %s", authToken)
}
if len(headers) == 0 {
return rt
}
return headerRoundTripper{host: host, headers: headers, rt: rt}
}

Expand Down
22 changes: 18 additions & 4 deletions internal/api/http_test.go
Expand Up @@ -30,6 +30,7 @@ func TestNewHTTPClient(t *testing.T) {
log *bytes.Buffer
host string
headers map[string]string
skipHeaders bool
wantHeaders http.Header
}{
{
Expand Down Expand Up @@ -94,6 +95,18 @@ func TestNewHTTPClient(t *testing.T) {
host: "TeSt.CoM",
wantHeaders: defaultHeaders(),
},
{
name: "skips default headers",
skipHeaders: true,
wantHeaders: func() http.Header {
h := defaultHeaders()
h.Del(accept)
h.Del(contentType)
h.Del(timeZone)
h.Del(userAgent)
return h
}(),
},
}

for _, tt := range tests {
Expand All @@ -102,10 +115,11 @@ func TestNewHTTPClient(t *testing.T) {
tt.host = "test.com"
}
opts := api.ClientOptions{
Host: tt.host,
AuthToken: "oauth_token",
Headers: tt.headers,
Transport: reflectHTTP,
Host: tt.host,
AuthToken: "oauth_token",
Headers: tt.headers,
SkipDefaultHeaders: tt.skipHeaders,
Transport: reflectHTTP,
}
if tt.enableLog {
opts.Log = tt.log
Expand Down
5 changes: 4 additions & 1 deletion pkg/api/client.go
Expand Up @@ -27,7 +27,7 @@ type ClientOptions struct {
EnableCache bool

// Headers are the headers that will be sent with every API request.
// Default headers set are Accept, Authorization, Content-Type, Time-Zone, and User-Agent.
// Default headers set are Accept, Content-Type, Time-Zone, and User-Agent.
// Default headers will be overridden by keys specified in Headers.
Headers map[string]string

Expand All @@ -38,6 +38,9 @@ type ClientOptions struct {
// Default is no logging.
Log io.Writer

// SkipDefaultHeaders disables setting of the default headers.
SkipDefaultHeaders bool

// Timeout specifies a time limit for each API request.
// Default is no timeout.
Timeout time.Duration
Expand Down

0 comments on commit 3058e5a

Please sign in to comment.