Skip to content

Commit

Permalink
feat: Support --header for CLI commands to support proxies
Browse files Browse the repository at this point in the history
Fixes #3527.
  • Loading branch information
kylecarbs committed Sep 12, 2022
1 parent 66ad86a commit 7a42d60
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 2 deletions.
12 changes: 12 additions & 0 deletions cli/cliflag/cliflag.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string
flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env))
}

func StringArray(flagset *pflag.FlagSet, name, shorthand, env string, def []string, usage string) {
v, ok := os.LookupEnv(env)
if !ok || v == "" {
if v == "" {
def = []string{}
} else {
def = strings.Split(v, ",")
}
}
flagset.StringArrayP(name, shorthand, def, fmtUsage(usage, env))
}

func StringArrayVarP(flagset *pflag.FlagSet, ptr *[]string, name string, shorthand string, env string, def []string, usage string) {
val, ok := os.LookupEnv(env)
if ok {
Expand Down
5 changes: 4 additions & 1 deletion cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ func login() *cobra.Command {
serverURL.Scheme = "https"
}

client := codersdk.New(serverURL)
client, err := createUnauthenticatedClient(cmd, serverURL)
if err != nil {
return err
}

// Try to check the version of the server prior to logging in.
// It may be useful to warn the user if they are trying to login
Expand Down
41 changes: 40 additions & 1 deletion cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"flag"
"fmt"
"net/http"
"net/url"
"os"
"strings"
Expand Down Expand Up @@ -41,6 +42,7 @@ const (
varAgentToken = "agent-token"
varAgentURL = "agent-url"
varGlobalConfig = "global-config"
varHeader = "header"
varNoOpen = "no-open"
varNoVersionCheck = "no-version-warning"
varNoFeatureWarning = "no-feature-warning"
Expand Down Expand Up @@ -174,6 +176,7 @@ func Root(subcommands []*cobra.Command) *cobra.Command {
cliflag.String(cmd.PersistentFlags(), varAgentURL, "", "CODER_AGENT_URL", "", "Specify the URL for an agent to access your deployment.")
_ = cmd.PersistentFlags().MarkHidden(varAgentURL)
cliflag.String(cmd.PersistentFlags(), varGlobalConfig, "", "CODER_CONFIG_DIR", configdir.LocalConfig("coderv2"), "Specify the path to the global `coder` config directory.")
cliflag.StringArray(cmd.PersistentFlags(), varHeader, "", "CODER_HEADER", []string{}, "HTTP headers added to all requests. Provide as \"Key=Value\"")
cmd.PersistentFlags().Bool(varForceTty, false, "Force the `coder` command to run as if connected to a TTY.")
_ = cmd.PersistentFlags().MarkHidden(varForceTty)
cmd.PersistentFlags().Bool(varNoOpen, false, "Block automatically opening URLs in the browser.")
Expand Down Expand Up @@ -237,8 +240,32 @@ func CreateClient(cmd *cobra.Command) (*codersdk.Client, error) {
return nil, err
}
}
client, err := createUnauthenticatedClient(cmd, serverURL)
if err != nil {
return nil, err
}
client.SessionToken = token
return client, nil
}

func createUnauthenticatedClient(cmd *cobra.Command, serverURL *url.URL) (*codersdk.Client, error) {
client := codersdk.New(serverURL)
client.SessionToken = strings.TrimSpace(token)
headers, err := cmd.Flags().GetStringArray(varHeader)
if err != nil {
return nil, err
}
transport := &headerTransport{
transport: http.DefaultTransport,
headers: map[string]string{},
}
for _, header := range headers {
parts := strings.SplitN(header, "=", 2)
if len(parts) < 2 {
return nil, xerrors.Errorf("split header %q had less than two parts", header)
}
transport.headers[parts[0]] = parts[1]
}
client.HTTPClient.Transport = transport
return client, nil
}

Expand Down Expand Up @@ -530,3 +557,15 @@ func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error {

return nil
}

type headerTransport struct {
transport http.RoundTripper
headers map[string]string
}

func (h *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for k, v := range h.headers {
req.Header.Add(k, v)
}
return h.transport.RoundTrip(req)
}
24 changes: 24 additions & 0 deletions cli/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package cli_test

import (
"bytes"
"net/http"
"net/http/httptest"
"testing"

"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"

Expand Down Expand Up @@ -129,4 +132,25 @@ func TestRoot(t *testing.T) {
require.Contains(t, output, buildinfo.Version(), "has version")
require.Contains(t, output, buildinfo.ExternalURL(), "has url")
})

t.Run("Header", func(t *testing.T) {
t.Parallel()

done := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "wow", r.Header.Get("X-Testing"))
w.WriteHeader(http.StatusGone)
select {
case <-done:
close(done)
default:
}
}))
defer srv.Close()
buf := new(bytes.Buffer)
cmd, _ := clitest.New(t, "--header", "X-Testing=wow", "login", srv.URL)
cmd.SetOut(buf)
// This won't succeed, because we're using the login cmd to assert requests.
_ = cmd.Execute()
})
}

0 comments on commit 7a42d60

Please sign in to comment.