Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rpc_util: Fix RecvBufferPool deactivation issues #6766

Merged
merged 3 commits into from Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
196 changes: 149 additions & 47 deletions experimental/shared_buffer_pool_test.go
Expand Up @@ -26,12 +26,12 @@ import (
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/encoding/gzip"
"google.golang.org/grpc/experimental"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"

testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
)

type s struct {
Expand All @@ -44,59 +44,161 @@ func Test(t *testing.T) {

const defaultTestTimeout = 10 * time.Second

func (s) TestRecvBufferPool(t *testing.T) {
ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
for i := 0; i < 10; i++ {
preparedMsg := &grpc.PreparedMsg{}
err := preparedMsg.Encode(stream, &testpb.StreamingOutputCallResponse{
Payload: &testpb.Payload{
Body: []byte{'0' + uint8(i)},
},
})
func (s) TestRecvBufferPoolStream(t *testing.T) {
tcs := []struct {
name string
callOpts []grpc.CallOption
}{
{
name: "default",
},
{
name: "useCompressor",
callOpts: []grpc.CallOption{
grpc.UseCompressor(gzip.Name),
},
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
const reqCount = 10

ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
for i := 0; i < reqCount; i++ {
preparedMsg := &grpc.PreparedMsg{}
if err := preparedMsg.Encode(stream, &testgrpc.StreamingOutputCallResponse{
Payload: &testgrpc.Payload{
Body: []byte{'0' + uint8(i)},
},
}); err != nil {
return err
}
stream.SendMsg(preparedMsg)
}
return nil
},
}

pool := &checkBufferPool{}
sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
if err := ss.Start(sopts, dopts...); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

stream, err := ss.Client.FullDuplexCall(ctx, tc.callOpts...)
if err != nil {
t.Fatalf("ss.Client.FullDuplexCall failed: %v", err)
}

var ngot int
var buf bytes.Buffer
for {
reply, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return err
t.Fatal(err)
}
stream.SendMsg(preparedMsg)
ngot++
if buf.Len() > 0 {
buf.WriteByte(',')
}
buf.Write(reply.GetPayload().GetBody())
}
return nil
},
if want := 10; ngot != want {
t.Fatalf("Got %d replies, want %d", ngot, want)
}
if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
t.Fatalf("Got replies %q; want %q", got, want)
}

if len(pool.puts) != reqCount {
t.Fatalf("Expected 10 buffers to be returned to the pool, got %d", len(pool.puts))
}
})
}
sopts := []grpc.ServerOption{experimental.RecvBufferPool(grpc.NewSharedBufferPool())}
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(grpc.NewSharedBufferPool())}
if err := ss.Start(sopts, dopts...); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}

func (s) TestRecvBufferPoolUnary(t *testing.T) {
tcs := []struct {
name string
callOpts []grpc.CallOption
}{
{
name: "default",
},
{
name: "useCompressor",
callOpts: []grpc.CallOption{
grpc.UseCompressor(gzip.Name),
},
},
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
const largeSize = 1024

stream, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("ss.Client.FullDuplexCall failed: %f", err)
}
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testgrpc.SimpleRequest) (*testgrpc.SimpleResponse, error) {
return &testgrpc.SimpleResponse{
Payload: &testgrpc.Payload{
Body: make([]byte, largeSize),
},
}, nil
},
}

var ngot int
var buf bytes.Buffer
for {
reply, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
ngot++
if buf.Len() > 0 {
buf.WriteByte(',')
}
buf.Write(reply.GetPayload().GetBody())
}
if want := 10; ngot != want {
t.Errorf("Got %d replies, want %d", ngot, want)
}
if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
t.Errorf("Got replies %q; want %q", got, want)
pool := &checkBufferPool{}
sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
if err := ss.Start(sopts, dopts...); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

const reqCount = 10
for i := 0; i < reqCount; i++ {
if _, err := ss.Client.UnaryCall(
ctx,
&testgrpc.SimpleRequest{
Payload: &testgrpc.Payload{
Body: make([]byte, largeSize),
},
},
tc.callOpts...,
); err != nil {
t.Fatalf("ss.Client.UnaryCall failed: %v", err)
}
}

const bufferCount = reqCount * 2 // req + resp
if len(pool.puts) != bufferCount {
t.Fatalf("Expected %d buffers to be returned to the pool, got %d", bufferCount, len(pool.puts))
}
})
}
}

type checkBufferPool struct {
puts [][]byte
}

func (p *checkBufferPool) Get(size int) []byte {
return make([]byte, size)
}

func (p *checkBufferPool) Put(bs *[]byte) {
p.puts = append(p.puts, *bs)
}
54 changes: 35 additions & 19 deletions rpc_util.go
Expand Up @@ -726,39 +726,55 @@
uncompressedBytes []byte
}

func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) ([]byte, error) {
pf, buf, err := p.recvMsg(maxReceiveMessageSize)
// recvAndDecompress reads a message from the stream, decompressing it if necessary.
//
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
// the buffer is no longer needed.

Check warning on line 732 in rpc_util.go

View check run for this annotation

Codecov / codecov/patch

rpc_util.go#L731-L732

Added lines #L731 - L732 were not covered by tests
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor,
) (uncompressedBuf []byte, cancel func(), err error) {
pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)

Check warning on line 735 in rpc_util.go

View check run for this annotation

Codecov / codecov/patch

rpc_util.go#L734-L735

Added lines #L734 - L735 were not covered by tests
if err != nil {
return nil, err
}
if payInfo != nil {
payInfo.compressedLength = len(buf)
return nil, nil, err

Check warning on line 737 in rpc_util.go

View check run for this annotation

Codecov / codecov/patch

rpc_util.go#L737

Added line #L737 was not covered by tests
}

if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
return nil, st.Err()
return nil, nil, st.Err()
}

var size int
if pf == compressionMade {
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default.
if dc != nil {
buf, err = dc.Do(bytes.NewReader(buf))
size = len(buf)
uncompressedBuf, err = dc.Do(bytes.NewReader(compressedBuf))
size = len(uncompressedBuf)
} else {
buf, size, err = decompress(compressor, buf, maxReceiveMessageSize)
uncompressedBuf, size, err = decompress(compressor, compressedBuf, maxReceiveMessageSize)
}
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
return nil, nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
}
if size > maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)

Check warning on line 760 in rpc_util.go

View check run for this annotation

Codecov / codecov/patch

rpc_util.go#L760

Added line #L760 was not covered by tests
}
} else {
uncompressedBuf = compressedBuf
}
return buf, nil

if payInfo != nil {
payInfo.compressedLength = len(compressedBuf)
payInfo.uncompressedBytes = uncompressedBuf

cancel = func() {}
} else {
cancel = func() {
p.recvBufferPool.Put(&compressedBuf)
}
}

return uncompressedBuf, cancel, nil

Check warning on line 777 in rpc_util.go

View check run for this annotation

Codecov / codecov/patch

rpc_util.go#L776-L777

Added lines #L776 - L777 were not covered by tests
}

// Using compressor, decompress d, returning data and size.
Expand All @@ -778,6 +794,9 @@
// size is used as an estimate to size the buffer, but we
// will read more data if available.
// +MinRead so ReadFrom will not reallocate if size is correct.
//
// TODO: If we ensure that the buffer size is the same as the DecompressedSize,
// we can also utilize the recv buffer pool here.
buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
return buf.Bytes(), int(bytesRead), err
Expand All @@ -793,18 +812,15 @@
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
buf, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
if err != nil {
return err
}
defer cancel()

if err := c.Unmarshal(buf, m); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
}
if payInfo != nil {
payInfo.uncompressedBytes = buf
} else {
p.recvBufferPool.Put(&buf)
}
return nil
}

Expand Down
5 changes: 4 additions & 1 deletion server.go
Expand Up @@ -1319,7 +1319,8 @@
if len(shs) != 0 || len(binlogs) != 0 {
payInfo = &payloadInfo{}
}
d, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)

d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
if err != nil {
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e)
Expand All @@ -1330,6 +1331,8 @@
t.IncrMsgRecv()
}
df := func(v any) error {
defer cancel()

Check warning on line 1335 in server.go

View check run for this annotation

Codecov / codecov/patch

server.go#L1335

Added line #L1335 was not covered by tests
if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
}
Expand Down