From daf108be4f246275bb04c73ae523dc2644cfb50a Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Mon, 25 Jul 2022 13:15:00 +0200 Subject: [PATCH 1/2] Allow skipping option and header resolution for api clients --- example_gh_test.go | 29 +++++++++++++++++++++++++++++ gh.go | 3 +++ gh_test.go | 7 +++++++ internal/api/http.go | 28 ++++++++++++++++++++-------- internal/api/http_test.go | 34 ++++++++++++++++++++++++---------- pkg/api/client.go | 6 ++++++ 6 files changed, 89 insertions(+), 18 deletions(-) diff --git a/example_gh_test.go b/example_gh_test.go index f7f4c08..b15f672 100644 --- a/example_gh_test.go +++ b/example_gh_test.go @@ -125,3 +125,32 @@ func ExampleCurrentRepository() { } fmt.Printf("%s/%s/%s\n", repo.Host(), repo.Owner(), repo.Name()) } + +// Use SkipResolution ClientOption to change a http.Client into a api.RESTClient. +func ExampleHTTPClient_skipResolution() { + host := "github.com" + httpOpts := api.ClientOptions{ + Host: host, + } + httpClient, err := HTTPClient(&httpOpts) + if err != nil { + log.Fatal(err) + } + // Use SkipResolution as our http.Client does the handling of + // options and headers. + restOpts := api.ClientOptions{ + SkipResolution: true, + Host: host, + Transport: httpClient.Transport, + } + restClient, err := RESTClient(&restOpts) + if err != nil { + log.Fatal(err) + } + response := []struct{ Name string }{} + err = restClient.Get("repos/cli/cli/tags", &response) + if err != nil { + log.Fatal(err) + } + fmt.Println(response) +} diff --git a/gh.go b/gh.go index 4597400..200ec16 100644 --- a/gh.go +++ b/gh.go @@ -141,6 +141,9 @@ func CurrentRepository() (repo.Repository, error) { } func optionsNeedResolution(opts *api.ClientOptions) bool { + if opts.SkipResolution { + return false + } if opts.Host == "" { return true } diff --git a/gh_test.go b/gh_test.go index 68b19e8..b02ada3 100644 --- a/gh_test.go +++ b/gh_test.go @@ -262,6 +262,13 @@ func TestOptionsNeedResolution(t *testing.T) { }, out: true, }, + { + name: "SkipResolution specified", + opts: &api.ClientOptions{ + SkipResolution: true, + }, + out: false, + }, } for _, tt := range tests { diff --git a/internal/api/http.go b/internal/api/http.go index 4104057..80f24ae 100644 --- a/internal/api/http.go +++ b/internal/api/http.go @@ -112,6 +112,12 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client { transport = logger.RoundTripper(transport) } + if opts.Headers == nil { + opts.Headers = map[string]string{} + } + if !opts.SkipResolution { + resolveHeaders(opts.Headers) + } transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport) return http.Client{Transport: transport, Timeout: opts.Timeout} @@ -148,10 +154,7 @@ type headerRoundTripper struct { rt http.RoundTripper } -func newHeaderRoundTripper(host string, authToken string, headers map[string]string, rt http.RoundTripper) http.RoundTripper { - if headers == nil { - headers = map[string]string{} - } +func resolveHeaders(headers map[string]string) { if _, ok := headers[contentType]; !ok { headers[contentType] = jsonContentType } @@ -167,11 +170,11 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str } } } - if _, ok := headers[authorization]; !ok && authToken != "" { - headers[authorization] = fmt.Sprintf("token %s", authToken) - } if _, ok := headers[timeZone]; !ok { - headers[timeZone] = currentTimeZone() + tz := currentTimeZone() + if tz != "" { + headers[timeZone] = tz + } } if _, ok := headers[accept]; !ok { // Preview for PullRequest.mergeStateStatus. @@ -180,6 +183,15 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str a += ", application/vnd.github.nebula-preview" headers[accept] = a } +} + +func newHeaderRoundTripper(host string, authToken string, headers map[string]string, rt http.RoundTripper) http.RoundTripper { + if _, ok := headers[authorization]; !ok && authToken != "" { + headers[authorization] = fmt.Sprintf("token %s", authToken) + } + if len(headers) == 0 { + return rt + } return headerRoundTripper{host: host, headers: headers, rt: rt} } diff --git a/internal/api/http_test.go b/internal/api/http_test.go index 4384094..19ac6b0 100644 --- a/internal/api/http_test.go +++ b/internal/api/http_test.go @@ -25,12 +25,13 @@ func TestNewHTTPClient(t *testing.T) { } tests := []struct { - name string - enableLog bool - log *bytes.Buffer - host string - headers map[string]string - wantHeaders http.Header + name string + enableLog bool + log *bytes.Buffer + host string + headers map[string]string + skipResolution bool + wantHeaders http.Header }{ { name: "sets default headers", @@ -94,6 +95,18 @@ func TestNewHTTPClient(t *testing.T) { host: "TeSt.CoM", wantHeaders: defaultHeaders(), }, + { + name: "skips resolving headers", + skipResolution: true, + wantHeaders: func() http.Header { + h := defaultHeaders() + h.Del(accept) + h.Del(contentType) + h.Del(timeZone) + h.Del(userAgent) + return h + }(), + }, } for _, tt := range tests { @@ -102,10 +115,11 @@ func TestNewHTTPClient(t *testing.T) { tt.host = "test.com" } opts := api.ClientOptions{ - Host: tt.host, - AuthToken: "oauth_token", - Headers: tt.headers, - Transport: reflectHTTP, + Host: tt.host, + AuthToken: "oauth_token", + Headers: tt.headers, + SkipResolution: tt.skipResolution, + Transport: reflectHTTP, } if tt.enableLog { opts.Log = tt.log diff --git a/pkg/api/client.go b/pkg/api/client.go index 4a4eccc..84e937e 100644 --- a/pkg/api/client.go +++ b/pkg/api/client.go @@ -38,6 +38,12 @@ type ClientOptions struct { // Default is no logging. Log io.Writer + // SkipResolution disables all automatic resolution of options and headers. + // This option is best used in conjunction with the Transport option, + // where the Transport mechanism already provides the necessary information + // for interacting with the API. + SkipResolution bool + // Timeout specifies a time limit for each API request. // Default is no timeout. Timeout time.Duration From 02d6382e1e2f318dc82ef75167b2eae2b233209b Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Mon, 25 Jul 2022 19:36:52 +0200 Subject: [PATCH 2/2] Address PR comments --- example_gh_test.go | 29 ----------------------------- gh.go | 3 --- gh_test.go | 7 ------- internal/api/http.go | 2 +- internal/api/http_test.go | 28 ++++++++++++++-------------- pkg/api/client.go | 9 +++------ 6 files changed, 18 insertions(+), 60 deletions(-) diff --git a/example_gh_test.go b/example_gh_test.go index b15f672..f7f4c08 100644 --- a/example_gh_test.go +++ b/example_gh_test.go @@ -125,32 +125,3 @@ func ExampleCurrentRepository() { } fmt.Printf("%s/%s/%s\n", repo.Host(), repo.Owner(), repo.Name()) } - -// Use SkipResolution ClientOption to change a http.Client into a api.RESTClient. -func ExampleHTTPClient_skipResolution() { - host := "github.com" - httpOpts := api.ClientOptions{ - Host: host, - } - httpClient, err := HTTPClient(&httpOpts) - if err != nil { - log.Fatal(err) - } - // Use SkipResolution as our http.Client does the handling of - // options and headers. - restOpts := api.ClientOptions{ - SkipResolution: true, - Host: host, - Transport: httpClient.Transport, - } - restClient, err := RESTClient(&restOpts) - if err != nil { - log.Fatal(err) - } - response := []struct{ Name string }{} - err = restClient.Get("repos/cli/cli/tags", &response) - if err != nil { - log.Fatal(err) - } - fmt.Println(response) -} diff --git a/gh.go b/gh.go index 200ec16..4597400 100644 --- a/gh.go +++ b/gh.go @@ -141,9 +141,6 @@ func CurrentRepository() (repo.Repository, error) { } func optionsNeedResolution(opts *api.ClientOptions) bool { - if opts.SkipResolution { - return false - } if opts.Host == "" { return true } diff --git a/gh_test.go b/gh_test.go index b02ada3..68b19e8 100644 --- a/gh_test.go +++ b/gh_test.go @@ -262,13 +262,6 @@ func TestOptionsNeedResolution(t *testing.T) { }, out: true, }, - { - name: "SkipResolution specified", - opts: &api.ClientOptions{ - SkipResolution: true, - }, - out: false, - }, } for _, tt := range tests { diff --git a/internal/api/http.go b/internal/api/http.go index 80f24ae..9662b5e 100644 --- a/internal/api/http.go +++ b/internal/api/http.go @@ -115,7 +115,7 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client { if opts.Headers == nil { opts.Headers = map[string]string{} } - if !opts.SkipResolution { + if !opts.SkipDefaultHeaders { resolveHeaders(opts.Headers) } transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport) diff --git a/internal/api/http_test.go b/internal/api/http_test.go index 19ac6b0..64b92bf 100644 --- a/internal/api/http_test.go +++ b/internal/api/http_test.go @@ -25,13 +25,13 @@ func TestNewHTTPClient(t *testing.T) { } tests := []struct { - name string - enableLog bool - log *bytes.Buffer - host string - headers map[string]string - skipResolution bool - wantHeaders http.Header + name string + enableLog bool + log *bytes.Buffer + host string + headers map[string]string + skipHeaders bool + wantHeaders http.Header }{ { name: "sets default headers", @@ -96,8 +96,8 @@ func TestNewHTTPClient(t *testing.T) { wantHeaders: defaultHeaders(), }, { - name: "skips resolving headers", - skipResolution: true, + name: "skips default headers", + skipHeaders: true, wantHeaders: func() http.Header { h := defaultHeaders() h.Del(accept) @@ -115,11 +115,11 @@ func TestNewHTTPClient(t *testing.T) { tt.host = "test.com" } opts := api.ClientOptions{ - Host: tt.host, - AuthToken: "oauth_token", - Headers: tt.headers, - SkipResolution: tt.skipResolution, - Transport: reflectHTTP, + Host: tt.host, + AuthToken: "oauth_token", + Headers: tt.headers, + SkipDefaultHeaders: tt.skipHeaders, + Transport: reflectHTTP, } if tt.enableLog { opts.Log = tt.log diff --git a/pkg/api/client.go b/pkg/api/client.go index 84e937e..a31dbca 100644 --- a/pkg/api/client.go +++ b/pkg/api/client.go @@ -27,7 +27,7 @@ type ClientOptions struct { EnableCache bool // Headers are the headers that will be sent with every API request. - // Default headers set are Accept, Authorization, Content-Type, Time-Zone, and User-Agent. + // Default headers set are Accept, Content-Type, Time-Zone, and User-Agent. // Default headers will be overridden by keys specified in Headers. Headers map[string]string @@ -38,11 +38,8 @@ type ClientOptions struct { // Default is no logging. Log io.Writer - // SkipResolution disables all automatic resolution of options and headers. - // This option is best used in conjunction with the Transport option, - // where the Transport mechanism already provides the necessary information - // for interacting with the API. - SkipResolution bool + // SkipDefaultHeaders disables setting of the default headers. + SkipDefaultHeaders bool // Timeout specifies a time limit for each API request. // Default is no timeout.