Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rpc_util: Fix RecvBufferPool deactivation issues #6766

Merged
merged 3 commits into from Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
200 changes: 153 additions & 47 deletions experimental/shared_buffer_pool_test.go
Expand Up @@ -26,12 +26,12 @@ import (
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/encoding/gzip"
"google.golang.org/grpc/experimental"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"

testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
)

type s struct {
Expand All @@ -44,59 +44,165 @@ func Test(t *testing.T) {

const defaultTestTimeout = 10 * time.Second

func (s) TestRecvBufferPool(t *testing.T) {
ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
for i := 0; i < 10; i++ {
preparedMsg := &grpc.PreparedMsg{}
err := preparedMsg.Encode(stream, &testpb.StreamingOutputCallResponse{
Payload: &testpb.Payload{
Body: []byte{'0' + uint8(i)},
},
})
func (s) TestRecvBufferPoolStream(t *testing.T) {
tcs := []struct {
name string
callOpts []grpc.CallOption
}{
{
name: "default",
},
{
name: "useCompressor",
callOpts: []grpc.CallOption{
grpc.UseCompressor(gzip.Name),
},
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
const reqCount = 10

ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
for i := 0; i < reqCount; i++ {
preparedMsg := &grpc.PreparedMsg{}
err := preparedMsg.Encode(stream, &testgrpc.StreamingOutputCallResponse{
Payload: &testgrpc.Payload{
Body: []byte{'0' + uint8(i)},
},
})
if err != nil {
return err
}
zasweq marked this conversation as resolved.
Show resolved Hide resolved
stream.SendMsg(preparedMsg)
}
return nil
},
}

pool := &checkBufferPool{}

zasweq marked this conversation as resolved.
Show resolved Hide resolved
sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
if err := ss.Start(sopts, dopts...); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

stream, err := ss.Client.FullDuplexCall(ctx, tc.callOpts...)
if err != nil {
t.Fatalf("ss.Client.FullDuplexCall failed: %f", err)
}

var ngot int
var buf bytes.Buffer
for {
reply, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return err
t.Fatal(err)
}
stream.SendMsg(preparedMsg)
ngot++
if buf.Len() > 0 {
buf.WriteByte(',')
}
buf.Write(reply.GetPayload().GetBody())
}
return nil
},
if want := 10; ngot != want {
t.Fatalf("Got %d replies, want %d", ngot, want)
}
if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
t.Fatalf("Got replies %q; want %q", got, want)
}

if len(pool.puts) != reqCount {
t.Fatalf("Expected 10 buffers to be returned to the pool, got %d", len(pool.puts))
}
})
}
sopts := []grpc.ServerOption{experimental.RecvBufferPool(grpc.NewSharedBufferPool())}
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(grpc.NewSharedBufferPool())}
if err := ss.Start(sopts, dopts...); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}

func (s) TestRecvBufferPoolUnary(t *testing.T) {
tcs := []struct {
name string
callOpts []grpc.CallOption
}{
{
name: "default",
},
{
name: "useCompressor",
callOpts: []grpc.CallOption{
grpc.UseCompressor(gzip.Name),
},
},
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
const largeSize = 1024

stream, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("ss.Client.FullDuplexCall failed: %f", err)
}
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testgrpc.SimpleRequest) (*testgrpc.SimpleResponse, error) {
return &testgrpc.SimpleResponse{
Payload: &testgrpc.Payload{
Body: make([]byte, largeSize),
},
}, nil
},
}

var ngot int
var buf bytes.Buffer
for {
reply, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
ngot++
if buf.Len() > 0 {
buf.WriteByte(',')
}
buf.Write(reply.GetPayload().GetBody())
}
if want := 10; ngot != want {
t.Errorf("Got %d replies, want %d", ngot, want)
}
if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
t.Errorf("Got replies %q; want %q", got, want)
pool := &checkBufferPool{}

zasweq marked this conversation as resolved.
Show resolved Hide resolved
sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
if err := ss.Start(sopts, dopts...); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

