Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
samcoe committed May 26, 2022
1 parent 9041602 commit faad1a1
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 75 deletions.
34 changes: 9 additions & 25 deletions gh.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,7 @@ func RESTClient(opts *api.ClientOptions) (api.RESTClient, error) {
opts = &api.ClientOptions{}
}
if optionsNeedResolution(opts) {
cfg, err := config.Read()
if err != nil {
return nil, err
}
err = resolveOptions(opts, cfg)
err := resolveOptions(opts)
if err != nil {
return nil, err
}
Expand All @@ -83,11 +79,7 @@ func GQLClient(opts *api.ClientOptions) (api.GQLClient, error) {
opts = &api.ClientOptions{}
}
if optionsNeedResolution(opts) {
cfg, err := config.Read()
if err != nil {
return nil, err
}
err = resolveOptions(opts, cfg)
err := resolveOptions(opts)
if err != nil {
return nil, err
}
Expand All @@ -109,11 +101,7 @@ func HTTPClient(opts *api.ClientOptions) (*http.Client, error) {
opts = &api.ClientOptions{}
}
if optionsNeedResolution(opts) {
cfg, err := config.Read()
if err != nil {
return nil, err
}
err = resolveOptions(opts, cfg)
err := resolveOptions(opts)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -141,12 +129,7 @@ func CurrentRepository() (repo.Repository, error) {
translator := ssh.NewTranslator()
translateRemotes(remotes, translator)

cfg, err := config.Read()
if err != nil {
return nil, err
}

hosts := auth.KnownHosts(cfg)
hosts := auth.KnownHosts()

filteredRemotes := remotes.FilterByHosts(hosts)
if len(filteredRemotes) == 0 {
Expand All @@ -170,17 +153,18 @@ func optionsNeedResolution(opts *api.ClientOptions) bool {
return false
}

func resolveOptions(opts *api.ClientOptions, cfg *config.Config) error {
func resolveOptions(opts *api.ClientOptions) error {
cfg, _ := config.Read()
if opts.Host == "" {
opts.Host, _ = auth.DefaultHost(cfg)
opts.Host, _ = auth.DefaultHost()
}
if opts.AuthToken == "" {
opts.AuthToken, _ = auth.TokenForHost(cfg, opts.Host)
opts.AuthToken, _ = auth.TokenForHost(opts.Host)
if opts.AuthToken == "" {
return fmt.Errorf("authentication token not found for host %s", opts.Host)
}
}
if opts.UnixDomainSocket == "" {
if opts.UnixDomainSocket == "" && cfg != nil {
opts.UnixDomainSocket, _ = config.Get(cfg, []string{"http_unix_socket"})
}
return nil
Expand Down
36 changes: 26 additions & 10 deletions gh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import (
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"testing"

"github.com/cli/go-gh/pkg/api"
"github.com/cli/go-gh/pkg/config"
"github.com/stretchr/testify/assert"
"gopkg.in/h2non/gock.v1"
)
Expand Down Expand Up @@ -162,7 +162,18 @@ func TestHTTPClient(t *testing.T) {
}

func TestResolveOptions(t *testing.T) {
cfg := testConfig()
tempDir := t.TempDir()
orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR")
t.Cleanup(func() {
os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR)
})
os.Setenv("GH_CONFIG_DIR", tempDir)
globalFilePath := filepath.Join(tempDir, "config.yml")
hostsFilePath := filepath.Join(tempDir, "hosts.yml")
err := os.WriteFile(globalFilePath, []byte(testGlobalData()), 0755)
assert.NoError(t, err)
err = os.WriteFile(hostsFilePath, []byte(testHostsData()), 0755)
assert.NoError(t, err)

tests := []struct {
name string
Expand Down Expand Up @@ -193,7 +204,7 @@ func TestResolveOptions(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := resolveOptions(tt.opts, cfg)
err := resolveOptions(tt.opts)
assert.NoError(t, err)
assert.Equal(t, tt.wantHost, tt.opts.Host)
assert.Equal(t, tt.wantAuthToken, tt.opts.AuthToken)
Expand Down Expand Up @@ -303,16 +314,21 @@ func TestOptionsNeedResolution(t *testing.T) {
}
}

func testConfig() *config.Config {
func testGlobalData() string {
var data = `
hosts:
github.com:
user: user1
oauth_token: token
git_protocol: ssh
http_unix_socket: socket
`
return config.ReadFromString(data)
return data
}

func testHostsData() string {
var data = `
github.com:
user: user1
oauth_token: token
git_protocol: ssh
`
return data
}

func printPendingMocks(mocks []gock.Mock) string {
Expand Down
64 changes: 44 additions & 20 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

const (
defaultHost = "github.com"
github = "github.com"
ghEnterpriseToken = "GH_ENTERPRISE_TOKEN"
ghHost = "GH_HOST"
ghToken = "GH_TOKEN"
Expand All @@ -19,7 +19,12 @@ const (
hostsKey = "hosts"
)

func TokenForHost(cfg *config.Config, host string) (string, string) {
func TokenForHost(host string) (string, string) {
cfg, _ := config.Read()
return tokenForHost(cfg, host)
}

func tokenForHost(cfg *config.Config, host string) (string, string) {
host = normalizeHostname(host)
if isEnterprise(host) {
if token := os.Getenv(ghEnterpriseToken); token != "" {
Expand All @@ -28,53 +33,72 @@ func TokenForHost(cfg *config.Config, host string) (string, string) {
if token := os.Getenv(githubEnterpriseToken); token != "" {
return token, githubEnterpriseToken
}
token, _ := config.Get(cfg, []string{hostsKey, host, oauthToken})
return token, oauthToken
if cfg != nil {
token, _ := config.Get(cfg, []string{hostsKey, host, oauthToken})
return token, oauthToken
}
}
if token := os.Getenv(ghToken); token != "" {
return token, ghToken
}
if token := os.Getenv(githubToken); token != "" {
return token, githubToken
}
token, _ := config.Get(cfg, []string{hostsKey, host, oauthToken})
return token, oauthToken
if cfg != nil {
token, _ := config.Get(cfg, []string{hostsKey, host, oauthToken})
return token, oauthToken
}
return "", ""
}

func KnownHosts() []string {
cfg, _ := config.Read()
return knownHosts(cfg)
}

func KnownHosts(cfg *config.Config) []string {
func knownHosts(cfg *config.Config) []string {
hosts := set.NewStringSet()
if host := os.Getenv(ghHost); host != "" {
hosts.Add(host)
}
if token, _ := TokenForHost(cfg, defaultHost); token != "" {
hosts.Add(defaultHost)
if token, _ := tokenForHost(cfg, github); token != "" {
hosts.Add(github)
}
keys, err := config.Keys(cfg, []string{hostsKey})
if err == nil {
hosts.AddValues(keys)
if cfg != nil {
keys, err := config.Keys(cfg, []string{hostsKey})
if err == nil {
hosts.AddValues(keys)
}
}
return hosts.ToSlice()
}

func DefaultHost(cfg *config.Config) (string, string) {
func DefaultHost() (string, string) {
cfg, _ := config.Read()
return defaultHost(cfg)
}

func defaultHost(cfg *config.Config) (string, string) {
if host := os.Getenv(ghHost); host != "" {
return host, ghHost
}
keys, err := config.Keys(cfg, []string{hostsKey})
if err == nil && len(keys) == 1 {
return keys[0], hostsKey
if cfg != nil {
keys, err := config.Keys(cfg, []string{hostsKey})
if err == nil && len(keys) == 1 {
return keys[0], hostsKey
}
}
return defaultHost, "default"
return github, "default"
}

func isEnterprise(host string) bool {
return host != defaultHost
return host != github
}

func normalizeHostname(host string) string {
hostname := strings.ToLower(host)
if strings.HasSuffix(hostname, "."+defaultHost) {
return defaultHost
if strings.HasSuffix(hostname, "."+github) {
return github
}
return hostname
}
6 changes: 3 additions & 3 deletions pkg/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestTokenForHost(t *testing.T) {
os.Setenv("GITHUB_ENTERPRISE_TOKEN", tt.githubEnterpriseToken)
os.Setenv("GH_TOKEN", tt.ghToken)
os.Setenv("GH_ENTERPRISE_TOKEN", tt.ghEnterpriseToken)
token, source := TokenForHost(tt.config, tt.host)
token, source := tokenForHost(tt.config, tt.host)
assert.Equal(t, tt.wantToken, token)
assert.Equal(t, tt.wantSource, source)
})
Expand Down Expand Up @@ -157,7 +157,7 @@ func TestDefaultHost(t *testing.T) {
os.Setenv(k, tt.ghHost)
defer os.Setenv(k, old)
}
host, source := DefaultHost(tt.config)
host, source := defaultHost(tt.config)
assert.Equal(t, tt.wantHost, host)
assert.Equal(t, tt.wantSource, source)
})
Expand Down Expand Up @@ -217,7 +217,7 @@ func TestKnownHosts(t *testing.T) {
os.Setenv(k, tt.ghToken)
defer os.Setenv(k, old)
}
hosts := KnownHosts(tt.config)
hosts := knownHosts(tt.config)
assert.Equal(t, tt.wantHosts, hosts)
})
}
Expand Down
13 changes: 6 additions & 7 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"path/filepath"
"runtime"

"github.com/MakeNowJust/heredoc"
"github.com/cli/go-gh/internal/yamlmap"
)

Expand All @@ -33,7 +32,7 @@ func Get(c *Config, keys []string) (string, error) {
var err error
m, err = m.FindEntry(key)
if err != nil {
return "", NotFoundError{key}
return "", KeyNotFoundError{key}
}
}
return m.Value, nil
Expand All @@ -45,7 +44,7 @@ func Keys(c *Config, keys []string) ([]string, error) {
var err error
m, err = m.FindEntry(key)
if err != nil {
return nil, NotFoundError{key}
return nil, KeyNotFoundError{key}
}
}
return m.Keys(), nil
Expand All @@ -58,12 +57,12 @@ func Remove(c *Config, keys []string) error {
key := keys[i]
m, err = m.FindEntry(key)
if err != nil {
return NotFoundError{key}
return KeyNotFoundError{key}
}
}
err := m.RemoveEntry(keys[len(keys)-1])
if err != nil {
return NotFoundError{keys[len(keys)-1]}
return KeyNotFoundError{keys[len(keys)-1]}
}
return nil
}
Expand Down Expand Up @@ -250,7 +249,7 @@ func writeFile(filename string, data []byte) error {
return err
}

var defaultGeneralEntries = heredoc.Doc(`
var defaultGeneralEntries = `
# What protocol to use when performing git operations. Supported values: ssh, https
git_protocol: https
# What editor gh should run when creating issues, pull requests, etc. If blank, will refer to environment.
Expand All @@ -266,4 +265,4 @@ aliases:
http_unix_socket:
# What web browser gh should use when opening URLs. If blank, will refer to environment.
browser:
`)
`
13 changes: 9 additions & 4 deletions pkg/config/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@ func (e InvalidConfigFileError) Error() string {
return fmt.Sprintf("invalid config file %s: %s", e.Path, e.Err)
}

// NotFoundError represents an error when trying to find a config key
// Allow InvalidConfigFileError to be unwrapped.
func (e InvalidConfigFileError) Unwrap() error {
return e.Err
}

// KeyNotFoundError represents an error when trying to find a config key
// that does not exist.
type NotFoundError struct {
type KeyNotFoundError struct {
Key string
}

// Allow NotFoundError to satisfy error interface.
func (e NotFoundError) Error() string {
// Allow KeyNotFoundError to satisfy error interface.
func (e KeyNotFoundError) Error() string {
return fmt.Sprintf("could not find key %q", e.Key)
}
7 changes: 1 addition & 6 deletions pkg/repository/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/cli/go-gh/internal/git"
irepo "github.com/cli/go-gh/internal/repository"
"github.com/cli/go-gh/pkg/auth"
"github.com/cli/go-gh/pkg/config"
)

// Repository is the interface that wraps repository information methods.
Expand Down Expand Up @@ -48,11 +47,7 @@ func Parse(s string) (Repository, error) {
case 3:
return irepo.New(parts[0], parts[1], parts[2]), nil
case 2:
host := "github.com"
cfg, err := config.Read()
if err == nil {
host, _ = auth.DefaultHost(cfg)
}
host, _ := auth.DefaultHost()
return irepo.New(host, parts[0], parts[1]), nil
default:
return nil, fmt.Errorf(`expected the "[HOST/]OWNER/REPO" format, got %q`, s)
Expand Down

0 comments on commit faad1a1

Please sign in to comment.