Skip to content

Commit

Permalink
client: set grpc-accept-encoding header with all registered compressors
Browse files Browse the repository at this point in the history
  • Loading branch information
jronak committed Oct 3, 2022
1 parent d83070e commit 511c171
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 3 deletions.
3 changes: 3 additions & 0 deletions encoding/encoding.go
Expand Up @@ -28,6 +28,8 @@ package encoding
import (
"io"
"strings"

"google.golang.org/grpc/internal/grpcutil"
)

// Identity specifies the optional encoding for uncompressed streams.
Expand Down Expand Up @@ -73,6 +75,7 @@ var registeredCompressor = make(map[string]Compressor)
// registered with the same name, the one registered last will take effect.
func RegisterCompressor(c Compressor) {
registeredCompressor[c.Name()] = c
grpcutil.RegisteredCompressorNames = append(grpcutil.RegisteredCompressorNames, c.Name())
}

// GetCompressor returns Compressor for the given compressor name.
Expand Down
8 changes: 6 additions & 2 deletions internal/envconfig/envconfig.go
Expand Up @@ -25,11 +25,15 @@ import (
)

const (
prefix = "GRPC_GO_"
txtErrIgnoreStr = prefix + "IGNORE_TXT_ERRORS"
prefix = "GRPC_GO_"
txtErrIgnoreStr = prefix + "IGNORE_TXT_ERRORS"
disableCompressorAdStr = prefix + "DISABLE_COMPRESSOR_ADVERTISEMENT"
)

