Skip to content

Commit

Permalink
Add support for http_unix_socket
Browse files Browse the repository at this point in the history
  • Loading branch information
samcoe committed May 4, 2022
1 parent c41a127 commit b3f8b4f
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 18 deletions.
36 changes: 25 additions & 11 deletions gh.go
Expand Up @@ -16,6 +16,7 @@ import (

iapi "github.com/cli/go-gh/internal/api"
"github.com/cli/go-gh/internal/config"
iconfig "github.com/cli/go-gh/internal/config"
"github.com/cli/go-gh/internal/git"
irepo "github.com/cli/go-gh/internal/repository"
"github.com/cli/go-gh/internal/ssh"
Expand Down Expand Up @@ -54,8 +55,9 @@ func run(path string, env []string, args ...string) (stdOut, stdErr bytes.Buffer
}

// RESTClient builds a client to send requests to GitHub REST API endpoints.
// As part of the configuration a hostname, auth token, and default set of headers are resolved
// from the gh environment configuration. These behaviors can be overridden using the opts argument.
// As part of the configuration a hostname, auth token, default set of headers,
// and unix domain socket are resolved from the gh environment configuration.
// These behaviors can be overridden using the opts argument.
func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) {
if opts == nil {
opts = &api.ClientOptions{}
Expand All @@ -72,8 +74,9 @@ func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) {
}

// GQLClient builds a client to send requests to GitHub GraphQL API endpoints.
// As part of the configuration a hostname, auth token, and default set of headers are resolved
// from the gh environment configuration. These behaviors can be overridden using the opts argument.
// As part of the configuration a hostname, auth token, default set of headers,
// and unix domain socket are resolved from the gh environment configuration.
// These behaviors can be overridden using the opts argument.
func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) {
if opts == nil {
opts = &api.ClientOptions{}
Expand All @@ -90,12 +93,14 @@ func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) {
}

// HTTPClient builds a client that can be passed to another library.
// As part of the configuration a hostname, auth token, and default set of headers are resolved
// from the gh environment configuration. These behaviors can be overridden using the opts argument.
// In this instance providing opts.Host will not change the destination of your request as it is
// the responsibility of the consumer to configure this. However, if opts.Host does not match the request
// host, the auth token will not be added to the headers. This is to protect against the case where tokens
// could be sent to an arbitrary host.
// As part of the configuration a hostname, auth token, default set of headers,
// and unix domain socket are resolved from the gh environment configuration.
// These behaviors can be overridden using the opts argument. In this instance
// providing opts.Host will not change the destination of your request as it is
// the responsibility of the consumer to configure this. However, if opts.Host
// does not match the request host, the auth token will not be added to the headers.
// This is to protect against the case where tokens could be sent to an arbitrary
// host.
func HTTPClient(opts *api.ClientOptions) (*http.Client, error) {
if opts == nil {
opts = &api.ClientOptions{}
Expand Down Expand Up @@ -156,10 +161,19 @@ func resolveOptions(opts *api.ClientOptions, cfg config.Config) error {
if opts.AuthToken == "" {
token, err = cfg.AuthToken(opts.Host)
if err != nil {
return err
var notFoundError iconfig.NotFoundError
if errors.As(err, &notFoundError) {
return fmt.Errorf("auth token not found for host %s", opts.Host)
} else {
return err
}
}
opts.AuthToken = token
}
if opts.UnixDomainSocket == "" {
unixSocket, _ := cfg.Get("http_unix_socket")
opts.UnixDomainSocket = unixSocket
}
return nil
}

Expand Down
10 changes: 8 additions & 2 deletions gh_test.go
Expand Up @@ -138,21 +138,25 @@ func TestResolveOptions(t *testing.T) {
opts *api.ClientOptions
wantAuthToken string
wantHost string
wantSocket string
}{
{
name: "honors consumer provided ClientOptions",
opts: &api.ClientOptions{
Host: "test.com",
AuthToken: "token_from_opts",
Host: "test.com",
AuthToken: "token_from_opts",
UnixDomainSocket: "socket_from_opts",
},
wantAuthToken: "token_from_opts",
wantHost: "test.com",
wantSocket: "socket_from_opts",
},
{
name: "uses config values if there are no consumer provided ClientOptions",
opts: &api.ClientOptions{},
wantAuthToken: "token",
wantHost: "github.com",
wantSocket: "socket",
},
}

Expand All @@ -162,6 +166,7 @@ func TestResolveOptions(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, tt.wantHost, tt.opts.Host)
assert.Equal(t, tt.wantAuthToken, tt.opts.AuthToken)
assert.Equal(t, tt.wantSocket, tt.opts.UnixDomainSocket)
})
}
}
Expand All @@ -173,6 +178,7 @@ hosts:
user: user1
oauth_token: token
git_protocol: ssh
http_unix_socket: socket
`
cfg, _ := config.FromString(data)
return cfg
Expand Down
28 changes: 23 additions & 5 deletions internal/api/http.go
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -74,10 +75,15 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client {
}

transport := http.DefaultTransport

if opts.Transport != nil {
transport = opts.Transport
}

if opts.UnixDomainSocket != "" {
transport = newUnixDomainSocketRoundTripper(opts.UnixDomainSocket)
}

transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport)

if opts.Log != nil {
Expand Down Expand Up @@ -207,10 +213,10 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str
if headers == nil {
headers = map[string]string{}
}
if headers[contentType] == "" {
if _, ok := headers[contentType]; !ok {
headers[contentType] = jsonContentType
}
if headers[userAgent] == "" {
if _, ok := headers[userAgent]; !ok {
headers[userAgent] = "go-gh"
info, ok := debug.ReadBuildInfo()
if ok {
Expand All @@ -222,13 +228,13 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str
}
}
}
if headers[authorization] == "" && authToken != "" {
if _, ok := headers[authorization]; !ok && authToken != "" {
headers[authorization] = fmt.Sprintf("token %s", authToken)
}
if headers[timeZone] == "" {
if _, ok := headers[timeZone]; !ok {
headers[timeZone] = currentTimeZone()
}
if headers[accept] == "" {
if _, ok := headers[accept]; !ok {
// Preview for PullRequest.mergeStateStatus.
a := "application/vnd.github.merge-info-preview+json"
// Preview for visibility when RESTing repos into an org.
Expand Down Expand Up @@ -260,6 +266,18 @@ func (hrt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
return hrt.rt.RoundTrip(req)
}

func newUnixDomainSocketRoundTripper(socketPath string) http.RoundTripper {
dial := func(network, addr string) (net.Conn, error) {
return net.Dial("unix", socketPath)
}

return &http.Transport{
Dial: dial,
DialTLS: dial,
DisableKeepAlives: true,
}
}

func currentTimeZone() string {
tz := time.Local.String()
if tz == "Local" {
Expand Down
6 changes: 6 additions & 0 deletions pkg/api/client.go
Expand Up @@ -44,6 +44,12 @@ type ClientOptions struct {
// Transport specifies the mechanism by which individual API requests are made.
// Default is http.DefaultTransport.
Transport http.RoundTripper

// UnixDomainSocket specifies the Unix domain socket address by which individual
// API requests will be routed. If specifed, this will form the base of the API
// request transport chain taking precedence over Transport ClientOption.
// Default is no socket address.
UnixDomainSocket string
}

// RESTClient is the interface that wraps methods for the different types of
Expand Down

0 comments on commit b3f8b4f

Please sign in to comment.