Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grpc: restrict status codes from control plane (gRFC A54) #5653

Merged
merged 3 commits into from Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 10 additions & 10 deletions credentials/credentials.go
Expand Up @@ -36,16 +36,16 @@ import (
// PerRPCCredentials defines the common interface for the credentials which need to
// attach security information to every RPC (e.g., oauth2).
type PerRPCCredentials interface {
// GetRequestMetadata gets the current request metadata, refreshing
// tokens if required. This should be called by the transport layer on
// each request, and the data should be populated in headers or other
// context. If a status code is returned, it will be used as the status
// for the RPC. uri is the URI of the entry point for the request.
// When supported by the underlying implementation, ctx can be used for
// timeout and cancellation. Additionally, RequestInfo data will be
// available via ctx to this call.
// TODO(zhaoq): Define the set of the qualified keys instead of leaving
// it as an arbitrary string.
// GetRequestMetadata gets the current request metadata, refreshing tokens
// if required. This should be called by the transport layer on each
// request, and the data should be populated in headers or other
// context. If a status code is returned, it will be used as the status for
// the RPC (restricted to an allowable set of codes as defined by gRFC
// A54). uri is the URI of the entry point for the request. When supported
// by the underlying implementation, ctx can be used for timeout and
// cancellation. Additionally, RequestInfo data will be available via ctx
// to this call. TODO(zhaoq): Define the set of the qualified keys instead
// of leaving it as an arbitrary string.
GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error)
// RequireTransportSecurity indicates whether the credentials requires
// transport security.
Expand Down
10 changes: 10 additions & 0 deletions internal/status/status.go
Expand Up @@ -164,3 +164,13 @@ func (e *Error) Is(target error) bool {
}
return proto.Equal(e.s.s, tse.s.s)
}

