Skip to content

Commit

Permalink
Add support for http_unix_socket (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
samcoe committed May 17, 2022
1 parent 293b1eb commit d4b5b6b
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 39 deletions.
36 changes: 25 additions & 11 deletions gh.go
Expand Up @@ -15,6 +15,7 @@ import (

iapi "github.com/cli/go-gh/internal/api"
"github.com/cli/go-gh/internal/config"
iconfig "github.com/cli/go-gh/internal/config"
"github.com/cli/go-gh/internal/git"
irepo "github.com/cli/go-gh/internal/repository"
"github.com/cli/go-gh/pkg/api"
Expand Down Expand Up @@ -53,8 +54,9 @@ func run(path string, env []string, args ...string) (stdOut, stdErr bytes.Buffer
}

// RESTClient builds a client to send requests to GitHub REST API endpoints.
// As part of the configuration a hostname, auth token, and default set of headers are resolved
// from the gh environment configuration. These behaviors can be overridden using the opts argument.
// As part of the configuration a hostname, auth token, default set of headers,
// and unix domain socket are resolved from the gh environment configuration.
// These behaviors can be overridden using the opts argument.
func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) {
if opts == nil {
opts = &api.ClientOptions{}
Expand All @@ -71,8 +73,9 @@ func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) {
}

// GQLClient builds a client to send requests to GitHub GraphQL API endpoints.
// As part of the configuration a hostname, auth token, and default set of headers are resolved
// from the gh environment configuration. These behaviors can be overridden using the opts argument.
// As part of the configuration a hostname, auth token, default set of headers,
// and unix domain socket are resolved from the gh environment configuration.
// These behaviors can be overridden using the opts argument.
func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) {
if opts == nil {
opts = &api.ClientOptions{}
Expand All @@ -89,12 +92,14 @@ func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) {
}

// HTTPClient builds a client that can be passed to another library.
// As part of the configuration a hostname, auth token, and default set of headers are resolved
// from the gh environment configuration. These behaviors can be overridden using the opts argument.
// In this instance providing opts.Host will not change the destination of your request as it is
// the responsibility of the consumer to configure this. However, if opts.Host does not match the request
// host, the auth token will not be added to the headers. This is to protect against the case where tokens
// could be sent to an arbitrary host.
// As part of the configuration a hostname, auth token, default set of headers,
// and unix domain socket are resolved from the gh environment configuration.
// These behaviors can be overridden using the opts argument. In this instance
// providing opts.Host will not change the destination of your request as it is
// the responsibility of the consumer to configure this. However, if opts.Host
// does not match the request host, the auth token will not be added to the headers.
// This is to protect against the case where tokens could be sent to an arbitrary
// host.
func HTTPClient(opts *api.ClientOptions) (*http.Client, error) {
if opts == nil {
opts = &api.ClientOptions{}
Expand Down Expand Up @@ -155,10 +160,19 @@ func resolveOptions(opts *api.ClientOptions, cfg config.Config) error {
if opts.AuthToken == "" {
token, err = cfg.AuthToken(opts.Host)
if err != nil {
return err
var notFoundError iconfig.NotFoundError
if errors.As(err, &notFoundError) {
return fmt.Errorf("authentication token not found for host %s", opts.Host)
} else {
return err
}
}
opts.AuthToken = token
}
if opts.UnixDomainSocket == "" {
unixSocket, _ := cfg.Get("http_unix_socket")
opts.UnixDomainSocket = unixSocket
}
return nil
}

Expand Down
10 changes: 8 additions & 2 deletions gh_test.go
Expand Up @@ -168,21 +168,25 @@ func TestResolveOptions(t *testing.T) {
opts *api.ClientOptions
wantAuthToken string
wantHost string
wantSocket string
}{
{
name: "honors consumer provided ClientOptions",
opts: &api.ClientOptions{
Host: "test.com",
AuthToken: "token_from_opts",
Host: "test.com",
AuthToken: "token_from_opts",
UnixDomainSocket: "socket_from_opts",
},
wantAuthToken: "token_from_opts",
wantHost: "test.com",
wantSocket: "socket_from_opts",
},
{
name: "uses config values if there are no consumer provided ClientOptions",
opts: &api.ClientOptions{},
wantAuthToken: "token",
wantHost: "github.com",
wantSocket: "socket",
},
}

Expand All @@ -192,6 +196,7 @@ func TestResolveOptions(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, tt.wantHost, tt.opts.Host)
assert.Equal(t, tt.wantAuthToken, tt.opts.AuthToken)
assert.Equal(t, tt.wantSocket, tt.opts.UnixDomainSocket)
})
}
}
Expand All @@ -203,6 +208,7 @@ hosts:
user: user1
oauth_token: token
git_protocol: ssh
http_unix_socket: socket
`
cfg, _ := config.FromString(data)
return cfg
Expand Down
59 changes: 39 additions & 20 deletions internal/api/http.go
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -74,11 +75,25 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client {
}

transport := http.DefaultTransport

if opts.UnixDomainSocket != "" {
transport = newUnixDomainSocketRoundTripper(opts.UnixDomainSocket)
}

if opts.Transport != nil {
transport = opts.Transport
}

transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport)
if opts.EnableCache {
if opts.CacheDir == "" {
opts.CacheDir = filepath.Join(os.TempDir(), "gh-cli-cache")
}
if opts.CacheTTL == 0 {
opts.CacheTTL = time.Hour * 24
}
c := cache{dir: opts.CacheDir, ttl: opts.CacheTTL}
transport = c.RoundTripper(transport)
}

if opts.Log != nil {
logger := &httpretty.Logger{
Expand All @@ -99,25 +114,17 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client {
transport = logger.RoundTripper(transport)
}

if opts.EnableCache {
if opts.CacheDir == "" {
opts.CacheDir = filepath.Join(os.TempDir(), "gh-cli-cache")
}
if opts.CacheTTL == 0 {
opts.CacheTTL = time.Hour * 24
}
c := cache{dir: opts.CacheDir, ttl: opts.CacheTTL}
transport = c.RoundTripper(transport)
}
transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport)

return http.Client{Transport: transport, Timeout: opts.Timeout}
}

// TODO: Export function in near future.
func handleHTTPError(resp *http.Response) error {
httpError := api.HTTPError{
StatusCode: resp.StatusCode,
RequestURL: resp.Request.URL,
OAuthScopes: resp.Header.Get("X-Oauth-Scopes"),
StatusCode: resp.StatusCode,
RequestURL: resp.Request.URL,
Headers: resp.Header,
}

if !jsonTypeRE.MatchString(resp.Header.Get(contentType)) {
Expand Down Expand Up @@ -198,19 +205,19 @@ func isEnterprise(host string) bool {
}

type headerRoundTripper struct {
host string
headers map[string]string
host string
rt http.RoundTripper
}

func newHeaderRoundTripper(host string, authToken string, headers map[string]string, rt http.RoundTripper) http.RoundTripper {
if headers == nil {
headers = map[string]string{}
}
if headers[contentType] == "" {
if _, ok := headers[contentType]; !ok {
headers[contentType] = jsonContentType
}
if headers[userAgent] == "" {
if _, ok := headers[userAgent]; !ok {
headers[userAgent] = "go-gh"
info, ok := debug.ReadBuildInfo()
if ok {
Expand All @@ -222,13 +229,13 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str
}
}
}
if headers[authorization] == "" && authToken != "" {
if _, ok := headers[authorization]; !ok && authToken != "" {
headers[authorization] = fmt.Sprintf("token %s", authToken)
}
if headers[timeZone] == "" {
if _, ok := headers[timeZone]; !ok {
headers[timeZone] = currentTimeZone()
}
if headers[accept] == "" {
if _, ok := headers[accept]; !ok {
// Preview for PullRequest.mergeStateStatus.
a := "application/vnd.github.merge-info-preview+json"
// Preview for visibility when RESTing repos into an org.
Expand Down Expand Up @@ -260,6 +267,18 @@ func (hrt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
return hrt.rt.RoundTrip(req)
}

func newUnixDomainSocketRoundTripper(socketPath string) http.RoundTripper {
dial := func(network, addr string) (net.Conn, error) {
return net.Dial("unix", socketPath)
}

return &http.Transport{
Dial: dial,
DialTLS: dial,
DisableKeepAlives: true,
}
}

func currentTimeZone() string {
tz := time.Local.String()
if tz == "Local" {
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Expand Up @@ -138,6 +138,7 @@ func defaultConfig() Config {
return config{global: configMap{Root: defaultGlobal().Content[0]}}
}

//TODO: Add caching so as not to load config multiple times.
func Load() (Config, error) {
return load(configFile(), hostsConfigFile())
}
Expand Down
12 changes: 11 additions & 1 deletion pkg/api/client.go
Expand Up @@ -32,7 +32,7 @@ type ClientOptions struct {
// Default headers will be overridden by keys specified in Headers.
Headers map[string]string

// Host is the host that every API request will be sent to.
// Host is the default host that API requests will be sent to.
Host string

// Log specifies a writer to write API request logs to.
Expand All @@ -44,8 +44,18 @@ type ClientOptions struct {
Timeout time.Duration

// Transport specifies the mechanism by which individual API requests are made.
// If both Transport and UnixDomainSocket are specified then Transport takes
// precedence. Due to this behavior any value set for Transport needs to manually
// handle routing to UnixDomainSocket if necessary. Generally, setting Transport
// should be reserved for testing purposes.
// Default is http.DefaultTransport.
Transport http.RoundTripper

// UnixDomainSocket specifies the Unix domain socket address by which individual
// API requests will be routed. If specifed, this will form the base of the API
// request transport chain.
// Default is no socket address.
UnixDomainSocket string
}

// RESTClient is the interface that wraps methods for the different types of
Expand Down
11 changes: 6 additions & 5 deletions pkg/api/http.go
Expand Up @@ -2,17 +2,18 @@ package api

import (
"fmt"
"net/http"
"net/url"
"strings"
)

// HTTPError represents an error response from the GitHub API.
type HTTPError struct {
StatusCode int
RequestURL *url.URL
Message string
OAuthScopes string
Errors []HTTPErrorItem
Errors []HTTPErrorItem
Headers http.Header
Message string
RequestURL *url.URL
StatusCode int
}

// HTTPErrorItem stores additional information about an error response
Expand Down

0 comments on commit d4b5b6b

Please sign in to comment.