var (
// TXTErrIgnore is set if TXT errors should be ignored ("GRPC_GO_IGNORE_TXT_ERRORS" is not "false").
TXTErrIgnore = !strings.EqualFold(os.Getenv(txtErrIgnoreStr), "false")
// DisableCompressorAd is set if registered compressor advertisement should
// be disabled ("GRPC_GO_DISABLE_COMPRESSOR_ADVERTISEMENT" is "true").
DisableCompressorAd = strings.EqualFold(os.Getenv(disableCompressorAdStr), "true")
)
47 changes: 47 additions & 0 deletions internal/grpcutil/compressor.go
@@ -0,0 +1,47 @@
/*
*
* Copyright 2022 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package grpcutil

import (
"strings"

"google.golang.org/grpc/internal/envconfig"
)

// RegisteredCompressorNames holds names of the registered compressors.
var RegisteredCompressorNames []string

// IsCompressorNameRegistered returns true when name is available in registry.
func IsCompressorNameRegistered(name string) bool {
for _, compressor := range RegisteredCompressorNames {
if compressor == name {
return true
}
}
return false
}

// RegisteredCompressors returns a string of registered compressor names
// separated by comma.
func RegisteredCompressors() string {
if envconfig.DisableCompressorAd {
return ""
}
return strings.Join(RegisteredCompressorNames, ",")
}
46 changes: 46 additions & 0 deletions internal/grpcutil/compressor_test.go
@@ -0,0 +1,46 @@
/*
*
* Copyright 2022 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package grpcutil

import (
"testing"

"google.golang.org/grpc/internal/envconfig"
)

func TestRegisteredCompressors(t *testing.T) {
defer func(c []string) { RegisteredCompressorNames = c }(RegisteredCompressorNames)
defer func(v bool) { envconfig.DisableCompressorAd = v }(envconfig.DisableCompressorAd)
RegisteredCompressorNames = []string{"gzip", "snappy"}
tests := []struct {
desc string
disableAd bool
want string
}{
{desc: "compressor_ad_disabled", disableAd: true, want: ""},
{desc: "compressor_ad_enabled", disableAd: false, want: "gzip,snappy"},
}
for _, tt := range tests {
envconfig.DisableCompressorAd = tt.disableAd
compressors := RegisteredCompressors()
if compressors != tt.want {
t.Fatalf("Unexpected compressors got:%s, want:%s", compressors, tt.want)
}
}
}
17 changes: 16 additions & 1 deletion internal/transport/http2_client.go
Expand Up @@ -109,6 +109,7 @@ type http2Client struct {
streamsQuotaAvailable chan struct{}
waitingStreams uint32
nextID uint32
registeredCompressors string

// Do not access controlBuf with mu held.
mu sync.Mutex // guard the following variables
Expand Down Expand Up @@ -299,6 +300,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
ctxDone: ctx.Done(), // Cache Done chan.
cancel: cancel,
userAgent: opts.UserAgent,
registeredCompressors: grpcutil.RegisteredCompressors(),
conn: conn,
remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(),
Expand Down Expand Up @@ -507,9 +509,22 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-previous-rpc-attempts", Value: strconv.Itoa(callHdr.PreviousAttempts)})
}

registeredCompressors := t.registeredCompressors
if callHdr.SendCompress != "" {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-accept-encoding", Value: callHdr.SendCompress})
// Include the outgoing compressor name when compressor is not registered
// via encoding.RegisterCompressor. This is possible when client uses
// WithCompressor dial option.
if !grpcutil.IsCompressorNameRegistered(callHdr.SendCompress) {
if registeredCompressors != "" {
registeredCompressors += ","
}
registeredCompressors += callHdr.SendCompress
}
}

if registeredCompressors != "" {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-accept-encoding", Value: registeredCompressors})
}
if dl, ok := ctx.Deadline(); ok {
// Send out timeout regardless its value. The server can detect timeout context by itself.
Expand Down
77 changes: 77 additions & 0 deletions test/end2end_test.go
Expand Up @@ -3249,6 +3249,7 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
delete(header, "date") // the Date header is also optional
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")
}
if !reflect.DeepEqual(header, testMetadata) {
t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
Expand Down Expand Up @@ -3288,6 +3289,7 @@ func testMetadataOrderUnaryRPC(t *testing.T, e env) {
delete(header, "date") // the Date header is also optional
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")
}

if !reflect.DeepEqual(header, newMetadata) {
Expand Down Expand Up @@ -3400,6 +3402,8 @@ func testSetAndSendHeaderUnaryRPC(t *testing.T, e env) {
}
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")

expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
Expand Down Expand Up @@ -3444,6 +3448,7 @@ func testMultipleSetHeaderUnaryRPC(t *testing.T, e env) {
}
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
Expand Down Expand Up @@ -3487,6 +3492,7 @@ func testMultipleSetHeaderUnaryRPCError(t *testing.T, e env) {
}
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
Expand Down Expand Up @@ -3527,6 +3533,7 @@ func testSetAndSendHeaderStreamingRPC(t *testing.T, e env) {
}
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
Expand Down Expand Up @@ -3590,6 +3597,7 @@ func testMultipleSetHeaderStreamingRPC(t *testing.T, e env) {
}
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
Expand Down Expand Up @@ -3650,6 +3658,7 @@ func testMultipleSetHeaderStreamingRPCError(t *testing.T, e env) {
}
delete(header, "user-agent")
delete(header, "content-type")
delete(header, "grpc-accept-encoding")
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
Expand Down Expand Up @@ -3981,6 +3990,7 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
delete(headerMD, "trailer") // ignore if present
delete(headerMD, "user-agent")
delete(headerMD, "content-type")
delete(headerMD, "grpc-accept-encoding")
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
}
Expand All @@ -3989,6 +3999,7 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
delete(headerMD, "trailer") // ignore if present
delete(headerMD, "user-agent")
delete(headerMD, "content-type")
delete(headerMD, "grpc-accept-encoding")
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
}
Expand Down Expand Up @@ -5431,6 +5442,72 @@ func (s) TestForceServerCodec(t *testing.T) {
}
}

// renameCompressor is a grpc.Compressor wrapper that allows customizing the
// Type() of another compressor.
type renameCompressor struct {
grpc.Compressor
name string
}

func (r *renameCompressor) Type() string { return r.name }

// renameDecompressor is a grpc.Decompressor wrapper that allows customizing the
// Type() of another Decompressor.
type renameDecompressor struct {
grpc.Decompressor
name string
}

func (r *renameDecompressor) Type() string { return r.name }

func (s) TestClientForwardsGrpcAcceptEncodingHeader(t *testing.T) {
wantGrpcAcceptEncodingCh := make(chan []string, 1)
defer close(wantGrpcAcceptEncodingCh)

compressor := renameCompressor{Compressor: grpc.NewGZIPCompressor(), name: "testgzip"}
decompressor := renameDecompressor{Decompressor: grpc.NewGZIPDecompressor(), name: "testgzip"}

ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Internal, "no metadata in context")
}
if got, want := md["grpc-accept-encoding"], <-wantGrpcAcceptEncodingCh; !reflect.DeepEqual(got, want) {
return nil, status.Errorf(codes.Internal, "got grpc-accept-encoding=%q; want [%q]", got, want)
}
return &testpb.Empty{}, nil
},
}
if err := ss.Start([]grpc.ServerOption{grpc.RPCDecompressor(&decompressor)}); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

wantGrpcAcceptEncodingCh <- []string{"gzip"}
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
}

wantGrpcAcceptEncodingCh <- []string{"gzip"}
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.UseCompressor("gzip")); err != nil {
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
}

// Use compressor directly which is not registered via
// encoding.RegisterCompressor.
if err := ss.StartClient(grpc.WithCompressor(&compressor)); err != nil {
t.Fatalf("Error starting client: %v", err)
}
wantGrpcAcceptEncodingCh <- []string{"gzip,testgzip"}
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
}
}

func (s) TestUnaryProxyDoesNotForwardMetadata(t *testing.T) {
const mdkey = "somedata"

Expand Down

0 comments on commit 511c171

Please sign in to comment.