Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: expose API to set send compressor #5744

Merged
merged 13 commits into from Jan 31, 2023
28 changes: 17 additions & 11 deletions internal/transport/handler_server_test.go
Expand Up @@ -280,31 +280,36 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
t.Errorf("stream method = %q; want %q", s.method, want)
}

err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value"))
if err != nil {
if err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")); err != nil {
t.Error(err)
}
err = s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value"))
if err != nil {

if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
t.Error(err)
}

if err := s.SetSendCompress("gzip"); err != nil {
t.Error(err)
}

md := metadata.Pairs("custom-header", "Another custom header value")
err = s.SendHeader(md)
delete(md, "custom-header")
if err != nil {
if err := s.SendHeader(md); err != nil {
t.Error(err)
}
delete(md, "custom-header")

err = s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored"))
if err == nil {
if err := s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")); err == nil {
t.Error("expected SetHeader call after SendHeader to fail")
}
err = s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well"))
if err == nil {

if err := s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")); err == nil {
t.Error("expected second SendHeader call to fail")
}

if err := s.SetSendCompress("snappy"); err == nil {
t.Error("expected second SetSendCompress call to fail")
}

st.bodyw.Close() // no body
st.ht.WriteStatus(s, status.New(codes.OK, ""))
}
Expand All @@ -317,6 +322,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
"Custom-Header": {"Custom header value", "Another custom header value"},
"Grpc-Encoding": {"gzip"},
}
wantTrailer := http.Header{
"Grpc-Status": {"0"},
Expand Down
11 changes: 11 additions & 0 deletions internal/transport/http2_server.go
Expand Up @@ -404,6 +404,17 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
s.contentSubtype = contentSubtype
isGRPC = true

case "grpc-accept-encoding":
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
if hf.Value == "" {
continue
}
jronak marked this conversation as resolved.
Show resolved Hide resolved
compressors := hf.Value
if s.clientAdvertisedCompressors != "" {
compressors = s.clientAdvertisedCompressors + "," + compressors
}
s.clientAdvertisedCompressors = compressors
case "grpc-encoding":
s.recvCompress = hf.Value
case ":method":
Expand Down
23 changes: 21 additions & 2 deletions internal/transport/transport.go
Expand Up @@ -257,6 +257,9 @@ type Stream struct {
fc *inFlow
wq *writeQuota

// Holds compressor names passed in grpc-accept-encoding metadata from the
// client. This is empty for the client side stream.
clientAdvertisedCompressors string
// Callback to state application's intentions to read data. This
// is used to adjust flow control, if needed.
requestRead func(int)
Expand Down Expand Up @@ -345,8 +348,24 @@ func (s *Stream) RecvCompress() string {
}

// SetSendCompress sets the compression algorithm to the stream.
func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str
func (s *Stream) SetSendCompress(name string) error {
if s.isHeaderSent() || s.getState() == streamDone {
return errors.New("transport: set send compressor called after headers sent or stream done")
}

s.sendCompress = name
return nil
}

// SendCompress returns the send compressor name.
func (s *Stream) SendCompress() string {
return s.sendCompress
}

// ClientAdvertisedCompressors returns the compressor names advertised by the
// client via grpc-accept-encoding header.
func (s *Stream) ClientAdvertisedCompressors() string {
return s.clientAdvertisedCompressors
}

// Done returns a channel which is closed when it receives the final status
Expand Down
100 changes: 96 additions & 4 deletions server.go
Expand Up @@ -45,6 +45,7 @@ import (
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
Expand Down Expand Up @@ -1263,6 +1264,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
var comp, decomp encoding.Compressor
var cp Compressor
var dc Decompressor
var sendCompressorName string

// If dc is set and matches the stream's compression, use it. Otherwise, try
// to find a matching registered compressor for decomp.
Expand All @@ -1283,12 +1285,18 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
cp = s.opts.cp
stream.SetSendCompress(cp.Type())
sendCompressorName = cp.Type()
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
comp = encoding.GetCompressor(rc)
if comp != nil {
stream.SetSendCompress(rc)
sendCompressorName = comp.Name()
}
}

if sendCompressorName != "" {
if err := stream.SetSendCompress(sendCompressorName); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err)
}
}

