Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for http_unix_socket #30

Merged
merged 2 commits into from May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 25 additions & 11 deletions gh.go
Expand Up @@ -16,6 +16,7 @@ import (

iapi "github.com/cli/go-gh/internal/api"
"github.com/cli/go-gh/internal/config"
iconfig "github.com/cli/go-gh/internal/config"
"github.com/cli/go-gh/internal/git"
irepo "github.com/cli/go-gh/internal/repository"
"github.com/cli/go-gh/internal/ssh"
Expand Down Expand Up @@ -54,8 +55,9 @@ func run(path string, env []string, args ...string) (stdOut, stdErr bytes.Buffer
}

// RESTClient builds a client to send requests to GitHub REST API endpoints.
// As part of the configuration a hostname, auth token, and default set of headers are resolved
// from the gh environment configuration. These behaviors can be overridden using the opts argument.
// As part of the configuration a hostname, auth token, default set of headers,
// and unix domain socket are resolved from the gh environment configuration.
// These behaviors can be overridden using the opts argument.
func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) {
if opts == nil {
opts = &api.ClientOptions{}
Expand All @@ -72,8 +74,9 @@ func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) {
}

// GQLClient builds a client to send requests to GitHub GraphQL API endpoints.
// As part of the configuration a hostname, auth token, and default set of headers are resolved
// from the gh environment configuration. These behaviors can be overridden using the opts argument.
// As part of the configuration a hostname, auth token, default set of headers,
// and unix domain socket are resolved from the gh environment configuration.
// These behaviors can be overridden using the opts argument.
func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) {
if opts == nil {
opts = &api.ClientOptions{}
Expand All @@ -90,12 +93,14 @@ func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) {
}

// HTTPClient builds a client that can be passed to another library.
// As part of the configuration a hostname, auth token, and default set of headers are resolved
// from the gh environment configuration. These behaviors can be overridden using the opts argument.
// In this instance providing opts.Host will not change the destination of your request as it is
// the responsibility of the consumer to configure this. However, if opts.Host does not match the request
// host, the auth token will not be added to the headers. This is to protect against the case where tokens
// could be sent to an arbitrary host.
// As part of the configuration a hostname, auth token, default set of headers,
// and unix domain socket are resolved from the gh environment configuration.
// These behaviors can be overridden using the opts argument. In this instance
// providing opts.Host will not change the destination of your request as it is
// the responsibility of the consumer to configure this. However, if opts.Host
// does not match the request host, the auth token will not be added to the headers.
// This is to protect against the case where tokens could be sent to an arbitrary
// host.
func HTTPClient(opts *api.ClientOptions) (*http.Client, error) {
if opts == nil {
opts = &api.ClientOptions{}
Expand Down Expand Up @@ -156,10 +161,19 @@ func resolveOptions(opts *api.ClientOptions, cfg config.Config) error {
if opts.AuthToken == "" {
token, err = cfg.AuthToken(opts.Host)
if err != nil {
return err
var notFoundError iconfig.NotFoundError
if errors.As(err, &notFoundError) {
return fmt.Errorf("authentication token not found for host %s", opts.Host)
} else {
return err
}
}
opts.AuthToken = token
}
if opts.UnixDomainSocket == "" {
unixSocket, _ := cfg.Get("http_unix_socket")
opts.UnixDomainSocket = unixSocket
}
return nil
}

Expand Down
10 changes: 8 additions & 2 deletions gh_test.go
Expand Up @@ -168,21 +168,25 @@ func TestResolveOptions(t *testing.T) {
opts *api.ClientOptions
wantAuthToken string
wantHost string
wantSocket string
}{
{
name: "honors consumer provided ClientOptions",
opts: &api.ClientOptions{
Host: "test.com",
AuthToken: "token_from_opts",
Host: "test.com",
AuthToken: "token_from_opts",
UnixDomainSocket: "socket_from_opts",
},
wantAuthToken: "token_from_opts",
wantHost: "test.com",
wantSocket: "socket_from_opts",
},
{
name: "uses config values if there are no consumer provided ClientOptions",
opts: &api.ClientOptions{},
wantAuthToken: "token",
wantHost: "github.com",
wantSocket: "socket",
},
}

Expand All @@ -192,6 +196,7 @@ func TestResolveOptions(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, tt.wantHost, tt.opts.Host)
assert.Equal(t, tt.wantAuthToken, tt.opts.AuthToken)
assert.Equal(t, tt.wantSocket, tt.opts.UnixDomainSocket)
})
}
}
Expand All @@ -203,6 +208,7 @@ hosts:
user: user1
oauth_token: token
git_protocol: ssh
http_unix_socket: socket
`
cfg, _ := config.FromString(data)
return cfg
Expand Down
59 changes: 39 additions & 20 deletions internal/api/http.go
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -74,11 +75,25 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client {
}

transport := http.DefaultTransport

if opts.UnixDomainSocket != "" {
transport = newUnixDomainSocketRoundTripper(opts.UnixDomainSocket)
}

if opts.Transport != nil {
transport = opts.Transport
}

transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport)
if opts.EnableCache {
if opts.CacheDir == "" {
opts.CacheDir = filepath.Join(os.TempDir(), "gh-cli-cache")
}
if opts.CacheTTL == 0 {
opts.CacheTTL = time.Hour * 24
}
c := cache{dir: opts.CacheDir, ttl: opts.CacheTTL}
transport = c.RoundTripper(transport)
}

if opts.Log != nil {
logger := &httpretty.Logger{
Expand All @@ -99,25 +114,17 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client {
transport = logger.RoundTripper(transport)
}

if opts.EnableCache {
if opts.CacheDir == "" {
opts.CacheDir = filepath.Join(os.TempDir(), "gh-cli-cache")
}
if opts.CacheTTL == 0 {
opts.CacheTTL = time.Hour * 24
}
c := cache{dir: opts.CacheDir, ttl: opts.CacheTTL}
transport = c.RoundTripper(transport)
}
transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport)

return http.Client{Transport: transport, Timeout: opts.Timeout}
}

// TODO: Export function in near future.
mislav marked this conversation as resolved.
Show resolved Hide resolved
func handleHTTPError(resp *http.Response) error {
httpError := api.HTTPError{
StatusCode: resp.StatusCode,
RequestURL: resp.Request.URL,
OAuthScopes: resp.Header.Get("X-Oauth-Scopes"),
StatusCode: resp.StatusCode,
RequestURL: resp.Request.URL,
Headers: resp.Header,
}

if !jsonTypeRE.MatchString(resp.Header.Get(contentType)) {
Expand Down Expand Up @@ -198,19 +205,19 @@ func isEnterprise(host string) bool {
}

type headerRoundTripper struct {
host string
headers map[string]string
host string
rt http.RoundTripper
}

func newHeaderRoundTripper(host string, authToken string, headers map[string]string, rt http.RoundTripper) http.RoundTripper {
if headers == nil {
headers = map[string]string{}
}
if headers[contentType] == "" {
if _, ok := headers[contentType]; !ok {
headers[contentType] = jsonContentType
}
if headers[userAgent] == "" {
if _, ok := headers[userAgent]; !ok {
headers[userAgent] = "go-gh"
info, ok := debug.ReadBuildInfo()
if ok {
Expand All @@ -222,13 +229,13 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str
}
}
}
if headers[authorization] == "" && authToken != "" {
if _, ok := headers[authorization]; !ok && authToken != "" {
headers[authorization] = fmt.Sprintf("token %s", authToken)
}
if headers[timeZone] == "" {
if _, ok := headers[timeZone]; !ok {
headers[timeZone] = currentTimeZone()
}
if headers[accept] == "" {
if _, ok := headers[accept]; !ok {
// Preview for PullRequest.mergeStateStatus.
a := "application/vnd.github.merge-info-preview+json"
// Preview for visibility when RESTing repos into an org.
Expand Down Expand Up @@ -260,6 +267,18 @@ func (hrt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
return hrt.rt.RoundTrip(req)
}

func newUnixDomainSocketRoundTripper(socketPath string) http.RoundTripper {
dial := func(network, addr string) (net.Conn, error) {
return net.Dial("unix", socketPath)
}

return &http.Transport{
Dial: dial,
DialTLS: dial,
DisableKeepAlives: true,
}
}

func currentTimeZone() string {
tz := time.Local.String()
if tz == "Local" {
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Expand Up @@ -138,6 +138,7 @@ func defaultConfig() Config {
return config{global: configMap{Root: defaultGlobal().Content[0]}}
}

//TODO: Add caching so as not to load config multiple times.
func Load() (Config, error) {
return load(configFile(), hostsConfigFile())
}
Expand Down
12 changes: 11 additions & 1 deletion pkg/api/client.go
Expand Up @@ -32,7 +32,7 @@ type ClientOptions struct {
// Default headers will be overridden by keys specified in Headers.
Headers map[string]string

// Host is the host that every API request will be sent to.
// Host is the default host that API requests will be sent to.
Host string

// Log specifies a writer to write API request logs to.
Expand All @@ -44,8 +44,18 @@ type ClientOptions struct {
Timeout time.Duration

// Transport specifies the mechanism by which individual API requests are made.
// If both Transport and UnixDomainSocket are specified then Transport takes
// precedence. Due to this behavior any value set for Transport needs to manually
// handle routing to UnixDomainSocket if necessary. Generally, setting Transport
// should be reserved for testing purposes.
// Default is http.DefaultTransport.
Transport http.RoundTripper

// UnixDomainSocket specifies the Unix domain socket address by which individual
// API requests will be routed. If specifed, this will form the base of the API
// request transport chain.
// Default is no socket address.
UnixDomainSocket string
mislav marked this conversation as resolved.
Show resolved Hide resolved
}

// RESTClient is the interface that wraps methods for the different types of
Expand Down
11 changes: 6 additions & 5 deletions pkg/api/http.go
Expand Up @@ -2,17 +2,18 @@ package api

import (
"fmt"
"net/http"
"net/url"
"strings"
)

// HTTPError represents an error response from the GitHub API.
type HTTPError struct {
StatusCode int
RequestURL *url.URL
Message string
OAuthScopes string
Errors []HTTPErrorItem
Errors []HTTPErrorItem
Headers http.Header
Message string
RequestURL *url.URL
StatusCode int
}

// HTTPErrorItem stores additional information about an error response
Expand Down