diff --git a/gh_test.go b/gh_test.go index 9bc127d..68b19e8 100644 --- a/gh_test.go +++ b/gh_test.go @@ -4,11 +4,11 @@ import ( "fmt" "net/http" "os" - "path/filepath" "strings" "testing" "github.com/cli/go-gh/pkg/api" + "github.com/cli/go-gh/pkg/config" "github.com/stretchr/testify/assert" "gopkg.in/h2non/gock.v1" ) @@ -49,20 +49,12 @@ func TestRunError(t *testing.T) { } func TestRESTClient(t *testing.T) { + stubConfig(t, testConfig()) t.Cleanup(gock.Off) - tempDir := t.TempDir() - orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR") - orig_GH_TOKEN := os.Getenv("GH_TOKEN") - t.Cleanup(func() { - os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR) - os.Setenv("GH_TOKEN", orig_GH_TOKEN) - }) - os.Setenv("GH_CONFIG_DIR", tempDir) - os.Setenv("GH_TOKEN", "GH_TOKEN") gock.New("https://api.github.com"). Get("/some/test/path"). - MatchHeader("Authorization", "token GH_TOKEN"). + MatchHeader("Authorization", "token abc123"). Reply(200). JSON(`{"message": "success"}`) @@ -77,20 +69,12 @@ func TestRESTClient(t *testing.T) { } func TestGQLClient(t *testing.T) { + stubConfig(t, testConfig()) t.Cleanup(gock.Off) - tempDir := t.TempDir() - orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR") - orig_GH_TOKEN := os.Getenv("GH_TOKEN") - t.Cleanup(func() { - os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR) - os.Setenv("GH_TOKEN", orig_GH_TOKEN) - }) - os.Setenv("GH_CONFIG_DIR", tempDir) - os.Setenv("GH_TOKEN", "GH_TOKEN") gock.New("https://api.github.com"). Post("/graphql"). - MatchHeader("Authorization", "token GH_TOKEN"). + MatchHeader("Authorization", "token abc123"). BodyString(`{"query":"QUERY","variables":{"var":"test"}}`). Reply(200). JSON(`{"data":{"viewer":{"login":"hubot"}}}`) @@ -107,20 +91,12 @@ func TestGQLClient(t *testing.T) { } func TestGQLClientError(t *testing.T) { + stubConfig(t, testConfig()) t.Cleanup(gock.Off) - tempDir := t.TempDir() - orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR") - orig_GH_TOKEN := os.Getenv("GH_TOKEN") - t.Cleanup(func() { - os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR) - os.Setenv("GH_TOKEN", orig_GH_TOKEN) - }) - os.Setenv("GH_CONFIG_DIR", tempDir) - os.Setenv("GH_TOKEN", "GH_TOKEN") gock.New("https://api.github.com"). Post("/graphql"). - MatchHeader("Authorization", "token GH_TOKEN"). + MatchHeader("Authorization", "token abc123"). BodyString(`{"query":"QUERY","variables":null}`). Reply(200). JSON(`{"errors":[{"type":"NOT_FOUND","path":["organization"],"message":"Could not resolve to an Organization with the login of 'cli'."}]}`) @@ -135,20 +111,12 @@ func TestGQLClientError(t *testing.T) { } func TestHTTPClient(t *testing.T) { + stubConfig(t, testConfig()) t.Cleanup(gock.Off) - tempDir := t.TempDir() - orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR") - orig_GH_TOKEN := os.Getenv("GH_TOKEN") - t.Cleanup(func() { - os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR) - os.Setenv("GH_TOKEN", orig_GH_TOKEN) - }) - os.Setenv("GH_CONFIG_DIR", tempDir) - os.Setenv("GH_TOKEN", "GH_TOKEN") gock.New("https://api.github.com"). Get("/some/test/path"). - MatchHeader("Authorization", "token GH_TOKEN"). + MatchHeader("Authorization", "token abc123"). Reply(200). JSON(`{"message": "success"}`) @@ -162,18 +130,7 @@ func TestHTTPClient(t *testing.T) { } func TestResolveOptions(t *testing.T) { - tempDir := t.TempDir() - orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR") - t.Cleanup(func() { - os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR) - }) - os.Setenv("GH_CONFIG_DIR", tempDir) - globalFilePath := filepath.Join(tempDir, "config.yml") - hostsFilePath := filepath.Join(tempDir, "hosts.yml") - err := os.WriteFile(globalFilePath, []byte(testGlobalData()), 0755) - assert.NoError(t, err) - err = os.WriteFile(hostsFilePath, []byte(testHostsData()), 0755) - assert.NoError(t, err) + stubConfig(t, testConfigWithSocket()) tests := []struct { name string @@ -314,23 +271,6 @@ func TestOptionsNeedResolution(t *testing.T) { } } -func testGlobalData() string { - var data = ` -http_unix_socket: socket -` - return data -} - -func testHostsData() string { - var data = ` -github.com: - user: user1 - oauth_token: token - git_protocol: ssh -` - return data -} - func printPendingMocks(mocks []gock.Mock) string { paths := []string{} for _, mock := range mocks { @@ -338,3 +278,35 @@ func printPendingMocks(mocks []gock.Mock) string { } return fmt.Sprintf("%d unmatched mocks: %s", len(paths), strings.Join(paths, ", ")) } + +func stubConfig(t *testing.T, cfgStr string) { + t.Helper() + old := config.Read + config.Read = func() (*config.Config, error) { + return config.ReadFromString(cfgStr), nil + } + t.Cleanup(func() { + config.Read = old + }) +} + +func testConfig() string { + return ` +hosts: + github.com: + user: user1 + oauth_token: abc123 + git_protocol: ssh +` +} + +func testConfigWithSocket() string { + return ` +http_unix_socket: socket +hosts: + github.com: + user: user1 + oauth_token: token + git_protocol: ssh +` +} diff --git a/internal/yamlmap/yaml_map.go b/internal/yamlmap/yaml_map.go index ac2d54d..6c904ba 100644 --- a/internal/yamlmap/yaml_map.go +++ b/internal/yamlmap/yaml_map.go @@ -147,6 +147,10 @@ func (m *Map) SetEntry(key string, value *Map) { // has no impact for our purposes. func (m *Map) SetModified() { // Can not mark a non-mapping node as modified + if m.Node.Kind != yaml.MappingNode && m.Node.Tag == "!!null" { + m.Node.Kind = yaml.MappingNode + m.Node.Tag = "!!map" + } if m.Node.Kind == yaml.MappingNode { m.Node.Value = modified } diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 6c663ee..1bcd5eb 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -4,6 +4,7 @@ package auth import ( "os" + "strconv" "strings" "github.com/cli/go-gh/internal/set" @@ -11,6 +12,7 @@ import ( ) const ( + codespaces = "CODESPACES" defaultSource = "default" ghEnterpriseToken = "GH_ENTERPRISE_TOKEN" ghHost = "GH_HOST" @@ -40,6 +42,11 @@ func tokenForHost(cfg *config.Config, host string) (string, string) { if token := os.Getenv(githubEnterpriseToken); token != "" { return token, githubEnterpriseToken } + if isCodespaces, _ := strconv.ParseBool(os.Getenv(codespaces)); isCodespaces { + if token := os.Getenv(githubToken); token != "" { + return token, githubToken + } + } if cfg != nil { token, _ := cfg.Get([]string{hostsKey, host, oauthToken}) return token, oauthToken diff --git a/pkg/config/config.go b/pkg/config/config.go index 307434d..20d0bee 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "runtime" + "sync" "github.com/cli/go-gh/internal/yamlmap" ) @@ -22,12 +23,18 @@ const ( xdgStateHome = "XDG_STATE_HOME" ) +var ( + cfg *Config + once sync.Once +) + // Config is a in memory representation of the gh configuration files. // It can be thought of as map where entries consist of a key that // correspond to either a string value or a map value, allowing for // multi-level maps. type Config struct { entries *yamlmap.Map + mu sync.RWMutex } // Get a string value from a Config. @@ -36,6 +43,8 @@ type Config struct { // if trying to retrieve a key that corresponds to a map value. // Returns "", KeyNotFoundError if any of the keys can not be found. func (c *Config) Get(keys []string) (string, error) { + c.mu.RLock() + defer c.mu.RUnlock() m := c.entries for _, key := range keys { var err error @@ -52,6 +61,8 @@ func (c *Config) Get(keys []string) (string, error) { // map values can be have their keys enumerated. // Returns nil, KeyNotFoundError if any of the keys can not be found. func (c *Config) Keys(keys []string) ([]string, error) { + c.mu.RLock() + defer c.mu.RUnlock() m := c.entries for _, key := range keys { var err error @@ -69,6 +80,8 @@ func (c *Config) Keys(keys []string) ([]string, error) { // entries removes those also. // Returns KeyNotFoundError if any of the keys can not be found. func (c *Config) Remove(keys []string) error { + c.mu.Lock() + defer c.mu.Unlock() m := c.entries for i := 0; i < len(keys)-1; i++ { var err error @@ -90,6 +103,8 @@ func (c *Config) Remove(keys []string) error { // entries can be set. If any of the keys do not exist they will // be created. func (c *Config) Set(keys []string, value string) { + c.mu.Lock() + defer c.mu.Unlock() m := c.entries for i := 0; i < len(keys)-1; i++ { key := keys[i] @@ -105,10 +120,12 @@ func (c *Config) Set(keys []string, value string) { // Read gh configuration files from the local file system and // return a Config. -func Read() (*Config, error) { - // TODO: Make global config singleton using sync.Once - // so as not to read from file every time. - return load(generalConfigFile(), hostsConfigFile()) +var Read = func() (*Config, error) { + var err error + once.Do(func() { + cfg, err = load(generalConfigFile(), hostsConfigFile()) + }) + return cfg, err } // ReadFromString takes a yaml string and returns a Config. @@ -119,14 +136,16 @@ func ReadFromString(str string) *Config { if m == nil { m = yamlmap.MapValue() } - return &Config{m} + return &Config{entries: m} } // Write gh configuration files to the local file system. // It will only write gh configuration files that have been modified // since last being read. -func Write(config *Config) error { - hosts, err := config.entries.FindEntry("hosts") +func Write(c *Config) error { + c.mu.Lock() + defer c.mu.Unlock() + hosts, err := c.entries.FindEntry("hosts") if err == nil && hosts.IsModified() { err := writeFile(hostsConfigFile(), []byte(hosts.String())) if err != nil { @@ -135,20 +154,20 @@ func Write(config *Config) error { hosts.SetUnmodified() } - if config.entries.IsModified() { + if c.entries.IsModified() { // Hosts gets written to a different file above so remove it // before writing and add it back in after writing. - hostsMap, hostsErr := config.entries.FindEntry("hosts") + hostsMap, hostsErr := c.entries.FindEntry("hosts") if hostsErr == nil { - _ = config.entries.RemoveEntry("hosts") + _ = c.entries.RemoveEntry("hosts") } - err := writeFile(generalConfigFile(), []byte(config.entries.String())) + err := writeFile(generalConfigFile(), []byte(c.entries.String())) if err != nil { return err } - config.entries.SetUnmodified() + c.entries.SetUnmodified() if hostsErr == nil { - config.entries.AddEntry("hosts", hostsMap) + c.entries.AddEntry("hosts", hostsMap) } } @@ -182,15 +201,15 @@ func load(generalFilePath, hostsFilePath string) (*Config, error) { generalMap.AddEntry("hosts", hostsMap) } - return &Config{generalMap}, nil + return &Config{entries: generalMap}, nil } func generalConfigFile() string { - return filepath.Join(configDir(), "config.yml") + return filepath.Join(ConfigDir(), "config.yml") } func hostsConfigFile() string { - return filepath.Join(configDir(), "hosts.yml") + return filepath.Join(ConfigDir(), "hosts.yml") } func mapFromFile(filename string) (*yamlmap.Map, error) { @@ -206,7 +225,7 @@ func mapFromString(str string) (*yamlmap.Map, error) { } // Config path precedence: GH_CONFIG_DIR, XDG_CONFIG_HOME, AppData (windows only), HOME. -func configDir() string { +func ConfigDir() string { var path string if a := os.Getenv(ghConfigDir); a != "" { path = a @@ -222,7 +241,7 @@ func configDir() string { } // State path precedence: XDG_STATE_HOME, LocalAppData (windows only), HOME. -func stateDir() string { +func StateDir() string { var path string if a := os.Getenv(xdgStateHome); a != "" { path = filepath.Join(a, "gh") @@ -236,7 +255,7 @@ func stateDir() string { } // Data path precedence: XDG_DATA_HOME, LocalAppData (windows only), HOME. -func dataDir() string { +func DataDir() string { var path string if a := os.Getenv(xdgDataHome); a != "" { path = filepath.Join(a, "gh") diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index f926fea..52c3c72 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -92,7 +92,7 @@ func TestConfigDir(t *testing.T) { defer os.Setenv(k, old) } } - assert.Equal(t, tt.output, configDir()) + assert.Equal(t, tt.output, ConfigDir()) }) } } @@ -156,7 +156,7 @@ func TestStateDir(t *testing.T) { defer os.Setenv(k, old) } } - assert.Equal(t, tt.output, stateDir()) + assert.Equal(t, tt.output, StateDir()) }) } } @@ -220,7 +220,7 @@ func TestDataDir(t *testing.T) { defer os.Setenv(k, old) } } - assert.Equal(t, tt.output, dataDir()) + assert.Equal(t, tt.output, DataDir()) }) } } @@ -380,7 +380,7 @@ func TestWrite(t *testing.T) { cfg := tt.createConfig() err := Write(cfg) assert.NoError(t, err) - loadedCfg, err := Read() + loadedCfg, err := load(generalConfigFile(), hostsConfigFile()) assert.NoError(t, err) wantCfg := cfg if tt.wantConfig != nil { diff --git a/pkg/repository/repository_test.go b/pkg/repository/repository_test.go index 065183b..e0cb2f1 100644 --- a/pkg/repository/repository_test.go +++ b/pkg/repository/repository_test.go @@ -2,13 +2,15 @@ package repository import ( "os" - "path/filepath" "testing" + "github.com/cli/go-gh/pkg/config" "github.com/stretchr/testify/assert" ) func TestParse(t *testing.T) { + stubConfig(t, "") + tests := []struct { name string input string @@ -102,29 +104,14 @@ func TestParse(t *testing.T) { } func TestParse_hostFromConfig(t *testing.T) { - tempDir := t.TempDir() - old := os.Getenv("GH_CONFIG_DIR") - os.Setenv("GH_CONFIG_DIR", tempDir) - t.Cleanup(func() { - os.Setenv("GH_CONFIG_DIR", old) - }) - - var configData = ` -git_protocol: ssh -editor: -prompt: enabled -pager: less -` - var hostData = ` -enterprise.com: - user: user2 - oauth_token: yyyyyyyyyyyyyyyyyyyy - git_protocol: https + var cfgStr = ` +hosts: + enterprise.com: + user: user2 + oauth_token: yyyyyyyyyyyyyyyyyyyy + git_protocol: https ` - err := os.WriteFile(filepath.Join(tempDir, "config.yml"), []byte(configData), 0644) - assert.NoError(t, err) - err = os.WriteFile(filepath.Join(tempDir, "hosts.yml"), []byte(hostData), 0644) - assert.NoError(t, err) + stubConfig(t, cfgStr) r, err := Parse("OWNER/REPO") assert.NoError(t, err) assert.Equal(t, "enterprise.com", r.Host()) @@ -207,3 +194,14 @@ func TestParseWithHost(t *testing.T) { }) } } + +func stubConfig(t *testing.T, cfgStr string) { + t.Helper() + old := config.Read + config.Read = func() (*config.Config, error) { + return config.ReadFromString(cfgStr), nil + } + t.Cleanup(func() { + config.Read = old + }) +}