diff --git a/example_gh_test.go b/example_gh_test.go index f7f4c08..b15f672 100644 --- a/example_gh_test.go +++ b/example_gh_test.go @@ -125,3 +125,32 @@ func ExampleCurrentRepository() { } fmt.Printf("%s/%s/%s\n", repo.Host(), repo.Owner(), repo.Name()) } + +// Use SkipResolution ClientOption to change a http.Client into a api.RESTClient. +func ExampleHTTPClient_skipResolution() { + host := "github.com" + httpOpts := api.ClientOptions{ + Host: host, + } + httpClient, err := HTTPClient(&httpOpts) + if err != nil { + log.Fatal(err) + } + // Use SkipResolution as our http.Client does the handling of + // options and headers. + restOpts := api.ClientOptions{ + SkipResolution: true, + Host: host, + Transport: httpClient.Transport, + } + restClient, err := RESTClient(&restOpts) + if err != nil { + log.Fatal(err) + } + response := []struct{ Name string }{} + err = restClient.Get("repos/cli/cli/tags", &response) + if err != nil { + log.Fatal(err) + } + fmt.Println(response) +} diff --git a/gh.go b/gh.go index 4597400..200ec16 100644 --- a/gh.go +++ b/gh.go @@ -141,6 +141,9 @@ func CurrentRepository() (repo.Repository, error) { } func optionsNeedResolution(opts *api.ClientOptions) bool { + if opts.SkipResolution { + return false + } if opts.Host == "" { return true } diff --git a/gh_test.go b/gh_test.go index 68b19e8..b02ada3 100644 --- a/gh_test.go +++ b/gh_test.go @@ -262,6 +262,13 @@ func TestOptionsNeedResolution(t *testing.T) { }, out: true, }, + { + name: "SkipResolution specified", + opts: &api.ClientOptions{ + SkipResolution: true, + }, + out: false, + }, } for _, tt := range tests { diff --git a/internal/api/http.go b/internal/api/http.go index 4104057..80f24ae 100644 --- a/internal/api/http.go +++ b/internal/api/http.go @@ -112,6 +112,12 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client { transport = logger.RoundTripper(transport) } + if opts.Headers == nil { + opts.Headers = map[string]string{} + } + if !opts.SkipResolution { + resolveHeaders(opts.Headers) + } transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport) return http.Client{Transport: transport, Timeout: opts.Timeout} @@ -148,10 +154,7 @@ type headerRoundTripper struct { rt http.RoundTripper } -func newHeaderRoundTripper(host string, authToken string, headers map[string]string, rt http.RoundTripper) http.RoundTripper { - if headers == nil { - headers = map[string]string{} - } +func resolveHeaders(headers map[string]string) { if _, ok := headers[contentType]; !ok { headers[contentType] = jsonContentType } @@ -167,11 +170,11 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str } } } - if _, ok := headers[authorization]; !ok && authToken != "" { - headers[authorization] = fmt.Sprintf("token %s", authToken) - } if _, ok := headers[timeZone]; !ok { - headers[timeZone] = currentTimeZone() + tz := currentTimeZone() + if tz != "" { + headers[timeZone] = tz + } } if _, ok := headers[accept]; !ok { // Preview for PullRequest.mergeStateStatus. @@ -180,6 +183,15 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str a += ", application/vnd.github.nebula-preview" headers[accept] = a } +} + +func newHeaderRoundTripper(host string, authToken string, headers map[string]string, rt http.RoundTripper) http.RoundTripper { + if _, ok := headers[authorization]; !ok && authToken != "" { + headers[authorization] = fmt.Sprintf("token %s", authToken) + } + if len(headers) == 0 { + return rt + } return headerRoundTripper{host: host, headers: headers, rt: rt} } diff --git a/internal/api/http_test.go b/internal/api/http_test.go index 4384094..19ac6b0 100644 --- a/internal/api/http_test.go +++ b/internal/api/http_test.go @@ -25,12 +25,13 @@ func TestNewHTTPClient(t *testing.T) { } tests := []struct { - name string - enableLog bool - log *bytes.Buffer - host string - headers map[string]string - wantHeaders http.Header + name string + enableLog bool + log *bytes.Buffer + host string + headers map[string]string + skipResolution bool + wantHeaders http.Header }{ { name: "sets default headers", @@ -94,6 +95,18 @@ func TestNewHTTPClient(t *testing.T) { host: "TeSt.CoM", wantHeaders: defaultHeaders(), }, + { + name: "skips resolving headers", + skipResolution: true, + wantHeaders: func() http.Header { + h := defaultHeaders() + h.Del(accept) + h.Del(contentType) + h.Del(timeZone) + h.Del(userAgent) + return h + }(), + }, } for _, tt := range tests { @@ -102,10 +115,11 @@ func TestNewHTTPClient(t *testing.T) { tt.host = "test.com" } opts := api.ClientOptions{ - Host: tt.host, - AuthToken: "oauth_token", - Headers: tt.headers, - Transport: reflectHTTP, + Host: tt.host, + AuthToken: "oauth_token", + Headers: tt.headers, + SkipResolution: tt.skipResolution, + Transport: reflectHTTP, } if tt.enableLog { opts.Log = tt.log diff --git a/pkg/api/client.go b/pkg/api/client.go index 4a4eccc..84e937e 100644 --- a/pkg/api/client.go +++ b/pkg/api/client.go @@ -38,6 +38,12 @@ type ClientOptions struct { // Default is no logging. Log io.Writer + // SkipResolution disables all automatic resolution of options and headers. + // This option is best used in conjunction with the Transport option, + // where the Transport mechanism already provides the necessary information + // for interacting with the API. + SkipResolution bool + // Timeout specifies a time limit for each API request. // Default is no timeout. Timeout time.Duration