const reqCount = 10
for i := 0; i < reqCount; i++ {
_, err := ss.Client.UnaryCall(
ctx,
&testgrpc.SimpleRequest{
Payload: &testgrpc.Payload{
Body: make([]byte, largeSize),
},
},
tc.callOpts...,
)
if err != nil {
zasweq marked this conversation as resolved.
Show resolved Hide resolved
t.Fatalf("ss.Client.UnaryCall failed: %f", err)
zasweq marked this conversation as resolved.
Show resolved Hide resolved
}
}

const bufferCount = reqCount * 2 // req + resp
if len(pool.puts) != bufferCount {
t.Fatalf("Expected 10 buffers to be returned to the pool, got %d", len(pool.puts))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bufferCount is 20. Should this say 20?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and it's exactly twice the reqCount.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but the error message says 10 :). Switch to 20?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last one before merging :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but the error message says 10 :). Switch to 20?

Got it! 389c89e

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks :)!

}
})
}
}

type checkBufferPool struct {
puts [][]byte
}

func (p *checkBufferPool) Get(size int) []byte {
return make([]byte, size)
}

func (p *checkBufferPool) Put(bs *[]byte) {
p.puts = append(p.puts, *bs)
}
53 changes: 34 additions & 19 deletions rpc_util.go
Expand Up @@ -726,39 +726,55 @@
uncompressedBytes []byte
}

func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) ([]byte, error) {
pf, buf, err := p.recvMsg(maxReceiveMessageSize)
// recvAndDecompress reads a message from the stream, decompressing it if necessary.
//
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
// the buffer is no longer needed.
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor,
) (uncompressedBuf []byte, cancel func(), err error) {
pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return nil, err
}
if payInfo != nil {
payInfo.compressedLength = len(buf)
return nil, nil, err
}

if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
return nil, st.Err()
return nil, nil, st.Err()

Check warning on line 741 in rpc_util.go

View check run for this annotation

Codecov / codecov/patch

rpc_util.go#L741

Added line #L741 was not covered by tests
}

var size int
if pf == compressionMade {
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default.
if dc != nil {
buf, err = dc.Do(bytes.NewReader(buf))
size = len(buf)
uncompressedBuf, err = dc.Do(bytes.NewReader(compressedBuf))
size = len(uncompressedBuf)
} else {
buf, size, err = decompress(compressor, buf, maxReceiveMessageSize)
uncompressedBuf, size, err = decompress(compressor, compressedBuf, maxReceiveMessageSize)
}
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
return nil, nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
}
if size > maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)

Check warning on line 760 in rpc_util.go

View check run for this annotation

Codecov / codecov/patch

rpc_util.go#L760

Added line #L760 was not covered by tests
}
} else {
uncompressedBuf = compressedBuf
}
return buf, nil

if payInfo != nil {
payInfo.compressedLength = len(compressedBuf)
payInfo.uncompressedBytes = uncompressedBuf

cancel = func() {}
} else {
cancel = func() {
p.recvBufferPool.Put(&compressedBuf)
}
}

return uncompressedBuf, cancel, nil
}

// Using compressor, decompress d, returning data and size.
Expand All @@ -778,6 +794,8 @@
// size is used as an estimate to size the buffer, but we
// will read more data if available.
// +MinRead so ReadFrom will not reallocate if size is correct.
//
// TODO: If we ensure that the buffer size is the same as the DecompressedSize, we can also utilize the recv buffer pool here.
zasweq marked this conversation as resolved.
Show resolved Hide resolved
buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
return buf.Bytes(), int(bytesRead), err
Expand All @@ -793,18 +811,15 @@
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
buf, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
if err != nil {
return err
}
defer cancel()

if err := c.Unmarshal(buf, m); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
}
if payInfo != nil {
payInfo.uncompressedBytes = buf
} else {
p.recvBufferPool.Put(&buf)
}
return nil
}

Expand Down
5 changes: 4 additions & 1 deletion server.go
Expand Up @@ -1319,7 +1319,8 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
if len(shs) != 0 || len(binlogs) != 0 {
payInfo = &payloadInfo{}
}
d, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)

d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
if err != nil {
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e)
Expand All @@ -1330,6 +1331,8 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
t.IncrMsgRecv()
}
df := func(v any) error {
defer cancel()

if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
}
Expand Down