Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export ssh package #40

Merged
merged 1 commit into from May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 6 additions & 7 deletions gh.go
Expand Up @@ -10,17 +10,16 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"os"
"os/exec"

iapi "github.com/cli/go-gh/internal/api"
"github.com/cli/go-gh/internal/config"
"github.com/cli/go-gh/internal/git"
irepo "github.com/cli/go-gh/internal/repository"
"github.com/cli/go-gh/internal/ssh"
"github.com/cli/go-gh/pkg/api"
repo "github.com/cli/go-gh/pkg/repository"
"github.com/cli/go-gh/pkg/ssh"
"github.com/cli/safeexec"
)

Expand Down Expand Up @@ -128,8 +127,8 @@ func CurrentRepository() (repo.Repository, error) {
return nil, errors.New("unable to determine current repository, no git remotes configured for this repository")
}

sshConfig := ssh.ParseConfig()
translateRemotes(remotes, sshConfig.Translator())
translator := ssh.NewTranslator()
translateRemotes(remotes, translator)

cfg, err := config.Load()
if err != nil {
Expand Down Expand Up @@ -163,13 +162,13 @@ func resolveOptions(opts *api.ClientOptions, cfg config.Config) error {
return nil
}

func translateRemotes(remotes git.RemoteSet, urlTranslate func(*url.URL) *url.URL) {
func translateRemotes(remotes git.RemoteSet, translator ssh.Translator) {
for _, r := range remotes {
if r.FetchURL != nil {
r.FetchURL = urlTranslate(r.FetchURL)
r.FetchURL = translator.Translate(r.FetchURL)
}
if r.PushURL != nil {
r.PushURL = urlTranslate(r.PushURL)
r.PushURL = translator.Translate(r.PushURL)
}
}
}
123 changes: 65 additions & 58 deletions internal/ssh/ssh.go → pkg/ssh/ssh.go
@@ -1,3 +1,5 @@
// Package ssh is a set of types and functions for parsing and
// applying a user's SSH hostname aliases.
package ssh

import (
Expand All @@ -15,34 +17,70 @@ var (
tokenRE = regexp.MustCompile(`%[%h]`)
)

// Config encapsulates the translation of SSH hostname aliases.
type Config map[string]string
// Translator is the interface that encapsulates the SSH hostname alias translate method.
type Translator interface {
Translate(*url.URL) *url.URL
}

// Translator returns a function that applies hostname aliases to URLs.
func (m Config) Translator() func(*url.URL) *url.URL {
return func(u *url.URL) *url.URL {
if u.Scheme != "ssh" {
return u
}
resolvedHost, ok := m[u.Hostname()]
if !ok {
return u
}
if strings.EqualFold(u.Hostname(), "github.com") && strings.EqualFold(resolvedHost, "ssh.github.com") {
return u
}
newURL, _ := url.Parse(u.String())
newURL.Host = resolvedHost
return newURL
}
type config struct {
aliases map[string]string
}

type parser struct {
dir string
config Config
hosts []string
open func(string) (io.Reader, error)
glob func(string) ([]string, error)
dir string
cfg config
hosts []string
open func(string) (io.Reader, error)
glob func(string) ([]string, error)
}

// NewTranslator constructs a map of SSH hostname aliases based on user and system configuration files.
// It returns a Translator to apply these mappings.
func NewTranslator() Translator {
configFiles := []string{
"/etc/ssh_config",
"/etc/ssh/ssh_config",
}

p := parser{}

if sshDir, err := homeDirPath(".ssh"); err == nil {
userConfig := filepath.Join(sshDir, "config")
configFiles = append([]string{userConfig}, configFiles...)
p.dir = filepath.Dir(sshDir)
}

for _, file := range configFiles {
_ = p.read(file)
}
return p.cfg
}

func homeDirPath(subdir string) (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}

newPath := filepath.Join(homeDir, subdir)
return newPath, nil
}

// Translate applies applicable SSH hostname aliases to the specified URL and returns the resulting URL.
func (c config) Translate(u *url.URL) *url.URL {
if u.Scheme != "ssh" {
return u
}
resolvedHost, ok := c.aliases[u.Hostname()]
if !ok {
return u
}
if strings.EqualFold(u.Hostname(), "github.com") && strings.EqualFold(resolvedHost, "ssh.github.com") {
return u
}
newURL, _ := url.Parse(u.String())
newURL.Host = resolvedHost
return newURL
}

func (p *parser) read(fileName string) error {
Expand Down Expand Up @@ -80,10 +118,10 @@ func (p *parser) read(fileName string) error {
case "hostname":
for _, host := range p.hosts {
for _, name := range strings.Fields(arguments) {
if p.config == nil {
p.config = make(Config)
if p.cfg.aliases == nil {
p.cfg.aliases = make(map[string]string)
}
p.config[host] = expandTokens(name, host)
p.cfg.aliases[host] = expandTokens(name, host)
}
}
case "include":
Expand Down Expand Up @@ -132,37 +170,6 @@ func (p *parser) absolutePath(parentFile, path string) string {
return filepath.Join(p.dir, ".ssh", path)
}

// ParseConfig constructs a map of SSH hostname aliases based on user and system configuration files.
func ParseConfig() Config {
configFiles := []string{
"/etc/ssh_config",
"/etc/ssh/ssh_config",
}

p := parser{}

if sshDir, err := homeDirPath(".ssh"); err == nil {
userConfig := filepath.Join(sshDir, "config")
configFiles = append([]string{userConfig}, configFiles...)
p.dir = filepath.Dir(sshDir)
}

for _, file := range configFiles {
_ = p.read(file)
}
return p.config
}

func homeDirPath(subdir string) (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}

newPath := filepath.Join(homeDir, subdir)
return newPath, nil
}

func expandTokens(text, host string) string {
return tokenRE.ReplaceAllStringFunc(text, func(match string) string {
switch match {
Expand Down
24 changes: 13 additions & 11 deletions internal/ssh/ssh_test.go → pkg/ssh/ssh_test.go
Expand Up @@ -66,19 +66,19 @@ func Test_sshParser_read(t *testing.T) {
t.Fatalf("read(user config) = %v", err)
}

if got := p.config["gh"]; got != "github.com" {
if got := p.cfg.aliases["gh"]; got != "github.com" {
t.Errorf("expected alias %q to expand to %q, got %q", "gh", "github.com", got)
}
if got := p.config["gittyhubby"]; got != "github.com" {
if got := p.cfg.aliases["gittyhubby"]; got != "github.com" {
t.Errorf("expected alias %q to expand to %q, got %q", "gittyhubby", "github.com", got)
}
if got := p.config["example.com"]; got != "" {
if got := p.cfg.aliases["example.com"]; got != "" {
t.Errorf("expected alias %q to expand to %q, got %q", "example.com", "", got)
}
if got := p.config["ex"]; got != "example.com" {
if got := p.cfg.aliases["ex"]; got != "example.com" {
t.Errorf("expected alias %q to expand to %q, got %q", "ex", "example.com", got)
}
if got := p.config["s1"]; got != "site1.net" {
if got := p.cfg.aliases["s1"]; got != "site1.net" {
t.Errorf("expected alias %q to expand to %q, got %q", "s1", "site1.net", got)
}
}
Expand Down Expand Up @@ -124,21 +124,23 @@ func Test_sshParser_absolutePath(t *testing.T) {
}
}

func Test_Translator(t *testing.T) {
m := Config{
"gh": "github.com",
"github.com": "ssh.github.com",
func Test_Translate(t *testing.T) {
m := config{
aliases: map[string]string{
"gh": "github.com",
"github.com": "ssh.github.com",
},
}
tr := m.Translator()

cases := [][]string{
{"ssh://gh/o/r", "ssh://github.com/o/r"},
{"ssh://github.com/o/r", "ssh://github.com/o/r"},
{"https://gh/o/r", "https://gh/o/r"},
}

for _, c := range cases {
u, _ := url.Parse(c[0])
got := tr(u)
got := m.Translate(u)
if got.String() != c[1] {
t.Errorf("%q: expected %q, got %q", c[0], c[1], got)
}
Expand Down