Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Turn config into a singleton #45

Merged
merged 2 commits into from Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
112 changes: 42 additions & 70 deletions gh_test.go
Expand Up @@ -4,11 +4,11 @@ import (
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"testing"

"github.com/cli/go-gh/pkg/api"
"github.com/cli/go-gh/pkg/config"
"github.com/stretchr/testify/assert"
"gopkg.in/h2non/gock.v1"
)
Expand Down Expand Up @@ -49,20 +49,12 @@ func TestRunError(t *testing.T) {
}

func TestRESTClient(t *testing.T) {
stubConfig(t, testConfig())
t.Cleanup(gock.Off)
tempDir := t.TempDir()
orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR")
orig_GH_TOKEN := os.Getenv("GH_TOKEN")
t.Cleanup(func() {
os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR)
os.Setenv("GH_TOKEN", orig_GH_TOKEN)
})
os.Setenv("GH_CONFIG_DIR", tempDir)
os.Setenv("GH_TOKEN", "GH_TOKEN")

gock.New("https://api.github.com").
Get("/some/test/path").
MatchHeader("Authorization", "token GH_TOKEN").
MatchHeader("Authorization", "token abc123").
Reply(200).
JSON(`{"message": "success"}`)

Expand All @@ -77,20 +69,12 @@ func TestRESTClient(t *testing.T) {
}

func TestGQLClient(t *testing.T) {
stubConfig(t, testConfig())
t.Cleanup(gock.Off)
tempDir := t.TempDir()
orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR")
orig_GH_TOKEN := os.Getenv("GH_TOKEN")
t.Cleanup(func() {
os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR)
os.Setenv("GH_TOKEN", orig_GH_TOKEN)
})
os.Setenv("GH_CONFIG_DIR", tempDir)
os.Setenv("GH_TOKEN", "GH_TOKEN")

gock.New("https://api.github.com").
Post("/graphql").
MatchHeader("Authorization", "token GH_TOKEN").
MatchHeader("Authorization", "token abc123").
BodyString(`{"query":"QUERY","variables":{"var":"test"}}`).
Reply(200).
JSON(`{"data":{"viewer":{"login":"hubot"}}}`)
Expand All @@ -107,20 +91,12 @@ func TestGQLClient(t *testing.T) {
}

func TestGQLClientError(t *testing.T) {
stubConfig(t, testConfig())
t.Cleanup(gock.Off)
tempDir := t.TempDir()
orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR")
orig_GH_TOKEN := os.Getenv("GH_TOKEN")
t.Cleanup(func() {
os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR)
os.Setenv("GH_TOKEN", orig_GH_TOKEN)
})
os.Setenv("GH_CONFIG_DIR", tempDir)
os.Setenv("GH_TOKEN", "GH_TOKEN")

gock.New("https://api.github.com").
Post("/graphql").
MatchHeader("Authorization", "token GH_TOKEN").
MatchHeader("Authorization", "token abc123").
BodyString(`{"query":"QUERY","variables":null}`).
Reply(200).
JSON(`{"errors":[{"type":"NOT_FOUND","path":["organization"],"message":"Could not resolve to an Organization with the login of 'cli'."}]}`)
Expand All @@ -135,20 +111,12 @@ func TestGQLClientError(t *testing.T) {
}

func TestHTTPClient(t *testing.T) {
stubConfig(t, testConfig())
t.Cleanup(gock.Off)
tempDir := t.TempDir()
orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR")
orig_GH_TOKEN := os.Getenv("GH_TOKEN")
t.Cleanup(func() {
os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR)
os.Setenv("GH_TOKEN", orig_GH_TOKEN)
})
os.Setenv("GH_CONFIG_DIR", tempDir)
os.Setenv("GH_TOKEN", "GH_TOKEN")

gock.New("https://api.github.com").
Get("/some/test/path").
MatchHeader("Authorization", "token GH_TOKEN").
MatchHeader("Authorization", "token abc123").
Reply(200).
JSON(`{"message": "success"}`)

Expand All @@ -162,18 +130,7 @@ func TestHTTPClient(t *testing.T) {
}

