diff --git a/go.mod b/go.mod index d291551691c..4e93b3be970 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/google/s2a-go v0.1.4 github.com/google/uuid v1.3.0 github.com/googleapis/enterprise-certificate-proxy v0.2.5 - github.com/googleapis/gax-go/v2 v2.11.0 + github.com/googleapis/gax-go/v2 v2.12.0 go.opencensus.io v0.24.0 golang.org/x/net v0.12.0 golang.org/x/oauth2 v0.9.0 diff --git a/go.sum b/go.sum index 52cbfc11b72..f1acacb3875 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.2.5 h1:UR4rDjcgpgEnqpIEvkiqTYKBCKLNmlge2eVjoZfySzM= github.com/googleapis/enterprise-certificate-proxy v0.2.5/go.mod h1:RxW0N9901Cko1VOCW3SXCpWP+mlIEkk2tP7jnHy9a3w= -github.com/googleapis/gax-go/v2 v2.11.0 h1:9V9PWXEsWnPpQhu/PeQIkS4eGzMlTLGgt80cUUI8Ki4= -github.com/googleapis/gax-go/v2 v2.11.0/go.mod h1:DxmR61SGKkGLa2xigwuZIQpkCI2S5iydzRfb3peWZJI= +github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas= +github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= diff --git a/internal/gensupport/send.go b/internal/gensupport/send.go index 693a1b1abaf..f39dd00d99f 100644 --- a/internal/gensupport/send.go +++ b/internal/gensupport/send.go @@ -15,6 +15,7 @@ import ( "github.com/google/uuid" "github.com/googleapis/gax-go/v2" + "github.com/googleapis/gax-go/v2/callctx" ) // Use this error type to return an error which allows introspection of both @@ -43,6 +44,16 @@ func (e wrappedCallErr) Is(target error) bool { // req.WithContext, then calls any functions returned by the hooks in // reverse order. func SendRequest(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { + // Add headers set in context metadata. + if ctx != nil { + headers := callctx.HeadersFromContext(ctx) + for k, vals := range headers { + for _, v := range vals { + req.Header.Add(k, v) + } + } + } + // Disallow Accept-Encoding because it interferes with the automatic gzip handling // done by the default http.Transport. See https://github.com/google/google-api-go-client/issues/219. if _, ok := req.Header["Accept-Encoding"]; ok { @@ -77,6 +88,16 @@ func send(ctx context.Context, client *http.Client, req *http.Request) (*http.Re // req.WithContext, then calls any functions returned by the hooks in // reverse order. func SendRequestWithRetry(ctx context.Context, client *http.Client, req *http.Request, retry *RetryConfig) (*http.Response, error) { + // Add headers set in context metadata. + if ctx != nil { + headers := callctx.HeadersFromContext(ctx) + for k, vals := range headers { + for _, v := range vals { + req.Header.Add(k, v) + } + } + } + // Disallow Accept-Encoding because it interferes with the automatic gzip handling // done by the default http.Transport. See https://github.com/google/google-api-go-client/issues/219. if _, ok := req.Header["Accept-Encoding"]; ok { diff --git a/internal/gensupport/send_test.go b/internal/gensupport/send_test.go index bdc074c81af..59534408393 100644 --- a/internal/gensupport/send_test.go +++ b/internal/gensupport/send_test.go @@ -7,8 +7,12 @@ package gensupport import ( "context" "errors" + "fmt" "net/http" "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/gax-go/v2/callctx" ) func TestSendRequest(t *testing.T) { @@ -31,6 +35,39 @@ func TestSendRequestWithRetry(t *testing.T) { } } +type headerRoundTripper struct { + wantHeader http.Header +} + +func (rt *headerRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + // Ignore x-goog headers sent by SendRequestWithRetry + r.Header.Del("X-Goog-Api-Client") + r.Header.Del("X-Goog-Gcs-Idempotency-Token") + if diff := cmp.Diff(r.Header, rt.wantHeader); diff != "" { + return nil, fmt.Errorf("headers don't match: %v", diff) + } + return &http.Response{StatusCode: 200}, nil +} + +// Ensure that headers set via the context are passed through to the request as expected. +func TestSendRequestHeader(t *testing.T) { + ctx := context.Background() + ctx = callctx.SetHeaders(ctx, "foo", "100", "bar", "200") + client := http.Client{ + Transport: &headerRoundTripper{ + wantHeader: map[string][]string{"Foo": {"100"}, "Bar": {"200"}}, + }, + } + req, _ := http.NewRequest("GET", "url", nil) + if _, err := SendRequest(ctx, &client, req); err != nil { + t.Errorf("SendRequest: %v", err) + } + req2, _ := http.NewRequest("GET", "url", nil) + if _, err := SendRequestWithRetry(ctx, &client, req2, nil); err != nil { + t.Errorf("SendRequest: %v", err) + } +} + type brokenRoundTripper struct{} func (t *brokenRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {