Skip to content

Commit

Permalink
Allow skipping option and header resolution for api clients
Browse files Browse the repository at this point in the history
  • Loading branch information
samcoe committed Jul 25, 2022
1 parent 91ca4ef commit 4eb2de5
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 18 deletions.
30 changes: 30 additions & 0 deletions example_gh_test.go
Expand Up @@ -125,3 +125,33 @@ func ExampleCurrentRepository() {
}
fmt.Printf("%s/%s/%s\n", repo.Host(), repo.Owner(), repo.Name())
}

// Example of using SkipResolution option to change a http.Client
// into a api.RESTClient.
func ExampleSkipResolution() {
host := "github.com"
httpOpts := api.ClientOptions{
Host: host,
}
httpClient, err := HTTPClient(&httpOpts)
if err != nil {
log.Fatal(err)
}
// Use SkipResolution as our http.Client does the handling of
// options and headers.
restOpts := api.ClientOptions{
SkipResolution: true,
Host: host,
Transport: httpClient.Transport,
}
restClient, err := RESTClient(&restOpts)
if err != nil {
log.Fatal(err)
}
response := []struct{ Name string }{}
err = restClient.Get("repos/cli/cli/tags", &response)
if err != nil {
log.Fatal(err)
}
fmt.Println(response)
}
3 changes: 3 additions & 0 deletions gh.go
Expand Up @@ -141,6 +141,9 @@ func CurrentRepository() (repo.Repository, error) {
}

func optionsNeedResolution(opts *api.ClientOptions) bool {
if opts.SkipResolution {
return false
}
if opts.Host == "" {
return true
}
Expand Down
7 changes: 7 additions & 0 deletions gh_test.go
Expand Up @@ -262,6 +262,13 @@ func TestOptionsNeedResolution(t *testing.T) {
},
out: true,
},
{
name: "SkipResolution specified",
opts: &api.ClientOptions{
SkipResolution: true,
},
out: false,
},
}

for _, tt := range tests {
Expand Down
28 changes: 20 additions & 8 deletions internal/api/http.go
Expand Up @@ -112,6 +112,12 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client {
transport = logger.RoundTripper(transport)
}

if opts.Headers == nil {
opts.Headers = map[string]string{}
}
if !opts.SkipResolution {
resolveHeaders(opts.Headers)
}
transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport)

return http.Client{Transport: transport, Timeout: opts.Timeout}
Expand Down Expand Up @@ -148,10 +154,7 @@ type headerRoundTripper struct {
rt http.RoundTripper
}

func newHeaderRoundTripper(host string, authToken string, headers map[string]string, rt http.RoundTripper) http.RoundTripper {
if headers == nil {
headers = map[string]string{}
}
func resolveHeaders(headers map[string]string) {
if _, ok := headers[contentType]; !ok {
headers[contentType] = jsonContentType
}
Expand All @@ -167,11 +170,11 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str
}
}
}
if _, ok := headers[authorization]; !ok && authToken != "" {
headers[authorization] = fmt.Sprintf("token %s", authToken)
}
if _, ok := headers[timeZone]; !ok {
headers[timeZone] = currentTimeZone()
tz := currentTimeZone()
if tz != "" {
headers[timeZone] = tz
}
}
if _, ok := headers[accept]; !ok {
// Preview for PullRequest.mergeStateStatus.
Expand All @@ -180,6 +183,15 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str
a += ", application/vnd.github.nebula-preview"
headers[accept] = a
}
}

func newHeaderRoundTripper(host string, authToken string, headers map[string]string, rt http.RoundTripper) http.RoundTripper {
if _, ok := headers[authorization]; !ok && authToken != "" {
headers[authorization] = fmt.Sprintf("token %s", authToken)
}
if len(headers) == 0 {
return rt
}
return headerRoundTripper{host: host, headers: headers, rt: rt}
}

Expand Down
34 changes: 24 additions & 10 deletions internal/api/http_test.go
Expand Up @@ -25,12 +25,13 @@ func TestNewHTTPClient(t *testing.T) {
}

tests := []struct {
name string
enableLog bool
log *bytes.Buffer
host string
headers map[string]string
wantHeaders http.Header
name string
enableLog bool
log *bytes.Buffer
host string
headers map[string]string
skipResolution bool
wantHeaders http.Header
}{
{
name: "sets default headers",
Expand Down Expand Up @@ -94,6 +95,18 @@ func TestNewHTTPClient(t *testing.T) {
host: "TeSt.CoM",
wantHeaders: defaultHeaders(),
},
{
name: "skips resolving headers",
skipResolution: true,
wantHeaders: func() http.Header {
h := defaultHeaders()
h.Del(accept)
h.Del(contentType)
h.Del(timeZone)
h.Del(userAgent)
return h
}(),
},
}

for _, tt := range tests {
Expand All @@ -102,10 +115,11 @@ func TestNewHTTPClient(t *testing.T) {
tt.host = "test.com"
}
opts := api.ClientOptions{
Host: tt.host,
AuthToken: "oauth_token",
Headers: tt.headers,
Transport: reflectHTTP,
Host: tt.host,
AuthToken: "oauth_token",
Headers: tt.headers,
SkipResolution: tt.skipResolution,
Transport: reflectHTTP,
}
if tt.enableLog {
opts.Log = tt.log
Expand Down
6 changes: 6 additions & 0 deletions pkg/api/client.go
Expand Up @@ -38,6 +38,12 @@ type ClientOptions struct {
// Default is no logging.
Log io.Writer

// SkipResolution disables all automatic resolution of options and headers.
// This option is best used in conjunction with the Transport option,
// where the Transport mechanism already provides the necessary information
// for interacting with the API.
SkipResolution bool

// Timeout specifies a time limit for each API request.
// Default is no timeout.
Timeout time.Duration
Expand Down

0 comments on commit 4eb2de5

Please sign in to comment.