func TestResolveOptions(t *testing.T) {
tempDir := t.TempDir()
orig_GH_CONFIG_DIR := os.Getenv("GH_CONFIG_DIR")
t.Cleanup(func() {
os.Setenv("GH_CONFIG_DIR", orig_GH_CONFIG_DIR)
})
os.Setenv("GH_CONFIG_DIR", tempDir)
globalFilePath := filepath.Join(tempDir, "config.yml")
hostsFilePath := filepath.Join(tempDir, "hosts.yml")
err := os.WriteFile(globalFilePath, []byte(testGlobalData()), 0755)
assert.NoError(t, err)
err = os.WriteFile(hostsFilePath, []byte(testHostsData()), 0755)
assert.NoError(t, err)
stubConfig(t, testConfigWithSocket())

tests := []struct {
name string
Expand Down Expand Up @@ -314,27 +271,42 @@ func TestOptionsNeedResolution(t *testing.T) {
}
}

func testGlobalData() string {
var data = `
http_unix_socket: socket
`
return data
}

func testHostsData() string {
var data = `
github.com:
user: user1
oauth_token: token
git_protocol: ssh
`
return data
}

func printPendingMocks(mocks []gock.Mock) string {
paths := []string{}
for _, mock := range mocks {
paths = append(paths, mock.Request().URLStruct.String())
}
return fmt.Sprintf("%d unmatched mocks: %s", len(paths), strings.Join(paths, ", "))
}

func stubConfig(t *testing.T, cfgStr string) {
t.Helper()
old := config.Read
config.Read = func() (*config.Config, error) {
return config.ReadFromString(cfgStr), nil
}
t.Cleanup(func() {
config.Read = old
})
}

func testConfig() string {
return `
hosts:
github.com:
user: user1
oauth_token: abc123
git_protocol: ssh
`
}

func testConfigWithSocket() string {
return `
http_unix_socket: socket
hosts:
github.com:
user: user1
oauth_token: token
git_protocol: ssh
`
}
4 changes: 4 additions & 0 deletions internal/yamlmap/yaml_map.go
Expand Up @@ -147,6 +147,10 @@ func (m *Map) SetEntry(key string, value *Map) {
// has no impact for our purposes.
func (m *Map) SetModified() {
// Can not mark a non-mapping node as modified
if m.Node.Kind != yaml.MappingNode && m.Node.Tag == "!!null" {
m.Node.Kind = yaml.MappingNode
m.Node.Tag = "!!map"
}
if m.Node.Kind == yaml.MappingNode {
m.Node.Value = modified
}
Expand Down
7 changes: 7 additions & 0 deletions pkg/auth/auth.go
Expand Up @@ -4,13 +4,15 @@ package auth

import (
"os"
"strconv"
"strings"

"github.com/cli/go-gh/internal/set"
"github.com/cli/go-gh/pkg/config"
)

const (
codespaces = "CODESPACES"
defaultSource = "default"
ghEnterpriseToken = "GH_ENTERPRISE_TOKEN"
ghHost = "GH_HOST"
Expand Down Expand Up @@ -40,6 +42,11 @@ func tokenForHost(cfg *config.Config, host string) (string, string) {
if token := os.Getenv(githubEnterpriseToken); token != "" {
return token, githubEnterpriseToken
}
if isCodespaces, _ := strconv.ParseBool(os.Getenv(codespaces)); isCodespaces {
if token := os.Getenv(githubToken); token != "" {
return token, githubToken
}
}
if cfg != nil {
token, _ := cfg.Get([]string{hostsKey, host, oauthToken})
return token, oauthToken
Expand Down
57 changes: 38 additions & 19 deletions pkg/config/config.go
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"path/filepath"
"runtime"
"sync"

"github.com/cli/go-gh/internal/yamlmap"
)
Expand All @@ -22,12 +23,18 @@ const (
xdgStateHome = "XDG_STATE_HOME"
)

var (
cfg *Config
once sync.Once
)

// Config is a in memory representation of the gh configuration files.
// It can be thought of as map where entries consist of a key that
// correspond to either a string value or a map value, allowing for
// multi-level maps.
type Config struct {
entries *yamlmap.Map
mu sync.RWMutex
}

// Get a string value from a Config.
Expand All @@ -36,6 +43,8 @@ type Config struct {
// if trying to retrieve a key that corresponds to a map value.
// Returns "", KeyNotFoundError if any of the keys can not be found.
func (c *Config) Get(keys []string) (string, error) {
c.mu.RLock()
defer c.mu.RUnlock()
m := c.entries
for _, key := range keys {
var err error
Expand All @@ -52,6 +61,8 @@ func (c *Config) Get(keys []string) (string, error) {
// map values can be have their keys enumerated.
// Returns nil, KeyNotFoundError if any of the keys can not be found.
func (c *Config) Keys(keys []string) ([]string, error) {
c.mu.RLock()
defer c.mu.RUnlock()
m := c.entries
for _, key := range keys {
var err error
Expand All @@ -69,6 +80,8 @@ func (c *Config) Keys(keys []string) ([]string, error) {
// entries removes those also.
// Returns KeyNotFoundError if any of the keys can not be found.
func (c *Config) Remove(keys []string) error {
c.mu.Lock()
defer c.mu.Unlock()
m := c.entries
for i := 0; i < len(keys)-1; i++ {
var err error
Expand All @@ -90,6 +103,8 @@ func (c *Config) Remove(keys []string) error {
// entries can be set. If any of the keys do not exist they will
// be created.
func (c *Config) Set(keys []string, value string) {
c.mu.Lock()
defer c.mu.Unlock()
m := c.entries
for i := 0; i < len(keys)-1; i++ {
key := keys[i]
Expand All @@ -105,10 +120,12 @@ func (c *Config) Set(keys []string, value string) {

// Read gh configuration files from the local file system and
// return a Config.
func Read() (*Config, error) {
// TODO: Make global config singleton using sync.Once
// so as not to read from file every time.
return load(generalConfigFile(), hostsConfigFile())
var Read = func() (*Config, error) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing this to be a variable for testing purposes seemed to be the best approach here to bypass having to deal with sync.Once in tests.

var err error
once.Do(func() {
cfg, err = load(generalConfigFile(), hostsConfigFile())
})
return cfg, err
}

// ReadFromString takes a yaml string and returns a Config.
Expand All @@ -119,14 +136,16 @@ func ReadFromString(str string) *Config {
if m == nil {
m = yamlmap.MapValue()
}
return &Config{m}
return &Config{entries: m}
}

// Write gh configuration files to the local file system.
// It will only write gh configuration files that have been modified
// since last being read.
func Write(config *Config) error {
hosts, err := config.entries.FindEntry("hosts")
func Write(c *Config) error {
c.mu.Lock()
defer c.mu.Unlock()
hosts, err := c.entries.FindEntry("hosts")
if err == nil && hosts.IsModified() {
err := writeFile(hostsConfigFile(), []byte(hosts.String()))
if err != nil {
Expand All @@ -135,20 +154,20 @@ func Write(config *Config) error {
hosts.SetUnmodified()
}

if config.entries.IsModified() {
if c.entries.IsModified() {
// Hosts gets written to a different file above so remove it
// before writing and add it back in after writing.
hostsMap, hostsErr := config.entries.FindEntry("hosts")
hostsMap, hostsErr := c.entries.FindEntry("hosts")
if hostsErr == nil {
_ = config.entries.RemoveEntry("hosts")
_ = c.entries.RemoveEntry("hosts")
}
err := writeFile(generalConfigFile(), []byte(config.entries.String()))
err := writeFile(generalConfigFile(), []byte(c.entries.String()))
if err != nil {
return err
}
config.entries.SetUnmodified()
c.entries.SetUnmodified()
if hostsErr == nil {
config.entries.AddEntry("hosts", hostsMap)
c.entries.AddEntry("hosts", hostsMap)
}
}

Expand Down Expand Up @@ -182,15 +201,15 @@ func load(generalFilePath, hostsFilePath string) (*Config, error) {
generalMap.AddEntry("hosts", hostsMap)
}

return &Config{generalMap}, nil
return &Config{entries: generalMap}, nil
}

func generalConfigFile() string {
return filepath.Join(configDir(), "config.yml")
return filepath.Join(ConfigDir(), "config.yml")
}

func hostsConfigFile() string {
return filepath.Join(configDir(), "hosts.yml")
return filepath.Join(ConfigDir(), "hosts.yml")
}

func mapFromFile(filename string) (*yamlmap.Map, error) {
Expand All @@ -206,7 +225,7 @@ func mapFromString(str string) (*yamlmap.Map, error) {
}

// Config path precedence: GH_CONFIG_DIR, XDG_CONFIG_HOME, AppData (windows only), HOME.
func configDir() string {
func ConfigDir() string {
var path string
if a := os.Getenv(ghConfigDir); a != "" {
path = a
Expand All @@ -222,7 +241,7 @@ func configDir() string {
}

// State path precedence: XDG_STATE_HOME, LocalAppData (windows only), HOME.
func stateDir() string {
func StateDir() string {
var path string
if a := os.Getenv(xdgStateHome); a != "" {
path = filepath.Join(a, "gh")
Expand All @@ -236,7 +255,7 @@ func stateDir() string {
}

// Data path precedence: XDG_DATA_HOME, LocalAppData (windows only), HOME.
func dataDir() string {
func DataDir() string {
var path string
if a := os.Getenv(xdgDataHome); a != "" {
path = filepath.Join(a, "gh")
Expand Down