diff --git a/gh.go b/gh.go index 8c5e8f1..8b1ce8e 100644 --- a/gh.go +++ b/gh.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "net/http" - "net/url" "os" "os/exec" @@ -18,9 +17,9 @@ import ( "github.com/cli/go-gh/internal/config" "github.com/cli/go-gh/internal/git" irepo "github.com/cli/go-gh/internal/repository" - "github.com/cli/go-gh/internal/ssh" "github.com/cli/go-gh/pkg/api" repo "github.com/cli/go-gh/pkg/repository" + "github.com/cli/go-gh/pkg/ssh" "github.com/cli/safeexec" ) @@ -128,8 +127,8 @@ func CurrentRepository() (repo.Repository, error) { return nil, errors.New("unable to determine current repository, no git remotes configured for this repository") } - sshConfig := ssh.ParseConfig() - translateRemotes(remotes, sshConfig.Translator()) + translator := ssh.NewTranslator() + translateRemotes(remotes, translator) cfg, err := config.Load() if err != nil { @@ -163,13 +162,13 @@ func resolveOptions(opts *api.ClientOptions, cfg config.Config) error { return nil } -func translateRemotes(remotes git.RemoteSet, urlTranslate func(*url.URL) *url.URL) { +func translateRemotes(remotes git.RemoteSet, translator ssh.Translator) { for _, r := range remotes { if r.FetchURL != nil { - r.FetchURL = urlTranslate(r.FetchURL) + r.FetchURL = translator.Translate(r.FetchURL) } if r.PushURL != nil { - r.PushURL = urlTranslate(r.PushURL) + r.PushURL = translator.Translate(r.PushURL) } } } diff --git a/internal/ssh/ssh.go b/pkg/ssh/ssh.go similarity index 69% rename from internal/ssh/ssh.go rename to pkg/ssh/ssh.go index 67563cf..0794d73 100644 --- a/internal/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -1,3 +1,5 @@ +// Package ssh is a set of types and functions for parsing and +// applying a user's SSH hostname aliases. package ssh import ( @@ -15,34 +17,70 @@ var ( tokenRE = regexp.MustCompile(`%[%h]`) ) -// Config encapsulates the translation of SSH hostname aliases. -type Config map[string]string +// Translator is the interface that encapsulates the SSH hostname alias translate method. +type Translator interface { + Translate(*url.URL) *url.URL +} -// Translator returns a function that applies hostname aliases to URLs. -func (m Config) Translator() func(*url.URL) *url.URL { - return func(u *url.URL) *url.URL { - if u.Scheme != "ssh" { - return u - } - resolvedHost, ok := m[u.Hostname()] - if !ok { - return u - } - if strings.EqualFold(u.Hostname(), "github.com") && strings.EqualFold(resolvedHost, "ssh.github.com") { - return u - } - newURL, _ := url.Parse(u.String()) - newURL.Host = resolvedHost - return newURL - } +type config struct { + aliases map[string]string } type parser struct { - dir string - config Config - hosts []string - open func(string) (io.Reader, error) - glob func(string) ([]string, error) + dir string + cfg config + hosts []string + open func(string) (io.Reader, error) + glob func(string) ([]string, error) +} + +// NewTranslator constructs a map of SSH hostname aliases based on user and system configuration files. +// It returns a Translator to apply these mappings. +func NewTranslator() Translator { + configFiles := []string{ + "/etc/ssh_config", + "/etc/ssh/ssh_config", + } + + p := parser{} + + if sshDir, err := homeDirPath(".ssh"); err == nil { + userConfig := filepath.Join(sshDir, "config") + configFiles = append([]string{userConfig}, configFiles...) + p.dir = filepath.Dir(sshDir) + } + + for _, file := range configFiles { + _ = p.read(file) + } + return p.cfg +} + +func homeDirPath(subdir string) (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", err + } + + newPath := filepath.Join(homeDir, subdir) + return newPath, nil +} + +// Translate applies applicable SSH hostname aliases to the specified URL and returns the resulting URL. +func (c config) Translate(u *url.URL) *url.URL { + if u.Scheme != "ssh" { + return u + } + resolvedHost, ok := c.aliases[u.Hostname()] + if !ok { + return u + } + if strings.EqualFold(u.Hostname(), "github.com") && strings.EqualFold(resolvedHost, "ssh.github.com") { + return u + } + newURL, _ := url.Parse(u.String()) + newURL.Host = resolvedHost + return newURL } func (p *parser) read(fileName string) error { @@ -80,10 +118,10 @@ func (p *parser) read(fileName string) error { case "hostname": for _, host := range p.hosts { for _, name := range strings.Fields(arguments) { - if p.config == nil { - p.config = make(Config) + if p.cfg.aliases == nil { + p.cfg.aliases = make(map[string]string) } - p.config[host] = expandTokens(name, host) + p.cfg.aliases[host] = expandTokens(name, host) } } case "include": @@ -132,37 +170,6 @@ func (p *parser) absolutePath(parentFile, path string) string { return filepath.Join(p.dir, ".ssh", path) } -// ParseConfig constructs a map of SSH hostname aliases based on user and system configuration files. -func ParseConfig() Config { - configFiles := []string{ - "/etc/ssh_config", - "/etc/ssh/ssh_config", - } - - p := parser{} - - if sshDir, err := homeDirPath(".ssh"); err == nil { - userConfig := filepath.Join(sshDir, "config") - configFiles = append([]string{userConfig}, configFiles...) - p.dir = filepath.Dir(sshDir) - } - - for _, file := range configFiles { - _ = p.read(file) - } - return p.config -} - -func homeDirPath(subdir string) (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", err - } - - newPath := filepath.Join(homeDir, subdir) - return newPath, nil -} - func expandTokens(text, host string) string { return tokenRE.ReplaceAllStringFunc(text, func(match string) string { switch match { diff --git a/internal/ssh/ssh_test.go b/pkg/ssh/ssh_test.go similarity index 87% rename from internal/ssh/ssh_test.go rename to pkg/ssh/ssh_test.go index d239bb6..eb66072 100644 --- a/internal/ssh/ssh_test.go +++ b/pkg/ssh/ssh_test.go @@ -66,19 +66,19 @@ func Test_sshParser_read(t *testing.T) { t.Fatalf("read(user config) = %v", err) } - if got := p.config["gh"]; got != "github.com" { + if got := p.cfg.aliases["gh"]; got != "github.com" { t.Errorf("expected alias %q to expand to %q, got %q", "gh", "github.com", got) } - if got := p.config["gittyhubby"]; got != "github.com" { + if got := p.cfg.aliases["gittyhubby"]; got != "github.com" { t.Errorf("expected alias %q to expand to %q, got %q", "gittyhubby", "github.com", got) } - if got := p.config["example.com"]; got != "" { + if got := p.cfg.aliases["example.com"]; got != "" { t.Errorf("expected alias %q to expand to %q, got %q", "example.com", "", got) } - if got := p.config["ex"]; got != "example.com" { + if got := p.cfg.aliases["ex"]; got != "example.com" { t.Errorf("expected alias %q to expand to %q, got %q", "ex", "example.com", got) } - if got := p.config["s1"]; got != "site1.net" { + if got := p.cfg.aliases["s1"]; got != "site1.net" { t.Errorf("expected alias %q to expand to %q, got %q", "s1", "site1.net", got) } } @@ -124,21 +124,23 @@ func Test_sshParser_absolutePath(t *testing.T) { } } -func Test_Translator(t *testing.T) { - m := Config{ - "gh": "github.com", - "github.com": "ssh.github.com", +func Test_Translate(t *testing.T) { + m := config{ + aliases: map[string]string{ + "gh": "github.com", + "github.com": "ssh.github.com", + }, } - tr := m.Translator() cases := [][]string{ {"ssh://gh/o/r", "ssh://github.com/o/r"}, {"ssh://github.com/o/r", "ssh://github.com/o/r"}, {"https://gh/o/r", "https://gh/o/r"}, } + for _, c := range cases { u, _ := url.Parse(c[0]) - got := tr(u) + got := m.Translate(u) if got.String() != c[1] { t.Errorf("%q: expected %q, got %q", c[0], c[1], got) }