From b19a201d2c6edd70a2dc74503a25d568f670b50f Mon Sep 17 00:00:00 2001 From: Ronak Jain Date: Thu, 27 Oct 2022 00:35:25 +0530 Subject: [PATCH 01/12] server: expose API to set send compressor --- internal/grpcutil/compressor.go | 39 +++++++++ internal/grpcutil/compressor_test.go | 42 +++++++++ internal/transport/handler_server_test.go | 9 ++ internal/transport/http2_server.go | 5 ++ internal/transport/transport.go | 21 ++++- server.go | 44 +++++++++- stream.go | 7 ++ test/end2end_test.go | 101 ++++++++++++++++++++++ 8 files changed, 263 insertions(+), 5 deletions(-) diff --git a/internal/grpcutil/compressor.go b/internal/grpcutil/compressor.go index 9f409096798..bab64208119 100644 --- a/internal/grpcutil/compressor.go +++ b/internal/grpcutil/compressor.go @@ -19,6 +19,7 @@ package grpcutil import ( + "fmt" "strings" "google.golang.org/grpc/internal/envconfig" @@ -45,3 +46,41 @@ func RegisteredCompressors() string { } return strings.Join(RegisteredCompressorNames, ",") } + +// ValidateSendCompressor returns an error when given compressor name cannot be +// handled by the server or the client based on the advertised compressors. +func ValidateSendCompressor(name, clientAdvertisedCompressors string) error { + if name == "identity" { + return nil + } + + if !IsCompressorNameRegistered(name) { + return fmt.Errorf("compressor not registered: %s", name) + } + + if !compressorExists(name, clientAdvertisedCompressors) { + return fmt.Errorf("client does not support compressor: %s", name) + } + + return nil +} + +// compressorExists returns true when the given name exists in the comma +// separated compressor list. +func compressorExists(name, compressors string) bool { + var ( + i = 0 + length = len(compressors) + ) + for j := 0; j <= length; j++ { + if j < length && compressors[j] != ',' { + continue + } + + if compressors[i:j] == name { + return true + } + i = j + 1 + } + return false +} diff --git a/internal/grpcutil/compressor_test.go b/internal/grpcutil/compressor_test.go index 0d639422a9a..6c080976ed8 100644 --- a/internal/grpcutil/compressor_test.go +++ b/internal/grpcutil/compressor_test.go @@ -19,6 +19,7 @@ package grpcutil import ( + "fmt" "testing" "google.golang.org/grpc/internal/envconfig" @@ -44,3 +45,44 @@ func TestRegisteredCompressors(t *testing.T) { } } } + +func TestValidateSendCompressors(t *testing.T) { + defer func(c []string) { RegisteredCompressorNames = c }(RegisteredCompressorNames) + RegisteredCompressorNames = []string{"gzip", "snappy"} + tests := []struct { + desc string + name string + advertisedCompressors string + wantErr error + }{ + { + desc: "success_when_identity_compressor", + name: "identity", + advertisedCompressors: "gzip,snappy", + }, + { + desc: "success_when_compressor_exists", + name: "snappy", + advertisedCompressors: "testcomp,gzip,snappy", + }, + { + desc: "failure_when_compressor_not_registered", + name: "testcomp", + advertisedCompressors: "testcomp,gzip,snappy", + wantErr: fmt.Errorf("compressor not registered: testcomp"), + }, + { + desc: "failure_when_compressor_not_advertised", + name: "gzip", + advertisedCompressors: "testcomp,snappy", + wantErr: fmt.Errorf("client does not support compressor: gzip"), + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + if err := ValidateSendCompressor(tt.name, tt.advertisedCompressors); fmt.Sprint(err) != fmt.Sprint(tt.wantErr) { + t.Fatalf("Unexpected validation got:%v, want:%v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/transport/handler_server_test.go b/internal/transport/handler_server_test.go index b08dcaaf3c4..25420b647b1 100644 --- a/internal/transport/handler_server_test.go +++ b/internal/transport/handler_server_test.go @@ -270,6 +270,10 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) { if err != nil { t.Error(err) } + err = s.SetSendCompress("gzip") + if err != nil { + t.Error(err) + } md := metadata.Pairs("custom-header", "Another custom header value") err = s.SendHeader(md) @@ -286,6 +290,10 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) { if err == nil { t.Error("expected second SendHeader call to fail") } + err = s.SetSendCompress("snappy") + if err == nil { + t.Error("expected second SetSendCompress call to fail") + } st.bodyw.Close() // no body st.ht.WriteStatus(s, status.New(codes.OK, "")) @@ -299,6 +307,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) { "Content-Type": {"application/grpc"}, "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, "Custom-Header": {"Custom header value", "Another custom header value"}, + "Grpc-Encoding": {"gzip"}, } wantTrailer := http.Header{ "Grpc-Status": {"0"}, diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 3dd15647bc8..df89d511fe9 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -27,6 +27,7 @@ import ( "net" "net/http" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -456,6 +457,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( return false } + if encodings := mdata["grpc-accept-encoding"]; len(encodings) != 0 { + s.clientAdvertisedCompressors = strings.Join(encodings, ",") + } + if !isGRPC || headerError { t.controlBuf.put(&cleanupStream{ streamID: streamID, diff --git a/internal/transport/transport.go b/internal/transport/transport.go index e21587b5321..ade0ff7049b 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -253,6 +253,9 @@ type Stream struct { fc *inFlow wq *writeQuota + // Holds compressor names passed in grpc-accept-encoding metadata from the + // client. This is empty for the client side Stream. + clientAdvertisedCompressors string // Callback to state application's intentions to read data. This // is used to adjust flow control, if needed. requestRead func(int) @@ -341,8 +344,24 @@ func (s *Stream) RecvCompress() string { } // SetSendCompress sets the compression algorithm to the stream. -func (s *Stream) SetSendCompress(str string) { +func (s *Stream) SetSendCompress(str string) error { + if s.isHeaderSent() || s.getState() == streamDone { + return status.Error(codes.Internal, "transport: set send compressor called after headers sent or stream done") + } + s.sendCompress = str + return nil +} + +// SendCompress returns the send compressor name. +func (s *Stream) SendCompress() string { + return s.sendCompress +} + +// ClientAdvertisedCompressors returns the advertised compressor names by the +// client. +func (s *Stream) ClientAdvertisedCompressors() string { + return s.clientAdvertisedCompressors } // Done returns a channel which is closed when it receives the final status diff --git a/server.go b/server.go index f4dde72b41f..13b346876ab 100644 --- a/server.go +++ b/server.go @@ -45,6 +45,7 @@ import ( "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" @@ -1267,6 +1268,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. var comp, decomp encoding.Compressor var cp Compressor var dc Decompressor + var sendCompressorName string // If dc is set and matches the stream's compression, use it. Otherwise, try // to find a matching registered compressor for decomp. @@ -1287,12 +1289,14 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. if s.opts.cp != nil { cp = s.opts.cp - stream.SetSendCompress(cp.Type()) + _ = stream.SetSendCompress(cp.Type()) + sendCompressorName = cp.Type() } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { // Legacy compressor not specified; attempt to respond with same encoding. comp = encoding.GetCompressor(rc) if comp != nil { - stream.SetSendCompress(rc) + _ = stream.SetSendCompress(rc) + sendCompressorName = comp.Name() } } @@ -1379,6 +1383,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } opts := &transport.Options{Last: true} + if stream.SendCompress() != sendCompressorName { + comp = encoding.GetCompressor(stream.SendCompress()) + } if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil { if err == io.EOF { // The entire stream is done (for unary RPC only). @@ -1606,12 +1613,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. if s.opts.cp != nil { ss.cp = s.opts.cp - stream.SetSendCompress(s.opts.cp.Type()) + _ = stream.SetSendCompress(s.opts.cp.Type()) + ss.sendCompressorName = s.opts.cp.Type() } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { // Legacy compressor not specified; attempt to respond with same encoding. ss.comp = encoding.GetCompressor(rc) if ss.comp != nil { - stream.SetSendCompress(rc) + _ = stream.SetSendCompress(rc) + ss.sendCompressorName = rc } } @@ -1944,6 +1953,33 @@ func SendHeader(ctx context.Context, md metadata.MD) error { return nil } +// SetSendCompressor sets the compressor that will be used when sending +// RPC payload back to the client. It may be called at most once, and must not +// be called after any event that causes headers to be sent (see SetHeader for +// a complete list). Provided compressor is used when below conditions are met: +// +// - compressor is registered via encoding.RegisterCompressor +// - compressor name exists in the client advertised compressor names sent in +// grpc-accept-encoding metadata. +// +// The context provided must be the context passed to the server's handler. +// +// The error returned is compatible with the status package. However, the +// status code will often not match the RPC status as seen by the client +// application, and therefore, should not be relied upon for this purpose. +func SetSendCompressor(ctx context.Context, name string) error { + stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream) + if !ok || stream == nil { + return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + } + + if err := grpcutil.ValidateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil { + return status.Errorf(codes.Internal, "grpc: failed to set send compressor %v", err) + } + + return stream.SetSendCompress(name) +} + // SetTrailer sets the trailer metadata that will be sent when an RPC returns. // When called more than once, all the provided metadata will be merged. // diff --git a/stream.go b/stream.go index b10ab1ab632..01c1066e2e4 100644 --- a/stream.go +++ b/stream.go @@ -1481,6 +1481,8 @@ type serverStream struct { comp encoding.Compressor decomp encoding.Compressor + sendCompressorName string + maxReceiveMessageSize int maxSendMessageSize int trInfo *traceInfo @@ -1573,6 +1575,11 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { } }() + if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName { + ss.comp = encoding.GetCompressor(sendCompressorsName) + ss.sendCompressorName = sendCompressorsName + } + // load hdr, payload, data hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp) if err != nil { diff --git a/test/end2end_test.go b/test/end2end_test.go index 438b43ca82f..02fd7550eb4 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -59,6 +59,7 @@ import ( "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/stubserver" @@ -5509,6 +5510,106 @@ func (s) TestClientForwardsGrpcAcceptEncodingHeader(t *testing.T) { } } +func (s) TestServerHandlerSetSendCompression(t *testing.T) { + for _, tt := range []struct { + desc string + reqCompressor string // request compressor name + resCompressor string // response compressor name + disableClientAdvertise bool // don't advertise compressor if true + sendMetadataEarly bool // server handler sends metadata immediately + wantErr error + }{ + { + desc: "success_with_gzip_req_res", + reqCompressor: "gzip", + resCompressor: "gzip", + }, + { + desc: "success_with_gzip_req_identity_res", + reqCompressor: "gzip", + resCompressor: "identity", + }, + { + desc: "fail_on_unregistered_res_compressor", + resCompressor: "snappy2", + wantErr: fmt.Errorf("rpc error: code = Internal desc = grpc: failed to set send compressor compressor not registered: snappy2"), + }, + { + desc: "fail_on_unadvertised_res_compressor", + resCompressor: "gzip", + disableClientAdvertise: true, + wantErr: fmt.Errorf("rpc error: code = Internal desc = grpc: failed to set send compressor client does not support compressor: gzip"), + }, + { + desc: "fail_when_set_compressor_called_after_headers_sent", + sendMetadataEarly: true, + resCompressor: "gzip", + wantErr: fmt.Errorf("rpc error: code = Internal desc = transport: set send compressor called after headers sent or stream done"), + }, + } { + t.Run(tt.desc, func(t *testing.T) { + if tt.disableClientAdvertise { + defer func(b bool) { envconfig.AdvertiseCompressors = b }(envconfig.AdvertiseCompressors) + envconfig.AdvertiseCompressors = false + } + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + if tt.sendMetadataEarly { + grpc.SendHeader(ctx, metadata.MD{}) + } + if tt.resCompressor != "" { + if err := grpc.SetSendCompressor(ctx, tt.resCompressor); err != nil { + return nil, err + } + } + return &testpb.Empty{}, nil + }, + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + if tt.sendMetadataEarly { + grpc.SendHeader(stream.Context(), metadata.MD{}) + } + if _, err := stream.Recv(); err != nil { + return err + } + + if tt.resCompressor != "" { + if err := grpc.SetSendCompressor(stream.Context(), tt.resCompressor); err != nil { + return err + } + } + + return stream.Send(&testpb.StreamingOutputCallResponse{}) + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.UseCompressor(tt.reqCompressor)) + if fmt.Sprint(err) != fmt.Sprint(tt.wantErr) { + t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, tt.wantErr) + } + + s, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err) + } + + if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil { + t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err) + } + + if _, err := s.Recv(); fmt.Sprint(err) != fmt.Sprint(tt.wantErr) { + t.Fatalf("Unexpected full duplex recv error, got: %v, want: %v", err, tt.wantErr) + } + }) + } +} + func (s) TestUnaryProxyDoesNotForwardMetadata(t *testing.T) { const mdkey = "somedata" From 6d03a483cea9f25dba32b66be634caedb965131e Mon Sep 17 00:00:00 2001 From: Ronak Jain Date: Mon, 21 Nov 2022 22:59:08 +0530 Subject: [PATCH 02/12] address comments --- internal/grpcutil/compressor.go | 39 --------------------- internal/grpcutil/compressor_test.go | 42 ---------------------- internal/transport/http2_server.go | 9 +++-- internal/transport/transport.go | 8 ++--- server.go | 52 ++++++++++++++++++++++------ stream.go | 2 ++ test/end2end_test.go | 4 +-- 7 files changed, 54 insertions(+), 102 deletions(-) diff --git a/internal/grpcutil/compressor.go b/internal/grpcutil/compressor.go index bab64208119..9f409096798 100644 --- a/internal/grpcutil/compressor.go +++ b/internal/grpcutil/compressor.go @@ -19,7 +19,6 @@ package grpcutil import ( - "fmt" "strings" "google.golang.org/grpc/internal/envconfig" @@ -46,41 +45,3 @@ func RegisteredCompressors() string { } return strings.Join(RegisteredCompressorNames, ",") } - -// ValidateSendCompressor returns an error when given compressor name cannot be -// handled by the server or the client based on the advertised compressors. -func ValidateSendCompressor(name, clientAdvertisedCompressors string) error { - if name == "identity" { - return nil - } - - if !IsCompressorNameRegistered(name) { - return fmt.Errorf("compressor not registered: %s", name) - } - - if !compressorExists(name, clientAdvertisedCompressors) { - return fmt.Errorf("client does not support compressor: %s", name) - } - - return nil -} - -// compressorExists returns true when the given name exists in the comma -// separated compressor list. -func compressorExists(name, compressors string) bool { - var ( - i = 0 - length = len(compressors) - ) - for j := 0; j <= length; j++ { - if j < length && compressors[j] != ',' { - continue - } - - if compressors[i:j] == name { - return true - } - i = j + 1 - } - return false -} diff --git a/internal/grpcutil/compressor_test.go b/internal/grpcutil/compressor_test.go index 6c080976ed8..0d639422a9a 100644 --- a/internal/grpcutil/compressor_test.go +++ b/internal/grpcutil/compressor_test.go @@ -19,7 +19,6 @@ package grpcutil import ( - "fmt" "testing" "google.golang.org/grpc/internal/envconfig" @@ -45,44 +44,3 @@ func TestRegisteredCompressors(t *testing.T) { } } } - -func TestValidateSendCompressors(t *testing.T) { - defer func(c []string) { RegisteredCompressorNames = c }(RegisteredCompressorNames) - RegisteredCompressorNames = []string{"gzip", "snappy"} - tests := []struct { - desc string - name string - advertisedCompressors string - wantErr error - }{ - { - desc: "success_when_identity_compressor", - name: "identity", - advertisedCompressors: "gzip,snappy", - }, - { - desc: "success_when_compressor_exists", - name: "snappy", - advertisedCompressors: "testcomp,gzip,snappy", - }, - { - desc: "failure_when_compressor_not_registered", - name: "testcomp", - advertisedCompressors: "testcomp,gzip,snappy", - wantErr: fmt.Errorf("compressor not registered: testcomp"), - }, - { - desc: "failure_when_compressor_not_advertised", - name: "gzip", - advertisedCompressors: "testcomp,snappy", - wantErr: fmt.Errorf("client does not support compressor: gzip"), - }, - } - for _, tt := range tests { - t.Run(tt.desc, func(t *testing.T) { - if err := ValidateSendCompressor(tt.name, tt.advertisedCompressors); fmt.Sprint(err) != fmt.Sprint(tt.wantErr) { - t.Fatalf("Unexpected validation got:%v, want:%v", err, tt.wantErr) - } - }) - } -} diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index df89d511fe9..507a589e49f 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -27,7 +27,6 @@ import ( "net" "net/http" "strconv" - "strings" "sync" "sync/atomic" "time" @@ -404,6 +403,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( mdata[hf.Name] = append(mdata[hf.Name], hf.Value) s.contentSubtype = contentSubtype isGRPC = true + + case "grpc-accept-encoding": + s.clientAdvertisedCompressors = hf.Value + mdata[hf.Name] = append(mdata[hf.Name], hf.Value) case "grpc-encoding": s.recvCompress = hf.Value case ":method": @@ -457,10 +460,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( return false } - if encodings := mdata["grpc-accept-encoding"]; len(encodings) != 0 { - s.clientAdvertisedCompressors = strings.Join(encodings, ",") - } - if !isGRPC || headerError { t.controlBuf.put(&cleanupStream{ streamID: streamID, diff --git a/internal/transport/transport.go b/internal/transport/transport.go index ade0ff7049b..f45cfa899b3 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -344,12 +344,12 @@ func (s *Stream) RecvCompress() string { } // SetSendCompress sets the compression algorithm to the stream. -func (s *Stream) SetSendCompress(str string) error { +func (s *Stream) SetSendCompress(name string) error { if s.isHeaderSent() || s.getState() == streamDone { return status.Error(codes.Internal, "transport: set send compressor called after headers sent or stream done") } - s.sendCompress = str + s.sendCompress = name return nil } @@ -358,8 +358,8 @@ func (s *Stream) SendCompress() string { return s.sendCompress } -// ClientAdvertisedCompressors returns the advertised compressor names by the -// client. +// ClientAdvertisedCompressors returns the compressor names advertised by the +// client via :grpc-accept-encoding header. func (s *Stream) ClientAdvertisedCompressors() string { return s.clientAdvertisedCompressors } diff --git a/server.go b/server.go index 13b346876ab..07ea638ffac 100644 --- a/server.go +++ b/server.go @@ -1289,17 +1289,20 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. if s.opts.cp != nil { cp = s.opts.cp - _ = stream.SetSendCompress(cp.Type()) sendCompressorName = cp.Type() } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { // Legacy compressor not specified; attempt to respond with same encoding. comp = encoding.GetCompressor(rc) if comp != nil { - _ = stream.SetSendCompress(rc) sendCompressorName = comp.Name() } } + if sendCompressorName != "" { + // Safe to ignore returned error value as we are guaranteed to succeed here + _ = stream.SetSendCompress(sendCompressorName) + } + var payInfo *payloadInfo if len(shs) != 0 || len(binlogs) != 0 { payInfo = &payloadInfo{} @@ -1383,6 +1386,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } opts := &transport.Options{Last: true} + // Server handler could have set new compressor by calling SetSendCompressor. + // In case it is set, we need to use it for compressing outbound message. if stream.SendCompress() != sendCompressorName { comp = encoding.GetCompressor(stream.SendCompress()) } @@ -1613,17 +1618,20 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. if s.opts.cp != nil { ss.cp = s.opts.cp - _ = stream.SetSendCompress(s.opts.cp.Type()) ss.sendCompressorName = s.opts.cp.Type() } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { // Legacy compressor not specified; attempt to respond with same encoding. ss.comp = encoding.GetCompressor(rc) if ss.comp != nil { - _ = stream.SetSendCompress(rc) ss.sendCompressorName = rc } } + if ss.sendCompressorName != "" { + // Safe to ignore returned error value as we are guaranteed to succeed here + _ = stream.SetSendCompress(ss.sendCompressorName) + } + ss.ctx = newContextWithRPCInfo(ss.ctx, false, ss.codec, ss.cp, ss.comp) if trInfo != nil { @@ -1953,27 +1961,32 @@ func SendHeader(ctx context.Context, md metadata.MD) error { return nil } -// SetSendCompressor sets the compressor that will be used when sending -// RPC payload back to the client. It may be called at most once, and must not -// be called after any event that causes headers to be sent (see SetHeader for -// a complete list). Provided compressor is used when below conditions are met: +// SetSendCompressor sets a compressor for outbound messages. +// It must not be called after any event that causes headers to be sent +// (see SetHeader for a complete list). Provided compressor is used when below +// conditions are met: // // - compressor is registered via encoding.RegisterCompressor // - compressor name exists in the client advertised compressor names sent in -// grpc-accept-encoding metadata. +// :grpc-accept-encoding header. // // The context provided must be the context passed to the server's handler. // // The error returned is compatible with the status package. However, the // status code will often not match the RPC status as seen by the client // application, and therefore, should not be relied upon for this purpose. +// +// # Experimental +// +// Notice: This type is EXPERIMENTAL and may be changed or removed in a +// later release. func SetSendCompressor(ctx context.Context, name string) error { stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream) if !ok || stream == nil { return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } - if err := grpcutil.ValidateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil { + if err := validateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil { return status.Errorf(codes.Internal, "grpc: failed to set send compressor %v", err) } @@ -2014,3 +2027,22 @@ type channelzServer struct { func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric { return c.s.channelzMetric() } + +// validateSendCompressor returns an error when given compressor name cannot be +// handled by the server or the client based on the advertised compressors. +func validateSendCompressor(name, clientCompressors string) error { + if name == "identity" { + return nil + } + + if !grpcutil.IsCompressorNameRegistered(name) { + return fmt.Errorf("compressor not registered %s", name) + } + + for _, clientCompressor := range strings.Split(clientCompressors, ",") { + if clientCompressor == name { + return nil // found match + } + } + return fmt.Errorf("client does not support compressor %s", name) +} diff --git a/stream.go b/stream.go index 01c1066e2e4..d261f30ba07 100644 --- a/stream.go +++ b/stream.go @@ -1575,6 +1575,8 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { } }() + // Server handler could have set new compressor by calling SetSendCompressor. + // In case it is set, we need to use it for compressing outbound message. if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName { ss.comp = encoding.GetCompressor(sendCompressorsName) ss.sendCompressorName = sendCompressorsName diff --git a/test/end2end_test.go b/test/end2end_test.go index 02fd7550eb4..30b7a14aefe 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5532,13 +5532,13 @@ func (s) TestServerHandlerSetSendCompression(t *testing.T) { { desc: "fail_on_unregistered_res_compressor", resCompressor: "snappy2", - wantErr: fmt.Errorf("rpc error: code = Internal desc = grpc: failed to set send compressor compressor not registered: snappy2"), + wantErr: fmt.Errorf("rpc error: code = Internal desc = grpc: failed to set send compressor compressor not registered snappy2"), }, { desc: "fail_on_unadvertised_res_compressor", resCompressor: "gzip", disableClientAdvertise: true, - wantErr: fmt.Errorf("rpc error: code = Internal desc = grpc: failed to set send compressor client does not support compressor: gzip"), + wantErr: fmt.Errorf("rpc error: code = Internal desc = grpc: failed to set send compressor client does not support compressor gzip"), }, { desc: "fail_when_set_compressor_called_after_headers_sent", From 3dfdd22ad0bbd981e66836765baeca1dd7658f44 Mon Sep 17 00:00:00 2001 From: Ronak Jain Date: Fri, 9 Dec 2022 01:16:47 +0530 Subject: [PATCH 03/12] Split tests into multiple functions Address nits across naming and test assertions Update method doc on setSendCompressor --- internal/transport/handler_server_test.go | 41 +++-- server.go | 13 +- test/end2end_test.go | 173 +++++++++++++--------- 3 files changed, 125 insertions(+), 102 deletions(-) diff --git a/internal/transport/handler_server_test.go b/internal/transport/handler_server_test.go index 25420b647b1..8bd700d2e62 100644 --- a/internal/transport/handler_server_test.go +++ b/internal/transport/handler_server_test.go @@ -262,37 +262,34 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) { t.Errorf("stream method = %q; want %q", s.method, want) } - err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")) - if err != nil { - t.Error(err) + if err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")); err != nil { + t.Fatal(err) } - err = s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")) - if err != nil { - t.Error(err) + + if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil { + t.Fatal(err) } - err = s.SetSendCompress("gzip") - if err != nil { - t.Error(err) + + if err := s.SetSendCompress("gzip"); err != nil { + t.Fatal(err) } md := metadata.Pairs("custom-header", "Another custom header value") - err = s.SendHeader(md) - delete(md, "custom-header") - if err != nil { - t.Error(err) + if err := s.SendHeader(md); err != nil { + t.Fatal(err) } + delete(md, "custom-header") - err = s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")) - if err == nil { - t.Error("expected SetHeader call after SendHeader to fail") + if err := s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")); err == nil { + t.Fatal("expected SetHeader call after SendHeader to fail") } - err = s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")) - if err == nil { - t.Error("expected second SendHeader call to fail") + + if err := s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")); err == nil { + t.Fatal("expected second SendHeader call to fail") } - err = s.SetSendCompress("snappy") - if err == nil { - t.Error("expected second SetSendCompress call to fail") + + if err := s.SetSendCompress("snappy"); err == nil { + t.Fatal("expected second SetSendCompress call to fail") } st.bodyw.Close() // no body diff --git a/server.go b/server.go index 07ea638ffac..3124b432004 100644 --- a/server.go +++ b/server.go @@ -1975,7 +1975,8 @@ func SendHeader(ctx context.Context, md metadata.MD) error { // The error returned is compatible with the status package. However, the // status code will often not match the RPC status as seen by the client // application, and therefore, should not be relied upon for this purpose. -// +// It is not safe to call SetSendCompressor concurrently with SendHeader and +// SendMsg. // # Experimental // // Notice: This type is EXPERIMENTAL and may be changed or removed in a @@ -1987,7 +1988,7 @@ func SetSendCompressor(ctx context.Context, name string) error { } if err := validateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil { - return status.Errorf(codes.Internal, "grpc: failed to set send compressor %v", err) + return status.Errorf(codes.Internal, "grpc: unable to set send compressor: %v", err) } return stream.SetSendCompress(name) @@ -2036,13 +2037,13 @@ func validateSendCompressor(name, clientCompressors string) error { } if !grpcutil.IsCompressorNameRegistered(name) { - return fmt.Errorf("compressor not registered %s", name) + return fmt.Errorf("compressor not registered %q", name) } - for _, clientCompressor := range strings.Split(clientCompressors, ",") { - if clientCompressor == name { + for _, c := range strings.Split(clientCompressors, ",") { + if c == name { return nil // found match } } - return fmt.Errorf("client does not support compressor %s", name) + return fmt.Errorf("client does not support compressor %q", name) } diff --git a/test/end2end_test.go b/test/end2end_test.go index 30b7a14aefe..c16d3e9800a 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5510,103 +5510,128 @@ func (s) TestClientForwardsGrpcAcceptEncodingHeader(t *testing.T) { } } -func (s) TestServerHandlerSetSendCompression(t *testing.T) { +func (s) TestServerSetSendCompressor(t *testing.T) { for _, tt := range []struct { - desc string - reqCompressor string // request compressor name - resCompressor string // response compressor name - disableClientAdvertise bool // don't advertise compressor if true - sendMetadataEarly bool // server handler sends metadata immediately - wantErr error + desc string + resCompressor string + wantErr error }{ { - desc: "success_with_gzip_req_res", - reqCompressor: "gzip", + desc: "gzip_response_compressor", resCompressor: "gzip", }, { - desc: "success_with_gzip_req_identity_res", - reqCompressor: "gzip", + desc: "identity_response_compressor", resCompressor: "identity", }, { - desc: "fail_on_unregistered_res_compressor", + desc: "unregistered_snappy2_response_compressor", resCompressor: "snappy2", - wantErr: fmt.Errorf("rpc error: code = Internal desc = grpc: failed to set send compressor compressor not registered snappy2"), - }, - { - desc: "fail_on_unadvertised_res_compressor", - resCompressor: "gzip", - disableClientAdvertise: true, - wantErr: fmt.Errorf("rpc error: code = Internal desc = grpc: failed to set send compressor client does not support compressor gzip"), - }, - { - desc: "fail_when_set_compressor_called_after_headers_sent", - sendMetadataEarly: true, - resCompressor: "gzip", - wantErr: fmt.Errorf("rpc error: code = Internal desc = transport: set send compressor called after headers sent or stream done"), + wantErr: fmt.Errorf("rpc error: code = Internal desc = grpc: unable to set send compressor: compressor not registered \"snappy2\""), }, } { t.Run(tt.desc, func(t *testing.T) { - if tt.disableClientAdvertise { - defer func(b bool) { envconfig.AdvertiseCompressors = b }(envconfig.AdvertiseCompressors) - envconfig.AdvertiseCompressors = false - } - ss := &stubserver.StubServer{ - EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - if tt.sendMetadataEarly { - grpc.SendHeader(ctx, metadata.MD{}) - } - if tt.resCompressor != "" { - if err := grpc.SetSendCompressor(ctx, tt.resCompressor); err != nil { - return nil, err - } - } - return &testpb.Empty{}, nil - }, - FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { - if tt.sendMetadataEarly { - grpc.SendHeader(stream.Context(), metadata.MD{}) - } - if _, err := stream.Recv(); err != nil { - return err - } + testSetSendCompressorHandler(t, tt.resCompressor, tt.wantErr) + }) + } +} - if tt.resCompressor != "" { - if err := grpc.SetSendCompressor(stream.Context(), tt.resCompressor); err != nil { - return err - } - } +func (s) TestUnadvertisedSendCompressorFailure(t *testing.T) { + // disable client compressor advertisement. + defer func(b bool) { envconfig.AdvertiseCompressors = b }(envconfig.AdvertiseCompressors) + envconfig.AdvertiseCompressors = false - return stream.Send(&testpb.StreamingOutputCallResponse{}) - }, + wantErr := fmt.Errorf("rpc error: code = Internal desc = grpc: unable to set send compressor: client does not support compressor \"gzip\"") + testSetSendCompressorHandler(t, "gzip", wantErr) +} + +func testSetSendCompressorHandler(t *testing.T, resCompressor string, wantErr error) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil { + return nil, err } - if err := ss.Start(nil); err != nil { - t.Fatalf("Error starting endpoint server: %v", err) + return &testpb.Empty{}, nil + }, + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + if _, err := stream.Recv(); err != nil { + return err } - defer ss.Stop() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.UseCompressor(tt.reqCompressor)) - if fmt.Sprint(err) != fmt.Sprint(tt.wantErr) { - t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, tt.wantErr) + if err := grpc.SetSendCompressor(stream.Context(), resCompressor); err != nil { + return err } - s, err := ss.Client.FullDuplexCall(ctx) - if err != nil { - t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err) - } + return stream.Send(&testpb.StreamingOutputCallResponse{}) + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() - if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil { - t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err) - } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.UseCompressor("gzip")) + if fmt.Sprint(err) != fmt.Sprint(wantErr) { + t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr) + } + + s, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err) + } + + if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil { + t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err) + } - if _, err := s.Recv(); fmt.Sprint(err) != fmt.Sprint(tt.wantErr) { - t.Fatalf("Unexpected full duplex recv error, got: %v, want: %v", err, tt.wantErr) + if _, err := s.Recv(); fmt.Sprint(err) != fmt.Sprint(wantErr) { + t.Fatalf("Unexpected full duplex recv error, got: %v, want: %v", err, wantErr) + } +} + +func (s) TestSendCompressFailureAfterHeaderSend(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + grpc.SendHeader(ctx, metadata.MD{}) + err := grpc.SetSendCompressor(ctx, "gzip") + if err == nil { + t.Fatalf("Wanted set send compressor error") } - }) + return nil, err + }, + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + grpc.SendHeader(stream.Context(), metadata.MD{}) + err := grpc.SetSendCompressor(stream.Context(), "gzip") + if err == nil { + t.Fatalf("Wanted set send compressor error") + } + return err + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + wantErr := fmt.Errorf("rpc error: code = Internal desc = transport: set send compressor called after headers sent or stream done") + _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}) + if fmt.Sprint(err) != fmt.Sprint(wantErr) { + t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr) + } + + s, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err) + } + + if _, err := s.Recv(); fmt.Sprint(err) != fmt.Sprint(wantErr) { + t.Fatalf("Unexpected full duplex recv error, got: %v, want: %v", err, wantErr) } } From 805e90e39ce693ad374a84142349fed2b9381043 Mon Sep 17 00:00:00 2001 From: Ronak Jain Date: Fri, 9 Dec 2022 01:25:46 +0530 Subject: [PATCH 04/12] Fix lint errors due to t.Fatal usage Vet complains when t.Fatal is used in non-test goroutine ``` Error: internal/transport/handler_server_test.go:299:21: call to (*T).Fatal from a non-test goroutine ``` --- internal/transport/handler_server_test.go | 14 +++++++------- test/end2end_test.go | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/transport/handler_server_test.go b/internal/transport/handler_server_test.go index 8bd700d2e62..728d2752107 100644 --- a/internal/transport/handler_server_test.go +++ b/internal/transport/handler_server_test.go @@ -263,33 +263,33 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) { } if err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")); err != nil { - t.Fatal(err) + t.Error(err) } if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil { - t.Fatal(err) + t.Error(err) } if err := s.SetSendCompress("gzip"); err != nil { - t.Fatal(err) + t.Error(err) } md := metadata.Pairs("custom-header", "Another custom header value") if err := s.SendHeader(md); err != nil { - t.Fatal(err) + t.Error(err) } delete(md, "custom-header") if err := s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")); err == nil { - t.Fatal("expected SetHeader call after SendHeader to fail") + t.Error("expected SetHeader call after SendHeader to fail") } if err := s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")); err == nil { - t.Fatal("expected second SendHeader call to fail") + t.Error("expected second SendHeader call to fail") } if err := s.SetSendCompress("snappy"); err == nil { - t.Fatal("expected second SetSendCompress call to fail") + t.Error("expected second SetSendCompress call to fail") } st.bodyw.Close() // no body diff --git a/test/end2end_test.go b/test/end2end_test.go index c16d3e9800a..70c57af71fd 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5598,7 +5598,7 @@ func (s) TestSendCompressFailureAfterHeaderSend(t *testing.T) { grpc.SendHeader(ctx, metadata.MD{}) err := grpc.SetSendCompressor(ctx, "gzip") if err == nil { - t.Fatalf("Wanted set send compressor error") + t.Error("Wanted set send compressor error") } return nil, err }, @@ -5606,7 +5606,7 @@ func (s) TestSendCompressFailureAfterHeaderSend(t *testing.T) { grpc.SendHeader(stream.Context(), metadata.MD{}) err := grpc.SetSendCompressor(stream.Context(), "gzip") if err == nil { - t.Fatalf("Wanted set send compressor error") + t.Error("Wanted set send compressor error") } return err }, From 52bc0e94dc511ebd0c5d073098be0ebf7f8bd05c Mon Sep 17 00:00:00 2001 From: Ronak Jain Date: Fri, 9 Dec 2022 01:53:13 +0530 Subject: [PATCH 05/12] Nit: rename test name and update comment --- internal/transport/transport.go | 2 +- test/end2end_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/transport/transport.go b/internal/transport/transport.go index f45cfa899b3..bb48b6d40c7 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -254,7 +254,7 @@ type Stream struct { wq *writeQuota // Holds compressor names passed in grpc-accept-encoding metadata from the - // client. This is empty for the client side Stream. + // client. This is empty for the client side stream. clientAdvertisedCompressors string // Callback to state application's intentions to read data. This // is used to adjust flow control, if needed. diff --git a/test/end2end_test.go b/test/end2end_test.go index 70c57af71fd..16905250f0a 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5510,7 +5510,7 @@ func (s) TestClientForwardsGrpcAcceptEncodingHeader(t *testing.T) { } } -func (s) TestServerSetSendCompressor(t *testing.T) { +func (s) TestServerHandlerSetSendCompressor(t *testing.T) { for _, tt := range []struct { desc string resCompressor string @@ -5592,7 +5592,7 @@ func testSetSendCompressorHandler(t *testing.T, resCompressor string, wantErr er } } -func (s) TestSendCompressFailureAfterHeaderSend(t *testing.T) { +func (s) TestSendCompressorFailureAfterHeaderSend(t *testing.T) { ss := &stubserver.StubServer{ EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { grpc.SendHeader(ctx, metadata.MD{}) From 1de5f58fda2646e1c3449dae914d3306e0a594e2 Mon Sep 17 00:00:00 2001 From: Ronak Jain Date: Fri, 9 Dec 2022 14:26:04 +0530 Subject: [PATCH 06/12] Refactor tests --- server.go | 1 + test/end2end_test.go | 161 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 134 insertions(+), 28 deletions(-) diff --git a/server.go b/server.go index 3124b432004..0afde513220 100644 --- a/server.go +++ b/server.go @@ -1977,6 +1977,7 @@ func SendHeader(ctx context.Context, md metadata.MD) error { // application, and therefore, should not be relied upon for this purpose. // It is not safe to call SetSendCompressor concurrently with SendHeader and // SendMsg. +// // # Experimental // // Notice: This type is EXPERIMENTAL and may be changed or removed in a diff --git a/test/end2end_test.go b/test/end2end_test.go index 16905250f0a..c65c1bdc757 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5510,7 +5510,7 @@ func (s) TestClientForwardsGrpcAcceptEncodingHeader(t *testing.T) { } } -func (s) TestServerHandlerSetSendCompressor(t *testing.T) { +func (s) TestSetSendCompressorSuccess(t *testing.T) { for _, tt := range []struct { desc string resCompressor string @@ -5524,28 +5524,108 @@ func (s) TestServerHandlerSetSendCompressor(t *testing.T) { desc: "identity_response_compressor", resCompressor: "identity", }, - { - desc: "unregistered_snappy2_response_compressor", - resCompressor: "snappy2", - wantErr: fmt.Errorf("rpc error: code = Internal desc = grpc: unable to set send compressor: compressor not registered \"snappy2\""), - }, } { t.Run(tt.desc, func(t *testing.T) { - testSetSendCompressorHandler(t, tt.resCompressor, tt.wantErr) + t.Run("unary", func(t *testing.T) { + testUnarySetSendCompressorSuccess(t, tt.resCompressor) + }) + + t.Run("stream", func(t *testing.T) { + testStreamSetSendCompressorSuccess(t, tt.resCompressor) + }) }) } } -func (s) TestUnadvertisedSendCompressorFailure(t *testing.T) { - // disable client compressor advertisement. +func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil { + return nil, err + } + return &testpb.Empty{}, nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("Unexpected unary call error, got: %v, want: nil", err) + } +} + +func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string) { + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + if _, err := stream.Recv(); err != nil { + return err + } + + if err := grpc.SetSendCompressor(stream.Context(), resCompressor); err != nil { + return err + } + + return stream.Send(&testpb.StreamingOutputCallResponse{}) + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v, want: nil", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + s, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err) + } + + if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil { + t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err) + } + + if _, err := s.Recv(); err != nil { + t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err) + } +} + +func (s) TestUnregisteredSetSendCompressorFailure(t *testing.T) { + resCompressor := "snappy2" + wantErr := status.Error(codes.Internal, "grpc: unable to set send compressor: compressor not registered \"snappy2\"") + + t.Run("unary", func(t *testing.T) { + testUnarySetSendCompressorFailure(t, resCompressor, wantErr) + }) + + t.Run("stream", func(t *testing.T) { + testStreamSetSendCompressorFailure(t, resCompressor, wantErr) + }) +} + +func (s) TestUnadvertisedSetSendCompressorFailure(t *testing.T) { + // Disable client compressor advertisement. defer func(b bool) { envconfig.AdvertiseCompressors = b }(envconfig.AdvertiseCompressors) envconfig.AdvertiseCompressors = false - wantErr := fmt.Errorf("rpc error: code = Internal desc = grpc: unable to set send compressor: client does not support compressor \"gzip\"") - testSetSendCompressorHandler(t, "gzip", wantErr) + resCompressor := "gzip" + wantErr := status.Error(codes.Internal, "grpc: unable to set send compressor: client does not support compressor \"gzip\"") + + t.Run("unary", func(t *testing.T) { + testUnarySetSendCompressorFailure(t, resCompressor, wantErr) + }) + + t.Run("stream", func(t *testing.T) { + testStreamSetSendCompressorFailure(t, resCompressor, wantErr) + }) } -func testSetSendCompressorHandler(t *testing.T, resCompressor string, wantErr error) { +func testUnarySetSendCompressorFailure(t *testing.T, resCompressor string, wantErr error) { ss := &stubserver.StubServer{ EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil { @@ -5553,6 +5633,22 @@ func testSetSendCompressorHandler(t *testing.T, resCompressor string, wantErr er } return &testpb.Empty{}, nil }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !equalError(err, wantErr) { + t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr) + } +} + +func testStreamSetSendCompressorFailure(t *testing.T, resCompressor string, wantErr error) { + ss := &stubserver.StubServer{ FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { if _, err := stream.Recv(); err != nil { return err @@ -5566,18 +5662,13 @@ func testSetSendCompressorHandler(t *testing.T, resCompressor string, wantErr er }, } if err := ss.Start(nil); err != nil { - t.Fatalf("Error starting endpoint server: %v", err) + t.Fatalf("Error starting endpoint server: %v, want: nil", err) } defer ss.Stop() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.UseCompressor("gzip")) - if fmt.Sprint(err) != fmt.Sprint(wantErr) { - t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr) - } - s, err := ss.Client.FullDuplexCall(ctx) if err != nil { t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err) @@ -5587,14 +5678,15 @@ func testSetSendCompressorHandler(t *testing.T, resCompressor string, wantErr er t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err) } - if _, err := s.Recv(); fmt.Sprint(err) != fmt.Sprint(wantErr) { - t.Fatalf("Unexpected full duplex recv error, got: %v, want: %v", err, wantErr) + if _, err := s.Recv(); !equalError(err, wantErr) { + t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err) } } -func (s) TestSendCompressorFailureAfterHeaderSend(t *testing.T) { +func (s) TestUnarySetSendCompressorAfterHeaderSendFailure(t *testing.T) { ss := &stubserver.StubServer{ EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + // Send headers early and then set send compressor. grpc.SendHeader(ctx, metadata.MD{}) err := grpc.SetSendCompressor(ctx, "gzip") if err == nil { @@ -5602,7 +5694,25 @@ func (s) TestSendCompressorFailureAfterHeaderSend(t *testing.T) { } return nil, err }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + wantErr := status.Error(codes.Internal, "transport: set send compressor called after headers sent or stream done") + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !equalError(err, wantErr) { + t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr) + } +} + +func (s) TestStreamSetSendCompressorAfterHeaderSendFailure(t *testing.T) { + ss := &stubserver.StubServer{ FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + // Send headers early and then set send compressor. grpc.SendHeader(stream.Context(), metadata.MD{}) err := grpc.SetSendCompressor(stream.Context(), "gzip") if err == nil { @@ -5619,18 +5729,13 @@ func (s) TestSendCompressorFailureAfterHeaderSend(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - wantErr := fmt.Errorf("rpc error: code = Internal desc = transport: set send compressor called after headers sent or stream done") - _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}) - if fmt.Sprint(err) != fmt.Sprint(wantErr) { - t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr) - } - + wantErr := status.Error(codes.Internal, "transport: set send compressor called after headers sent or stream done") s, err := ss.Client.FullDuplexCall(ctx) if err != nil { t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err) } - if _, err := s.Recv(); fmt.Sprint(err) != fmt.Sprint(wantErr) { + if _, err := s.Recv(); !equalError(err, wantErr) { t.Fatalf("Unexpected full duplex recv error, got: %v, want: %v", err, wantErr) } } From e7fc4fc545d93b0aa707a523728045f09e55dc13 Mon Sep 17 00:00:00 2001 From: Ronak Jain Date: Thu, 12 Jan 2023 00:23:34 +0530 Subject: [PATCH 07/12] address comments --- internal/transport/transport.go | 4 +-- server.go | 47 ++++++++++++++++++++++++--------- test/end2end_test.go | 38 +++++++++++++++++++++++--- 3 files changed, 70 insertions(+), 19 deletions(-) diff --git a/internal/transport/transport.go b/internal/transport/transport.go index bb48b6d40c7..91e59db7242 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -346,7 +346,7 @@ func (s *Stream) RecvCompress() string { // SetSendCompress sets the compression algorithm to the stream. func (s *Stream) SetSendCompress(name string) error { if s.isHeaderSent() || s.getState() == streamDone { - return status.Error(codes.Internal, "transport: set send compressor called after headers sent or stream done") + return errors.New("transport: set send compressor called after headers sent or stream done") } s.sendCompress = name @@ -359,7 +359,7 @@ func (s *Stream) SendCompress() string { } // ClientAdvertisedCompressors returns the compressor names advertised by the -// client via :grpc-accept-encoding header. +// client via grpc-accept-encoding header. func (s *Stream) ClientAdvertisedCompressors() string { return s.clientAdvertisedCompressors } diff --git a/server.go b/server.go index 0afde513220..463bceccde1 100644 --- a/server.go +++ b/server.go @@ -1299,8 +1299,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } if sendCompressorName != "" { - // Safe to ignore returned error value as we are guaranteed to succeed here - _ = stream.SetSendCompress(sendCompressorName) + if err := stream.SetSendCompress(sendCompressorName); err != nil { + return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err) + } } var payInfo *payloadInfo @@ -1628,8 +1629,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp } if ss.sendCompressorName != "" { - // Safe to ignore returned error value as we are guaranteed to succeed here - _ = stream.SetSendCompress(ss.sendCompressorName) + if err := stream.SetSendCompress(ss.sendCompressorName); err != nil { + return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err) + } } ss.ctx = newContextWithRPCInfo(ss.ctx, false, ss.codec, ss.cp, ss.comp) @@ -1963,38 +1965,57 @@ func SendHeader(ctx context.Context, md metadata.MD) error { // SetSendCompressor sets a compressor for outbound messages. // It must not be called after any event that causes headers to be sent -// (see SetHeader for a complete list). Provided compressor is used when below +// (see _ServerStream_.SetHeader for a complete list). Provided compressor is used when below // conditions are met: // // - compressor is registered via encoding.RegisterCompressor // - compressor name exists in the client advertised compressor names sent in -// :grpc-accept-encoding header. +// grpc-accept-encoding header. Use _ServerStream_.ClientAdvertisedCompressors +// to get client advertised compressor names. // // The context provided must be the context passed to the server's handler. +// It must be noted that compressor name "identity" disables the outbound compression. +// By default, server messages will be sent using the same compressor with which +// request messages were sent. // -// The error returned is compatible with the status package. However, the -// status code will often not match the RPC status as seen by the client -// application, and therefore, should not be relied upon for this purpose. // It is not safe to call SetSendCompressor concurrently with SendHeader and // SendMsg. // // # Experimental // -// Notice: This type is EXPERIMENTAL and may be changed or removed in a +// Notice: This _function_ is EXPERIMENTAL and may be changed or removed in a // later release. func SetSendCompressor(ctx context.Context, name string) error { stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream) if !ok || stream == nil { - return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + return fmt.Errorf("failed to fetch the stream from the given context") } if err := validateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil { - return status.Errorf(codes.Internal, "grpc: unable to set send compressor: %v", err) + return fmt.Errorf("unable to set send compressor: %w", err) } return stream.SetSendCompress(name) } +// ClientAdvertisedCompressors returns compressor names advertised by the client +// via grpc-accept-encoding header. +// +// The context provided must be the context passed to the server's handler. +// +// # Experimental +// +// Notice: This _function_ is EXPERIMENTAL and may be changed or removed in a +// later release. +func ClientAdvertisedCompressors(ctx context.Context) ([]string, error) { + stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream) + if !ok || stream == nil { + return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx) + } + + return strings.Split(stream.ClientAdvertisedCompressors(), ","), nil +} + // SetTrailer sets the trailer metadata that will be sent when an RPC returns. // When called more than once, all the provided metadata will be merged. // @@ -2033,7 +2054,7 @@ func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric { // validateSendCompressor returns an error when given compressor name cannot be // handled by the server or the client based on the advertised compressors. func validateSendCompressor(name, clientCompressors string) error { - if name == "identity" { + if name == encoding.Identity { return nil } diff --git a/test/end2end_test.go b/test/end2end_test.go index c65c1bdc757..a6e8a8d8962 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5597,7 +5597,7 @@ func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string) { func (s) TestUnregisteredSetSendCompressorFailure(t *testing.T) { resCompressor := "snappy2" - wantErr := status.Error(codes.Internal, "grpc: unable to set send compressor: compressor not registered \"snappy2\"") + wantErr := status.Error(codes.Unknown, "unable to set send compressor: compressor not registered \"snappy2\"") t.Run("unary", func(t *testing.T) { testUnarySetSendCompressorFailure(t, resCompressor, wantErr) @@ -5614,7 +5614,7 @@ func (s) TestUnadvertisedSetSendCompressorFailure(t *testing.T) { envconfig.AdvertiseCompressors = false resCompressor := "gzip" - wantErr := status.Error(codes.Internal, "grpc: unable to set send compressor: client does not support compressor \"gzip\"") + wantErr := status.Error(codes.Unknown, "unable to set send compressor: client does not support compressor \"gzip\"") t.Run("unary", func(t *testing.T) { testUnarySetSendCompressorFailure(t, resCompressor, wantErr) @@ -5703,7 +5703,7 @@ func (s) TestUnarySetSendCompressorAfterHeaderSendFailure(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - wantErr := status.Error(codes.Internal, "transport: set send compressor called after headers sent or stream done") + wantErr := status.Error(codes.Unknown, "transport: set send compressor called after headers sent or stream done") if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !equalError(err, wantErr) { t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr) } @@ -5729,7 +5729,7 @@ func (s) TestStreamSetSendCompressorAfterHeaderSendFailure(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - wantErr := status.Error(codes.Internal, "transport: set send compressor called after headers sent or stream done") + wantErr := status.Error(codes.Unknown, "transport: set send compressor called after headers sent or stream done") s, err := ss.Client.FullDuplexCall(ctx) if err != nil { t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err) @@ -5740,6 +5740,36 @@ func (s) TestStreamSetSendCompressorAfterHeaderSendFailure(t *testing.T) { } } +func (s) TestClientAdvertisedCompressors(t *testing.T) { + expectedCompressors := []string{"gzip"} + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + compressors, err := grpc.ClientAdvertisedCompressors(ctx) + if err != nil { + return nil, err + } + + if !reflect.DeepEqual(compressors, expectedCompressors) { + t.Errorf("unexpected client compressors got: %v, want: %v", compressors, expectedCompressors) + } + + return &testpb.Empty{}, nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v, want: nil", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}) + if err != nil { + t.Fatalf("Unexpected unary call error, got: %v, want: nil", err) + } +} + func (s) TestUnaryProxyDoesNotForwardMetadata(t *testing.T) { const mdkey = "somedata" From de7681775ba3555627f84023cd3fdb2af2fcb9ba Mon Sep 17 00:00:00 2001 From: Ronak Jain Date: Thu, 19 Jan 2023 20:31:34 +0530 Subject: [PATCH 08/12] Rename non-internal ClientAdvertisedCompressors to ClientSupportedCompressors Fix documentation nits --- internal/transport/http2_server.go | 6 +++++- server.go | 25 +++++++++++++------------ test/end2end_test.go | 25 +++++++++++++++++++------ 3 files changed, 37 insertions(+), 19 deletions(-) diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 507a589e49f..3cfdc135cd4 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -405,8 +405,12 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( isGRPC = true case "grpc-accept-encoding": - s.clientAdvertisedCompressors = hf.Value mdata[hf.Name] = append(mdata[hf.Name], hf.Value) + compressors := hf.Value + if s.clientAdvertisedCompressors != "" { + compressors = s.clientAdvertisedCompressors + "," + compressors + } + s.clientAdvertisedCompressors = compressors case "grpc-encoding": s.recvCompress = hf.Value case ":method": diff --git a/server.go b/server.go index 463bceccde1..c08f1bc87a6 100644 --- a/server.go +++ b/server.go @@ -1965,25 +1965,26 @@ func SendHeader(ctx context.Context, md metadata.MD) error { // SetSendCompressor sets a compressor for outbound messages. // It must not be called after any event that causes headers to be sent -// (see _ServerStream_.SetHeader for a complete list). Provided compressor is used when below -// conditions are met: +// (see ServerStream.SetHeader for the complete list). Provided compressor is +// used when below conditions are met: // // - compressor is registered via encoding.RegisterCompressor -// - compressor name exists in the client advertised compressor names sent in -// grpc-accept-encoding header. Use _ServerStream_.ClientAdvertisedCompressors -// to get client advertised compressor names. +// - compressor name must exist in the client advertised compressor names +// sent in grpc-accept-encoding header. Use ClientSupportedCompressors to +// get client supported compressor names. // // The context provided must be the context passed to the server's handler. -// It must be noted that compressor name "identity" disables the outbound compression. -// By default, server messages will be sent using the same compressor with which -// request messages were sent. +// It must be noted that compressor name encoding.Identity disables the +// outbound compression. +// By default, server messages will be sent using the same compressor with +// which request messages were sent. // // It is not safe to call SetSendCompressor concurrently with SendHeader and // SendMsg. // // # Experimental // -// Notice: This _function_ is EXPERIMENTAL and may be changed or removed in a +// Notice: This function_ is EXPERIMENTAL and may be changed or removed in a // later release. func SetSendCompressor(ctx context.Context, name string) error { stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream) @@ -1998,16 +1999,16 @@ func SetSendCompressor(ctx context.Context, name string) error { return stream.SetSendCompress(name) } -// ClientAdvertisedCompressors returns compressor names advertised by the client +// ClientSupportedCompressors returns compressor names advertised by the client // via grpc-accept-encoding header. // // The context provided must be the context passed to the server's handler. // // # Experimental // -// Notice: This _function_ is EXPERIMENTAL and may be changed or removed in a +// Notice: This function is EXPERIMENTAL and may be changed or removed in a // later release. -func ClientAdvertisedCompressors(ctx context.Context) ([]string, error) { +func ClientSupportedCompressors(ctx context.Context) ([]string, error) { stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream) if !ok || stream == nil { return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx) diff --git a/test/end2end_test.go b/test/end2end_test.go index a6e8a8d8962..63be1f50409 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5740,17 +5740,30 @@ func (s) TestStreamSetSendCompressorAfterHeaderSendFailure(t *testing.T) { } } -func (s) TestClientAdvertisedCompressors(t *testing.T) { - expectedCompressors := []string{"gzip"} +func (s) TestClientSupportedCompressors(t *testing.T) { + t.Run("No repeated grpc-accept-encoding header", func(t *testing.T) { + testClientSupportedCompressors(t, context.Background(), []string{"gzip"}) + }) + + t.Run("Repeated grpc-accept-encoding header", func(t *testing.T) { + ctx := metadata.AppendToOutgoingContext(context.Background(), + "grpc-accept-encoding", "test-compressor-1", + "grpc-accept-encoding", "test-compressor-2", + ) + testClientSupportedCompressors(t, ctx, []string{"gzip", "test-compressor-1", "test-compressor-2"}) + }) +} + +func testClientSupportedCompressors(t *testing.T, ctx context.Context, want []string) { ss := &stubserver.StubServer{ EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - compressors, err := grpc.ClientAdvertisedCompressors(ctx) + got, err := grpc.ClientSupportedCompressors(ctx) if err != nil { return nil, err } - if !reflect.DeepEqual(compressors, expectedCompressors) { - t.Errorf("unexpected client compressors got: %v, want: %v", compressors, expectedCompressors) + if !reflect.DeepEqual(got, want) { + t.Errorf("unexpected client compressors got: %v, want: %v", got, want) } return &testpb.Empty{}, nil @@ -5761,7 +5774,7 @@ func (s) TestClientAdvertisedCompressors(t *testing.T) { } defer ss.Stop() - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + ctx, cancel := context.WithTimeout(ctx, defaultTestTimeout) defer cancel() _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}) From 3e07c8f631030d5166d90bf6674d6cb68ea5ed6a Mon Sep 17 00:00:00 2001 From: Ronak Jain Date: Thu, 19 Jan 2023 20:45:13 +0530 Subject: [PATCH 09/12] Ignore empty grpc-accept-encoding header --- internal/transport/http2_server.go | 4 ++ test/end2end_test.go | 81 ++++++++++++++++++------------ 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 3cfdc135cd4..25d94e52bbc 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -405,6 +405,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( isGRPC = true case "grpc-accept-encoding": + if hf.Value == "" { + continue + } + mdata[hf.Name] = append(mdata[hf.Name], hf.Value) compressors := hf.Value if s.clientAdvertisedCompressors != "" { diff --git a/test/end2end_test.go b/test/end2end_test.go index 63be1f50409..194e0f7d735 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5741,45 +5741,60 @@ func (s) TestStreamSetSendCompressorAfterHeaderSendFailure(t *testing.T) { } func (s) TestClientSupportedCompressors(t *testing.T) { - t.Run("No repeated grpc-accept-encoding header", func(t *testing.T) { - testClientSupportedCompressors(t, context.Background(), []string{"gzip"}) - }) + for _, tt := range []struct { + desc string + ctx context.Context + want []string + }{ + { + desc: "No additional grpc-accept-encoding header", + ctx: context.Background(), + want: []string{"gzip"}, + }, + { + desc: "With additional grpc-accept-encoding header", + ctx: metadata.AppendToOutgoingContext(context.Background(), + "grpc-accept-encoding", "test-compressor-1", + "grpc-accept-encoding", "test-compressor-2", + ), + want: []string{"gzip", "test-compressor-1", "test-compressor-2"}, + }, + { + desc: "With additional empty grpc-accept-encoding header", + ctx: metadata.AppendToOutgoingContext(context.Background(), + "grpc-accept-encoding", "", + ), + want: []string{"gzip"}, + }, + } { + t.Run(tt.desc, func(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + got, err := grpc.ClientSupportedCompressors(ctx) + if err != nil { + return nil, err + } - t.Run("Repeated grpc-accept-encoding header", func(t *testing.T) { - ctx := metadata.AppendToOutgoingContext(context.Background(), - "grpc-accept-encoding", "test-compressor-1", - "grpc-accept-encoding", "test-compressor-2", - ) - testClientSupportedCompressors(t, ctx, []string{"gzip", "test-compressor-1", "test-compressor-2"}) - }) -} + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("unexpected client compressors got: %v, want: %v", got, tt.want) + } -func testClientSupportedCompressors(t *testing.T, ctx context.Context, want []string) { - ss := &stubserver.StubServer{ - EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - got, err := grpc.ClientSupportedCompressors(ctx) - if err != nil { - return nil, err + return &testpb.Empty{}, nil + }, } - - if !reflect.DeepEqual(got, want) { - t.Errorf("unexpected client compressors got: %v, want: %v", got, want) + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v, want: nil", err) } + defer ss.Stop() - return &testpb.Empty{}, nil - }, - } - if err := ss.Start(nil); err != nil { - t.Fatalf("Error starting endpoint server: %v, want: nil", err) - } - defer ss.Stop() - - ctx, cancel := context.WithTimeout(ctx, defaultTestTimeout) - defer cancel() + ctx, cancel := context.WithTimeout(tt.ctx, defaultTestTimeout) + defer cancel() - _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}) - if err != nil { - t.Fatalf("Unexpected unary call error, got: %v, want: nil", err) + _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}) + if err != nil { + t.Fatalf("Unexpected unary call error, got: %v, want: nil", err) + } + }) } } From 6297002884c230dc120db3c6723685433ecd161b Mon Sep 17 00:00:00 2001 From: Ronak Jain Date: Mon, 23 Jan 2023 20:12:18 +0530 Subject: [PATCH 10/12] update test to assert compressor invokes --- internal/transport/http2_server.go | 3 +- server.go | 4 +- test/end2end_test.go | 72 +++++++++++++++++++++++------- 3 files changed, 60 insertions(+), 19 deletions(-) diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 25d94e52bbc..fad6c151b85 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -405,11 +405,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( isGRPC = true case "grpc-accept-encoding": + mdata[hf.Name] = append(mdata[hf.Name], hf.Value) if hf.Value == "" { continue } - - mdata[hf.Name] = append(mdata[hf.Name], hf.Value) compressors := hf.Value if s.clientAdvertisedCompressors != "" { compressors = s.clientAdvertisedCompressors + "," + compressors diff --git a/server.go b/server.go index c08f1bc87a6..01700a93f7f 100644 --- a/server.go +++ b/server.go @@ -1963,7 +1963,7 @@ func SendHeader(ctx context.Context, md metadata.MD) error { return nil } -// SetSendCompressor sets a compressor for outbound messages. +// SetSendCompressor sets a compressor for outbound messages from the server. // It must not be called after any event that causes headers to be sent // (see ServerStream.SetHeader for the complete list). Provided compressor is // used when below conditions are met: @@ -1984,7 +1984,7 @@ func SendHeader(ctx context.Context, md metadata.MD) error { // // # Experimental // -// Notice: This function_ is EXPERIMENTAL and may be changed or removed in a +// Notice: This function is EXPERIMENTAL and may be changed or removed in a // later release. func SetSendCompressor(ctx context.Context, name string) error { stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream) diff --git a/test/end2end_test.go b/test/end2end_test.go index 194e0f7d735..6774a1ff2e2 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5510,34 +5510,64 @@ func (s) TestClientForwardsGrpcAcceptEncodingHeader(t *testing.T) { } } +// wrapCompressor is a wrapper of encoding.Compressor which maintains count of +// Compressor method invokes. +type wrapCompressor struct { + encoding.Compressor + compressInvokes int32 +} + +func (wc *wrapCompressor) Compress(w io.Writer) (io.WriteCloser, error) { + atomic.AddInt32(&wc.compressInvokes, 1) + return wc.Compressor.Compress(w) +} + +func setupGzipWrapCompressor(t *testing.T) *wrapCompressor { + oldC := encoding.GetCompressor("gzip") + c := &wrapCompressor{Compressor: oldC} + encoding.RegisterCompressor(c) + t.Cleanup(func() { + encoding.RegisterCompressor(oldC) + }) + return c +} + func (s) TestSetSendCompressorSuccess(t *testing.T) { for _, tt := range []struct { - desc string - resCompressor string - wantErr error + name string + desc string + dialOpts []grpc.DialOption + resCompressor string + wantCompressInvokes int32 }{ { - desc: "gzip_response_compressor", - resCompressor: "gzip", + name: "identity_request_and_gzip_response", + desc: "request is uncompressed and response is gzip compressed", + resCompressor: "gzip", + wantCompressInvokes: 1, }, { - desc: "identity_response_compressor", - resCompressor: "identity", + name: "gzip_request_and_identity_response", + desc: "request is gzip compressed and response is uncompressed with identity", + resCompressor: "identity", + dialOpts: []grpc.DialOption{grpc.WithCompressor(grpc.NewGZIPCompressor())}, + wantCompressInvokes: 0, }, } { - t.Run(tt.desc, func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { t.Run("unary", func(t *testing.T) { - testUnarySetSendCompressorSuccess(t, tt.resCompressor) + testUnarySetSendCompressorSuccess(t, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts) }) t.Run("stream", func(t *testing.T) { - testStreamSetSendCompressorSuccess(t, tt.resCompressor) + testStreamSetSendCompressorSuccess(t, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts) }) }) } } -func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string) { +func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) { + wc := setupGzipWrapCompressor(t) ss := &stubserver.StubServer{ EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil { @@ -5546,7 +5576,7 @@ func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string) { return &testpb.Empty{}, nil }, } - if err := ss.Start(nil); err != nil { + if err := ss.Start(nil, dialOpts...); err != nil { t.Fatalf("Error starting endpoint server: %v", err) } defer ss.Stop() @@ -5557,9 +5587,15 @@ func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string) { if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { t.Fatalf("Unexpected unary call error, got: %v, want: nil", err) } + + compressorInvokes := atomic.LoadInt32(&wc.compressInvokes) + if compressorInvokes != wantCompressInvokes { + t.Fatalf("Unexpected number of Compress calls, got:%d, want: %d", compressorInvokes, wantCompressInvokes) + } } -func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string) { +func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) { + wc := setupGzipWrapCompressor(t) ss := &stubserver.StubServer{ FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { if _, err := stream.Recv(); err != nil { @@ -5573,8 +5609,8 @@ func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string) { return stream.Send(&testpb.StreamingOutputCallResponse{}) }, } - if err := ss.Start(nil); err != nil { - t.Fatalf("Error starting endpoint server: %v, want: nil", err) + if err := ss.Start(nil, dialOpts...); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) } defer ss.Stop() @@ -5593,6 +5629,11 @@ func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string) { if _, err := s.Recv(); err != nil { t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err) } + + compressorInvokes := atomic.LoadInt32(&wc.compressInvokes) + if compressorInvokes != wantCompressInvokes { + t.Fatalf("Unexpected number of Compress calls, got:%d, want: %d", compressorInvokes, wantCompressInvokes) + } } func (s) TestUnregisteredSetSendCompressorFailure(t *testing.T) { @@ -5691,6 +5732,7 @@ func (s) TestUnarySetSendCompressorAfterHeaderSendFailure(t *testing.T) { err := grpc.SetSendCompressor(ctx, "gzip") if err == nil { t.Error("Wanted set send compressor error") + return &testpb.Empty{}, nil } return nil, err }, From cedc506a8ec87deaa8f39ac158548dbaf0e79fbc Mon Sep 17 00:00:00 2001 From: Ronak Jain Date: Tue, 24 Jan 2023 09:28:06 +0530 Subject: [PATCH 11/12] fix test var name --- test/end2end_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/end2end_test.go b/test/end2end_test.go index 6774a1ff2e2..42ccdb839a2 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5588,9 +5588,9 @@ func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string, wantC t.Fatalf("Unexpected unary call error, got: %v, want: nil", err) } - compressorInvokes := atomic.LoadInt32(&wc.compressInvokes) - if compressorInvokes != wantCompressInvokes { - t.Fatalf("Unexpected number of Compress calls, got:%d, want: %d", compressorInvokes, wantCompressInvokes) + compressInvokes := atomic.LoadInt32(&wc.compressInvokes) + if compressInvokes != wantCompressInvokes { + t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes) } } @@ -5630,9 +5630,9 @@ func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string, want t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err) } - compressorInvokes := atomic.LoadInt32(&wc.compressInvokes) - if compressorInvokes != wantCompressInvokes { - t.Fatalf("Unexpected number of Compress calls, got:%d, want: %d", compressorInvokes, wantCompressInvokes) + compressInvokes := atomic.LoadInt32(&wc.compressInvokes) + if compressInvokes != wantCompressInvokes { + t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes) } } From c52a550f3f557d171c4c44b7578caad2576754e1 Mon Sep 17 00:00:00 2001 From: Ronak Jain Date: Tue, 31 Jan 2023 23:57:41 +0530 Subject: [PATCH 12/12] add test comment --- test/end2end_test.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/end2end_test.go b/test/end2end_test.go index 9f8fd3c9aba..d3c339ccb8d 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5118,10 +5118,14 @@ func (s) TestSetSendCompressorSuccess(t *testing.T) { wantCompressInvokes: 1, }, { - name: "gzip_request_and_identity_response", - desc: "request is gzip compressed and response is uncompressed with identity", - resCompressor: "identity", - dialOpts: []grpc.DialOption{grpc.WithCompressor(grpc.NewGZIPCompressor())}, + name: "gzip_request_and_identity_response", + desc: "request is gzip compressed and response is uncompressed with identity", + resCompressor: "identity", + dialOpts: []grpc.DialOption{ + // Use WithCompressor instead of UseCompressor to avoid counting + // the client's compressor usage. + grpc.WithCompressor(grpc.NewGZIPCompressor()), + }, wantCompressInvokes: 0, }, } {