From 3caf6c40fa3a611046d9b2b54d67dcba95c14523 Mon Sep 17 00:00:00 2001 From: Mateusz Szostok Date: Tue, 21 Jun 2022 16:44:50 +0200 Subject: [PATCH] Add context for GraphQL and REST clients (#50) * Add context for GrahpQL client * Add context for REST client --- internal/api/gql_client.go | 30 ++++++--- internal/api/gql_client_test.go | 55 ++++++++++++++++ internal/api/rest_client.go | 17 +++-- internal/api/rest_client_test.go | 104 +++++++++++++++++++++++++++++++ pkg/api/client.go | 40 ++++++++---- 5 files changed, 223 insertions(+), 23 deletions(-) diff --git a/internal/api/gql_client.go b/internal/api/gql_client.go index d7723b6..2c381ab 100644 --- a/internal/api/gql_client.go +++ b/internal/api/gql_client.go @@ -9,6 +9,7 @@ import ( "net/http" "github.com/cli/go-gh/pkg/api" + graphql "github.com/cli/shurcooL-graphql" ) @@ -35,14 +36,14 @@ func NewGQLClient(host string, opts *api.ClientOptions) api.GQLClient { } } -// Do executes a single GraphQL query request and populates the response into the data argument. -func (c gqlClient) Do(query string, variables map[string]interface{}, response interface{}) error { +// DoWithContext executes a single GraphQL query request and populates the response into the data argument. +func (c gqlClient) DoWithContext(ctx context.Context, query string, variables map[string]interface{}, response interface{}) error { reqBody, err := json.Marshal(map[string]interface{}{"query": query, "variables": variables}) if err != nil { return err } - req, err := http.NewRequest("POST", c.host, bytes.NewBuffer(reqBody)) + req, err := http.NewRequestWithContext(ctx, "POST", c.host, bytes.NewBuffer(reqBody)) if err != nil { return err } @@ -80,18 +81,33 @@ func (c gqlClient) Do(query string, variables map[string]interface{}, response i return nil } -// Mutate executes a single GraphQL mutation request, +// Do wraps DoWithContext using context.Background. +func (c gqlClient) Do(query string, variables map[string]interface{}, response interface{}) error { + return c.DoWithContext(context.Background(), query, variables, response) +} + +// MutateWithContext executes a single GraphQL mutation request, // with a mutation derived from m, populating the response into it. // "m" should be a pointer to struct that corresponds to the GitHub GraphQL schema. +func (c gqlClient) MutateWithContext(ctx context.Context, name string, m interface{}, variables map[string]interface{}) error { + return c.client.MutateNamed(ctx, name, m, variables) +} + +// Mutate wraps MutateWithContext using context.Background. func (c gqlClient) Mutate(name string, m interface{}, variables map[string]interface{}) error { - return c.client.MutateNamed(context.Background(), name, m, variables) + return c.MutateWithContext(context.Background(), name, m, variables) } -// Query executes a single GraphQL query request, +// QueryWithContext executes a single GraphQL query request, // with a query derived from q, populating the response into it. // "q" should be a pointer to struct that corresponds to the GitHub GraphQL schema. +func (c gqlClient) QueryWithContext(ctx context.Context, name string, q interface{}, variables map[string]interface{}) error { + return c.client.QueryNamed(ctx, name, q, variables) +} + +// Query wraps QueryWithContext using context.Background. func (c gqlClient) Query(name string, q interface{}, variables map[string]interface{}) error { - return c.client.QueryNamed(context.Background(), name, q, variables) + return c.QueryWithContext(context.Background(), name, q, variables) } type gqlResponse struct { diff --git a/internal/api/gql_client_test.go b/internal/api/gql_client_test.go index 9cb2f58..c5c2eb0 100644 --- a/internal/api/gql_client_test.go +++ b/internal/api/gql_client_test.go @@ -1,7 +1,9 @@ package api import ( + "context" "testing" + "time" "github.com/stretchr/testify/assert" "gopkg.in/h2non/gock.v1" @@ -112,3 +114,56 @@ func TestGQLClientDo(t *testing.T) { }) } } + +func TestGQLClientDoWithContext(t *testing.T) { + tests := []struct { + name string + wantErrMsg string + getCtx func() context.Context + }{ + { + name: "http fail request canceled", + getCtx: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + // call 'cancel' to ensure that context is already canceled + cancel() + return ctx + }, + wantErrMsg: `Post "https://api.github.com/graphql": context canceled`, + }, + { + name: "http fail request timed out", + getCtx: func() context.Context { + // pass current time to ensure that deadline has already passed + ctx, cancel := context.WithDeadline(context.Background(), time.Now()) + cancel() + return ctx + }, + wantErrMsg: `Post "https://api.github.com/graphql": context deadline exceeded`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // given + t.Cleanup(gock.Off) + gock.New("https://api.github.com"). + Post("/graphql"). + BodyString(`{"query":"QUERY","variables":{"var":"test"}}`). + Reply(200). + JSON(`{}`) + + client := NewGQLClient("github.com", nil) + vars := map[string]interface{}{"var": "test"} + res := struct{ Viewer struct{ Login string } }{} + + // when + ctx := tt.getCtx() + gotErr := client.DoWithContext(ctx, "QUERY", vars, &res) + + // then + assert.True(t, gock.IsDone(), printPendingMocks(gock.Pending())) + assert.EqualError(t, gotErr, tt.wantErrMsg) + }) + } +} diff --git a/internal/api/rest_client.go b/internal/api/rest_client.go index 19d1607..8545577 100644 --- a/internal/api/rest_client.go +++ b/internal/api/rest_client.go @@ -1,6 +1,7 @@ package api import ( + "context" "encoding/json" "fmt" "io" @@ -23,9 +24,9 @@ func NewRESTClient(host string, opts *api.ClientOptions) api.RESTClient { } } -func (c restClient) Request(method string, path string, body io.Reader) (*http.Response, error) { +func (c restClient) RequestWithContext(ctx context.Context, method string, path string, body io.Reader) (*http.Response, error) { url := restURL(c.host, path) - req, err := http.NewRequest(method, url, body) + req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { return nil, err } @@ -47,9 +48,13 @@ func (c restClient) Request(method string, path string, body io.Reader) (*http.R return resp, err } -func (c restClient) Do(method string, path string, body io.Reader, response interface{}) error { +func (c restClient) Request(method string, path string, body io.Reader) (*http.Response, error) { + return c.RequestWithContext(context.Background(), method, path, body) +} + +func (c restClient) DoWithContext(ctx context.Context, method string, path string, body io.Reader, response interface{}) error { url := restURL(c.host, path) - req, err := http.NewRequest(method, url, body) + req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { return err } @@ -82,6 +87,10 @@ func (c restClient) Do(method string, path string, body io.Reader, response inte return nil } +func (c restClient) Do(method string, path string, body io.Reader, response interface{}) error { + return c.DoWithContext(context.Background(), method, path, body, response) +} + func (c restClient) Delete(path string, resp interface{}) error { return c.Do(http.MethodDelete, path, nil, resp) } diff --git a/internal/api/rest_client_test.go b/internal/api/rest_client_test.go index ea97aa5..fcf309e 100644 --- a/internal/api/rest_client_test.go +++ b/internal/api/rest_client_test.go @@ -2,10 +2,13 @@ package api import ( "bytes" + "context" "fmt" "io" + "net/http" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "gopkg.in/h2non/gock.v1" @@ -286,6 +289,107 @@ func TestRESTClientPut(t *testing.T) { assert.True(t, gock.IsDone(), printPendingMocks(gock.Pending())) } +func TestRESTClientDoWithContext(t *testing.T) { + tests := []struct { + name string + wantErrMsg string + getCtx func() context.Context + }{ + { + name: "http fail request canceled", + getCtx: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + // call 'cancel' to ensure that context is already canceled + cancel() + return ctx + }, + wantErrMsg: `Get "https://api.github.com/some/path": context canceled`, + }, + { + name: "http fail request timed out", + getCtx: func() context.Context { + // pass current time to ensure that deadline has already passed + ctx, cancel := context.WithDeadline(context.Background(), time.Now()) + cancel() + return ctx + }, + wantErrMsg: `Get "https://api.github.com/some/path": context deadline exceeded`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // given + t.Cleanup(gock.Off) + gock.New("https://api.github.com"). + Get("/some/path"). + Reply(204). + JSON(`{}`) + + client := NewRESTClient("github.com", nil) + res := struct{ Message string }{} + + // when + ctx := tt.getCtx() + gotErr := client.DoWithContext(ctx, http.MethodGet, "some/path", nil, &res) + + // then + assert.EqualError(t, gotErr, tt.wantErrMsg) + assert.True(t, gock.IsDone(), printPendingMocks(gock.Pending())) + }) + } +} + +func TestRESTClientRequestWithContext(t *testing.T) { + tests := []struct { + name string + wantErrMsg string + getCtx func() context.Context + }{ + { + name: "http fail request canceled", + getCtx: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + // call 'cancel' to ensure that context is already canceled + cancel() + return ctx + }, + wantErrMsg: `Get "https://api.github.com/some/path": context canceled`, + }, + { + name: "http fail request timed out", + getCtx: func() context.Context { + // pass current time to ensure that deadline has already passed + ctx, cancel := context.WithDeadline(context.Background(), time.Now()) + cancel() + return ctx + }, + wantErrMsg: `Get "https://api.github.com/some/path": context deadline exceeded`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // given + t.Cleanup(gock.Off) + gock.New("https://api.github.com"). + Get("/some/path"). + Reply(204). + JSON(`{}`) + + client := NewRESTClient("github.com", nil) + + // when + ctx := tt.getCtx() + _, gotErr := client.RequestWithContext(ctx, http.MethodGet, "some/path", nil) + + // then + assert.EqualError(t, gotErr, tt.wantErrMsg) + assert.True(t, gock.IsDone(), printPendingMocks(gock.Pending())) + }) + } +} + func printPendingMocks(mocks []gock.Mock) string { paths := []string{} for _, mock := range mocks { diff --git a/pkg/api/client.go b/pkg/api/client.go index a57feba..4a4eccc 100644 --- a/pkg/api/client.go +++ b/pkg/api/client.go @@ -2,12 +2,13 @@ package api import ( + "context" "io" "net/http" "time" ) -// Available options to configure API clients. +// ClientOptions holds available options to configure API clients. type ClientOptions struct { // AuthToken is the authorization token that will be used // to authenticate against API endpoints. @@ -59,16 +60,19 @@ type ClientOptions struct { // RESTClient is the interface that wraps methods for the different types of // API requests that are supported by the server. type RESTClient interface { - // Do issues a request with type specified by method to the + // Do wraps DoWithContext with context.Background. + Do(method string, path string, body io.Reader, response interface{}) error + + // DoWithContext issues a request with type specified by method to the // specified path with the specified body. // The response is populated into the response argument. - Do(method string, path string, body io.Reader, response interface{}) error + DoWithContext(ctx context.Context, method string, path string, body io.Reader, response interface{}) error // Delete issues a DELETE request to the specified path. // The response is populated into the response argument. Delete(path string, response interface{}) error - // GET issues a GET request to the specified path. + // Get issues a GET request to the specified path. // The response is populated into the response argument. Get(path string, response interface{}) error @@ -84,32 +88,44 @@ type RESTClient interface { // The response is populated into the response argument. Put(path string, body io.Reader, response interface{}) error - // Request issues a request with type specified by method to the + // Request wraps RequestWithContext with context.Background. + Request(method string, path string, body io.Reader) (*http.Response, error) + + // RequestWithContext issues a request with type specified by method to the // specified path with the specified body. // The response is returned rather than being populated // into a response argument. - Request(method string, path string, body io.Reader) (*http.Response, error) + RequestWithContext(ctx context.Context, method string, path string, body io.Reader) (*http.Response, error) } // GQLClient is the interface that wraps methods for the different types of // API requests that are supported by the server. type GQLClient interface { - // Do executes a GraphQL query request. - // The response is populated into the response argument. + // Do wraps DoWithContext using context.Background. Do(query string, variables map[string]interface{}, response interface{}) error - // Mutate executes a GraphQL mutation request. + // DoWithContext executes a GraphQL query request. + // The response is populated into the response argument. + DoWithContext(ctx context.Context, query string, variables map[string]interface{}, response interface{}) error + + // Mutate wraps MutateWithContext using context.Background. + Mutate(name string, mutation interface{}, variables map[string]interface{}) error + + // MutateWithContext executes a GraphQL mutation request. // The mutation string is derived from the mutation argument, and the // response is populated into it. // The mutation argument should be a pointer to struct that corresponds // to the GitHub GraphQL schema. // Provided input will be set as a variable named input. - Mutate(name string, mutation interface{}, variables map[string]interface{}) error + MutateWithContext(ctx context.Context, name string, mutation interface{}, variables map[string]interface{}) error - // Query executes a GraphQL query request, + // Query wraps QueryWithContext using context.Background. + Query(name string, query interface{}, variables map[string]interface{}) error + + // QueryWithContext executes a GraphQL query request, // The query string is derived from the query argument, and the // response is populated into it. // The query argument should be a pointer to struct that corresponds // to the GitHub GraphQL schema. - Query(name string, query interface{}, variables map[string]interface{}) error + QueryWithContext(ctx context.Context, name string, query interface{}, variables map[string]interface{}) error }