From 6cd5d410ec9d42871fafad0b40679b0996d83374 Mon Sep 17 00:00:00 2001 From: Glenn Lewis <6598971+gmlewis@users.noreply.github.com> Date: Thu, 24 Mar 2022 22:33:41 -0400 Subject: [PATCH] Remove code duplication (#2321) --- github/actions_artifacts.go | 31 +++---------------------------- github/actions_workflow_jobs.go | 31 +++---------------------------- github/actions_workflow_runs.go | 4 +++- github/github.go | 29 +++++++++++++++++++++++++++++ github/repos.go | 29 +---------------------------- github/repos_contents.go | 32 ++++---------------------------- github/repos_contents_test.go | 4 ++-- 7 files changed, 45 insertions(+), 115 deletions(-) diff --git a/github/actions_artifacts.go b/github/actions_artifacts.go index 4aa7dc4404..2925921208 100644 --- a/github/actions_artifacts.go +++ b/github/actions_artifacts.go @@ -110,45 +110,20 @@ func (s *ActionsService) GetArtifact(ctx context.Context, owner, repo string, ar func (s *ActionsService) DownloadArtifact(ctx context.Context, owner, repo string, artifactID int64, followRedirects bool) (*url.URL, *Response, error) { u := fmt.Sprintf("repos/%v/%v/actions/artifacts/%v/zip", owner, repo, artifactID) - resp, err := s.getDownloadArtifactFromURL(ctx, u, followRedirects) + resp, err := s.client.roundTripWithOptionalFollowRedirect(ctx, u, followRedirects) if err != nil { return nil, nil, err } + defer resp.Body.Close() if resp.StatusCode != http.StatusFound { return nil, newResponse(resp), fmt.Errorf("unexpected status code: %s", resp.Status) } + parsedURL, err := url.Parse(resp.Header.Get("Location")) return parsedURL, newResponse(resp), nil } -func (s *ActionsService) getDownloadArtifactFromURL(ctx context.Context, u string, followRedirects bool) (*http.Response, error) { - req, err := s.client.NewRequest("GET", u, nil) - if err != nil { - return nil, err - } - - var resp *http.Response - // Use http.DefaultTransport if no custom Transport is configured - req = withContext(ctx, req) - if s.client.client.Transport == nil { - resp, err = http.DefaultTransport.RoundTrip(req) - } else { - resp, err = s.client.client.Transport.RoundTrip(req) - } - if err != nil { - return nil, err - } - resp.Body.Close() - - // If redirect response is returned, follow it - if followRedirects && resp.StatusCode == http.StatusMovedPermanently { - u = resp.Header.Get("Location") - resp, err = s.getDownloadArtifactFromURL(ctx, u, false) - } - return resp, err -} - // DeleteArtifact deletes a workflow run artifact. // // GitHub API docs: https://docs.github.com/en/free-pro-team@latest/rest/reference/actions/#delete-an-artifact diff --git a/github/actions_workflow_jobs.go b/github/actions_workflow_jobs.go index 66b8ff6edb..ad0e7b6e2a 100644 --- a/github/actions_workflow_jobs.go +++ b/github/actions_workflow_jobs.go @@ -114,41 +114,16 @@ func (s *ActionsService) GetWorkflowJobByID(ctx context.Context, owner, repo str func (s *ActionsService) GetWorkflowJobLogs(ctx context.Context, owner, repo string, jobID int64, followRedirects bool) (*url.URL, *Response, error) { u := fmt.Sprintf("repos/%v/%v/actions/jobs/%v/logs", owner, repo, jobID) - resp, err := s.getWorkflowLogsFromURL(ctx, u, followRedirects) + resp, err := s.client.roundTripWithOptionalFollowRedirect(ctx, u, followRedirects) if err != nil { return nil, nil, err } + defer resp.Body.Close() if resp.StatusCode != http.StatusFound { return nil, newResponse(resp), fmt.Errorf("unexpected status code: %s", resp.Status) } + parsedURL, err := url.Parse(resp.Header.Get("Location")) return parsedURL, newResponse(resp), err } - -func (s *ActionsService) getWorkflowLogsFromURL(ctx context.Context, u string, followRedirects bool) (*http.Response, error) { - req, err := s.client.NewRequest("GET", u, nil) - if err != nil { - return nil, err - } - - var resp *http.Response - // Use http.DefaultTransport if no custom Transport is configured - req = withContext(ctx, req) - if s.client.client.Transport == nil { - resp, err = http.DefaultTransport.RoundTrip(req) - } else { - resp, err = s.client.client.Transport.RoundTrip(req) - } - if err != nil { - return nil, err - } - resp.Body.Close() - - // If redirect response is returned, follow it - if followRedirects && resp.StatusCode == http.StatusMovedPermanently { - u = resp.Header.Get("Location") - resp, err = s.getWorkflowLogsFromURL(ctx, u, false) - } - return resp, err -} diff --git a/github/actions_workflow_runs.go b/github/actions_workflow_runs.go index 273d2cc466..05d55c03f7 100644 --- a/github/actions_workflow_runs.go +++ b/github/actions_workflow_runs.go @@ -231,14 +231,16 @@ func (s *ActionsService) CancelWorkflowRunByID(ctx context.Context, owner, repo func (s *ActionsService) GetWorkflowRunLogs(ctx context.Context, owner, repo string, runID int64, followRedirects bool) (*url.URL, *Response, error) { u := fmt.Sprintf("repos/%v/%v/actions/runs/%v/logs", owner, repo, runID) - resp, err := s.getWorkflowLogsFromURL(ctx, u, followRedirects) + resp, err := s.client.roundTripWithOptionalFollowRedirect(ctx, u, followRedirects) if err != nil { return nil, nil, err } + defer resp.Body.Close() if resp.StatusCode != http.StatusFound { return nil, newResponse(resp), fmt.Errorf("unexpected status code: %s", resp.Status) } + parsedURL, err := url.Parse(resp.Header.Get("Location")) return parsedURL, newResponse(resp), err } diff --git a/github/github.go b/github/github.go index 164b2d0678..be46c31028 100644 --- a/github/github.go +++ b/github/github.go @@ -1279,6 +1279,35 @@ func formatRateReset(d time.Duration) string { return fmt.Sprintf("[rate reset in %v]", timeString) } +// When using roundTripWithOptionalFollowRedirect, note that it +// is the responsibility of the caller to close the response body. +func (c *Client) roundTripWithOptionalFollowRedirect(ctx context.Context, u string, followRedirects bool) (*http.Response, error) { + req, err := c.NewRequest("GET", u, nil) + if err != nil { + return nil, err + } + + var resp *http.Response + // Use http.DefaultTransport if no custom Transport is configured + req = withContext(ctx, req) + if c.client.Transport == nil { + resp, err = http.DefaultTransport.RoundTrip(req) + } else { + resp, err = c.client.Transport.RoundTrip(req) + } + if err != nil { + return nil, err + } + + // If redirect response is returned, follow it + if followRedirects && resp.StatusCode == http.StatusMovedPermanently { + resp.Body.Close() + u = resp.Header.Get("Location") + resp, err = c.roundTripWithOptionalFollowRedirect(ctx, u, false) + } + return resp, err +} + // Bool is a helper routine that allocates a new bool value // to store v and returns a pointer to it. func Bool(v bool) *bool { return &v } diff --git a/github/repos.go b/github/repos.go index b1ce079691..3e8b421787 100644 --- a/github/repos.go +++ b/github/repos.go @@ -1061,7 +1061,7 @@ func (s *RepositoriesService) ListBranches(ctx context.Context, owner string, re func (s *RepositoriesService) GetBranch(ctx context.Context, owner, repo, branch string, followRedirects bool) (*Branch, *Response, error) { u := fmt.Sprintf("repos/%v/%v/branches/%v", owner, repo, branch) - resp, err := s.getBranchFromURL(ctx, u, followRedirects) + resp, err := s.client.roundTripWithOptionalFollowRedirect(ctx, u, followRedirects) if err != nil { return nil, nil, err } @@ -1076,33 +1076,6 @@ func (s *RepositoriesService) GetBranch(ctx context.Context, owner, repo, branch return b, newResponse(resp), err } -func (s *RepositoriesService) getBranchFromURL(ctx context.Context, u string, followRedirects bool) (*http.Response, error) { - req, err := s.client.NewRequest("GET", u, nil) - if err != nil { - return nil, err - } - - var resp *http.Response - // Use http.DefaultTransport if no custom Transport is configured - req = withContext(ctx, req) - if s.client.client.Transport == nil { - resp, err = http.DefaultTransport.RoundTrip(req) - } else { - resp, err = s.client.client.Transport.RoundTrip(req) - } - if err != nil { - return nil, err - } - - // If redirect response is returned, follow it - if followRedirects && resp.StatusCode == http.StatusMovedPermanently { - resp.Body.Close() - u = resp.Header.Get("Location") - resp, err = s.getBranchFromURL(ctx, u, false) - } - return resp, err -} - // renameBranchRequest represents a request to rename a branch. type renameBranchRequest struct { NewName string `json:"new_name"` diff --git a/github/repos_contents.go b/github/repos_contents.go index 86e11c0a75..3fe2710132 100644 --- a/github/repos_contents.go +++ b/github/repos_contents.go @@ -284,40 +284,16 @@ func (s *RepositoriesService) GetArchiveLink(ctx context.Context, owner, repo st if opts != nil && opts.Ref != "" { u += fmt.Sprintf("/%s", opts.Ref) } - resp, err := s.getArchiveLinkFromURL(ctx, u, followRedirects) + resp, err := s.client.roundTripWithOptionalFollowRedirect(ctx, u, followRedirects) if err != nil { return nil, nil, err } + defer resp.Body.Close() + if resp.StatusCode != http.StatusFound { return nil, newResponse(resp), fmt.Errorf("unexpected status code: %s", resp.Status) } + parsedURL, err := url.Parse(resp.Header.Get("Location")) return parsedURL, newResponse(resp), err } - -func (s *RepositoriesService) getArchiveLinkFromURL(ctx context.Context, u string, followRedirects bool) (*http.Response, error) { - req, err := s.client.NewRequest("GET", u, nil) - if err != nil { - return nil, err - } - - var resp *http.Response - // Use http.DefaultTransport if no custom Transport is configured - req = withContext(ctx, req) - if s.client.client.Transport == nil { - resp, err = http.DefaultTransport.RoundTrip(req) - } else { - resp, err = s.client.client.Transport.RoundTrip(req) - } - if err != nil { - return nil, err - } - resp.Body.Close() - - // If redirect response is returned, follow it - if followRedirects && resp.StatusCode == http.StatusMovedPermanently { - u = resp.Header.Get("Location") - resp, err = s.getArchiveLinkFromURL(ctx, u, false) - } - return resp, err -} diff --git a/github/repos_contents_test.go b/github/repos_contents_test.go index befb2fadff..e0465e93d0 100644 --- a/github/repos_contents_test.go +++ b/github/repos_contents_test.go @@ -670,12 +670,12 @@ func TestRepositoriesService_DeleteFile(t *testing.T) { func TestRepositoriesService_GetArchiveLink(t *testing.T) { client, mux, _, teardown := setup() defer teardown() - mux.HandleFunc("/repos/o/r/tarball", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/repos/o/r/tarball/yo", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") http.Redirect(w, r, "http://github.com/a", http.StatusFound) }) ctx := context.Background() - url, resp, err := client.Repositories.GetArchiveLink(ctx, "o", "r", Tarball, &RepositoryContentGetOptions{}, true) + url, resp, err := client.Repositories.GetArchiveLink(ctx, "o", "r", Tarball, &RepositoryContentGetOptions{Ref: "yo"}, true) if err != nil { t.Errorf("Repositories.GetArchiveLink returned error: %v", err) }