From b3f8b4f3aaf62d29742276f29dfee3ba4e7e4e3b Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Tue, 3 May 2022 15:57:48 +0200 Subject: [PATCH] Add support for http_unix_socket --- gh.go | 36 +++++++++++++++++++++++++----------- gh_test.go | 10 ++++++++-- internal/api/http.go | 28 +++++++++++++++++++++++----- pkg/api/client.go | 6 ++++++ 4 files changed, 62 insertions(+), 18 deletions(-) diff --git a/gh.go b/gh.go index 8c5e8f1..05cd1be 100644 --- a/gh.go +++ b/gh.go @@ -16,6 +16,7 @@ import ( iapi "github.com/cli/go-gh/internal/api" "github.com/cli/go-gh/internal/config" + iconfig "github.com/cli/go-gh/internal/config" "github.com/cli/go-gh/internal/git" irepo "github.com/cli/go-gh/internal/repository" "github.com/cli/go-gh/internal/ssh" @@ -54,8 +55,9 @@ func run(path string, env []string, args ...string) (stdOut, stdErr bytes.Buffer } // RESTClient builds a client to send requests to GitHub REST API endpoints. -// As part of the configuration a hostname, auth token, and default set of headers are resolved -// from the gh environment configuration. These behaviors can be overridden using the opts argument. +// As part of the configuration a hostname, auth token, default set of headers, +// and unix domain socket are resolved from the gh environment configuration. +// These behaviors can be overridden using the opts argument. func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) { if opts == nil { opts = &api.ClientOptions{} @@ -72,8 +74,9 @@ func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) { } // GQLClient builds a client to send requests to GitHub GraphQL API endpoints. -// As part of the configuration a hostname, auth token, and default set of headers are resolved -// from the gh environment configuration. These behaviors can be overridden using the opts argument. +// As part of the configuration a hostname, auth token, default set of headers, +// and unix domain socket are resolved from the gh environment configuration. +// These behaviors can be overridden using the opts argument. func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) { if opts == nil { opts = &api.ClientOptions{} @@ -90,12 +93,14 @@ func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) { } // HTTPClient builds a client that can be passed to another library. -// As part of the configuration a hostname, auth token, and default set of headers are resolved -// from the gh environment configuration. These behaviors can be overridden using the opts argument. -// In this instance providing opts.Host will not change the destination of your request as it is -// the responsibility of the consumer to configure this. However, if opts.Host does not match the request -// host, the auth token will not be added to the headers. This is to protect against the case where tokens -// could be sent to an arbitrary host. +// As part of the configuration a hostname, auth token, default set of headers, +// and unix domain socket are resolved from the gh environment configuration. +// These behaviors can be overridden using the opts argument. In this instance +// providing opts.Host will not change the destination of your request as it is +// the responsibility of the consumer to configure this. However, if opts.Host +// does not match the request host, the auth token will not be added to the headers. +// This is to protect against the case where tokens could be sent to an arbitrary +// host. func HTTPClient(opts *api.ClientOptions) (*http.Client, error) { if opts == nil { opts = &api.ClientOptions{} @@ -156,10 +161,19 @@ func resolveOptions(opts *api.ClientOptions, cfg config.Config) error { if opts.AuthToken == "" { token, err = cfg.AuthToken(opts.Host) if err != nil { - return err + var notFoundError iconfig.NotFoundError + if errors.As(err, ¬FoundError) { + return fmt.Errorf("auth token not found for host %s", opts.Host) + } else { + return err + } } opts.AuthToken = token } + if opts.UnixDomainSocket == "" { + unixSocket, _ := cfg.Get("http_unix_socket") + opts.UnixDomainSocket = unixSocket + } return nil } diff --git a/gh_test.go b/gh_test.go index e0d3cea..98d370e 100644 --- a/gh_test.go +++ b/gh_test.go @@ -138,21 +138,25 @@ func TestResolveOptions(t *testing.T) { opts *api.ClientOptions wantAuthToken string wantHost string + wantSocket string }{ { name: "honors consumer provided ClientOptions", opts: &api.ClientOptions{ - Host: "test.com", - AuthToken: "token_from_opts", + Host: "test.com", + AuthToken: "token_from_opts", + UnixDomainSocket: "socket_from_opts", }, wantAuthToken: "token_from_opts", wantHost: "test.com", + wantSocket: "socket_from_opts", }, { name: "uses config values if there are no consumer provided ClientOptions", opts: &api.ClientOptions{}, wantAuthToken: "token", wantHost: "github.com", + wantSocket: "socket", }, } @@ -162,6 +166,7 @@ func TestResolveOptions(t *testing.T) { assert.NoError(t, err) assert.Equal(t, tt.wantHost, tt.opts.Host) assert.Equal(t, tt.wantAuthToken, tt.opts.AuthToken) + assert.Equal(t, tt.wantSocket, tt.opts.UnixDomainSocket) }) } } @@ -173,6 +178,7 @@ hosts: user: user1 oauth_token: token git_protocol: ssh +http_unix_socket: socket ` cfg, _ := config.FromString(data) return cfg diff --git a/internal/api/http.go b/internal/api/http.go index f7c4bcc..97fe76f 100644 --- a/internal/api/http.go +++ b/internal/api/http.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "os" "path/filepath" @@ -74,10 +75,15 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client { } transport := http.DefaultTransport + if opts.Transport != nil { transport = opts.Transport } + if opts.UnixDomainSocket != "" { + transport = newUnixDomainSocketRoundTripper(opts.UnixDomainSocket) + } + transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport) if opts.Log != nil { @@ -207,10 +213,10 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str if headers == nil { headers = map[string]string{} } - if headers[contentType] == "" { + if _, ok := headers[contentType]; !ok { headers[contentType] = jsonContentType } - if headers[userAgent] == "" { + if _, ok := headers[userAgent]; !ok { headers[userAgent] = "go-gh" info, ok := debug.ReadBuildInfo() if ok { @@ -222,13 +228,13 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str } } } - if headers[authorization] == "" && authToken != "" { + if _, ok := headers[authorization]; !ok && authToken != "" { headers[authorization] = fmt.Sprintf("token %s", authToken) } - if headers[timeZone] == "" { + if _, ok := headers[timeZone]; !ok { headers[timeZone] = currentTimeZone() } - if headers[accept] == "" { + if _, ok := headers[accept]; !ok { // Preview for PullRequest.mergeStateStatus. a := "application/vnd.github.merge-info-preview+json" // Preview for visibility when RESTing repos into an org. @@ -260,6 +266,18 @@ func (hrt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro return hrt.rt.RoundTrip(req) } +func newUnixDomainSocketRoundTripper(socketPath string) http.RoundTripper { + dial := func(network, addr string) (net.Conn, error) { + return net.Dial("unix", socketPath) + } + + return &http.Transport{ + Dial: dial, + DialTLS: dial, + DisableKeepAlives: true, + } +} + func currentTimeZone() string { tz := time.Local.String() if tz == "Local" { diff --git a/pkg/api/client.go b/pkg/api/client.go index 842988f..351540a 100644 --- a/pkg/api/client.go +++ b/pkg/api/client.go @@ -44,6 +44,12 @@ type ClientOptions struct { // Transport specifies the mechanism by which individual API requests are made. // Default is http.DefaultTransport. Transport http.RoundTripper + + // UnixDomainSocket specifies the Unix domain socket address by which individual + // API requests will be routed. If specifed, this will form the base of the API + // request transport chain taking precedence over Transport ClientOption. + // Default is no socket address. + UnixDomainSocket string } // RESTClient is the interface that wraps methods for the different types of