Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to API package for use in cli #49

Merged
merged 4 commits into from Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion gh_test.go
Expand Up @@ -130,7 +130,7 @@ func TestGQLClientError(t *testing.T) {

res := struct{ Organization struct{ Name string } }{}
err = client.Do("QUERY", nil, &res)
assert.EqualError(t, err, "GQL: Could not resolve to an Organization with the login of 'cli'. (organization)")
assert.EqualError(t, err, "GraphQL: Could not resolve to an Organization with the login of 'cli'. (organization)")
assert.True(t, gock.IsDone(), printPendingMocks(gock.Pending()))
}

Expand Down
30 changes: 30 additions & 0 deletions internal/api/cache.go
Expand Up @@ -82,10 +82,25 @@ func (c cache) RoundTripper(rt http.RoundTripper) http.RoundTripper {
}

func (crt cacheRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
reqDir, reqTTL := requestCacheOptions(req)

if crt.fs.ttl == 0 && reqTTL == 0 {
return crt.rt.RoundTrip(req)
}

if !isCacheableRequest(req) {
return crt.rt.RoundTrip(req)
}

origDir := crt.fs.dir
if reqDir != "" {
crt.fs.dir = reqDir
}
origTTL := crt.fs.ttl
if reqTTL != 0 {
crt.fs.ttl = reqTTL
}

key, keyErr := cacheKey(req)
if keyErr == nil {
if res, err := crt.fs.read(key); err == nil {
Expand All @@ -98,9 +113,24 @@ func (crt cacheRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
if err == nil && keyErr == nil && isCacheableResponse(res) {
_ = crt.fs.store(key, res)
}

crt.fs.dir = origDir
crt.fs.ttl = origTTL

return res, err
}

// Allow an individual request to override cache options.
func requestCacheOptions(req *http.Request) (string, time.Duration) {
var dur time.Duration
dir := req.Header.Get("X-GH-CACHE-DIR")
ttl := req.Header.Get("X-GH-CACHE-TTL")
if ttl != "" {
dur, _ = time.ParseDuration(ttl)
}
return dir, dur
}

func (fs *fileStorage) filePath(key string) string {
if len(key) >= 6 {
return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:])
Expand Down
93 changes: 93 additions & 0 deletions internal/api/cache_test.go
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"path/filepath"
"testing"
"time"

"github.com/cli/go-gh/pkg/api"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -91,3 +92,95 @@ func TestCacheResponse(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "7: GET http://example.com/error", res)
}

func TestCacheResponseRequestCacheOptions(t *testing.T) {
counter := 0
fakeHTTP := tripper{
roundTrip: func(req *http.Request) (*http.Response, error) {
counter += 1
body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String())
status := 200
if req.URL.Path == "/error" {
status = 500
}
return &http.Response{
StatusCode: status,
Body: io.NopCloser(bytes.NewBufferString(body)),
}, nil
},
}

cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache")

httpClient := NewHTTPClient(
&api.ClientOptions{
Transport: fakeHTTP,
EnableCache: false,
CacheDir: cacheDir,
})

do := func(method, url string, body io.Reader) (string, error) {
req, err := http.NewRequest(method, url, body)
if err != nil {
return "", err
}
req.Header.Set("X-GH-CACHE-DIR", cacheDir)
req.Header.Set("X-GH-CACHE-TTL", "1h")
res, err := httpClient.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
resBody, err := io.ReadAll(res.Body)
if err != nil {
err = fmt.Errorf("ReadAll: %w", err)
}
return string(resBody), err
}

var res string
var err error

res, err = do("GET", "http://example.com/path", nil)
assert.NoError(t, err)
assert.Equal(t, "1: GET http://example.com/path", res)
res, err = do("GET", "http://example.com/path", nil)
assert.NoError(t, err)
assert.Equal(t, "1: GET http://example.com/path", res)

res, err = do("GET", "http://example.com/path2", nil)
assert.NoError(t, err)
assert.Equal(t, "2: GET http://example.com/path2", res)

res, err = do("POST", "http://example.com/path2", nil)
assert.NoError(t, err)
assert.Equal(t, "3: POST http://example.com/path2", res)

res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
assert.NoError(t, err)
assert.Equal(t, "4: POST http://example.com/graphql", res)
res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
assert.NoError(t, err)
assert.Equal(t, "4: POST http://example.com/graphql", res)

res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`))
assert.NoError(t, err)
assert.Equal(t, "5: POST http://example.com/graphql", res)

res, err = do("GET", "http://example.com/error", nil)
assert.NoError(t, err)
assert.Equal(t, "6: GET http://example.com/error", res)
res, err = do("GET", "http://example.com/error", nil)
assert.NoError(t, err)
assert.Equal(t, "7: GET http://example.com/error", res)
}

func TestRequestCacheOptions(t *testing.T) {
req, err := http.NewRequest("GET", "some/url", nil)
assert.NoError(t, err)
req.Header.Set("X-GH-CACHE-DIR", "some/dir/path")
req.Header.Set("X-GH-CACHE-TTL", "1h")
dir, ttl := requestCacheOptions(req)
assert.Equal(t, dir, "some/dir/path")
assert.Equal(t, ttl, time.Hour)
}
2 changes: 1 addition & 1 deletion internal/api/gql_client_test.go
Expand Up @@ -37,7 +37,7 @@ func TestGQLClientDo(t *testing.T) {
JSON(`{"errors":[{"message":"OH NO"},{"message":"this is fine"}]}`)
},
wantErr: true,
wantErrMsg: "GQL: OH NO, this is fine",
wantErrMsg: "GraphQL: OH NO, this is fine",
},
{
name: "http fail request empty response",
Expand Down
20 changes: 7 additions & 13 deletions internal/api/http.go
Expand Up @@ -83,16 +83,14 @@ func NewHTTPClient(opts *api.ClientOptions) http.Client {
transport = opts.Transport
}

if opts.EnableCache {
if opts.CacheDir == "" {
opts.CacheDir = filepath.Join(os.TempDir(), "gh-cli-cache")
}
if opts.CacheTTL == 0 {
opts.CacheTTL = time.Hour * 24
}
c := cache{dir: opts.CacheDir, ttl: opts.CacheTTL}
transport = c.RoundTripper(transport)
if opts.CacheDir == "" {
opts.CacheDir = filepath.Join(os.TempDir(), "gh-cli-cache")
}
if opts.EnableCache && opts.CacheTTL == 0 {
opts.CacheTTL = time.Hour * 24
}
c := cache{dir: opts.CacheDir, ttl: opts.CacheTTL}
transport = c.RoundTripper(transport)
Copy link
Contributor Author

@samcoe samcoe Jun 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All requests will now hit the cache layer but if opts.EnableCache is not set the TTL will be 0 and it is a no-op.


if opts.Log != nil {
logger := &httpretty.Logger{
Expand Down Expand Up @@ -168,10 +166,6 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str
a := "application/vnd.github.merge-info-preview+json"
// Preview for visibility when RESTing repos into an org.
a += ", application/vnd.github.nebula-preview"
// Preview for Commit.statusCheckRollup for old GHES versions.
a += ", application/vnd.github.antiope-preview"
// Preview for // PullRequest.isDraft for old GHES versions.
a += ", application/vnd.github.shadow-cat-preview"
headers[accept] = a
}
return headerRoundTripper{host: host, headers: headers, rt: rt}
Expand Down
2 changes: 0 additions & 2 deletions internal/api/http_test.go
Expand Up @@ -133,8 +133,6 @@ func defaultHeaders() http.Header {
h := http.Header{}
a := "application/vnd.github.merge-info-preview+json"
a += ", application/vnd.github.nebula-preview"
a += ", application/vnd.github.antiope-preview"
a += ", application/vnd.github.shadow-cat-preview"
h.Set(contentType, jsonContentType)
h.Set(userAgent, "go-gh")
h.Set(authorization, fmt.Sprintf("token %s", "oauth_token"))
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/errors.go
Expand Up @@ -67,7 +67,7 @@ func (gr GQLError) Error() string {
}
errorMessages = append(errorMessages, msg)
}
return fmt.Sprintf("GQL: %s", strings.Join(errorMessages, ", "))
return fmt.Sprintf("GraphQL: %s", strings.Join(errorMessages, ", "))
}

// Match determines if the GQLError is about a specific type on a specific path.
Expand Down
3 changes: 3 additions & 0 deletions pkg/repository/repository_test.go
Expand Up @@ -80,6 +80,9 @@ func TestParse(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
oldDir := os.Getenv("GH_CONFIG_DIR")
os.Setenv("GH_CONFIG_DIR", "nonexistant")
defer os.Setenv("GH_CONFIG_DIR", oldDir)
if tt.hostOverride != "" {
old := os.Getenv("GH_HOST")
os.Setenv("GH_HOST", tt.hostOverride)
Expand Down