Skip to content

Commit

Permalink
Export ssh package (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
samcoe committed May 17, 2022
1 parent 8d34617 commit 293b1eb
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 76 deletions.
13 changes: 6 additions & 7 deletions gh.go
Expand Up @@ -10,17 +10,16 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"os"
"os/exec"

iapi "github.com/cli/go-gh/internal/api"
"github.com/cli/go-gh/internal/config"
"github.com/cli/go-gh/internal/git"
irepo "github.com/cli/go-gh/internal/repository"
"github.com/cli/go-gh/internal/ssh"
"github.com/cli/go-gh/pkg/api"
repo "github.com/cli/go-gh/pkg/repository"
"github.com/cli/go-gh/pkg/ssh"
"github.com/cli/safeexec"
)

Expand Down Expand Up @@ -128,8 +127,8 @@ func CurrentRepository() (repo.Repository, error) {
return nil, errors.New("unable to determine current repository, no git remotes configured for this repository")
}

sshConfig := ssh.ParseConfig()
translateRemotes(remotes, sshConfig.Translator())
translator := ssh.NewTranslator()
translateRemotes(remotes, translator)

cfg, err := config.Load()
if err != nil {
Expand Down Expand Up @@ -163,13 +162,13 @@ func resolveOptions(opts *api.ClientOptions, cfg config.Config) error {
return nil
}

func translateRemotes(remotes git.RemoteSet, urlTranslate func(*url.URL) *url.URL) {
func translateRemotes(remotes git.RemoteSet, translator ssh.Translator) {
for _, r := range remotes {
if r.FetchURL != nil {
r.FetchURL = urlTranslate(r.FetchURL)
r.FetchURL = translator.Translate(r.FetchURL)
}
if r.PushURL != nil {
r.PushURL = urlTranslate(r.PushURL)
r.PushURL = translator.Translate(r.PushURL)
}
}
}
123 changes: 65 additions & 58 deletions internal/ssh/ssh.go → pkg/ssh/ssh.go
@@ -1,3 +1,5 @@
// Package ssh is a set of types and functions for parsing and
// applying a user's SSH hostname aliases.
package ssh

import (
Expand All @@ -15,34 +17,70 @@ var (
tokenRE = regexp.MustCompile(`%[%h]`)
)

// Config encapsulates the translation of SSH hostname aliases.
type Config map[string]string
// Translator is the interface that encapsulates the SSH hostname alias translate method.
type Translator interface {
Translate(*url.URL) *url.URL
}

// Translator returns a function that applies hostname aliases to URLs.
func (m Config) Translator() func(*url.URL) *url.URL {
return func(u *url.URL) *url.URL {
if u.Scheme != "ssh" {
return u
}
resolvedHost, ok := m[u.Hostname()]
if !ok {
return u
}
if strings.EqualFold(u.Hostname(), "github.com") && strings.EqualFold(resolvedHost, "ssh.github.com") {
return u
}
newURL, _ := url.Parse(u.String())
newURL.Host = resolvedHost
return newURL
}
type config struct {
aliases map[string]string
}

type parser struct {
dir string
config Config
hosts []string
open func(string) (io.Reader, error)
glob func(string) ([]string, error)
dir string
cfg config
hosts []string
open func(string) (io.Reader, error)
glob func(string) ([]string, error)
}

// NewTranslator constructs a map of SSH hostname aliases based on user and system configuration files.
// It returns a Translator to apply these mappings.
func NewTranslator() Translator {
configFiles := []string{
"/etc/ssh_config",
"/etc/ssh/ssh_config",
}

p := parser{}

if sshDir, err := homeDirPath(".ssh"); err == nil {
userConfig := filepath.Join(sshDir, "config")
configFiles = append([]string{userConfig}, configFiles...)
p.dir = filepath.Dir(sshDir)
}

for _, file := range configFiles {
_ = p.read(file)
}
return p.cfg
}

func homeDirPath(subdir string) (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}

newPath := filepath.Join(homeDir, subdir)
return newPath, nil
}

// Translate applies applicable SSH hostname aliases to the specified URL and returns the resulting URL.
func (c config) Translate(u *url.URL) *url.URL {
if u.Scheme != "ssh" {
return u
}
resolvedHost, ok := c.aliases[u.Hostname()]
if !ok {
return u
}
if strings.EqualFold(u.Hostname(), "github.com") && strings.EqualFold(resolvedHost, "ssh.github.com") {
return u
}
newURL, _ := url.Parse(u.String())
newURL.Host = resolvedHost
return newURL
}

func (p *parser) read(fileName string) error {
Expand Down Expand Up @@ -80,10 +118,10 @@ func (p *parser) read(fileName string) error {
case "hostname":
for _, host := range p.hosts {
for _, name := range strings.Fields(arguments) {
if p.config == nil {
p.config = make(Config)
if p.cfg.aliases == nil {
p.cfg.aliases = make(map[string]string)
}
p.config[host] = expandTokens(name, host)
p.cfg.aliases[host] = expandTokens(name, host)
}
}
case "include":
Expand Down Expand Up @@ -132,37 +170,6 @@ func (p *parser) absolutePath(parentFile, path string) string {
return filepath.Join(p.dir, ".ssh", path)
}

// ParseConfig constructs a map of SSH hostname aliases based on user and system configuration files.
func ParseConfig() Config {
configFiles := []string{
"/etc/ssh_config",
"/etc/ssh/ssh_config",
}

p := parser{}

if sshDir, err := homeDirPath(".ssh"); err == nil {
userConfig := filepath.Join(sshDir, "config")
configFiles = append([]string{userConfig}, configFiles...)
p.dir = filepath.Dir(sshDir)
}

for _, file := range configFiles {
_ = p.read(file)
}
return p.config
}

func homeDirPath(subdir string) (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}

newPath := filepath.Join(homeDir, subdir)
return newPath, nil
}

func expandTokens(text, host string) string {
return tokenRE.ReplaceAllStringFunc(text, func(match string) string {
switch match {
Expand Down
24 changes: 13 additions & 11 deletions internal/ssh/ssh_test.go → pkg/ssh/ssh_test.go
Expand Up @@ -66,19 +66,19 @@ func Test_sshParser_read(t *testing.T) {
t.Fatalf("read(user config) = %v", err)
}

if got := p.config["gh"]; got != "github.com" {
if got := p.cfg.aliases["gh"]; got != "github.com" {
t.Errorf("expected alias %q to expand to %q, got %q", "gh", "github.com", got)
}
if got := p.config["gittyhubby"]; got != "github.com" {
if got := p.cfg.aliases["gittyhubby"]; got != "github.com" {
t.Errorf("expected alias %q to expand to %q, got %q", "gittyhubby", "github.com", got)
}
if got := p.config["example.com"]; got != "" {
if got := p.cfg.aliases["example.com"]; got != "" {
t.Errorf("expected alias %q to expand to %q, got %q", "example.com", "", got)
}
if got := p.config["ex"]; got != "example.com" {
if got := p.cfg.aliases["ex"]; got != "example.com" {
t.Errorf("expected alias %q to expand to %q, got %q", "ex", "example.com", got)
}
if got := p.config["s1"]; got != "site1.net" {
if got := p.cfg.aliases["s1"]; got != "site1.net" {
t.Errorf("expected alias %q to expand to %q, got %q", "s1", "site1.net", got)
}
}
Expand Down Expand Up @@ -124,21 +124,23 @@ func Test_sshParser_absolutePath(t *testing.T) {
}
}

func Test_Translator(t *testing.T) {
m := Config{
"gh": "github.com",
"github.com": "ssh.github.com",
func Test_Translate(t *testing.T) {
m := config{
aliases: map[string]string{
"gh": "github.com",
"github.com": "ssh.github.com",
},
}
tr := m.Translator()

cases := [][]string{
{"ssh://gh/o/r", "ssh://github.com/o/r"},
{"ssh://github.com/o/r", "ssh://github.com/o/r"},
{"https://gh/o/r", "https://gh/o/r"},
}

for _, c := range cases {
u, _ := url.Parse(c[0])
got := tr(u)
got := m.Translate(u)
if got.String() != c[1] {
t.Errorf("%q: expected %q, got %q", c[0], c[1], got)
}
Expand Down

0 comments on commit 293b1eb

Please sign in to comment.