Skip to content

Commit

Permalink
Turn config into a singleton (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
samcoe committed Jun 21, 2022
1 parent 3c417a3 commit f06a22b
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 116 deletions.
112 changes: 42 additions & 70 deletions gh_test.go
Expand Up @@ -4,11 +4,11 @@ import (
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"testing"

"github.com/cli/go-gh/pkg/api"
"github.com/cli/go-gh/pkg/config"
"github.com/stretchr/testify/assert"
"gopkg.in/h2non/gock.v1"
)
Expand Down Expand Up @@ -49,20 +49,12 @@ func TestRunError(t *testing.T) {
}

func TestRESTClient(t *testing.T) {
stubConfig(t, testConfig())
t.Cleanup(gock.Off)
tempDir := t.TempDir()
orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR")
orig_GH_TOKEN := os.Getenv("GH_TOKEN")
t.Cleanup(func() {
os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR)
os.Setenv("GH_TOKEN", orig_GH_TOKEN)
})
os.Setenv("GH_CONFIG_DIR", tempDir)
os.Setenv("GH_TOKEN", "GH_TOKEN")

gock.New("https://api.github.com").
Get("/some/test/path").
MatchHeader("Authorization", "token GH_TOKEN").
MatchHeader("Authorization", "token abc123").
Reply(200).
JSON(`{"message": "success"}`)

Expand All @@ -77,20 +69,12 @@ func TestRESTClient(t *testing.T) {
}

func TestGQLClient(t *testing.T) {
stubConfig(t, testConfig())
t.Cleanup(gock.Off)
tempDir := t.TempDir()
orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR")
orig_GH_TOKEN := os.Getenv("GH_TOKEN")
t.Cleanup(func() {
os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR)
os.Setenv("GH_TOKEN", orig_GH_TOKEN)
})
os.Setenv("GH_CONFIG_DIR", tempDir)
os.Setenv("GH_TOKEN", "GH_TOKEN")

gock.New("https://api.github.com").
Post("/graphql").
MatchHeader("Authorization", "token GH_TOKEN").
MatchHeader("Authorization", "token abc123").
BodyString(`{"query":"QUERY","variables":{"var":"test"}}`).
Reply(200).
JSON(`{"data":{"viewer":{"login":"hubot"}}}`)
Expand All @@ -107,20 +91,12 @@ func TestGQLClient(t *testing.T) {
}

func TestGQLClientError(t *testing.T) {
stubConfig(t, testConfig())
t.Cleanup(gock.Off)
tempDir := t.TempDir()
orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR")
orig_GH_TOKEN := os.Getenv("GH_TOKEN")
t.Cleanup(func() {
os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR)
os.Setenv("GH_TOKEN", orig_GH_TOKEN)
})
os.Setenv("GH_CONFIG_DIR", tempDir)
os.Setenv("GH_TOKEN", "GH_TOKEN")

gock.New("https://api.github.com").
Post("/graphql").
MatchHeader("Authorization", "token GH_TOKEN").
MatchHeader("Authorization", "token abc123").
BodyString(`{"query":"QUERY","variables":null}`).
Reply(200).
JSON(`{"errors":[{"type":"NOT_FOUND","path":["organization"],"message":"Could not resolve to an Organization with the login of 'cli'."}]}`)
Expand All @@ -135,20 +111,12 @@ func TestGQLClientError(t *testing.T) {
}

func TestHTTPClient(t *testing.T) {
stubConfig(t, testConfig())
t.Cleanup(gock.Off)
tempDir := t.TempDir()
orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR")
orig_GH_TOKEN := os.Getenv("GH_TOKEN")
t.Cleanup(func() {
os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR)
os.Setenv("GH_TOKEN", orig_GH_TOKEN)
})
os.Setenv("GH_CONFIG_DIR", tempDir)
os.Setenv("GH_TOKEN", "GH_TOKEN")

gock.New("https://api.github.com").
Get("/some/test/path").
MatchHeader("Authorization", "token GH_TOKEN").
MatchHeader("Authorization", "token abc123").
Reply(200).
JSON(`{"message": "success"}`)

Expand All @@ -162,18 +130,7 @@ func TestHTTPClient(t *testing.T) {
}

func TestResolveOptions(t *testing.T) {
tempDir := t.TempDir()
orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR")
t.Cleanup(func() {
os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR)
})
os.Setenv("GH_CONFIG_DIR", tempDir)
globalFilePath := filepath.Join(tempDir, "config.yml")
hostsFilePath := filepath.Join(tempDir, "hosts.yml")
err := os.WriteFile(globalFilePath, []byte(testGlobalData()), 0755)
assert.NoError(t, err)
err = os.WriteFile(hostsFilePath, []byte(testHostsData()), 0755)
assert.NoError(t, err)
stubConfig(t, testConfigWithSocket())

tests := []struct {
name string
Expand Down Expand Up @@ -314,27 +271,42 @@ func TestOptionsNeedResolution(t *testing.T) {
}
}

func testGlobalData() string {
var data = `
http_unix_socket: socket
`
return data
}

func testHostsData() string {
var data = `
github.com:
user: user1
oauth_token: token
git_protocol: ssh
`
return data
}

func printPendingMocks(mocks []gock.Mock) string {
paths := []string{}
for _, mock := range mocks {
paths = append(paths, mock.Request().URLStruct.String())
}
return fmt.Sprintf("%d unmatched mocks: %s", len(paths), strings.Join(paths, ", "))
}

func stubConfig(t *testing.T, cfgStr string) {
t.Helper()
old := config.Read
config.Read = func() (*config.Config, error) {
return config.ReadFromString(cfgStr), nil
}
t.Cleanup(func() {
config.Read = old
})
}

func testConfig() string {
return `
hosts:
github.com:
user: user1
oauth_token: abc123
git_protocol: ssh
`
}

func testConfigWithSocket() string {
return `
http_unix_socket: socket
hosts:
github.com:
user: user1
oauth_token: token
git_protocol: ssh
`
}
4 changes: 4 additions & 0 deletions internal/yamlmap/yaml_map.go
Expand Up @@ -147,6 +147,10 @@ func (m *Map) SetEntry(key string, value *Map) {
// has no impact for our purposes.
func (m *Map) SetModified() {
// Can not mark a non-mapping node as modified
if m.Node.Kind != yaml.MappingNode && m.Node.Tag == "!!null" {
m.Node.Kind = yaml.MappingNode
m.Node.Tag = "!!map"
}
if m.Node.Kind == yaml.MappingNode {
m.Node.Value = modified
}
Expand Down
7 changes: 7 additions & 0 deletions pkg/auth/auth.go
Expand Up @@ -4,13 +4,15 @@ package auth

import (
"os"
"strconv"
"strings"

"github.com/cli/go-gh/internal/set"
"github.com/cli/go-gh/pkg/config"
)

const (
codespaces = "CODESPACES"
defaultSource = "default"
ghEnterpriseToken = "GH_ENTERPRISE_TOKEN"
ghHost = "GH_HOST"
Expand Down Expand Up @@ -40,6 +42,11 @@ func tokenForHost(cfg *config.Config, host string) (string, string) {
if token := os.Getenv(githubEnterpriseToken); token != "" {
return token, githubEnterpriseToken
}
if isCodespaces, _ := strconv.ParseBool(os.Getenv(codespaces)); isCodespaces {
if token := os.Getenv(githubToken); token != "" {
return token, githubToken
}
}
if cfg != nil {
token, _ := cfg.Get([]string{hostsKey, host, oauthToken})
return token, oauthToken
Expand Down
57 changes: 38 additions & 19 deletions pkg/config/config.go
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"path/filepath"
"runtime"
"sync"

"github.com/cli/go-gh/internal/yamlmap"
)
Expand All @@ -22,12 +23,18 @@ const (
xdgStateHome = "XDG_STATE_HOME"
)

var (
cfg *Config
once sync.Once
)

// Config is a in memory representation of the gh configuration files.
// It can be thought of as map where entries consist of a key that
// correspond to either a string value or a map value, allowing for
// multi-level maps.
type Config struct {
entries *yamlmap.Map
mu sync.RWMutex
}

// Get a string value from a Config.
Expand All @@ -36,6 +43,8 @@ type Config struct {
// if trying to retrieve a key that corresponds to a map value.
// Returns "", KeyNotFoundError if any of the keys can not be found.
func (c *Config) Get(keys []string) (string, error) {
c.mu.RLock()
defer c.mu.RUnlock()
m := c.entries
for _, key := range keys {
var err error
Expand All @@ -52,6 +61,8 @@ func (c *Config) Get(keys []string) (string, error) {
// map values can be have their keys enumerated.
// Returns nil, KeyNotFoundError if any of the keys can not be found.
func (c *Config) Keys(keys []string) ([]string, error) {
c.mu.RLock()
defer c.mu.RUnlock()
m := c.entries
for _, key := range keys {
var err error
Expand All @@ -69,6 +80,8 @@ func (c *Config) Keys(keys []string) ([]string, error) {
// entries removes those also.
// Returns KeyNotFoundError if any of the keys can not be found.
func (c *Config) Remove(keys []string) error {
c.mu.Lock()
defer c.mu.Unlock()
m := c.entries
for i := 0; i < len(keys)-1; i++ {
var err error
Expand All @@ -90,6 +103,8 @@ func (c *Config) Remove(keys []string) error {
// entries can be set. If any of the keys do not exist they will
// be created.
func (c *Config) Set(keys []string, value string) {
c.mu.Lock()
defer c.mu.Unlock()
m := c.entries
for i := 0; i < len(keys)-1; i++ {
key := keys[i]
Expand All @@ -105,10 +120,12 @@ func (c *Config) Set(keys []string, value string) {

// Read gh configuration files from the local file system and
// return a Config.
func Read() (*Config, error) {
// TODO: Make global config singleton using sync.Once
// so as not to read from file every time.
return load(generalConfigFile(), hostsConfigFile())
var Read = func() (*Config, error) {
var err error
once.Do(func() {
cfg, err = load(generalConfigFile(), hostsConfigFile())
})
return cfg, err
}

// ReadFromString takes a yaml string and returns a Config.
Expand All @@ -119,14 +136,16 @@ func ReadFromString(str string) *Config {
if m == nil {
m = yamlmap.MapValue()
}
return &Config{m}
return &Config{entries: m}
}

// Write gh configuration files to the local file system.
// It will only write gh configuration files that have been modified
// since last being read.
func Write(config *Config) error {
hosts, err := config.entries.FindEntry("hosts")
func Write(c *Config) error {
c.mu.Lock()
defer c.mu.Unlock()
hosts, err := c.entries.FindEntry("hosts")
if err == nil && hosts.IsModified() {
err := writeFile(hostsConfigFile(), []byte(hosts.String()))
if err != nil {
Expand All @@ -135,20 +154,20 @@ func Write(config *Config) error {
hosts.SetUnmodified()
}

if config.entries.IsModified() {
if c.entries.IsModified() {
// Hosts gets written to a different file above so remove it
// before writing and add it back in after writing.
hostsMap, hostsErr := config.entries.FindEntry("hosts")
hostsMap, hostsErr := c.entries.FindEntry("hosts")
if hostsErr == nil {
_ = config.entries.RemoveEntry("hosts")
_ = c.entries.RemoveEntry("hosts")
}
err := writeFile(generalConfigFile(), []byte(config.entries.String()))
err := writeFile(generalConfigFile(), []byte(c.entries.String()))
if err != nil {
return err
}
config.entries.SetUnmodified()
c.entries.SetUnmodified()
if hostsErr == nil {
config.entries.AddEntry("hosts", hostsMap)
c.entries.AddEntry("hosts", hostsMap)
}
}

Expand Down Expand Up @@ -182,15 +201,15 @@ func load(generalFilePath, hostsFilePath string) (*Config, error) {
generalMap.AddEntry("hosts", hostsMap)
}

return &Config{generalMap}, nil
return &Config{entries: generalMap}, nil
}

func generalConfigFile() string {
return filepath.Join(configDir(), "config.yml")
return filepath.Join(ConfigDir(), "config.yml")
}

func hostsConfigFile() string {
return filepath.Join(configDir(), "hosts.yml")
return filepath.Join(ConfigDir(), "hosts.yml")
}

func mapFromFile(filename string) (*yamlmap.Map, error) {
Expand All @@ -206,7 +225,7 @@ func mapFromString(str string) (*yamlmap.Map, error) {
}

// Config path precedence: GH_CONFIG_DIR, XDG_CONFIG_HOME, AppData (windows only), HOME.
func configDir() string {
func ConfigDir() string {
var path string
if a := os.Getenv(ghConfigDir); a != "" {
path = a
Expand All @@ -222,7 +241,7 @@ func configDir() string {
}

// State path precedence: XDG_STATE_HOME, LocalAppData (windows only), HOME.
func stateDir() string {
func StateDir() string {
var path string
if a := os.Getenv(xdgStateHome); a != "" {
path = filepath.Join(a, "gh")
Expand All @@ -236,7 +255,7 @@ func stateDir() string {
}

// Data path precedence: XDG_DATA_HOME, LocalAppData (windows only), HOME.
func dataDir() string {
func DataDir() string {
var path string
if a := os.Getenv(xdgDataHome); a != "" {
path = filepath.Join(a, "gh")
Expand Down

0 comments on commit f06a22b

Please sign in to comment.