// IsRestrictedControlPlaneCode returns whether the status includes a code
// restricted for control plane usage as defined by gRFC A54.
func IsRestrictedControlPlaneCode(s *Status) bool {
switch s.Code() {
case codes.InvalidArgument, codes.NotFound, codes.AlreadyExists, codes.FailedPrecondition, codes.Aborted, codes.OutOfRange, codes.DataLoss:
return true
}
return false
}
16 changes: 14 additions & 2 deletions internal/transport/http2_client.go
Expand Up @@ -40,6 +40,7 @@ import (
icredentials "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/internal/grpcutil"
imetadata "google.golang.org/grpc/internal/metadata"
istatus "google.golang.org/grpc/internal/status"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/internal/transport/networktype"
"google.golang.org/grpc/keepalive"
Expand Down Expand Up @@ -589,7 +590,11 @@ func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[s
for _, c := range t.perRPCCreds {
data, err := c.GetRequestMetadata(ctx, audience)
if err != nil {
if _, ok := status.FromError(err); ok {
if st, ok := status.FromError(err); ok {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little concerned that we have to do this in multiple places. How can we guarantee that we haven't missed a place (in this PR), or that we don't miss one in the future when we are adding something.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A54 specifies the places that should be covered, and this PR should cover all of them. I've also reached out to Eric about whether other plugins (e.g. compressors and encoders) should be considered.

// Restrict the code to the list allowed by gRFC A54.
if istatus.IsRestrictedControlPlaneCode(st) {
err = status.Errorf(codes.Internal, "transport: received per-RPC creds error with illegal status: %v", err)
}
return nil, err
}

Expand Down Expand Up @@ -618,7 +623,14 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call
}
data, err := callCreds.GetRequestMetadata(ctx, audience)
if err != nil {
return nil, status.Errorf(codes.Internal, "transport: %v", err)
if st, ok := status.FromError(err); ok {
// Restrict the code to the list allowed by gRFC A54.
if istatus.IsRestrictedControlPlaneCode(st) {
err = status.Errorf(codes.Internal, "transport: received per-RPC creds error with illegal status: %v", err)
}
return nil, err
}
return nil, status.Errorf(codes.Internal, "transport: per-RPC creds failed due to error: %v", err)
}
callAuthData = make(map[string]string, len(data))
for k, v := range data {
Expand Down
7 changes: 6 additions & 1 deletion picker_wrapper.go
Expand Up @@ -26,6 +26,7 @@ import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/internal/channelz"
istatus "google.golang.org/grpc/internal/status"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/status"
)
Expand Down Expand Up @@ -129,8 +130,12 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
if err == balancer.ErrNoSubConnAvailable {
continue
}
if _, ok := status.FromError(err); ok {
if st, ok := status.FromError(err); ok {
// Status error: end the RPC unconditionally with this status.
// First restrict the code to the list allowed by gRFC A54.
if istatus.IsRestrictedControlPlaneCode(st) {
err = status.Errorf(codes.Internal, "received picker error with illegal status: %v", err)
}
return nil, nil, dropError{error: err}
}
// For all other errors, wait for ready RPCs should block and other
Expand Down
8 changes: 8 additions & 0 deletions stream.go
Expand Up @@ -39,6 +39,7 @@ import (
imetadata "google.golang.org/grpc/internal/metadata"
iresolver "google.golang.org/grpc/internal/resolver"
"google.golang.org/grpc/internal/serviceconfig"
istatus "google.golang.org/grpc/internal/status"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
Expand Down Expand Up @@ -195,6 +196,13 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
rpcInfo := iresolver.RPCInfo{Context: ctx, Method: method}
rpcConfig, err := cc.safeConfigSelector.SelectConfig(rpcInfo)
if err != nil {
if st, ok := status.FromError(err); ok {
// Restrict the code to the list allowed by gRFC A54.
if istatus.IsRestrictedControlPlaneCode(st) {
err = status.Errorf(codes.Internal, "config selector returned illegal status: %v", err)
}
return nil, err
}
return nil, toRPCErr(err)
}

Expand Down
234 changes: 234 additions & 0 deletions test/control_plane_status_test.go
@@ -0,0 +1,234 @@
/*
*
* Copyright 2022 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package test

import (
"context"
"strings"
"testing"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/balancer/stub"
iresolver "google.golang.org/grpc/internal/resolver"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing"
)

func (s) TestConfigSelectorStatusCodes(t *testing.T) {
testCases := []struct {
name string
csErr error
want error
}{{
name: "legal status code",
csErr: status.Errorf(codes.Unavailable, "this error is fine"),
want: status.Errorf(codes.Unavailable, "this error is fine"),
}, {
name: "illegal status code",
csErr: status.Errorf(codes.NotFound, "this error is bad"),
want: status.Errorf(codes.Internal, "this error is bad"),
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}
ss.R = manual.NewBuilderWithScheme("confSel")

if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

state := iresolver.SetConfigSelector(resolver.State{
Addresses: []resolver.Address{{Addr: ss.Address}},
ServiceConfig: parseServiceConfig(t, ss.R, "{}"),
}, funcConfigSelector{
f: func(i iresolver.RPCInfo) (*iresolver.RPCConfig, error) {
return nil, tc.csErr
},
})
ss.R.UpdateState(state) // Blocks until config selector is applied

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) {
t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want)
}
})
}
}

func (s) TestPickerStatusCodes(t *testing.T) {
testCases := []struct {
name string
pickerErr error
want error
}{{
name: "legal status code",
pickerErr: status.Errorf(codes.Unavailable, "this error is fine"),
want: status.Errorf(codes.Unavailable, "this error is fine"),
}, {
name: "illegal status code",
pickerErr: status.Errorf(codes.NotFound, "this error is bad"),
want: status.Errorf(codes.Internal, "this error is bad"),
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}

if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

// Create a stub balancer that creates a picker that always returns
// an error.
sbf := stub.BalancerFuncs{
UpdateClientConnState: func(d *stub.BalancerData, _ balancer.ClientConnState) error {
d.ClientConn.UpdateState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: base.NewErrPicker(tc.pickerErr),
})
return nil
},
}
stub.Register("testPickerStatusCodesBalancer", sbf)

ss.NewServiceConfig(`{"loadBalancingConfig": [{"testPickerStatusCodesBalancer":{}}] }`)

// Make calls until pickerErr is received.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

var lastErr error
for ctx.Err() == nil {
if _, lastErr = ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(lastErr) == status.Code(tc.want) && strings.Contains(lastErr.Error(), status.Convert(tc.want).Message()) {
// Success!
return
}
time.Sleep(time.Millisecond)
}

t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", lastErr, tc.want)
})
}
}

func (s) TestCallCredsFromDialOptionsStatusCodes(t *testing.T) {
testCases := []struct {
name string
credsErr error
want error
}{{
name: "legal status code",
credsErr: status.Errorf(codes.Unavailable, "this error is fine"),
want: status.Errorf(codes.Unavailable, "this error is fine"),
}, {
name: "illegal status code",
credsErr: status.Errorf(codes.NotFound, "this error is bad"),
want: status.Errorf(codes.Internal, "this error is bad"),
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}

errChan := make(chan error, 1)
creds := &testPerRPCCredentials{errChan: errChan}

if err := ss.Start(nil, grpc.WithPerRPCCredentials(creds)); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

errChan <- tc.credsErr

if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) {
t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want)
}
})
}
}

func (s) TestCallCredsFromCallOptionsStatusCodes(t *testing.T) {
testCases := []struct {
name string
credsErr error
want error
}{{
name: "legal status code",
credsErr: status.Errorf(codes.Unavailable, "this error is fine"),
want: status.Errorf(codes.Unavailable, "this error is fine"),
}, {
name: "illegal status code",
credsErr: status.Errorf(codes.NotFound, "this error is bad"),
want: status.Errorf(codes.Internal, "this error is bad"),
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}

errChan := make(chan error, 1)
creds := &testPerRPCCredentials{errChan: errChan}

if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

errChan <- tc.credsErr

if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(creds)); status.Code(err) != status.Code(tc.want) || !strings.Contains(err.Error(), status.Convert(tc.want).Message()) {
t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.want)
}
})
}
}