diff --git a/gh.go b/gh.go index 8c5e8f1..fd11005 100644 --- a/gh.go +++ b/gh.go @@ -16,6 +16,7 @@ import ( iapi "github.com/cli/go-gh/internal/api" "github.com/cli/go-gh/internal/config" + iconfig "github.com/cli/go-gh/internal/config" "github.com/cli/go-gh/internal/git" irepo "github.com/cli/go-gh/internal/repository" "github.com/cli/go-gh/internal/ssh" @@ -54,8 +55,9 @@ func run(path string, env []string, args ...string) (stdOut, stdErr bytes.Buffer } // RESTClient builds a client to send requests to GitHub REST API endpoints. -// As part of the configuration a hostname, auth token, and default set of headers are resolved -// from the gh environment configuration. These behaviors can be overridden using the opts argument. +// As part of the configuration a hostname, auth token, default set of headers, +// and unix domain socket are resolved from the gh environment configuration. +// These behaviors can be overridden using the opts argument. func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) { if opts == nil { opts = &api.ClientOptions{} @@ -72,8 +74,9 @@ func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) { } // GQLClient builds a client to send requests to GitHub GraphQL API endpoints. -// As part of the configuration a hostname, auth token, and default set of headers are resolved -// from the gh environment configuration. These behaviors can be overridden using the opts argument. +// As part of the configuration a hostname, auth token, default set of headers, +// and unix domain socket are resolved from the gh environment configuration. +// These behaviors can be overridden using the opts argument. func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) { if opts == nil { opts = &api.ClientOptions{} @@ -90,12 +93,14 @@ func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) { } // HTTPClient builds a client that can be passed to another library. -// As part of the configuration a hostname, auth token, and default set of headers are resolved -// from the gh environment configuration. These behaviors can be overridden using the opts argument. -// In this instance providing opts.Host will not change the destination of your request as it is -// the responsibility of the consumer to configure this. However, if opts.Host does not match the request -// host, the auth token will not be added to the headers. This is to protect against the case where tokens -// could be sent to an arbitrary host. +// As part of the configuration a hostname, auth token, default set of headers, +// and unix domain socket are resolved from the gh environment configuration. +// These behaviors can be overridden using the opts argument. In this instance +// providing opts.Host will not change the destination of your request as it is +// the responsibility of the consumer to configure this. However, if opts.Host +// does not match the request host, the auth token will not be added to the headers. +// This is to protect against the case where tokens could be sent to an arbitrary +// host. func HTTPClient(opts *api.ClientOptions) (*http.Client, error) { if opts == nil { opts = &api.ClientOptions{} @@ -156,10 +161,19 @@ func resolveOptions(opts *api.ClientOptions, cfg config.Config) error { if opts.AuthToken == "" { token, err = cfg.AuthToken(opts.Host) if err != nil { - return err + var notFoundError iconfig.NotFoundError + if errors.As(err, ¬FoundError) { + return fmt.Errorf("authentication token not found for host %s", opts.Host) + } else { + return err + } } opts.AuthToken = token } + if opts.UnixDomainSocket == "" { + unixSocket, _ := cfg.Get("http_unix_socket") + opts.UnixDomainSocket = unixSocket + } return nil } diff --git a/gh_test.go b/gh_test.go index 781c4c4..25908ca 100644 --- a/gh_test.go +++ b/gh_test.go @@ -168,21 +168,25 @@ func TestResolveOptions(t *testing.T) { opts *api.ClientOptions wantAuthToken string wantHost string + wantSocket string }{ { name: "honors consumer provided ClientOptions", opts: &api.ClientOptions{ - Host: "test.com", - AuthToken: "token_from_opts", + Host: "test.com", + AuthToken: "token_from_opts", + UnixDomainSocket: "socket_from_opts", }, wantAuthToken: "token_from_opts", wantHost: "test.com", + wantSocket: "socket_from_opts", }, { name: "uses config values if there are no consumer provided ClientOptions", opts: &api.ClientOptions{}, wantAuthToken: "token", wantHost: "github.com", + wantSocket: "socket", }, } @@ -192,6 +196,7 @@ func TestResolveOptions(t *testing.T) { assert.NoError(t, err) assert.Equal(t, tt.wantHost, tt.opts.Host) assert.Equal(t, tt.wantAuthToken, tt.opts.AuthToken) + assert.Equal(t, tt.wantSocket, tt.opts.UnixDomainSocket) }) } } @@ -203,6 +208,7 @@ hosts: user: user1 oauth_token: token git_protocol: ssh +http_unix_socket: socket ` cfg, _ := config.FromString(data) return cfg diff --git a/internal/api/http.go b/internal/api/http.go index 228c9e6..82bb857 100644 --- a/internal/api/http.go +++ b/internal/api/http.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "os" "path/filepath" @@ -74,11 +75,25 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client { } transport := http.DefaultTransport + + if opts.UnixDomainSocket != "" { + transport = newUnixDomainSocketRoundTripper(opts.UnixDomainSocket) + } + if opts.Transport != nil { transport = opts.Transport } - transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport) + if opts.EnableCache { + if opts.CacheDir == "" { + opts.CacheDir = filepath.Join(os.TempDir(), "gh-cli-cache") + } + if opts.CacheTTL == 0 { + opts.CacheTTL = time.Hour * 24 + } + c := cache{dir: opts.CacheDir, ttl: opts.CacheTTL} + transport = c.RoundTripper(transport) + } if opts.Log != nil { logger := &httpretty.Logger{ @@ -99,25 +114,17 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client { transport = logger.RoundTripper(transport) } - if opts.EnableCache { - if opts.CacheDir == "" { - opts.CacheDir = filepath.Join(os.TempDir(), "gh-cli-cache") - } - if opts.CacheTTL == 0 { - opts.CacheTTL = time.Hour * 24 - } - c := cache{dir: opts.CacheDir, ttl: opts.CacheTTL} - transport = c.RoundTripper(transport) - } + transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport) return http.Client{Transport: transport, Timeout: opts.Timeout} } +// TODO: Export function in near future. func handleHTTPError(resp *http.Response) error { httpError := api.HTTPError{ - StatusCode: resp.StatusCode, - RequestURL: resp.Request.URL, - OAuthScopes: resp.Header.Get("X-Oauth-Scopes"), + StatusCode: resp.StatusCode, + RequestURL: resp.Request.URL, + Headers: resp.Header, } if !jsonTypeRE.MatchString(resp.Header.Get(contentType)) { @@ -198,8 +205,8 @@ func isEnterprise(host string) bool { } type headerRoundTripper struct { - host string headers map[string]string + host string rt http.RoundTripper } @@ -207,10 +214,10 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str if headers == nil { headers = map[string]string{} } - if headers[contentType] == "" { + if _, ok := headers[contentType]; !ok { headers[contentType] = jsonContentType } - if headers[userAgent] == "" { + if _, ok := headers[userAgent]; !ok { headers[userAgent] = "go-gh" info, ok := debug.ReadBuildInfo() if ok { @@ -222,13 +229,13 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str } } } - if headers[authorization] == "" && authToken != "" { + if _, ok := headers[authorization]; !ok && authToken != "" { headers[authorization] = fmt.Sprintf("token %s", authToken) } - if headers[timeZone] == "" { + if _, ok := headers[timeZone]; !ok { headers[timeZone] = currentTimeZone() } - if headers[accept] == "" { + if _, ok := headers[accept]; !ok { // Preview for PullRequest.mergeStateStatus. a := "application/vnd.github.merge-info-preview+json" // Preview for visibility when RESTing repos into an org. @@ -260,6 +267,18 @@ func (hrt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro return hrt.rt.RoundTrip(req) } +func newUnixDomainSocketRoundTripper(socketPath string) http.RoundTripper { + dial := func(network, addr string) (net.Conn, error) { + return net.Dial("unix", socketPath) + } + + return &http.Transport{ + Dial: dial, + DialTLS: dial, + DisableKeepAlives: true, + } +} + func currentTimeZone() string { tz := time.Local.String() if tz == "Local" { diff --git a/internal/config/config.go b/internal/config/config.go index b54d4d1..e595c23 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -138,6 +138,7 @@ func defaultConfig() Config { return config{global: configMap{Root: defaultGlobal().Content[0]}} } +//TODO: Add caching so as not to load config multiple times. func Load() (Config, error) { return load(configFile(), hostsConfigFile()) } diff --git a/pkg/api/client.go b/pkg/api/client.go index ce3e2ad..11a3cd8 100644 --- a/pkg/api/client.go +++ b/pkg/api/client.go @@ -32,7 +32,7 @@ type ClientOptions struct { // Default headers will be overridden by keys specified in Headers. Headers map[string]string - // Host is the host that every API request will be sent to. + // Host is the default host that API requests will be sent to. Host string // Log specifies a writer to write API request logs to. @@ -44,8 +44,18 @@ type ClientOptions struct { Timeout time.Duration // Transport specifies the mechanism by which individual API requests are made. + // If both Transport and UnixDomainSocket are specified then Transport takes + // precedence. Due to this behavior any value set for Transport needs to manually + // handle routing to UnixDomainSocket if necessary. Generally, setting Transport + // should be reserved for testing purposes. // Default is http.DefaultTransport. Transport http.RoundTripper + + // UnixDomainSocket specifies the Unix domain socket address by which individual + // API requests will be routed. If specifed, this will form the base of the API + // request transport chain. + // Default is no socket address. + UnixDomainSocket string } // RESTClient is the interface that wraps methods for the different types of diff --git a/pkg/api/http.go b/pkg/api/http.go index e996f76..4061aff 100644 --- a/pkg/api/http.go +++ b/pkg/api/http.go @@ -2,17 +2,18 @@ package api import ( "fmt" + "net/http" "net/url" "strings" ) // HTTPError represents an error response from the GitHub API. type HTTPError struct { - StatusCode int - RequestURL *url.URL - Message string - OAuthScopes string - Errors []HTTPErrorItem + Errors []HTTPErrorItem + Headers http.Header + Message string + RequestURL *url.URL + StatusCode int } // HTTPErrorItem stores additional information about an error response