Expand Down Expand Up @@ -1375,6 +1383,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
opts := &transport.Options{Last: true}

// Server handler could have set new compressor by calling SetSendCompressor.
// In case it is set, we need to use it for compressing outbound message.
if stream.SendCompress() != sendCompressorName {
jronak marked this conversation as resolved.
Show resolved Hide resolved
comp = encoding.GetCompressor(stream.SendCompress())
}
if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
if err == io.EOF {
// The entire stream is done (for unary RPC only).
Expand Down Expand Up @@ -1597,12 +1610,18 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
ss.cp = s.opts.cp
stream.SetSendCompress(s.opts.cp.Type())
ss.sendCompressorName = s.opts.cp.Type()
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
ss.comp = encoding.GetCompressor(rc)
if ss.comp != nil {
stream.SetSendCompress(rc)
ss.sendCompressorName = rc
}
}

if ss.sendCompressorName != "" {
if err := stream.SetSendCompress(ss.sendCompressorName); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err)
}
}

Expand Down Expand Up @@ -1935,6 +1954,60 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
return nil
}

// SetSendCompressor sets a compressor for outbound messages from the server.
// It must not be called after any event that causes headers to be sent
// (see ServerStream.SetHeader for the complete list). Provided compressor is
// used when below conditions are met:
//
// - compressor is registered via encoding.RegisterCompressor
// - compressor name must exist in the client advertised compressor names
// sent in grpc-accept-encoding header. Use ClientSupportedCompressors to
// get client supported compressor names.
//
// The context provided must be the context passed to the server's handler.
// It must be noted that compressor name encoding.Identity disables the
// outbound compression.
// By default, server messages will be sent using the same compressor with
// which request messages were sent.
//
// It is not safe to call SetSendCompressor concurrently with SendHeader and
// SendMsg.
//
easwars marked this conversation as resolved.
Show resolved Hide resolved
// # Experimental
jronak marked this conversation as resolved.
Show resolved Hide resolved
//
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
// later release.
func SetSendCompressor(ctx context.Context, name string) error {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
easwars marked this conversation as resolved.
Show resolved Hide resolved
if !ok || stream == nil {
return fmt.Errorf("failed to fetch the stream from the given context")
}

if err := validateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil {
return fmt.Errorf("unable to set send compressor: %w", err)
}

return stream.SetSendCompress(name)
}

// ClientSupportedCompressors returns compressor names advertised by the client
// via grpc-accept-encoding header.
//
// The context provided must be the context passed to the server's handler.
//
// # Experimental
//
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
// later release.
func ClientSupportedCompressors(ctx context.Context) ([]string, error) {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
if !ok || stream == nil {
return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx)
}

return strings.Split(stream.ClientAdvertisedCompressors(), ","), nil
}

// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
// When called more than once, all the provided metadata will be merged.
//
Expand Down Expand Up @@ -1969,3 +2042,22 @@ type channelzServer struct {
func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric {
return c.s.channelzMetric()
}

// validateSendCompressor returns an error when given compressor name cannot be
// handled by the server or the client based on the advertised compressors.
func validateSendCompressor(name, clientCompressors string) error {
if name == encoding.Identity {
return nil
}

if !grpcutil.IsCompressorNameRegistered(name) {
return fmt.Errorf("compressor not registered %q", name)
}

for _, c := range strings.Split(clientCompressors, ",") {
if c == name {
return nil // found match
}
}
return fmt.Errorf("client does not support compressor %q", name)
}
9 changes: 9 additions & 0 deletions stream.go
Expand Up @@ -1511,6 +1511,8 @@ type serverStream struct {
comp encoding.Compressor
decomp encoding.Compressor

sendCompressorName string
easwars marked this conversation as resolved.
Show resolved Hide resolved

maxReceiveMessageSize int
maxSendMessageSize int
trInfo *traceInfo
Expand Down Expand Up @@ -1603,6 +1605,13 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
}
}()

// Server handler could have set new compressor by calling SetSendCompressor.
// In case it is set, we need to use it for compressing outbound message.
if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName {
ss.comp = encoding.GetCompressor(sendCompressorsName)
ss.sendCompressorName = sendCompressorsName
}

// load hdr, payload, data
hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp)
if err != nil {
Expand Down