Skip to content

Commit

Permalink
client: set grpc-accept-encoding header with all registered compressors
Browse files Browse the repository at this point in the history
  • Loading branch information
jronak committed Aug 18, 2022
1 parent b695a7f commit 98a34eb
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 1 deletion.
3 changes: 3 additions & 0 deletions encoding/encoding.go
Expand Up @@ -28,6 +28,8 @@ package encoding
import (
"io"
"strings"

"google.golang.org/grpc/internal/grpcutil"
)

// Identity specifies the optional encoding for uncompressed streams.
Expand Down Expand Up @@ -73,6 +75,7 @@ var registeredCompressor = make(map[string]Compressor)
// registered with the same name, the one registered last will take effect.
func RegisterCompressor(c Compressor) {
registeredCompressor[c.Name()] = c
grpcutil.RegisteredCompressorNames = append(grpcutil.RegisteredCompressorNames, c.Name())
}

// GetCompressor returns Compressor for the given compressor name.
Expand Down
62 changes: 62 additions & 0 deletions internal/grpcutil/compress_test.go
@@ -0,0 +1,62 @@
/*
*
* Copyright 2022 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package grpcutil

import (
"testing"
)

func TestAdvertisedCompressors(t *testing.T) {
tests := []struct {
name string
compressors []string
additionalCompressor string
want string
}{
{
name: "no_registered_compressors",
},
{
name: "with_registered_compressors",
compressors: []string{"gzip", "snappy"},
want: "identity,gzip,snappy",
},
{
name: "with_additional_compressors",
additionalCompressor: "gzip",
want: "identity,gzip",
},
{
name: "with_registered_and_additional_compressors",
compressors: []string{"gzip", "snappy"},
additionalCompressor: "test",
want: "identity,gzip,snappy,test",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
RegisteredCompressorNames = test.compressors
t.Cleanup(func() { RegisteredCompressorNames = []string{} })

if got := AdvertisedCompressors(test.additionalCompressor); got != test.want {
t.Fatalf("got \"%s\", want \"%s\"", got, test.want)
}
})
}
}
56 changes: 56 additions & 0 deletions internal/grpcutil/compressor.go
@@ -0,0 +1,56 @@
/*
*
* Copyright 2022 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package grpcutil

import (
"strings"
)

// RegisteredCompressorNames holds names of the registered compressors.
var RegisteredCompressorNames []string

// IsCompressorNameRegistered returns true when name is available in registry.
func IsCompressorNameRegistered(name string) bool {
for _, compressor := range RegisteredCompressorNames {
if compressor == name {
return true
}
}
return false
}

// AdvertisedCompressors returns a string of registered compressor names
// and provided additional compressor separated by comma.
func AdvertisedCompressors(additionalCompressor string) string {
if len(RegisteredCompressorNames) == 0 && additionalCompressor == "" {
return ""
}

b := strings.Builder{}
b.WriteString("identity")
for _, compressor := range RegisteredCompressorNames {
b.WriteRune(',')
b.WriteString(compressor)
}
if additionalCompressor != "" {
b.WriteRune(',')
b.WriteString(additionalCompressor)
}
return b.String()
}
14 changes: 13 additions & 1 deletion internal/transport/http2_client.go
Expand Up @@ -109,6 +109,7 @@ type http2Client struct {
streamsQuotaAvailable chan struct{}
waitingStreams uint32
nextID uint32
advertisedCompressors string

// Do not access controlBuf with mu held.
mu sync.Mutex // guard the following variables
Expand Down Expand Up @@ -299,6 +300,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
ctxDone: ctx.Done(), // Cache Done chan.
cancel: cancel,
userAgent: opts.UserAgent,
advertisedCompressors: grpcutil.AdvertisedCompressors(""),
conn: conn,
remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(),
Expand Down Expand Up @@ -505,9 +507,19 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-previous-rpc-attempts", Value: strconv.Itoa(callHdr.PreviousAttempts)})
}

advertisedCompressors := t.advertisedCompressors
if callHdr.SendCompress != "" {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-accept-encoding", Value: callHdr.SendCompress})
// Include the outgoing compressor name when compressor is not registered
// via encoding.RegisterCompressor. This is possible when client uses
// WithCompressor dial option.
if !grpcutil.IsCompressorNameRegistered(callHdr.SendCompress) {
advertisedCompressors = grpcutil.AdvertisedCompressors(callHdr.SendCompress)
}
}

if advertisedCompressors != "" {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-accept-encoding", Value: advertisedCompressors})
}
if dl, ok := ctx.Deadline(); ok {
// Send out timeout regardless its value. The server can detect timeout context by itself.
Expand Down
77 changes: 77 additions & 0 deletions test/end2end_test.go
Expand Up @@ -3249,6 +3249,7 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
delete(header, "date") // the Date header is also optional
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")
}
if !reflect.DeepEqual(header, testMetadata) {
t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
Expand Down Expand Up @@ -3288,6 +3289,7 @@ func testMetadataOrderUnaryRPC(t *testing.T, e env) {
delete(header, "date") // the Date header is also optional
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")
}

if !reflect.DeepEqual(header, newMetadata) {
Expand Down Expand Up @@ -3400,6 +3402,8 @@ func testSetAndSendHeaderUnaryRPC(t *testing.T, e env) {
}
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")

expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
Expand Down Expand Up @@ -3444,6 +3448,7 @@ func testMultipleSetHeaderUnaryRPC(t *testing.T, e env) {
}
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
Expand Down Expand Up @@ -3487,6 +3492,7 @@ func testMultipleSetHeaderUnaryRPCError(t *testing.T, e env) {
}
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
Expand Down Expand Up @@ -3527,6 +3533,7 @@ func testSetAndSendHeaderStreamingRPC(t *testing.T, e env) {
}
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
Expand Down Expand Up @@ -3590,6 +3597,7 @@ func testMultipleSetHeaderStreamingRPC(t *testing.T, e env) {
}
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
Expand Down Expand Up @@ -3650,6 +3658,7 @@ func testMultipleSetHeaderStreamingRPCError(t *testing.T, e env) {
}
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
Expand Down Expand Up @@ -3981,6 +3990,7 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
delete(headerMD, "trailer") // ignore if present
delete(headerMD, "user-agent")
delete(headerMD, "content-type")
delete(headerMD, "grpc-accept-encoding")
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
}
Expand All @@ -3989,6 +3999,7 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
delete(headerMD, "trailer") // ignore if present
delete(headerMD, "user-agent")
delete(headerMD, "content-type")
delete(headerMD, "grpc-accept-encoding")
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
}
Expand Down Expand Up @@ -5431,6 +5442,72 @@ func (s) TestForceServerCodec(t *testing.T) {
}
}

// renameCompressor is a grpc.Compressor wrapper that allows customizing the
// Type() of another compressor.
type renameCompressor struct {
grpc.Compressor
name string
}

func (r *renameCompressor) Type() string { return r.name }

// renameDecompressor is a grpc.Decompressor wrapper that allows customizing the
// Type() of another Decompressor.
type renameDecompressor struct {
grpc.Decompressor
name string
}

func (r *renameDecompressor) Type() string { return r.name }

func (s) TestClientForwardsGrpcAcceptEncodingHeader(t *testing.T) {
wantGrpcAcceptEncodingCh := make(chan []string, 1)
defer close(wantGrpcAcceptEncodingCh)

compressor := renameCompressor{Compressor: grpc.NewGZIPCompressor(), name: "testgzip"}
decompressor := renameDecompressor{Decompressor: grpc.NewGZIPDecompressor(), name: "testgzip"}

ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Internal, "no metadata in context")
}
if got, want := md["grpc-accept-encoding"], <-wantGrpcAcceptEncodingCh; !reflect.DeepEqual(got, want) {
return nil, status.Errorf(codes.Internal, "got grpc-accept-encoding=%q; want [%q]", got, want)
}
return &testpb.Empty{}, nil
},
}
if err := ss.Start([]grpc.ServerOption{grpc.RPCDecompressor(&decompressor)}); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

wantGrpcAcceptEncodingCh <- []string{"identity,gzip"}
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
}

wantGrpcAcceptEncodingCh <- []string{"identity,gzip"}
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.UseCompressor("gzip")); err != nil {
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
}

// Use compressor directly which is not registered via
// encoding.RegisterCompressor.
if err := ss.StartClient(grpc.WithCompressor(&compressor)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
wantGrpcAcceptEncodingCh <- []string{"identity,gzip,testgzip"}
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
}
}

func (s) TestUnaryProxyDoesNotForwardMetadata(t *testing.T) {
const mdkey = "somedata"

Expand Down

0 comments on commit 98a34eb

Please sign in to comment.