diff --git a/stats/stats.go b/stats/stats.go index a5ebeeb6932..0285dcc6a26 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -36,12 +36,12 @@ type RPCStats interface { IsClient() bool } -// Begin contains stats when an RPC begins. +// Begin contains stats when an RPC attempt begins. // FailFast is only valid if this Begin is from client side. type Begin struct { // Client is true if this Begin is from client side. Client bool - // BeginTime is the time when the RPC begins. + // BeginTime is the time when the RPC attempt begins. BeginTime time.Time // FailFast indicates if this RPC is failfast. FailFast bool @@ -49,6 +49,9 @@ type Begin struct { IsClientStream bool // IsServerStream indicates whether the RPC is a server streaming RPC. IsServerStream bool + // IsTransparentRetryAttempt indicates whether this attempt was initiated + // due to transparently retrying a previous attempt. + IsTransparentRetryAttempt bool } // IsClient indicates if the stats information is from client side. diff --git a/stream.go b/stream.go index e224af12d21..78b7e9e8015 100644 --- a/stream.go +++ b/stream.go @@ -274,35 +274,6 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client if c.creds != nil { callHdr.Creds = c.creds } - var trInfo *traceInfo - if EnableTracing { - trInfo = &traceInfo{ - tr: trace.New("grpc.Sent."+methodFamily(method), method), - firstLine: firstLine{ - client: true, - }, - } - if deadline, ok := ctx.Deadline(); ok { - trInfo.firstLine.deadline = time.Until(deadline) - } - trInfo.tr.LazyLog(&trInfo.firstLine, false) - ctx = trace.NewContext(ctx, trInfo.tr) - } - ctx = newContextWithRPCInfo(ctx, c.failFast, c.codec, cp, comp) - sh := cc.dopts.copts.StatsHandler - var beginTime time.Time - if sh != nil { - ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast}) - beginTime = time.Now() - begin := &stats.Begin{ - Client: true, - BeginTime: beginTime, - FailFast: c.failFast, - IsClientStream: desc.ClientStreams, - IsServerStream: desc.ServerStreams, - } - sh.HandleRPC(ctx, begin) - } cs := &clientStream{ callHdr: callHdr, @@ -316,7 +287,6 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client cp: cp, comp: comp, cancel: cancel, - beginTime: beginTime, firstAttempt: true, onCommit: onCommit, } @@ -325,9 +295,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client } cs.binlog = binarylog.GetMethodLogger(method) - // Only this initial attempt has stats/tracing. - // TODO(dfawley): move to newAttempt when per-attempt stats are implemented. - if err := cs.newAttemptLocked(sh, trInfo); err != nil { + if err := cs.newAttemptLocked(false /* isTransparent */); err != nil { cs.finish(err) return nil, err } @@ -375,8 +343,43 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client // newAttemptLocked creates a new attempt with a transport. // If it succeeds, then it replaces clientStream's attempt with this new attempt. -func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) (retErr error) { +func (cs *clientStream) newAttemptLocked(isTransparent bool) (retErr error) { + ctx := newContextWithRPCInfo(cs.ctx, cs.callInfo.failFast, cs.callInfo.codec, cs.cp, cs.comp) + method := cs.callHdr.Method + sh := cs.cc.dopts.copts.StatsHandler + var beginTime time.Time + if sh != nil { + ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: cs.callInfo.failFast}) + beginTime = time.Now() + begin := &stats.Begin{ + Client: true, + BeginTime: beginTime, + FailFast: cs.callInfo.failFast, + IsClientStream: cs.desc.ClientStreams, + IsServerStream: cs.desc.ServerStreams, + IsTransparentRetryAttempt: isTransparent, + } + sh.HandleRPC(ctx, begin) + } + + var trInfo *traceInfo + if EnableTracing { + trInfo = &traceInfo{ + tr: trace.New("grpc.Sent."+methodFamily(method), method), + firstLine: firstLine{ + client: true, + }, + } + if deadline, ok := ctx.Deadline(); ok { + trInfo.firstLine.deadline = time.Until(deadline) + } + trInfo.tr.LazyLog(&trInfo.firstLine, false) + ctx = trace.NewContext(ctx, trInfo.tr) + } + newAttempt := &csAttempt{ + ctx: ctx, + beginTime: beginTime, cs: cs, dc: cs.cc.dopts.dc, statsHandler: sh, @@ -391,15 +394,14 @@ func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) (r } }() - if err := cs.ctx.Err(); err != nil { + if err := ctx.Err(); err != nil { return toRPCErr(err) } - ctx := cs.ctx if cs.cc.parsedTarget.Scheme == "xds" { // Add extra metadata (metadata that will be added by transport) to context // so the balancer can see them. - ctx = grpcutil.WithExtraMetadata(cs.ctx, metadata.Pairs( + ctx = grpcutil.WithExtraMetadata(ctx, metadata.Pairs( "content-type", grpcutil.ContentType(cs.callHdr.ContentSubtype), )) } @@ -419,7 +421,7 @@ func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) (r func (a *csAttempt) newStream() error { cs := a.cs cs.callHdr.PreviousAttempts = cs.numRetries - s, err := a.t.NewStream(cs.ctx, cs.callHdr) + s, err := a.t.NewStream(a.ctx, cs.callHdr) if err != nil { // Return without converting to an RPC error so retry code can // inspect. @@ -444,8 +446,7 @@ type clientStream struct { cancel context.CancelFunc // cancels all attempts - sentLast bool // sent an end stream - beginTime time.Time + sentLast bool // sent an end stream methodConfig *MethodConfig @@ -485,6 +486,7 @@ type clientStream struct { // csAttempt implements a single transport stream attempt within a // clientStream. type csAttempt struct { + ctx context.Context cs *clientStream t transport.ClientTransport s *transport.Stream @@ -503,6 +505,7 @@ type csAttempt struct { trInfo *traceInfo statsHandler stats.Handler + beginTime time.Time } func (cs *clientStream) commitAttemptLocked() { @@ -520,15 +523,16 @@ func (cs *clientStream) commitAttempt() { } // shouldRetry returns nil if the RPC should be retried; otherwise it returns -// the error that should be returned by the operation. -func (cs *clientStream) shouldRetry(err error) error { +// the error that should be returned by the operation. If the RPC should be +// retried, the bool indicates whether it is being retried transparently. +func (cs *clientStream) shouldRetry(err error) (bool, error) { if cs.attempt.s == nil { // Error from NewClientStream. nse, ok := err.(*transport.NewStreamError) if !ok { // Unexpected, but assume no I/O was performed and the RPC is not // fatal, so retry indefinitely. - return nil + return true, nil } // Unwrap and convert error. @@ -537,19 +541,19 @@ func (cs *clientStream) shouldRetry(err error) error { // Never retry DoNotRetry errors, which indicate the RPC should not be // retried due to max header list size violation, etc. if nse.DoNotRetry { - return err + return false, err } // In the event of a non-IO operation error from NewStream, we never // attempted to write anything to the wire, so we can retry // indefinitely. if !nse.PerformedIO { - return nil + return true, nil } } if cs.finished || cs.committed { // RPC is finished or committed; cannot retry. - return err + return false, err } // Wait for the trailers. unprocessed := false @@ -559,17 +563,17 @@ func (cs *clientStream) shouldRetry(err error) error { } if cs.firstAttempt && unprocessed { // First attempt, stream unprocessed: transparently retry. - return nil + return true, nil } if cs.cc.dopts.disableRetry { - return err + return false, err } pushback := 0 hasPushback := false if cs.attempt.s != nil { if !cs.attempt.s.TrailersOnly() { - return err + return false, err } // TODO(retry): Move down if the spec changes to not check server pushback @@ -580,13 +584,13 @@ func (cs *clientStream) shouldRetry(err error) error { if pushback, e = strconv.Atoi(sps[0]); e != nil || pushback < 0 { channelz.Infof(logger, cs.cc.channelzID, "Server retry pushback specified to abort (%q).", sps[0]) cs.retryThrottler.throttle() // This counts as a failure for throttling. - return err + return false, err } hasPushback = true } else if len(sps) > 1 { channelz.Warningf(logger, cs.cc.channelzID, "Server retry pushback specified multiple values (%q); not retrying.", sps) cs.retryThrottler.throttle() // This counts as a failure for throttling. - return err + return false, err } } @@ -599,16 +603,16 @@ func (cs *clientStream) shouldRetry(err error) error { rp := cs.methodConfig.RetryPolicy if rp == nil || !rp.RetryableStatusCodes[code] { - return err + return false, err } // Note: the ordering here is important; we count this as a failure // only if the code matched a retryable code. if cs.retryThrottler.throttle() { - return err + return false, err } if cs.numRetries+1 >= rp.MaxAttempts { - return err + return false, err } var dur time.Duration @@ -631,10 +635,10 @@ func (cs *clientStream) shouldRetry(err error) error { select { case <-t.C: cs.numRetries++ - return nil + return false, nil case <-cs.ctx.Done(): t.Stop() - return status.FromContextError(cs.ctx.Err()).Err() + return false, status.FromContextError(cs.ctx.Err()).Err() } } @@ -642,12 +646,13 @@ func (cs *clientStream) shouldRetry(err error) error { func (cs *clientStream) retryLocked(lastErr error) error { for { cs.attempt.finish(toRPCErr(lastErr)) - if err := cs.shouldRetry(lastErr); err != nil { + isTransparent, err := cs.shouldRetry(lastErr) + if err != nil { cs.commitAttemptLocked() return err } cs.firstAttempt = false - if err := cs.newAttemptLocked(nil, nil); err != nil { + if err := cs.newAttemptLocked(isTransparent); err != nil { return err } if lastErr = cs.replayBufferLocked(); lastErr == nil { @@ -937,7 +942,7 @@ func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte) error { return io.EOF } if a.statsHandler != nil { - a.statsHandler.HandleRPC(cs.ctx, outPayload(true, m, data, payld, time.Now())) + a.statsHandler.HandleRPC(a.ctx, outPayload(true, m, data, payld, time.Now())) } if channelz.IsOn() { a.t.IncrMsgSent() @@ -985,7 +990,7 @@ func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) { a.mu.Unlock() } if a.statsHandler != nil { - a.statsHandler.HandleRPC(cs.ctx, &stats.InPayload{ + a.statsHandler.HandleRPC(a.ctx, &stats.InPayload{ Client: true, RecvTime: time.Now(), Payload: m, @@ -1047,12 +1052,12 @@ func (a *csAttempt) finish(err error) { if a.statsHandler != nil { end := &stats.End{ Client: true, - BeginTime: a.cs.beginTime, + BeginTime: a.beginTime, EndTime: time.Now(), Trailer: tr, Error: err, } - a.statsHandler.HandleRPC(a.cs.ctx, end) + a.statsHandler.HandleRPC(a.ctx, end) } if a.trInfo != nil && a.trInfo.tr != nil { if err == nil { diff --git a/test/end2end_test.go b/test/end2end_test.go index 5702f18bdb2..53d38684285 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -3724,10 +3724,12 @@ func (s) TestTransparentRetry(t *testing.T) { } defer lis.Close() server := &httpServer{ - headerFields: [][]string{{ - ":status", "200", - "content-type", "application/grpc", - "grpc-status", "0", + responses: []httpServerResponse{{ + trailers: [][]string{{ + ":status", "200", + "content-type", "application/grpc", + "grpc-status", "0", + }}, }}, refuseStream: func(i uint32) bool { switch i { @@ -7343,9 +7345,15 @@ func (s) TestHTTPHeaderFrameErrorHandlingMoreThanTwoHeaders(t *testing.T) { doHTTPHeaderTest(t, codes.Internal, header, header, header) } +type httpServerResponse struct { + headers [][]string + payload []byte + trailers [][]string +} + type httpServer struct { - headerFields [][]string refuseStream func(uint32) bool + responses []httpServerResponse } func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields []string, endStream bool) error { @@ -7369,6 +7377,10 @@ func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields }) } +func (s *httpServer) writePayload(framer *http2.Framer, sid uint32, payload []byte) error { + return framer.WriteData(sid, false, payload) +} + func (s *httpServer) start(t *testing.T, lis net.Listener) { // Launch an HTTP server to send back header. go func() { @@ -7394,6 +7406,7 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) { var sid uint32 // Loop until conn is closed and framer returns io.EOF + requestNum := 0 for { // Read frames until a header is received. for { @@ -7413,13 +7426,34 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) { writer.Flush() } } - for i, headers := range s.headerFields { - if err = s.writeHeader(framer, sid, headers, i == len(s.headerFields)-1); err != nil { + + response := s.responses[requestNum] + for _, header := range response.headers { + if err = s.writeHeader(framer, sid, header, false); err != nil { t.Errorf("Error at server-side while writing headers. Err: %v", err) return } writer.Flush() } + if response.payload != nil { + if err = s.writePayload(framer, sid, response.payload); err != nil { + t.Errorf("Error at server-side while writing payload. Err: %v", err) + return + } + writer.Flush() + } + for i, trailer := range response.trailers { + if err = s.writeHeader(framer, sid, trailer, i == len(response.trailers)-1); err != nil { + t.Errorf("Error at server-side while writing trailers. Err: %v", err) + return + } + writer.Flush() + } + + requestNum++ + if requestNum >= len(s.responses) { + requestNum = 0 + } } }() } @@ -7432,7 +7466,7 @@ func doHTTPHeaderTest(t *testing.T, errCode codes.Code, headerFields ...[]string } defer lis.Close() server := &httpServer{ - headerFields: headerFields, + responses: []httpServerResponse{{trailers: headerFields}}, } server.start(t, lis) cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) diff --git a/test/retry_test.go b/test/retry_test.go index ad1268faa96..96c6a16c1cb 100644 --- a/test/retry_test.go +++ b/test/retry_test.go @@ -22,9 +22,12 @@ import ( "context" "fmt" "io" + "net" "os" + "reflect" "strconv" "strings" + "sync" "testing" "time" @@ -34,6 +37,7 @@ import ( "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" "google.golang.org/grpc/status" testpb "google.golang.org/grpc/test/grpc_testing" ) @@ -550,3 +554,166 @@ func (s) TestRetryStreaming(t *testing.T) { }() } } + +type retryStatsHandler struct { + mu sync.Mutex + s []stats.RPCStats +} + +func (*retryStatsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { + return ctx +} +func (h *retryStatsHandler) HandleRPC(_ context.Context, s stats.RPCStats) { + h.mu.Lock() + h.s = append(h.s, s) + h.mu.Unlock() +} +func (*retryStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { + return ctx +} +func (*retryStatsHandler) HandleConn(context.Context, stats.ConnStats) {} + +func (s) TestRetryStats(t *testing.T) { + defer enableRetry()() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen. Err: %v", err) + } + defer lis.Close() + server := &httpServer{ + responses: []httpServerResponse{{ + trailers: [][]string{{ + ":status", "200", + "content-type", "application/grpc", + "grpc-status", "14", // UNAVAILABLE + "grpc-message", "unavailable retry", + "grpc-retry-pushback-ms", "10", + }}, + }, { + headers: [][]string{{ + ":status", "200", + "content-type", "application/grpc", + }}, + payload: []byte{0, 0, 0, 0, 0}, // header for 0-byte response message. + trailers: [][]string{{ + "grpc-status", "0", // OK + }}, + }}, + refuseStream: func(i uint32) bool { + return i == 1 + }, + } + server.start(t, lis) + handler := &retryStatsHandler{} + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithStatsHandler(handler), + grpc.WithDefaultServiceConfig((`{ + "methodConfig": [{ + "name": [{"service": "grpc.testing.TestService"}], + "retryPolicy": { + "MaxAttempts": 4, + "InitialBackoff": ".01s", + "MaxBackoff": ".01s", + "BackoffMultiplier": 1.0, + "RetryableStatusCodes": [ "UNAVAILABLE" ] + } + }]}`))) + if err != nil { + t.Fatalf("failed to dial due to err: %v", err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + client := testpb.NewTestServiceClient(cc) + + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("unexpected EmptyCall error: %v", err) + } + handler.mu.Lock() + want := []stats.RPCStats{ + &stats.Begin{}, + &stats.OutHeader{FullMethod: "/grpc.testing.TestService/EmptyCall"}, + &stats.OutPayload{WireLength: 5}, + &stats.End{}, + + &stats.Begin{IsTransparentRetryAttempt: true}, + &stats.OutHeader{FullMethod: "/grpc.testing.TestService/EmptyCall"}, + &stats.OutPayload{WireLength: 5}, + &stats.InTrailer{Trailer: metadata.Pairs("content-type", "application/grpc", "grpc-retry-pushback-ms", "10")}, + &stats.End{}, + + &stats.Begin{}, + &stats.OutHeader{FullMethod: "/grpc.testing.TestService/EmptyCall"}, + &stats.OutPayload{WireLength: 5}, + &stats.InHeader{}, + &stats.InPayload{WireLength: 5}, + &stats.InTrailer{}, + &stats.End{}, + } + + // There is a race between noticing the RST_STREAM during the first RPC + // attempt and writing the payload. If we detect that the client did not + // send the OutPayload, we remove it from want. + if _, ok := handler.s[2].(*stats.End); ok { + want = append(want[:2], want[3:]...) + } + + toString := func(ss []stats.RPCStats) (ret []string) { + for _, s := range ss { + ret = append(ret, fmt.Sprintf("%T - %v", s, s)) + } + return ret + } + t.Logf("Handler received frames:\n%v\n---\nwant:\n%v\n", + strings.Join(toString(handler.s), "\n"), + strings.Join(toString(want), "\n")) + + if len(handler.s) != len(want) { + t.Fatalf("received unexpected number of RPCStats: got %v; want %v", len(handler.s), len(want)) + } + + // There is a race between receiving the payload (triggered by the + // application / gRPC library) and receiving the trailer (triggered at the + // transport layer). Adjust the received stats accordingly if necessary. + // Note: we measure from the end of the RPCStats due to the race above. + tIdx, pIdx := len(handler.s)-3, len(handler.s)-2 + _, okT := handler.s[tIdx].(*stats.InTrailer) + _, okP := handler.s[pIdx].(*stats.InPayload) + if okT && okP { + handler.s[pIdx], handler.s[tIdx] = handler.s[tIdx], handler.s[pIdx] + } + + for i := range handler.s { + w, s := want[i], handler.s[i] + + // Validate the event type + if reflect.TypeOf(w) != reflect.TypeOf(s) { + t.Fatalf("at position %v: got %T; want %T", i, s, w) + } + wv, sv := reflect.ValueOf(w).Elem(), reflect.ValueOf(s).Elem() + + // Validate that Client is always true + if sv.FieldByName("Client").Interface().(bool) != true { + t.Fatalf("at position %v: got Client=false; want true", i) + } + + // Validate any set fields in want + for i := 0; i < wv.NumField(); i++ { + if !wv.Field(i).IsZero() { + if got, want := sv.Field(i).Interface(), wv.Field(i).Interface(); !reflect.DeepEqual(got, want) { + name := reflect.TypeOf(w).Elem().Field(i).Name + t.Fatalf("at position %v, field %v: got %v; want %v", i, name, got, want) + } + } + } + } + + // Validate timings between last Begin and preceding End. + end := handler.s[len(handler.s)-8].(*stats.End) + begin := handler.s[len(handler.s)-7].(*stats.Begin) + diff := begin.BeginTime.Sub(end.EndTime) + if diff < 10*time.Millisecond || diff > 50*time.Millisecond { + t.Fatalf("pushback time before final attempt = %v; want ~10ms", diff) + } +}