From 593c1300b949e52be87814d166227ce57f622b1c Mon Sep 17 00:00:00 2001 From: Will Roden Date: Fri, 1 Dec 2023 15:35:49 -0600 Subject: [PATCH] Don't update httpClient passed to NewClient --- github/github.go | 8 ++++++- github/github_test.go | 51 +++++++++++++++++++++++++++++-------------- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/github/github.go b/github/github.go index c248b256f6..5342ecc681 100644 --- a/github/github.go +++ b/github/github.go @@ -220,6 +220,8 @@ type service struct { } // Client returns the http.Client used by this GitHub client. +// This should only be used for requests to the GitHub API because +// request headers will contain an authorization token. func (c *Client) Client() *http.Client { c.clientMu.Lock() defer c.clientMu.Unlock() @@ -315,7 +317,11 @@ func addOptions(s string, opts interface{}) (string, error) { // an http.Client that will perform the authentication for you (such as that // provided by the golang.org/x/oauth2 library). func NewClient(httpClient *http.Client) *Client { - c := &Client{client: httpClient} + if httpClient == nil { + httpClient = &http.Client{} + } + httpClient2 := *httpClient + c := &Client{client: &httpClient2} c.initialize() return c } diff --git a/github/github_test.go b/github/github_test.go index b994496cc0..3e4ea79d26 100644 --- a/github/github_test.go +++ b/github/github_test.go @@ -321,26 +321,45 @@ func TestClient(t *testing.T) { func TestWithAuthToken(t *testing.T) { token := "gh_test_token" - var gotAuthHeaderVals []string - wantAuthHeaderVals := []string{"Bearer " + token} - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotAuthHeaderVals = r.Header["Authorization"] - })) - validate := func(c *Client) { + + validate := func(t *testing.T, c *http.Client, token string) { t.Helper() - gotAuthHeaderVals = nil - _, err := c.Client().Get(srv.URL) - if err != nil { - t.Fatalf("Get returned unexpected error: %v", err) + want := token + if want != "" { + want = "Bearer " + want + } + gotReq := false + headerVal := "" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotReq = true + headerVal = r.Header.Get("Authorization") + })) + _, err := c.Get(srv.URL) + assertNilError(t, err) + if !gotReq { + t.Error("request not sent") } - diff := cmp.Diff(wantAuthHeaderVals, gotAuthHeaderVals) - if diff != "" { - t.Errorf("Authorization header values mismatch (-want +got):\n%s", diff) + if headerVal != want { + t.Errorf("Authorization header is %v, want %v", headerVal, want) } } - validate(NewClient(nil).WithAuthToken(token)) - validate(new(Client).WithAuthToken(token)) - validate(NewTokenClient(context.Background(), token)) + + t.Run("zero-value Client", func(t *testing.T) { + c := new(Client).WithAuthToken(token) + validate(t, c.Client(), token) + }) + + t.Run("NewClient", func(t *testing.T) { + httpClient := &http.Client{} + client := NewClient(httpClient).WithAuthToken(token) + validate(t, client.Client(), token) + // make sure the original client isn't setting auth headers now + validate(t, httpClient, "") + }) + + t.Run("NewTokenClient", func(t *testing.T) { + validate(t, NewTokenClient(context.Background(), token).Client(), token) + }) } func TestWithEnterpriseURLs(t *testing.T) {