From b09df35c614f4181c9916589e4b5725afa436ba9 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Mon, 12 Sep 2022 22:20:29 +0000 Subject: [PATCH 1/3] grpc: restrict status codes from control plane (gRFC A54) --- credentials/credentials.go | 20 +- internal/status/status.go | 10 + internal/transport/http2_client.go | 16 +- picker_wrapper.go | 7 +- stream.go | 8 + test/control_plane_status_test.go | 293 +++++++++++++++++++++++++++++ test/creds_test.go | 21 ++- 7 files changed, 355 insertions(+), 20 deletions(-) create mode 100644 test/control_plane_status_test.go diff --git a/credentials/credentials.go b/credentials/credentials.go index 96ff1877e75..5feac3aa0e4 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -36,16 +36,16 @@ import ( // PerRPCCredentials defines the common interface for the credentials which need to // attach security information to every RPC (e.g., oauth2). type PerRPCCredentials interface { - // GetRequestMetadata gets the current request metadata, refreshing - // tokens if required. This should be called by the transport layer on - // each request, and the data should be populated in headers or other - // context. If a status code is returned, it will be used as the status - // for the RPC. uri is the URI of the entry point for the request. - // When supported by the underlying implementation, ctx can be used for - // timeout and cancellation. Additionally, RequestInfo data will be - // available via ctx to this call. - // TODO(zhaoq): Define the set of the qualified keys instead of leaving - // it as an arbitrary string. + // GetRequestMetadata gets the current request metadata, refreshing tokens + // if required. This should be called by the transport layer on each + // request, and the data should be populated in headers or other + // context. If a status code is returned, it will be used as the status for + // the RPC (restricted to an allowable set of codes as defined by gRFC + // A54). uri is the URI of the entry point for the request. When supported + // by the underlying implementation, ctx can be used for timeout and + // cancellation. Additionally, RequestInfo data will be available via ctx + // to this call. TODO(zhaoq): Define the set of the qualified keys instead + // of leaving it as an arbitrary string. GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) // RequireTransportSecurity indicates whether the credentials requires // transport security. diff --git a/internal/status/status.go b/internal/status/status.go index e5c6513edd1..b0ead4f54f8 100644 --- a/internal/status/status.go +++ b/internal/status/status.go @@ -164,3 +164,13 @@ func (e *Error) Is(target error) bool { } return proto.Equal(e.s.s, tse.s.s) } + +// IsRestrictedControlPlaneCode returns whether the status includes a code +// restricted for control plane usage as defined by gRFC A54. +func IsRestrictedControlPlaneCode(s *Status) bool { + switch s.Code() { + case codes.InvalidArgument, codes.NotFound, codes.AlreadyExists, codes.FailedPrecondition, codes.Aborted, codes.OutOfRange, codes.DataLoss: + return true + } + return false +} diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 53643fa9747..a6a4cbeac32 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -40,6 +40,7 @@ import ( icredentials "google.golang.org/grpc/internal/credentials" "google.golang.org/grpc/internal/grpcutil" imetadata "google.golang.org/grpc/internal/metadata" + istatus "google.golang.org/grpc/internal/status" "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/internal/transport/networktype" "google.golang.org/grpc/keepalive" @@ -589,7 +590,11 @@ func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[s for _, c := range t.perRPCCreds { data, err := c.GetRequestMetadata(ctx, audience) if err != nil { - if _, ok := status.FromError(err); ok { + if st, ok := status.FromError(err); ok { + // Restrict the code to the list allowed by gRFC A54. + if istatus.IsRestrictedControlPlaneCode(st) { + err = status.Errorf(codes.Internal, "transport: received per-RPC creds error with illegal status: %v", err) + } return nil, err } @@ -618,7 +623,14 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call } data, err := callCreds.GetRequestMetadata(ctx, audience) if err != nil { - return nil, status.Errorf(codes.Internal, "transport: %v", err) + if st, ok := status.FromError(err); ok { + // Restrict the code to the list allowed by gRFC A54. + if istatus.IsRestrictedControlPlaneCode(st) { + err = status.Errorf(codes.Internal, "transport: received per-RPC creds error with illegal status: %v", err) + } + return nil, err + } + return nil, status.Errorf(codes.Internal, "transport: per-RPC creds failed due to error: %v", err) } callAuthData = make(map[string]string, len(data)) for k, v := range data { diff --git a/picker_wrapper.go b/picker_wrapper.go index 843633c910a..a5d5516ee06 100644 --- a/picker_wrapper.go +++ b/picker_wrapper.go @@ -26,6 +26,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/channelz" + istatus "google.golang.org/grpc/internal/status" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/status" ) @@ -129,8 +130,12 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. if err == balancer.ErrNoSubConnAvailable { continue } - if _, ok := status.FromError(err); ok { + if st, ok := status.FromError(err); ok { // Status error: end the RPC unconditionally with this status. + // First restrict the code to the list allowed by gRFC A54. + if istatus.IsRestrictedControlPlaneCode(st) { + err = status.Errorf(codes.Internal, "received picker error with illegal status: %v", err) + } return nil, nil, dropError{error: err} } // For all other errors, wait for ready RPCs should block and other diff --git a/stream.go b/stream.go index 446a91e323e..b678596949b 100644 --- a/stream.go +++ b/stream.go @@ -39,6 +39,7 @@ import ( imetadata "google.golang.org/grpc/internal/metadata" iresolver "google.golang.org/grpc/internal/resolver" "google.golang.org/grpc/internal/serviceconfig" + istatus "google.golang.org/grpc/internal/status" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" @@ -195,6 +196,13 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth rpcInfo := iresolver.RPCInfo{Context: ctx, Method: method} rpcConfig, err := cc.safeConfigSelector.SelectConfig(rpcInfo) if err != nil { + if st, ok := status.FromError(err); ok { + // Restrict the code to the list allowed by gRFC A54. + if istatus.IsRestrictedControlPlaneCode(st) { + err = status.Errorf(codes.Internal, "config selector returned illegal status: %v", err) + } + return nil, err + } return nil, toRPCErr(err) } diff --git a/test/control_plane_status_test.go b/test/control_plane_status_test.go new file mode 100644 index 00000000000..b5ada6372d0 --- /dev/null +++ b/test/control_plane_status_test.go @@ -0,0 +1,293 @@ +/* + * + * Copyright 2022 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package test + +import ( + "context" + "errors" + "strings" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" + _ "google.golang.org/grpc/encoding/gzip" + iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/status" + testpb "google.golang.org/grpc/test/grpc_testing" +) + +func (s) TestConfigSelectorStatusCodes(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + ss.R = manual.NewBuilderWithScheme("confSel") + + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + csErr := make(chan error, 1) + state := iresolver.SetConfigSelector(resolver.State{ + Addresses: []resolver.Address{{Addr: ss.Address}}, + ServiceConfig: parseServiceConfig(t, ss.R, "{}"), + }, funcConfigSelector{ + f: func(i iresolver.RPCInfo) (*iresolver.RPCConfig, error) { + return nil, <-csErr + }, + }) + ss.R.UpdateState(state) // Blocks until config selector is applied + + testCases := []struct { + name string + csErr error + want error + }{{ + name: "legal status code", + csErr: status.Errorf(codes.Unavailable, "this error is fine"), + want: status.Errorf(codes.Unavailable, "this error is fine"), + }, { + name: "illegal status code", + csErr: status.Errorf(codes.NotFound, "this error is bad"), + want: status.Errorf(codes.Internal, "this error is bad"), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // In case the channel is full due to a previous iteration failure, + // do not block. + select { + case csErr <- tc.csErr: + default: + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) { + t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want) + } + }) + } +} + +type lbBuilderWrapper struct { + builder balancer.Builder // real Builder + name string + picker func(balancer.PickInfo) error +} + +func (l *lbBuilderWrapper) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + return l.builder.Build(&lbCCWrapper{ClientConn: cc, picker: l.picker}, opts) +} + +func (l *lbBuilderWrapper) Name() string { + return l.name +} + +type lbCCWrapper struct { + balancer.ClientConn // real ClientConn + picker func(balancer.PickInfo) error +} + +func (l *lbCCWrapper) UpdateState(s balancer.State) { + s.Picker = &lbPickerWrapper{picker: l.picker, Picker: s.Picker} + l.ClientConn.UpdateState(s) +} + +type lbPickerWrapper struct { + balancer.Picker // real Picker + picker func(balancer.PickInfo) error +} + +func (lp *lbPickerWrapper) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + if err := lp.picker(info); err != nil { + return balancer.PickResult{}, err + } + return lp.Picker.Pick(info) +} + +func (s) TestPickerStatusCodes(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + pickerErr := make(chan error, 1) + balancer.Register(&lbBuilderWrapper{ + builder: balancer.Get("round_robin"), + name: "testPickerStatusCodesBalancer", + picker: func(balancer.PickInfo) error { + return <-pickerErr + }, + }) + + ss.NewServiceConfig(`{"loadBalancingConfig": [{"testPickerStatusCodesBalancer":{}}] }`) + + // Make calls until pickerErr is used. + pickerErr <- errors.New("err") + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + used := false + for !used { + ss.Client.EmptyCall(ctx, &testpb.Empty{}) + select { + case pickerErr <- errors.New("err"): + <-pickerErr + used = true + default: + } + } + + testCases := []struct { + name string + pickerErr error + want error + }{{ + name: "legal status code", + pickerErr: status.Errorf(codes.Unavailable, "this error is fine"), + want: status.Errorf(codes.Unavailable, "this error is fine"), + }, { + name: "illegal status code", + pickerErr: status.Errorf(codes.NotFound, "this error is bad"), + want: status.Errorf(codes.Internal, "this error is bad"), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // In case the channel is full due to a previous iteration failure, + // do not block. + select { + case pickerErr <- tc.pickerErr: + default: + } + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) { + t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want) + } + }) + } +} + +func (s) TestCallCredsFromDialOptionsStatusCodes(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + errChan := make(chan error, 1) + creds := &testPerRPCCredentials{errChan: errChan} + + if err := ss.Start(nil, grpc.WithPerRPCCredentials(creds)); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + testCases := []struct { + name string + credsErr error + want error + }{{ + name: "legal status code", + credsErr: status.Errorf(codes.Unavailable, "this error is fine"), + want: status.Errorf(codes.Unavailable, "this error is fine"), + }, { + name: "illegal status code", + credsErr: status.Errorf(codes.NotFound, "this error is bad"), + want: status.Errorf(codes.Internal, "this error is bad"), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // In case the channel is full due to a previous iteration failure, + // do not block. + select { + case errChan <- tc.credsErr: + default: + } + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) { + t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want) + } + }) + } +} + +func (s) TestCallCredsFromCallOptionsStatusCodes(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + errChan := make(chan error, 1) + creds := &testPerRPCCredentials{errChan: errChan} + + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + testCases := []struct { + name string + credsErr error + want error + }{{ + name: "legal status code", + credsErr: status.Errorf(codes.Unavailable, "this error is fine"), + want: status.Errorf(codes.Unavailable, "this error is fine"), + }, { + name: "illegal status code", + credsErr: status.Errorf(codes.NotFound, "this error is bad"), + want: status.Errorf(codes.Internal, "this error is bad"), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // In case the channel is full due to a previous iteration failure, + // do not block. + select { + case errChan <- tc.credsErr: + default: + } + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(creds)); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) { + t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want) + } + }) + } +} diff --git a/test/creds_test.go b/test/creds_test.go index d886220d8a4..5323affa790 100644 --- a/test/creds_test.go +++ b/test/creds_test.go @@ -68,7 +68,7 @@ func (c *testCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials { if c.mode == bundleTLSOnly { return nil } - return testPerRPCCredentials{} + return testPerRPCCredentials{authdata: authdata} } func (c *testCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) { @@ -284,10 +284,17 @@ var ( } ) -type testPerRPCCredentials struct{} +type testPerRPCCredentials struct { + authdata map[string]string + errChan chan error +} func (cr testPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - return authdata, nil + var err error + if cr.errChan != nil { + err = <-cr.errChan + } + return cr.authdata, err } func (cr testPerRPCCredentials) RequireTransportSecurity() bool { @@ -320,7 +327,7 @@ func (s) TestPerRPCCredentialsViaDialOptions(t *testing.T) { func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) { te := newTest(t, e) te.tapHandle = authHandle - te.perRPCCreds = testPerRPCCredentials{} + te.perRPCCreds = testPerRPCCredentials{authdata: authdata} te.startServer(&testServer{security: e.security}) defer te.tearDown() @@ -349,7 +356,7 @@ func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) { tc := testpb.NewTestServiceClient(cc) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil { + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{authdata: authdata})); err != nil { t.Fatalf("Test failed. Reason: %v", err) } } @@ -362,7 +369,7 @@ func (s) TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) { func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) { te := newTest(t, e) - te.perRPCCreds = testPerRPCCredentials{} + te.perRPCCreds = testPerRPCCredentials{authdata: authdata} // When credentials are provided via both dial options and call options, // we apply both sets. te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) { @@ -391,7 +398,7 @@ func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) { tc := testpb.NewTestServiceClient(cc) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil { + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{authdata: authdata})); err != nil { t.Fatalf("Test failed. Reason: %v", err) } } From 757f2951ec69f66c7ec5dd444ed8fb7ec560ed95 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Tue, 27 Sep 2022 14:10:55 -0700 Subject: [PATCH 2/3] test simplifications --- test/control_plane_status_test.go | 248 ++++++++++++------------------ 1 file changed, 95 insertions(+), 153 deletions(-) diff --git a/test/control_plane_status_test.go b/test/control_plane_status_test.go index b5ada6372d0..2c09825cad8 100644 --- a/test/control_plane_status_test.go +++ b/test/control_plane_status_test.go @@ -20,14 +20,17 @@ package test import ( "context" - "errors" "strings" "testing" + "time" "google.golang.org/grpc" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" _ "google.golang.org/grpc/encoding/gzip" + "google.golang.org/grpc/internal/balancer/stub" iresolver "google.golang.org/grpc/internal/resolver" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/resolver" @@ -37,29 +40,6 @@ import ( ) func (s) TestConfigSelectorStatusCodes(t *testing.T) { - ss := &stubserver.StubServer{ - EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil - }, - } - ss.R = manual.NewBuilderWithScheme("confSel") - - if err := ss.Start(nil); err != nil { - t.Fatalf("Error starting endpoint server: %v", err) - } - defer ss.Stop() - - csErr := make(chan error, 1) - state := iresolver.SetConfigSelector(resolver.State{ - Addresses: []resolver.Address{{Addr: ss.Address}}, - ServiceConfig: parseServiceConfig(t, ss.R, "{}"), - }, funcConfigSelector{ - f: func(i iresolver.RPCInfo) (*iresolver.RPCConfig, error) { - return nil, <-csErr - }, - }) - ss.R.UpdateState(state) // Blocks until config selector is applied - testCases := []struct { name string csErr error @@ -76,12 +56,27 @@ func (s) TestConfigSelectorStatusCodes(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - // In case the channel is full due to a previous iteration failure, - // do not block. - select { - case csErr <- tc.csErr: - default: + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, } + ss.R = manual.NewBuilderWithScheme("confSel") + + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + state := iresolver.SetConfigSelector(resolver.State{ + Addresses: []resolver.Address{{Addr: ss.Address}}, + ServiceConfig: parseServiceConfig(t, ss.R, "{}"), + }, funcConfigSelector{ + f: func(i iresolver.RPCInfo) (*iresolver.RPCConfig, error) { + return nil, tc.csErr + }, + }) + ss.R.UpdateState(state) // Blocks until config selector is applied ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -92,80 +87,7 @@ func (s) TestConfigSelectorStatusCodes(t *testing.T) { } } -type lbBuilderWrapper struct { - builder balancer.Builder // real Builder - name string - picker func(balancer.PickInfo) error -} - -func (l *lbBuilderWrapper) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { - return l.builder.Build(&lbCCWrapper{ClientConn: cc, picker: l.picker}, opts) -} - -func (l *lbBuilderWrapper) Name() string { - return l.name -} - -type lbCCWrapper struct { - balancer.ClientConn // real ClientConn - picker func(balancer.PickInfo) error -} - -func (l *lbCCWrapper) UpdateState(s balancer.State) { - s.Picker = &lbPickerWrapper{picker: l.picker, Picker: s.Picker} - l.ClientConn.UpdateState(s) -} - -type lbPickerWrapper struct { - balancer.Picker // real Picker - picker func(balancer.PickInfo) error -} - -func (lp *lbPickerWrapper) Pick(info balancer.PickInfo) (balancer.PickResult, error) { - if err := lp.picker(info); err != nil { - return balancer.PickResult{}, err - } - return lp.Picker.Pick(info) -} - func (s) TestPickerStatusCodes(t *testing.T) { - ss := &stubserver.StubServer{ - EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil - }, - } - - if err := ss.Start(nil); err != nil { - t.Fatalf("Error starting endpoint server: %v", err) - } - defer ss.Stop() - - pickerErr := make(chan error, 1) - balancer.Register(&lbBuilderWrapper{ - builder: balancer.Get("round_robin"), - name: "testPickerStatusCodesBalancer", - picker: func(balancer.PickInfo) error { - return <-pickerErr - }, - }) - - ss.NewServiceConfig(`{"loadBalancingConfig": [{"testPickerStatusCodesBalancer":{}}] }`) - - // Make calls until pickerErr is used. - pickerErr <- errors.New("err") - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - used := false - for !used { - ss.Client.EmptyCall(ctx, &testpb.Empty{}) - select { - case pickerErr <- errors.New("err"): - <-pickerErr - used = true - default: - } - } - testCases := []struct { name string pickerErr error @@ -182,38 +104,51 @@ func (s) TestPickerStatusCodes(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - // In case the channel is full due to a previous iteration failure, - // do not block. - select { - case pickerErr <- tc.pickerErr: - default: + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, } - if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) { - t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want) + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) } - }) - } -} + defer ss.Stop() + + // Create a stub balancer that creates a picker that always returns + // an error. + sbf := stub.BalancerFuncs{ + UpdateClientConnState: func(d *stub.BalancerData, _ balancer.ClientConnState) error { + d.ClientConn.UpdateState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: base.NewErrPicker(tc.pickerErr), + }) + return nil + }, + } + stub.Register("testPickerStatusCodesBalancer", sbf) -func (s) TestCallCredsFromDialOptionsStatusCodes(t *testing.T) { - ss := &stubserver.StubServer{ - EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil - }, - } + ss.NewServiceConfig(`{"loadBalancingConfig": [{"testPickerStatusCodesBalancer":{}}] }`) - errChan := make(chan error, 1) - creds := &testPerRPCCredentials{errChan: errChan} + // Make calls until pickerErr is received. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() - if err := ss.Start(nil, grpc.WithPerRPCCredentials(creds)); err != nil { - t.Fatalf("Error starting endpoint server: %v", err) - } - defer ss.Stop() + var lastErr error + for ctx.Err() == nil { + if _, lastErr = ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(lastErr) == status.Code(tc.want) && strings.Contains(lastErr.Error(), status.Convert(tc.want).Message()) { + // Success! + return + } + time.Sleep(time.Millisecond) + } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() + t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", lastErr, tc.want) + }) + } +} +func (s) TestCallCredsFromDialOptionsStatusCodes(t *testing.T) { testCases := []struct { name string credsErr error @@ -230,13 +165,25 @@ func (s) TestCallCredsFromDialOptionsStatusCodes(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - // In case the channel is full due to a previous iteration failure, - // do not block. - select { - case errChan <- tc.credsErr: - default: + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, } + errChan := make(chan error, 1) + creds := &testPerRPCCredentials{errChan: errChan} + + if err := ss.Start(nil, grpc.WithPerRPCCredentials(creds)); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + errChan <- tc.credsErr + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) { t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want) } @@ -245,23 +192,6 @@ func (s) TestCallCredsFromDialOptionsStatusCodes(t *testing.T) { } func (s) TestCallCredsFromCallOptionsStatusCodes(t *testing.T) { - ss := &stubserver.StubServer{ - EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil - }, - } - - errChan := make(chan error, 1) - creds := &testPerRPCCredentials{errChan: errChan} - - if err := ss.Start(nil); err != nil { - t.Fatalf("Error starting endpoint server: %v", err) - } - defer ss.Stop() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - testCases := []struct { name string credsErr error @@ -278,12 +208,24 @@ func (s) TestCallCredsFromCallOptionsStatusCodes(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - // In case the channel is full due to a previous iteration failure, - // do not block. - select { - case errChan <- tc.credsErr: - default: + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + errChan := make(chan error, 1) + creds := &testPerRPCCredentials{errChan: errChan} + + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + errChan <- tc.credsErr if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(creds)); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) { t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want) From e77f5d4840d429828b3eb558269bb6fcc1cd497f Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Tue, 27 Sep 2022 14:12:23 -0700 Subject: [PATCH 3/3] remove unneeded import --- test/control_plane_status_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/test/control_plane_status_test.go b/test/control_plane_status_test.go index 2c09825cad8..be191f456b2 100644 --- a/test/control_plane_status_test.go +++ b/test/control_plane_status_test.go @@ -29,7 +29,6 @@ import ( "google.golang.org/grpc/balancer/base" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" - _ "google.golang.org/grpc/encoding/gzip" "google.golang.org/grpc/internal/balancer/stub" iresolver "google.golang.org/grpc/internal/resolver" "google.golang.org/grpc/internal/stubserver"