diff --git a/gh.go b/gh.go index d492215..05dc749 100644 --- a/gh.go +++ b/gh.go @@ -61,13 +61,15 @@ func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) { if opts == nil { opts = &api.ClientOptions{} } - cfg, err := config.Load() - if err != nil { - return nil, err - } - err = resolveOptions(opts, cfg) - if err != nil { - return nil, err + if optionsNeedResolution(opts) { + cfg, err := config.Load() + if err != nil { + return nil, err + } + err = resolveOptions(opts, cfg) + if err != nil { + return nil, err + } } return iapi.NewRESTClient(opts.Host, opts), nil } @@ -80,13 +82,15 @@ func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) { if opts == nil { opts = &api.ClientOptions{} } - cfg, err := config.Load() - if err != nil { - return nil, err - } - err = resolveOptions(opts, cfg) - if err != nil { - return nil, err + if optionsNeedResolution(opts) { + cfg, err := config.Load() + if err != nil { + return nil, err + } + err = resolveOptions(opts, cfg) + if err != nil { + return nil, err + } } return iapi.NewGQLClient(opts.Host, opts), nil } @@ -104,13 +108,15 @@ func HTTPClient(opts *api.ClientOptions) (*http.Client, error) { if opts == nil { opts = &api.ClientOptions{} } - cfg, err := config.Load() - if err != nil { - return nil, err - } - err = resolveOptions(opts, cfg) - if err != nil { - return nil, err + if optionsNeedResolution(opts) { + cfg, err := config.Load() + if err != nil { + return nil, err + } + err = resolveOptions(opts, cfg) + if err != nil { + return nil, err + } } client := iapi.NewHTTPClient(opts) return &client, nil @@ -151,6 +157,19 @@ func CurrentRepository() (repo.Repository, error) { return irepo.New(r.Host, r.Owner, r.Repo), nil } +func optionsNeedResolution(opts *api.ClientOptions) bool { + if opts.Host == "" { + return true + } + if opts.AuthToken == "" { + return true + } + if opts.UnixDomainSocket == "" && opts.Transport == nil { + return true + } + return false +} + func resolveOptions(opts *api.ClientOptions, cfg config.Config) error { var token string var err error diff --git a/gh_test.go b/gh_test.go index 25908ca..9e6f62d 100644 --- a/gh_test.go +++ b/gh_test.go @@ -2,6 +2,7 @@ package gh import ( "fmt" + "net/http" "os" "strings" "testing" @@ -129,7 +130,7 @@ func TestGQLClientError(t *testing.T) { res := struct{ Organization struct{ Name string } }{} err = client.Do("QUERY", nil, &res) - assert.EqualError(t, err, "GQL error: Could not resolve to an Organization with the login of 'cli'.") + assert.EqualError(t, err, "GQL: Could not resolve to an Organization with the login of 'cli'. (organization)") assert.True(t, gock.IsDone(), printPendingMocks(gock.Pending())) } @@ -201,6 +202,107 @@ func TestResolveOptions(t *testing.T) { } } +func TestOptionsNeedResolution(t *testing.T) { + tests := []struct { + name string + opts *api.ClientOptions + out bool + }{ + { + name: "Host, AuthToken, and UnixDomainSocket specified", + opts: &api.ClientOptions{ + Host: "test.com", + AuthToken: "token", + UnixDomainSocket: "socket", + }, + out: false, + }, + { + name: "Host, AuthToken, and Transport specified", + opts: &api.ClientOptions{ + Host: "test.com", + AuthToken: "token", + Transport: http.DefaultTransport, + }, + out: false, + }, + { + name: "Host, and AuthToken specified", + opts: &api.ClientOptions{ + Host: "test.com", + AuthToken: "token", + }, + out: true, + }, + { + name: "Host, and UnixDomainSocket specified", + opts: &api.ClientOptions{ + Host: "test.com", + UnixDomainSocket: "socket", + }, + out: true, + }, + { + name: "Host, and Transport specified", + opts: &api.ClientOptions{ + Host: "test.com", + Transport: http.DefaultTransport, + }, + out: true, + }, + { + name: "AuthToken, and UnixDomainSocket specified", + opts: &api.ClientOptions{ + AuthToken: "token", + UnixDomainSocket: "socket", + }, + out: true, + }, + { + name: "AuthToken, and Transport specified", + opts: &api.ClientOptions{ + AuthToken: "token", + Transport: http.DefaultTransport, + }, + out: true, + }, + { + name: "Host specified", + opts: &api.ClientOptions{ + Host: "test.com", + }, + out: true, + }, + { + name: "AuthToken specified", + opts: &api.ClientOptions{ + AuthToken: "token", + }, + out: true, + }, + { + name: "UnixDomainSocket specified", + opts: &api.ClientOptions{ + UnixDomainSocket: "socket", + }, + out: true, + }, + { + name: "Transport specified", + opts: &api.ClientOptions{ + Transport: http.DefaultTransport, + }, + out: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.out, optionsNeedResolution(tt.opts)) + }) + } +} + func testConfig() config.Config { var data = ` hosts: diff --git a/internal/api/gql_client.go b/internal/api/gql_client.go index b4d727b..d7723b6 100644 --- a/internal/api/gql_client.go +++ b/internal/api/gql_client.go @@ -55,7 +55,7 @@ func (c gqlClient) Do(query string, variables map[string]interface{}, response i success := resp.StatusCode >= 200 && resp.StatusCode < 300 if !success { - return handleHTTPError(resp) + return api.HandleHTTPError(resp) } if resp.StatusCode == http.StatusNoContent { @@ -67,14 +67,14 @@ func (c gqlClient) Do(query string, variables map[string]interface{}, response i return err } - gr := &gqlResponse{Data: response} + gr := gqlResponse{Data: response} err = json.Unmarshal(body, &gr) if err != nil { return err } if len(gr.Errors) > 0 { - return &api.GQLError{Errors: gr.Errors} + return api.GQLError{Errors: gr.Errors} } return nil diff --git a/internal/api/gql_client_test.go b/internal/api/gql_client_test.go index 8400004..5cd7959 100644 --- a/internal/api/gql_client_test.go +++ b/internal/api/gql_client_test.go @@ -37,7 +37,7 @@ func TestGQLClientDo(t *testing.T) { JSON(`{"errors":[{"message":"OH NO"},{"message":"this is fine"}]}`) }, wantErr: true, - wantErrMsg: "GQL error: OH NO\nthis is fine", + wantErrMsg: "GQL: OH NO, this is fine", }, { name: "http fail request empty response", diff --git a/internal/api/http.go b/internal/api/http.go index 82bb857..4c0e277 100644 --- a/internal/api/http.go +++ b/internal/api/http.go @@ -1,9 +1,7 @@ package api import ( - "encoding/json" "fmt" - "io" "net" "net/http" "os" @@ -29,6 +27,7 @@ const ( ) var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`) + var timeZoneNames = map[int]string{ -39600: "Pacific/Niue", -36000: "Pacific/Honolulu", @@ -119,77 +118,6 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client { return http.Client{Transport: transport, Timeout: opts.Timeout} } -// TODO: Export function in near future. -func handleHTTPError(resp *http.Response) error { - httpError := api.HTTPError{ - StatusCode: resp.StatusCode, - RequestURL: resp.Request.URL, - Headers: resp.Header, - } - - if !jsonTypeRE.MatchString(resp.Header.Get(contentType)) { - httpError.Message = resp.Status - return httpError - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - httpError.Message = err.Error() - return httpError - } - - var parsedBody struct { - Message string `json:"message"` - Errors []json.RawMessage - } - if err := json.Unmarshal(body, &parsedBody); err != nil { - return httpError - } - - var messages []string - if parsedBody.Message != "" { - messages = append(messages, parsedBody.Message) - } - for _, raw := range parsedBody.Errors { - switch raw[0] { - case '"': - var errString string - _ = json.Unmarshal(raw, &errString) - messages = append(messages, errString) - httpError.Errors = append(httpError.Errors, api.HTTPErrorItem{Message: errString}) - case '{': - var errInfo api.HTTPErrorItem - _ = json.Unmarshal(raw, &errInfo) - msg := errInfo.Message - if errInfo.Code != "" && errInfo.Code != "custom" { - msg = fmt.Sprintf("%s.%s %s", errInfo.Resource, errInfo.Field, errorCodeToMessage(errInfo.Code)) - } - if msg != "" { - messages = append(messages, msg) - } - httpError.Errors = append(httpError.Errors, errInfo) - } - } - httpError.Message = strings.Join(messages, "\n") - - return httpError -} - -// Convert common error codes to human readable messages -// See https://docs.github.com/en/rest/overview/resources-in-the-rest-api#client-errors for more details. -func errorCodeToMessage(code string) string { - switch code { - case "missing", "missing_field": - return "is missing" - case "invalid", "unprocessable": - return "is invalid" - case "already_exists": - return "already exists" - default: - return code - } -} - func inspectableMIMEType(t string) bool { return strings.HasPrefix(t, "text/") || jsonTypeRE.MatchString(t) } diff --git a/internal/api/rest_client.go b/internal/api/rest_client.go index c009867..f1651e2 100644 --- a/internal/api/rest_client.go +++ b/internal/api/rest_client.go @@ -23,6 +23,31 @@ func NewRESTClient(host string, opts *api.ClientOptions) api.RESTClient { } } +func (c restClient) Request(method string, path string, body io.Reader) (*http.Response, error) { + url := restURL(c.host, path) + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + + resp, err := c.client.Do(req) + if err != nil { + return resp, err + } + + success := resp.StatusCode >= 200 && resp.StatusCode < 300 + if !success { + err = api.HTTPError{ + StatusCode: resp.StatusCode, + RequestURL: resp.Request.URL, + AcceptedOAuthScopes: resp.Header.Get("X-Accepted-Oauth-Scopes"), + OAuthScopes: resp.Header.Get("X-Oauth-Scopes"), + } + } + + return resp, err +} + func (c restClient) Do(method string, path string, body io.Reader, response interface{}) error { url := restURL(c.host, path) req, err := http.NewRequest(method, url, body) @@ -38,7 +63,7 @@ func (c restClient) Do(method string, path string, body io.Reader, response inte success := resp.StatusCode >= 200 && resp.StatusCode < 300 if !success { - return handleHTTPError(resp) + return api.HandleHTTPError(resp) } if resp.StatusCode == http.StatusNoContent { diff --git a/internal/api/rest_client_test.go b/internal/api/rest_client_test.go index 48fa149..ea97aa5 100644 --- a/internal/api/rest_client_test.go +++ b/internal/api/rest_client_test.go @@ -3,6 +3,7 @@ package api import ( "bytes" "fmt" + "io" "strings" "testing" @@ -10,6 +11,114 @@ import ( "gopkg.in/h2non/gock.v1" ) +func TestRESTClientRequest(t *testing.T) { + tests := []struct { + name string + host string + path string + httpMocks func() + wantErr bool + wantErrMsg string + wantBody string + }{ + { + name: "success request empty response", + path: "some/test/path", + httpMocks: func() { + gock.New("https://api.github.com"). + Get("/some/test/path"). + Reply(204). + JSON(`{}`) + }, + wantBody: `{}`, + }, + { + name: "success request non-empty response", + path: "some/test/path", + httpMocks: func() { + gock.New("https://api.github.com"). + Get("/some/test/path"). + Reply(200). + JSON(`{"message": "success"}`) + }, + wantBody: `{"message": "success"}`, + }, + { + name: "fail request empty response", + path: "some/test/path", + httpMocks: func() { + gock.New("https://api.github.com"). + Get("/some/test/path"). + Reply(404). + JSON(`{}`) + }, + wantErr: true, + wantErrMsg: "HTTP 404 (https://api.github.com/some/test/path)", + wantBody: `{}`, + }, + { + name: "fail request non-empty response", + path: "some/test/path", + httpMocks: func() { + gock.New("https://api.github.com"). + Get("/some/test/path"). + Reply(422). + JSON(`{"message": "OH NO"}`) + }, + wantErr: true, + wantErrMsg: "HTTP 422 (https://api.github.com/some/test/path)", + wantBody: `{"message": "OH NO"}`, + }, + { + name: "support full urls", + path: "https://example.com/someother/test/path", + httpMocks: func() { + gock.New("https://example.com"). + Get("/someother/test/path"). + Reply(204). + JSON(`{}`) + }, + wantBody: `{}`, + }, + { + name: "support enterprise hosts", + host: "enterprise.com", + path: "some/test/path", + httpMocks: func() { + gock.New("https://enterprise.com"). + Get("/some/test/path"). + Reply(204). + JSON(`{}`) + }, + wantBody: `{}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Cleanup(gock.Off) + if tt.host == "" { + tt.host = "github.com" + } + if tt.httpMocks != nil { + tt.httpMocks() + } + client := NewRESTClient(tt.host, nil) + resp, err := client.Request("GET", tt.path, nil) + t.Cleanup(func() { resp.Body.Close() }) + body, readErr := io.ReadAll(resp.Body) + assert.NoError(t, readErr) + if tt.wantErr { + assert.EqualError(t, err, tt.wantErrMsg) + } else { + assert.NoError(t, err) + } + assert.True(t, gock.IsDone(), printPendingMocks(gock.Pending())) + assert.Equal(t, tt.wantBody, string(body)) + }) + } +} + func TestRESTClientDo(t *testing.T) { tests := []struct { name string diff --git a/pkg/api/client.go b/pkg/api/client.go index 11a3cd8..f4feb0f 100644 --- a/pkg/api/client.go +++ b/pkg/api/client.go @@ -2,10 +2,8 @@ package api import ( - "fmt" "io" "net/http" - "strings" "time" ) @@ -85,6 +83,12 @@ type RESTClient interface { // Put issues a PUT request to the specified path with the specified body. // The response is populated into the response argument. Put(path string, body io.Reader, response interface{}) error + + // Request issues a request with type specified by method to the + // specified path with the specified body. + // The response is returned rather than being populated + // into a response argument. + Request(method string, path string, body io.Reader) (*http.Response, error) } // GQLClient is the interface that wraps methods for the different types of @@ -109,23 +113,3 @@ type GQLClient interface { // to the GitHub GraphQL schema. Query(name string, query interface{}, variables map[string]interface{}) error } - -// GQLError contains GQLErrors from a GraphQL request. -type GQLError struct { - Errors []GQLErrorItem -} - -// GQLErrorItem contains error information from a GraphQL request. -type GQLErrorItem struct { - Type string - Message string -} - -// Error formats all GQLError messages. -func (gr GQLError) Error() string { - errorMessages := make([]string, 0, len(gr.Errors)) - for _, e := range gr.Errors { - errorMessages = append(errorMessages, e.Message) - } - return fmt.Sprintf("GQL error: %s", strings.Join(errorMessages, "\n")) -} diff --git a/pkg/api/errors.go b/pkg/api/errors.go new file mode 100644 index 0000000..7f92fb9 --- /dev/null +++ b/pkg/api/errors.go @@ -0,0 +1,173 @@ +package api + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "regexp" + "strings" +) + +const ( + contentType = "Content-Type" +) + +var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`) + +// HTTPError represents an error response from the GitHub API. +type HTTPError struct { + AcceptedOAuthScopes string + Errors []HTTPErrorItem + Message string + OAuthScopes string + RequestURL *url.URL + StatusCode int +} + +// HTTPErrorItem stores additional information about an error response +// returned from the GitHub API. +type HTTPErrorItem struct { + Message string + Resource string + Field string + Code string +} + +// Allow HTTPError to satisfy error interface. +func (err HTTPError) Error() string { + if msgs := strings.SplitN(err.Message, "\n", 2); len(msgs) > 1 { + return fmt.Sprintf("HTTP %d: %s (%s)\n%s", err.StatusCode, msgs[0], err.RequestURL, msgs[1]) + } else if err.Message != "" { + return fmt.Sprintf("HTTP %d: %s (%s)", err.StatusCode, err.Message, err.RequestURL) + } + return fmt.Sprintf("HTTP %d (%s)", err.StatusCode, err.RequestURL) +} + +// GQLError represents an error response from GitHub GraphQL API. +type GQLError struct { + Errors []GQLErrorItem +} + +// GQLErrorItem stores additional information about an error response +// returned from the GitHub GraphQL API. +type GQLErrorItem struct { + Message string + Path []interface{} + Type string +} + +// Allow GQLError to satisfy error interface. +func (gr GQLError) Error() string { + errorMessages := make([]string, 0, len(gr.Errors)) + for _, e := range gr.Errors { + msg := e.Message + if p := e.pathString(); p != "" { + msg = fmt.Sprintf("%s (%s)", msg, p) + } + errorMessages = append(errorMessages, msg) + } + return fmt.Sprintf("GQL: %s", strings.Join(errorMessages, ", ")) +} + +// Match determines if the GQLError is about a specific type on a specific path. +// If the path argument ends with a ".", it will match all its subpaths. +func (gr GQLError) Match(expectType, expectPath string) bool { + for _, e := range gr.Errors { + if e.Type != expectType || !matchPath(e.pathString(), expectPath) { + return false + } + } + return true +} + +func (ge GQLErrorItem) pathString() string { + var res strings.Builder + for i, v := range ge.Path { + if i > 0 { + res.WriteRune('.') + } + fmt.Fprintf(&res, "%v", v) + } + return res.String() +} + +func matchPath(p, expect string) bool { + if strings.HasSuffix(expect, ".") { + return strings.HasPrefix(p, expect) || p == strings.TrimSuffix(expect, ".") + } + return p == expect +} + +// HandleHTTPError parses a http.Response into a HTTPError. +func HandleHTTPError(resp *http.Response) error { + httpError := HTTPError{ + StatusCode: resp.StatusCode, + RequestURL: resp.Request.URL, + AcceptedOAuthScopes: resp.Header.Get("X-Accepted-Oauth-Scopes"), + OAuthScopes: resp.Header.Get("X-Oauth-Scopes"), + } + + if !jsonTypeRE.MatchString(resp.Header.Get(contentType)) { + httpError.Message = resp.Status + return httpError + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + httpError.Message = err.Error() + return httpError + } + + var parsedBody struct { + Message string `json:"message"` + Errors []json.RawMessage + } + if err := json.Unmarshal(body, &parsedBody); err != nil { + return httpError + } + + var messages []string + if parsedBody.Message != "" { + messages = append(messages, parsedBody.Message) + } + for _, raw := range parsedBody.Errors { + switch raw[0] { + case '"': + var errString string + _ = json.Unmarshal(raw, &errString) + messages = append(messages, errString) + httpError.Errors = append(httpError.Errors, HTTPErrorItem{Message: errString}) + case '{': + var errInfo HTTPErrorItem + _ = json.Unmarshal(raw, &errInfo) + msg := errInfo.Message + if errInfo.Code != "" && errInfo.Code != "custom" { + msg = fmt.Sprintf("%s.%s %s", errInfo.Resource, errInfo.Field, errorCodeToMessage(errInfo.Code)) + } + if msg != "" { + messages = append(messages, msg) + } + httpError.Errors = append(httpError.Errors, errInfo) + } + } + httpError.Message = strings.Join(messages, "\n") + + return httpError +} + +// Convert common error codes to human readable messages +// See https://docs.github.com/en/rest/overview/resources-in-the-rest-api#client-errors for more details. +func errorCodeToMessage(code string) string { + switch code { + case "missing", "missing_field": + return "is missing" + case "invalid", "unprocessable": + return "is invalid" + case "already_exists": + return "already exists" + default: + return code + } +} diff --git a/pkg/api/errors_test.go b/pkg/api/errors_test.go new file mode 100644 index 0000000..d1f8783 --- /dev/null +++ b/pkg/api/errors_test.go @@ -0,0 +1,60 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGQLErrorMatch(t *testing.T) { + tests := []struct { + name string + error GQLError + kind string + path string + wantMatch bool + }{ + { + name: "matches path and type", + error: GQLError{Errors: []GQLErrorItem{ + {Path: []interface{}{"repository", "issue"}, Type: "NOT_FOUND"}, + }}, + kind: "NOT_FOUND", + path: "repository.issue", + wantMatch: true, + }, + { + name: "matches base path and type", + error: GQLError{Errors: []GQLErrorItem{ + {Path: []interface{}{"repository", "issue"}, Type: "NOT_FOUND"}, + }}, + kind: "NOT_FOUND", + path: "repository.", + wantMatch: true, + }, + { + name: "does not match path but matches type", + error: GQLError{Errors: []GQLErrorItem{ + {Path: []interface{}{"repository", "issue"}, Type: "NOT_FOUND"}, + }}, + kind: "NOT_FOUND", + path: "label.title", + wantMatch: false, + }, + { + name: "matches path but not type", + error: GQLError{Errors: []GQLErrorItem{ + {Path: []interface{}{"repository", "issue"}, Type: "NOT_FOUND"}, + }}, + kind: "UNKNOWN", + path: "repository.issue", + wantMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.wantMatch, tt.error.Match(tt.kind, tt.path)) + }) + } +} diff --git a/pkg/api/http.go b/pkg/api/http.go deleted file mode 100644 index 4061aff..0000000 --- a/pkg/api/http.go +++ /dev/null @@ -1,36 +0,0 @@ -package api - -import ( - "fmt" - "net/http" - "net/url" - "strings" -) - -// HTTPError represents an error response from the GitHub API. -type HTTPError struct { - Errors []HTTPErrorItem - Headers http.Header - Message string - RequestURL *url.URL - StatusCode int -} - -// HTTPErrorItem stores additional information about an error response -// returned from the GitHub API. -type HTTPErrorItem struct { - Message string - Resource string - Field string - Code string -} - -// Allow HTTPError to satisfy error interface. -func (err HTTPError) Error() string { - if msgs := strings.SplitN(err.Message, "\n", 2); len(msgs) > 1 { - return fmt.Sprintf("HTTP %d: %s (%s)\n%s", err.StatusCode, msgs[0], err.RequestURL, msgs[1]) - } else if err.Message != "" { - return fmt.Sprintf("HTTP %d: %s (%s)", err.StatusCode, err.Message, err.RequestURL) - } - return fmt.Sprintf("HTTP %d (%s)", err.StatusCode, err.RequestURL) -}