diff --git a/credentials/credentials.go b/credentials/credentials.go index 96ff1877e75..5feac3aa0e4 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -36,16 +36,16 @@ import ( // PerRPCCredentials defines the common interface for the credentials which need to // attach security information to every RPC (e.g., oauth2). type PerRPCCredentials interface { - // GetRequestMetadata gets the current request metadata, refreshing - // tokens if required. This should be called by the transport layer on - // each request, and the data should be populated in headers or other - // context. If a status code is returned, it will be used as the status - // for the RPC. uri is the URI of the entry point for the request. - // When supported by the underlying implementation, ctx can be used for - // timeout and cancellation. Additionally, RequestInfo data will be - // available via ctx to this call. - // TODO(zhaoq): Define the set of the qualified keys instead of leaving - // it as an arbitrary string. + // GetRequestMetadata gets the current request metadata, refreshing tokens + // if required. This should be called by the transport layer on each + // request, and the data should be populated in headers or other + // context. If a status code is returned, it will be used as the status for + // the RPC (restricted to an allowable set of codes as defined by gRFC + // A54). uri is the URI of the entry point for the request. When supported + // by the underlying implementation, ctx can be used for timeout and + // cancellation. Additionally, RequestInfo data will be available via ctx + // to this call. TODO(zhaoq): Define the set of the qualified keys instead + // of leaving it as an arbitrary string. GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) // RequireTransportSecurity indicates whether the credentials requires // transport security. diff --git a/internal/status/status.go b/internal/status/status.go index e5c6513edd1..b0ead4f54f8 100644 --- a/internal/status/status.go +++ b/internal/status/status.go @@ -164,3 +164,13 @@ func (e *Error) Is(target error) bool { } return proto.Equal(e.s.s, tse.s.s) } + +// IsRestrictedControlPlaneCode returns whether the status includes a code +// restricted for control plane usage as defined by gRFC A54. +func IsRestrictedControlPlaneCode(s *Status) bool { + switch s.Code() { + case codes.InvalidArgument, codes.NotFound, codes.AlreadyExists, codes.FailedPrecondition, codes.Aborted, codes.OutOfRange, codes.DataLoss: + return true + } + return false +} diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 53643fa9747..a6a4cbeac32 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -40,6 +40,7 @@ import ( icredentials "google.golang.org/grpc/internal/credentials" "google.golang.org/grpc/internal/grpcutil" imetadata "google.golang.org/grpc/internal/metadata" + istatus "google.golang.org/grpc/internal/status" "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/internal/transport/networktype" "google.golang.org/grpc/keepalive" @@ -589,7 +590,11 @@ func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[s for _, c := range t.perRPCCreds { data, err := c.GetRequestMetadata(ctx, audience) if err != nil { - if _, ok := status.FromError(err); ok { + if st, ok := status.FromError(err); ok { + // Restrict the code to the list allowed by gRFC A54. + if istatus.IsRestrictedControlPlaneCode(st) { + err = status.Errorf(codes.Internal, "transport: received per-RPC creds error with illegal status: %v", err) + } return nil, err } @@ -618,7 +623,14 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call } data, err := callCreds.GetRequestMetadata(ctx, audience) if err != nil { - return nil, status.Errorf(codes.Internal, "transport: %v", err) + if st, ok := status.FromError(err); ok { + // Restrict the code to the list allowed by gRFC A54. + if istatus.IsRestrictedControlPlaneCode(st) { + err = status.Errorf(codes.Internal, "transport: received per-RPC creds error with illegal status: %v", err) + } + return nil, err + } + return nil, status.Errorf(codes.Internal, "transport: per-RPC creds failed due to error: %v", err) } callAuthData = make(map[string]string, len(data)) for k, v := range data { diff --git a/picker_wrapper.go b/picker_wrapper.go index 843633c910a..a5d5516ee06 100644 --- a/picker_wrapper.go +++ b/picker_wrapper.go @@ -26,6 +26,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/channelz" + istatus "google.golang.org/grpc/internal/status" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/status" ) @@ -129,8 +130,12 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. if err == balancer.ErrNoSubConnAvailable { continue } - if _, ok := status.FromError(err); ok { + if st, ok := status.FromError(err); ok { // Status error: end the RPC unconditionally with this status. + // First restrict the code to the list allowed by gRFC A54. + if istatus.IsRestrictedControlPlaneCode(st) { + err = status.Errorf(codes.Internal, "received picker error with illegal status: %v", err) + } return nil, nil, dropError{error: err} } // For all other errors, wait for ready RPCs should block and other diff --git a/stream.go b/stream.go index 446a91e323e..b678596949b 100644 --- a/stream.go +++ b/stream.go @@ -39,6 +39,7 @@ import ( imetadata "google.golang.org/grpc/internal/metadata" iresolver "google.golang.org/grpc/internal/resolver" "google.golang.org/grpc/internal/serviceconfig" + istatus "google.golang.org/grpc/internal/status" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" @@ -195,6 +196,13 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth rpcInfo := iresolver.RPCInfo{Context: ctx, Method: method} rpcConfig, err := cc.safeConfigSelector.SelectConfig(rpcInfo) if err != nil { + if st, ok := status.FromError(err); ok { + // Restrict the code to the list allowed by gRFC A54. + if istatus.IsRestrictedControlPlaneCode(st) { + err = status.Errorf(codes.Internal, "config selector returned illegal status: %v", err) + } + return nil, err + } return nil, toRPCErr(err) } diff --git a/test/control_plane_status_test.go b/test/control_plane_status_test.go new file mode 100644 index 00000000000..be191f456b2 --- /dev/null +++ b/test/control_plane_status_test.go @@ -0,0 +1,234 @@ +/* + * + * Copyright 2022 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package test + +import ( + "context" + "strings" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/balancer/stub" + iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/status" + testpb "google.golang.org/grpc/test/grpc_testing" +) + +func (s) TestConfigSelectorStatusCodes(t *testing.T) { + testCases := []struct { + name string + csErr error + want error + }{{ + name: "legal status code", + csErr: status.Errorf(codes.Unavailable, "this error is fine"), + want: status.Errorf(codes.Unavailable, "this error is fine"), + }, { + name: "illegal status code", + csErr: status.Errorf(codes.NotFound, "this error is bad"), + want: status.Errorf(codes.Internal, "this error is bad"), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + ss.R = manual.NewBuilderWithScheme("confSel") + + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + state := iresolver.SetConfigSelector(resolver.State{ + Addresses: []resolver.Address{{Addr: ss.Address}}, + ServiceConfig: parseServiceConfig(t, ss.R, "{}"), + }, funcConfigSelector{ + f: func(i iresolver.RPCInfo) (*iresolver.RPCConfig, error) { + return nil, tc.csErr + }, + }) + ss.R.UpdateState(state) // Blocks until config selector is applied + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) { + t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want) + } + }) + } +} + +func (s) TestPickerStatusCodes(t *testing.T) { + testCases := []struct { + name string + pickerErr error + want error + }{{ + name: "legal status code", + pickerErr: status.Errorf(codes.Unavailable, "this error is fine"), + want: status.Errorf(codes.Unavailable, "this error is fine"), + }, { + name: "illegal status code", + pickerErr: status.Errorf(codes.NotFound, "this error is bad"), + want: status.Errorf(codes.Internal, "this error is bad"), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + // Create a stub balancer that creates a picker that always returns + // an error. + sbf := stub.BalancerFuncs{ + UpdateClientConnState: func(d *stub.BalancerData, _ balancer.ClientConnState) error { + d.ClientConn.UpdateState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: base.NewErrPicker(tc.pickerErr), + }) + return nil + }, + } + stub.Register("testPickerStatusCodesBalancer", sbf) + + ss.NewServiceConfig(`{"loadBalancingConfig": [{"testPickerStatusCodesBalancer":{}}] }`) + + // Make calls until pickerErr is received. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + var lastErr error + for ctx.Err() == nil { + if _, lastErr = ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(lastErr) == status.Code(tc.want) && strings.Contains(lastErr.Error(), status.Convert(tc.want).Message()) { + // Success! + return + } + time.Sleep(time.Millisecond) + } + + t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", lastErr, tc.want) + }) + } +} + +func (s) TestCallCredsFromDialOptionsStatusCodes(t *testing.T) { + testCases := []struct { + name string + credsErr error + want error + }{{ + name: "legal status code", + credsErr: status.Errorf(codes.Unavailable, "this error is fine"), + want: status.Errorf(codes.Unavailable, "this error is fine"), + }, { + name: "illegal status code", + credsErr: status.Errorf(codes.NotFound, "this error is bad"), + want: status.Errorf(codes.Internal, "this error is bad"), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + errChan := make(chan error, 1) + creds := &testPerRPCCredentials{errChan: errChan} + + if err := ss.Start(nil, grpc.WithPerRPCCredentials(creds)); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + errChan <- tc.credsErr + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) { + t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want) + } + }) + } +} + +func (s) TestCallCredsFromCallOptionsStatusCodes(t *testing.T) { + testCases := []struct { + name string + credsErr error + want error + }{{ + name: "legal status code", + credsErr: status.Errorf(codes.Unavailable, "this error is fine"), + want: status.Errorf(codes.Unavailable, "this error is fine"), + }, { + name: "illegal status code", + credsErr: status.Errorf(codes.NotFound, "this error is bad"), + want: status.Errorf(codes.Internal, "this error is bad"), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + errChan := make(chan error, 1) + creds := &testPerRPCCredentials{errChan: errChan} + + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + errChan <- tc.credsErr + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(creds)); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) { + t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want) + } + }) + } +} diff --git a/test/creds_test.go b/test/creds_test.go index d886220d8a4..5323affa790 100644 --- a/test/creds_test.go +++ b/test/creds_test.go @@ -68,7 +68,7 @@ func (c *testCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials { if c.mode == bundleTLSOnly { return nil } - return testPerRPCCredentials{} + return testPerRPCCredentials{authdata: authdata} } func (c *testCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) { @@ -284,10 +284,17 @@ var ( } ) -type testPerRPCCredentials struct{} +type testPerRPCCredentials struct { + authdata map[string]string + errChan chan error +} func (cr testPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - return authdata, nil + var err error + if cr.errChan != nil { + err = <-cr.errChan + } + return cr.authdata, err } func (cr testPerRPCCredentials) RequireTransportSecurity() bool { @@ -320,7 +327,7 @@ func (s) TestPerRPCCredentialsViaDialOptions(t *testing.T) { func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) { te := newTest(t, e) te.tapHandle = authHandle - te.perRPCCreds = testPerRPCCredentials{} + te.perRPCCreds = testPerRPCCredentials{authdata: authdata} te.startServer(&testServer{security: e.security}) defer te.tearDown() @@ -349,7 +356,7 @@ func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) { tc := testpb.NewTestServiceClient(cc) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil { + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{authdata: authdata})); err != nil { t.Fatalf("Test failed. Reason: %v", err) } } @@ -362,7 +369,7 @@ func (s) TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) { func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) { te := newTest(t, e) - te.perRPCCreds = testPerRPCCredentials{} + te.perRPCCreds = testPerRPCCredentials{authdata: authdata} // When credentials are provided via both dial options and call options, // we apply both sets. te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) { @@ -391,7 +398,7 @@ func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) { tc := testpb.NewTestServiceClient(cc) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil { + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{authdata: authdata})); err != nil { t.Fatalf("Test failed. Reason: %v", err) } }