Skip to content

Commit

Permalink
Improvements to API package for use in cli (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
samcoe committed Jun 14, 2022
1 parent 9dbbfe2 commit ef2bca9
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 18 deletions.
2 changes: 1 addition & 1 deletion gh_test.go
Expand Up @@ -130,7 +130,7 @@ func TestGQLClientError(t *testing.T) {

res := struct{ Organization struct{ Name string } }{}
err = client.Do("QUERY", nil, &res)
assert.EqualError(t, err, "GQL: Could not resolve to an Organization with the login of 'cli'. (organization)")
assert.EqualError(t, err, "GraphQL: Could not resolve to an Organization with the login of 'cli'. (organization)")
assert.True(t, gock.IsDone(), printPendingMocks(gock.Pending()))
}

Expand Down
30 changes: 30 additions & 0 deletions internal/api/cache.go
Expand Up @@ -82,10 +82,25 @@ func (c cache) RoundTripper(rt http.RoundTripper) http.RoundTripper {
}

func (crt cacheRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
reqDir, reqTTL := requestCacheOptions(req)

if crt.fs.ttl == 0 && reqTTL == 0 {
return crt.rt.RoundTrip(req)
}

if !isCacheableRequest(req) {
return crt.rt.RoundTrip(req)
}

origDir := crt.fs.dir
if reqDir != "" {
crt.fs.dir = reqDir
}
origTTL := crt.fs.ttl
if reqTTL != 0 {
crt.fs.ttl = reqTTL
}

key, keyErr := cacheKey(req)
if keyErr == nil {
if res, err := crt.fs.read(key); err == nil {
Expand All @@ -98,9 +113,24 @@ func (crt cacheRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
if err == nil && keyErr == nil && isCacheableResponse(res) {
_ = crt.fs.store(key, res)
}

crt.fs.dir = origDir
crt.fs.ttl = origTTL

return res, err
}

// Allow an individual request to override cache options.
func requestCacheOptions(req *http.Request) (string, time.Duration) {
var dur time.Duration
dir := req.Header.Get("X-GH-CACHE-DIR")
ttl := req.Header.Get("X-GH-CACHE-TTL")
if ttl != "" {
dur, _ = time.ParseDuration(ttl)
}
return dir, dur
}

func (fs *fileStorage) filePath(key string) string {
if len(key) >= 6 {
return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:])
Expand Down
93 changes: 93 additions & 0 deletions internal/api/cache_test.go
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"path/filepath"
"testing"
"time"

"github.com/cli/go-gh/pkg/api"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -91,3 +92,95 @@ func TestCacheResponse(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "7: GET http://example.com/error", res)
}

func TestCacheResponseRequestCacheOptions(t *testing.T) {
counter := 0
fakeHTTP := tripper{
roundTrip: func(req *http.Request) (*http.Response, error) {
counter += 1
body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String())
status := 200
if req.URL.Path == "/error" {
status = 500
}
return &http.Response{
StatusCode: status,
Body: io.NopCloser(bytes.NewBufferString(body)),
}, nil
},
}

cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache")

httpClient := NewHTTPClient(
&api.ClientOptions{
Transport: fakeHTTP,
EnableCache: false,
CacheDir: cacheDir,
})

do := func(method, url string, body io.Reader) (string, error) {
req, err := http.NewRequest(method, url, body)
if err != nil {
return "", err
}
req.Header.Set("X-GH-CACHE-DIR", cacheDir)
req.Header.Set("X-GH-CACHE-TTL", "1h")
res, err := httpClient.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
resBody, err := io.ReadAll(res.Body)
if err != nil {
err = fmt.Errorf("ReadAll: %w", err)
}
return string(resBody), err
}

var res string
var err error

res, err = do("GET", "http://example.com/path", nil)
assert.NoError(t, err)
assert.Equal(t, "1: GET http://example.com/path", res)
res, err = do("GET", "http://example.com/path", nil)
assert.NoError(t, err)
assert.Equal(t, "1: GET http://example.com/path", res)

res, err = do("GET", "http://example.com/path2", nil)
assert.NoError(t, err)
assert.Equal(t, "2: GET http://example.com/path2", res)

res, err = do("POST", "http://example.com/path2", nil)
assert.NoError(t, err)
assert.Equal(t, "3: POST http://example.com/path2", res)

res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
assert.NoError(t, err)
assert.Equal(t, "4: POST http://example.com/graphql", res)
res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
assert.NoError(t, err)
assert.Equal(t, "4: POST http://example.com/graphql", res)

res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`))
assert.NoError(t, err)
assert.Equal(t, "5: POST http://example.com/graphql", res)

res, err = do("GET", "http://example.com/error", nil)
assert.NoError(t, err)
assert.Equal(t, "6: GET http://example.com/error", res)
res, err = do("GET", "http://example.com/error", nil)
assert.NoError(t, err)
assert.Equal(t, "7: GET http://example.com/error", res)
}

func TestRequestCacheOptions(t *testing.T) {
req, err := http.NewRequest("GET", "some/url", nil)
assert.NoError(t, err)
req.Header.Set("X-GH-CACHE-DIR", "some/dir/path")
req.Header.Set("X-GH-CACHE-TTL", "1h")
dir, ttl := requestCacheOptions(req)
assert.Equal(t, dir, "some/dir/path")
assert.Equal(t, ttl, time.Hour)
}
2 changes: 1 addition & 1 deletion internal/api/gql_client_test.go
Expand Up @@ -37,7 +37,7 @@ func TestGQLClientDo(t *testing.T) {
JSON(`{"errors":[{"message":"OH NO"},{"message":"this is fine"}]}`)
},
wantErr: true,
wantErrMsg: "GQL: OH NO, this is fine",
wantErrMsg: "GraphQL: OH NO, this is fine",
},
{
name: "http fail request empty response",
Expand Down
20 changes: 7 additions & 13 deletions internal/api/http.go
Expand Up @@ -83,16 +83,14 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client {
transport = opts.Transport
}

if opts.EnableCache {
if opts.CacheDir == "" {
opts.CacheDir = filepath.Join(os.TempDir(), "gh-cli-cache")
}
if opts.CacheTTL == 0 {
opts.CacheTTL = time.Hour * 24
}
c := cache{dir: opts.CacheDir, ttl: opts.CacheTTL}
transport = c.RoundTripper(transport)
if opts.CacheDir == "" {
opts.CacheDir = filepath.Join(os.TempDir(), "gh-cli-cache")
}
if opts.EnableCache && opts.CacheTTL == 0 {
opts.CacheTTL = time.Hour * 24
}
c := cache{dir: opts.CacheDir, ttl: opts.CacheTTL}
transport = c.RoundTripper(transport)

if opts.Log != nil {
logger := &httpretty.Logger{
Expand Down Expand Up @@ -168,10 +166,6 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str
a := "application/vnd.github.merge-info-preview+json"
// Preview for visibility when RESTing repos into an org.
a += ", application/vnd.github.nebula-preview"
// Preview for Commit.statusCheckRollup for old GHES versions.
a += ", application/vnd.github.antiope-preview"
// Preview for // PullRequest.isDraft for old GHES versions.
a += ", application/vnd.github.shadow-cat-preview"
headers[accept] = a
}
return headerRoundTripper{host: host, headers: headers, rt: rt}
Expand Down
2 changes: 0 additions & 2 deletions internal/api/http_test.go
Expand Up @@ -133,8 +133,6 @@ func defaultHeaders() http.Header {
h := http.Header{}
a := "application/vnd.github.merge-info-preview+json"
a += ", application/vnd.github.nebula-preview"
a += ", application/vnd.github.antiope-preview"
a += ", application/vnd.github.shadow-cat-preview"
h.Set(contentType, jsonContentType)
h.Set(userAgent, "go-gh")
h.Set(authorization, fmt.Sprintf("token %s", "oauth_token"))
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/errors.go
Expand Up @@ -67,7 +67,7 @@ func (gr GQLError) Error() string {
}
errorMessages = append(errorMessages, msg)
}
return fmt.Sprintf("GQL: %s", strings.Join(errorMessages, ", "))
return fmt.Sprintf("GraphQL: %s", strings.Join(errorMessages, ", "))
}

// Match determines if the GQLError is about a specific type on a specific path.
Expand Down
3 changes: 3 additions & 0 deletions pkg/repository/repository_test.go
Expand Up @@ -80,6 +80,9 @@ func TestParse(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
oldDir := os.Getenv("GH_CONFIG_DIR")
os.Setenv("GH_CONFIG_DIR", "nonexistant")
defer os.Setenv("GH_CONFIG_DIR", oldDir)
if tt.hostOverride != "" {
old := os.Getenv("GH_HOST")
os.Setenv("GH_HOST", tt.hostOverride)
Expand Down

0 comments on commit ef2bca9

Please sign in to comment.