diff --git a/gh.go b/gh.go index 05cd1be..31e1012 100644 --- a/gh.go +++ b/gh.go @@ -62,13 +62,15 @@ func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) { if opts == nil { opts = &api.ClientOptions{} } - cfg, err := config.Load() - if err != nil { - return nil, err - } - err = resolveOptions(opts, cfg) - if err != nil { - return nil, err + if optionsNeedResolution(opts) { + cfg, err := config.Load() + if err != nil { + return nil, err + } + err = resolveOptions(opts, cfg) + if err != nil { + return nil, err + } } return iapi.NewRESTClient(opts.Host, opts), nil } @@ -81,13 +83,15 @@ func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) { if opts == nil { opts = &api.ClientOptions{} } - cfg, err := config.Load() - if err != nil { - return nil, err - } - err = resolveOptions(opts, cfg) - if err != nil { - return nil, err + if optionsNeedResolution(opts) { + cfg, err := config.Load() + if err != nil { + return nil, err + } + err = resolveOptions(opts, cfg) + if err != nil { + return nil, err + } } return iapi.NewGQLClient(opts.Host, opts), nil } @@ -105,13 +109,15 @@ func HTTPClient(opts *api.ClientOptions) (*http.Client, error) { if opts == nil { opts = &api.ClientOptions{} } - cfg, err := config.Load() - if err != nil { - return nil, err - } - err = resolveOptions(opts, cfg) - if err != nil { - return nil, err + if optionsNeedResolution(opts) { + cfg, err := config.Load() + if err != nil { + return nil, err + } + err = resolveOptions(opts, cfg) + if err != nil { + return nil, err + } } client := iapi.NewHTTPClient(opts) return &client, nil @@ -152,6 +158,19 @@ func CurrentRepository() (repo.Repository, error) { return irepo.New(r.Host, r.Owner, r.Repo), nil } +func optionsNeedResolution(opts *api.ClientOptions) bool { + if opts.Host == "" { + return true + } + if opts.AuthToken == "" { + return true + } + if opts.UnixDomainSocket == "" && opts.Transport == nil { + return true + } + return false +} + func resolveOptions(opts *api.ClientOptions, cfg config.Config) error { var token string var err error diff --git a/gh_test.go b/gh_test.go index 25908ca..60afc3a 100644 --- a/gh_test.go +++ b/gh_test.go @@ -129,7 +129,7 @@ func TestGQLClientError(t *testing.T) { res := struct{ Organization struct{ Name string } }{} err = client.Do("QUERY", nil, &res) - assert.EqualError(t, err, "GQL error: Could not resolve to an Organization with the login of 'cli'.") + assert.EqualError(t, err, "GQL: Could not resolve to an Organization with the login of 'cli'. (organization)") assert.True(t, gock.IsDone(), printPendingMocks(gock.Pending())) } diff --git a/internal/api/gql_client.go b/internal/api/gql_client.go index b4d727b..d7723b6 100644 --- a/internal/api/gql_client.go +++ b/internal/api/gql_client.go @@ -55,7 +55,7 @@ func (c gqlClient) Do(query string, variables map[string]interface{}, response i success := resp.StatusCode >= 200 && resp.StatusCode < 300 if !success { - return handleHTTPError(resp) + return api.HandleHTTPError(resp) } if resp.StatusCode == http.StatusNoContent { @@ -67,14 +67,14 @@ func (c gqlClient) Do(query string, variables map[string]interface{}, response i return err } - gr := &gqlResponse{Data: response} + gr := gqlResponse{Data: response} err = json.Unmarshal(body, &gr) if err != nil { return err } if len(gr.Errors) > 0 { - return &api.GQLError{Errors: gr.Errors} + return api.GQLError{Errors: gr.Errors} } return nil diff --git a/internal/api/gql_client_test.go b/internal/api/gql_client_test.go index 8400004..5cd7959 100644 --- a/internal/api/gql_client_test.go +++ b/internal/api/gql_client_test.go @@ -37,7 +37,7 @@ func TestGQLClientDo(t *testing.T) { JSON(`{"errors":[{"message":"OH NO"},{"message":"this is fine"}]}`) }, wantErr: true, - wantErrMsg: "GQL error: OH NO\nthis is fine", + wantErrMsg: "GQL: OH NO, this is fine", }, { name: "http fail request empty response", diff --git a/internal/api/http.go b/internal/api/http.go index de36cc7..4c0e277 100644 --- a/internal/api/http.go +++ b/internal/api/http.go @@ -1,9 +1,7 @@ package api import ( - "encoding/json" "fmt" - "io" "net" "net/http" "os" @@ -29,6 +27,7 @@ const ( ) var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`) + var timeZoneNames = map[int]string{ -39600: "Pacific/Niue", -36000: "Pacific/Honolulu", @@ -119,78 +118,6 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client { return http.Client{Transport: transport, Timeout: opts.Timeout} } -// TODO: Export function in near future. -func handleHTTPError(resp *http.Response) error { - httpError := api.HTTPError{ - StatusCode: resp.StatusCode, - RequestURL: resp.Request.URL, - AcceptedOAuthScopes: resp.Header.Get("X-Accepted-Oauth-Scopes"), - OAuthScopes: resp.Header.Get("X-Oauth-Scopes"), - } - - if !jsonTypeRE.MatchString(resp.Header.Get(contentType)) { - httpError.Message = resp.Status - return httpError - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - httpError.Message = err.Error() - return httpError - } - - var parsedBody struct { - Message string `json:"message"` - Errors []json.RawMessage - } - if err := json.Unmarshal(body, &parsedBody); err != nil { - return httpError - } - - var messages []string - if parsedBody.Message != "" { - messages = append(messages, parsedBody.Message) - } - for _, raw := range parsedBody.Errors { - switch raw[0] { - case '"': - var errString string - _ = json.Unmarshal(raw, &errString) - messages = append(messages, errString) - httpError.Errors = append(httpError.Errors, api.HTTPErrorItem{Message: errString}) - case '{': - var errInfo api.HTTPErrorItem - _ = json.Unmarshal(raw, &errInfo) - msg := errInfo.Message - if errInfo.Code != "" && errInfo.Code != "custom" { - msg = fmt.Sprintf("%s.%s %s", errInfo.Resource, errInfo.Field, errorCodeToMessage(errInfo.Code)) - } - if msg != "" { - messages = append(messages, msg) - } - httpError.Errors = append(httpError.Errors, errInfo) - } - } - httpError.Message = strings.Join(messages, "\n") - - return httpError -} - -// Convert common error codes to human readable messages -// See https://docs.github.com/en/rest/overview/resources-in-the-rest-api#client-errors for more details. -func errorCodeToMessage(code string) string { - switch code { - case "missing", "missing_field": - return "is missing" - case "invalid", "unprocessable": - return "is invalid" - case "already_exists": - return "already exists" - default: - return code - } -} - func inspectableMIMEType(t string) bool { return strings.HasPrefix(t, "text/") || jsonTypeRE.MatchString(t) } diff --git a/internal/api/rest_client.go b/internal/api/rest_client.go index c009867..68e7197 100644 --- a/internal/api/rest_client.go +++ b/internal/api/rest_client.go @@ -23,6 +23,15 @@ func NewRESTClient(host string, opts *api.ClientOptions) api.RESTClient { } } +func (c restClient) Raw(method string, path string, body io.Reader) (*http.Response, error) { + url := restURL(c.host, path) + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + return c.client.Do(req) +} + func (c restClient) Do(method string, path string, body io.Reader, response interface{}) error { url := restURL(c.host, path) req, err := http.NewRequest(method, url, body) @@ -38,7 +47,7 @@ func (c restClient) Do(method string, path string, body io.Reader, response inte success := resp.StatusCode >= 200 && resp.StatusCode < 300 if !success { - return handleHTTPError(resp) + return api.HandleHTTPError(resp) } if resp.StatusCode == http.StatusNoContent { diff --git a/pkg/api/client.go b/pkg/api/client.go index 11a3cd8..e2538e3 100644 --- a/pkg/api/client.go +++ b/pkg/api/client.go @@ -2,10 +2,8 @@ package api import ( - "fmt" "io" "net/http" - "strings" "time" ) @@ -85,6 +83,12 @@ type RESTClient interface { // Put issues a PUT request to the specified path with the specified body. // The response is populated into the response argument. Put(path string, body io.Reader, response interface{}) error + + // Raw issues a request with type specified by method to the + // specified path with the specified body. + // The response is returned rather than being populated + // into a response argument. + Raw(method string, path string, body io.Reader) (*http.Response, error) } // GQLClient is the interface that wraps methods for the different types of @@ -109,23 +113,3 @@ type GQLClient interface { // to the GitHub GraphQL schema. Query(name string, query interface{}, variables map[string]interface{}) error } - -// GQLError contains GQLErrors from a GraphQL request. -type GQLError struct { - Errors []GQLErrorItem -} - -// GQLErrorItem contains error information from a GraphQL request. -type GQLErrorItem struct { - Type string - Message string -} - -// Error formats all GQLError messages. -func (gr GQLError) Error() string { - errorMessages := make([]string, 0, len(gr.Errors)) - for _, e := range gr.Errors { - errorMessages = append(errorMessages, e.Message) - } - return fmt.Sprintf("GQL error: %s", strings.Join(errorMessages, "\n")) -} diff --git a/pkg/api/http.go b/pkg/api/http.go index 876c8a9..7f92fb9 100644 --- a/pkg/api/http.go +++ b/pkg/api/http.go @@ -1,11 +1,21 @@ package api import ( + "encoding/json" "fmt" + "io" + "net/http" "net/url" + "regexp" "strings" ) +const ( + contentType = "Content-Type" +) + +var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`) + // HTTPError represents an error response from the GitHub API. type HTTPError struct { AcceptedOAuthScopes string @@ -34,3 +44,130 @@ func (err HTTPError) Error() string { } return fmt.Sprintf("HTTP %d (%s)", err.StatusCode, err.RequestURL) } + +// GQLError represents an error response from GitHub GraphQL API. +type GQLError struct { + Errors []GQLErrorItem +} + +// GQLErrorItem stores additional information about an error response +// returned from the GitHub GraphQL API. +type GQLErrorItem struct { + Message string + Path []interface{} + Type string +} + +// Allow GQLError to satisfy error interface. +func (gr GQLError) Error() string { + errorMessages := make([]string, 0, len(gr.Errors)) + for _, e := range gr.Errors { + msg := e.Message + if p := e.pathString(); p != "" { + msg = fmt.Sprintf("%s (%s)", msg, p) + } + errorMessages = append(errorMessages, msg) + } + return fmt.Sprintf("GQL: %s", strings.Join(errorMessages, ", ")) +} + +// Match determines if the GQLError is about a specific type on a specific path. +// If the path argument ends with a ".", it will match all its subpaths. +func (gr GQLError) Match(expectType, expectPath string) bool { + for _, e := range gr.Errors { + if e.Type != expectType || !matchPath(e.pathString(), expectPath) { + return false + } + } + return true +} + +func (ge GQLErrorItem) pathString() string { + var res strings.Builder + for i, v := range ge.Path { + if i > 0 { + res.WriteRune('.') + } + fmt.Fprintf(&res, "%v", v) + } + return res.String() +} + +func matchPath(p, expect string) bool { + if strings.HasSuffix(expect, ".") { + return strings.HasPrefix(p, expect) || p == strings.TrimSuffix(expect, ".") + } + return p == expect +} + +// HandleHTTPError parses a http.Response into a HTTPError. +func HandleHTTPError(resp *http.Response) error { + httpError := HTTPError{ + StatusCode: resp.StatusCode, + RequestURL: resp.Request.URL, + AcceptedOAuthScopes: resp.Header.Get("X-Accepted-Oauth-Scopes"), + OAuthScopes: resp.Header.Get("X-Oauth-Scopes"), + } + + if !jsonTypeRE.MatchString(resp.Header.Get(contentType)) { + httpError.Message = resp.Status + return httpError + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + httpError.Message = err.Error() + return httpError + } + + var parsedBody struct { + Message string `json:"message"` + Errors []json.RawMessage + } + if err := json.Unmarshal(body, &parsedBody); err != nil { + return httpError + } + + var messages []string + if parsedBody.Message != "" { + messages = append(messages, parsedBody.Message) + } + for _, raw := range parsedBody.Errors { + switch raw[0] { + case '"': + var errString string + _ = json.Unmarshal(raw, &errString) + messages = append(messages, errString) + httpError.Errors = append(httpError.Errors, HTTPErrorItem{Message: errString}) + case '{': + var errInfo HTTPErrorItem + _ = json.Unmarshal(raw, &errInfo) + msg := errInfo.Message + if errInfo.Code != "" && errInfo.Code != "custom" { + msg = fmt.Sprintf("%s.%s %s", errInfo.Resource, errInfo.Field, errorCodeToMessage(errInfo.Code)) + } + if msg != "" { + messages = append(messages, msg) + } + httpError.Errors = append(httpError.Errors, errInfo) + } + } + httpError.Message = strings.Join(messages, "\n") + + return httpError +} + +// Convert common error codes to human readable messages +// See https://docs.github.com/en/rest/overview/resources-in-the-rest-api#client-errors for more details. +func errorCodeToMessage(code string) string { + switch code { + case "missing", "missing_field": + return "is missing" + case "invalid", "unprocessable": + return "is invalid" + case "already_exists": + return "already exists" + default: + return code + } +}