From 386b26661b5e3a29fdd189c312973fefe3cffdeb Mon Sep 17 00:00:00 2001 From: yihuaz Date: Fri, 30 Oct 2020 15:49:06 -0700 Subject: [PATCH] add per-connection and per-rpc check --- credentials/credentials.go | 4 +- internal/transport/http2_client.go | 19 ++++++- test/insecure_creds_test.go | 85 ++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 4 deletions(-) diff --git a/credentials/credentials.go b/credentials/credentials.go index 02766443ae74..ef1d3d65e0ed 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -92,7 +92,7 @@ type CommonAuthInfo struct { } // GetCommonAuthInfo returns the pointer to CommonAuthInfo struct. -func (c *CommonAuthInfo) GetCommonAuthInfo() *CommonAuthInfo { +func (c CommonAuthInfo) GetCommonAuthInfo() CommonAuthInfo { return c } @@ -231,7 +231,7 @@ func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo { // This API is experimental. func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error { type internalInfo interface { - GetCommonAuthInfo() *CommonAuthInfo + GetCommonAuthInfo() CommonAuthInfo } ri, _ := RequestInfoFromContext(ctx) if ri.AuthInfo == nil { diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 0364df53f868..b966327c1077 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -234,6 +234,19 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts if err != nil { return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err) } + type internalInfo interface { + GetCommonAuthInfo() credentials.CommonAuthInfo + } + if ci, ok := authInfo.(internalInfo); ok { + level := ci.GetCommonAuthInfo().SecurityLevel + for _, cd := range perRPCCreds { + if cd.RequireTransportSecurity() { + if level != credentials.Invalid && level < credentials.PrivacyAndIntegrity { + return nil, connectionErrorf(true, nil, "transport: cannot send secure credentials on an insecure connection") + } + } + } + } isSecure = true if transportCreds.Info().SecurityProtocol == "tls" { scheme = "https" @@ -557,8 +570,10 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call // Note: if these credentials are provided both via dial options and call // options, then both sets of credentials will be applied. if callCreds := callHdr.Creds; callCreds != nil { - if !t.isSecure && callCreds.RequireTransportSecurity() { - return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection") + if callCreds.RequireTransportSecurity() { + if !t.isSecure || credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity) != nil { + return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection") + } } data, err := callCreds.GetRequestMetadata(ctx, audience) if err != nil { diff --git a/test/insecure_creds_test.go b/test/insecure_creds_test.go index 81d5a5ba5d0b..61aaf68bbae4 100644 --- a/test/insecure_creds_test.go +++ b/test/insecure_creds_test.go @@ -21,6 +21,7 @@ package test import ( "context" "net" + "strings" "testing" "time" @@ -30,11 +31,27 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" + testpb "google.golang.org/grpc/test/grpc_testing" ) const defaultTestTimeout = 5 * time.Second +// testLegacyPerRPCCredentials is a PerRPCCredentials that has yet incorporated security level. +type testLegacyPerRPCCredentials struct{} + +func (cr testLegacyPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + return nil, nil +} + +func (cr testLegacyPerRPCCredentials) RequireTransportSecurity() bool { + return true +} + +func isExpectedError(got, want error) bool { + return status.Code(got) == status.Code(want) && strings.Contains(status.Convert(got).Message(), status.Convert(want).Message()) +} + // TestInsecureCreds tests the use of insecure creds on the server and client // side, and verifies that expect security level and auth info are returned. // Also verifies that this credential can interop with existing `WithInsecure` @@ -122,3 +139,71 @@ func (s) TestInsecureCreds(t *testing.T) { }) } } + +func (s) TestInsecureCredsWithPerRPCCredentials(t *testing.T) { + tests := []struct { + desc string + perRPCCredsViaDialOptions bool + perRPCCredsViaCallOptions bool + wantErr string + }{ + { + desc: "send PerRPCCredentials via DialOptions", + perRPCCredsViaDialOptions: true, + perRPCCredsViaCallOptions: false, + wantErr: "context deadline exceeded", + }, + { + desc: "send PerRPCCredentials via CallOptions", + perRPCCredsViaDialOptions: false, + perRPCCredsViaCallOptions: true, + wantErr: "transport: cannot send secure credentials on an insecure connection", + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + ss := &stubServer{ + emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + sOpts := []grpc.ServerOption{} + sOpts = append(sOpts, grpc.Creds(insecure.NewCredentials())) + s := grpc.NewServer(sOpts...) + defer s.Stop() + + testpb.RegisterTestServiceServer(s, ss) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("net.Listen(tcp, localhost:0) failed: %v", err) + } + + go s.Serve(lis) + + addr := lis.Addr().String() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cOpts := []grpc.DialOption{grpc.WithBlock()} + cOpts = append(cOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) + + if test.perRPCCredsViaDialOptions { + cOpts = append(cOpts, grpc.WithPerRPCCredentials(testLegacyPerRPCCredentials{})) + if _, err := grpc.DialContext(ctx, addr, cOpts...); !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_DialOptions = %v; want %s", err, test.wantErr) + } + } + + if test.perRPCCredsViaCallOptions { + cc, err := grpc.DialContext(ctx, addr, cOpts...) + defer cc.Close() + + c := testpb.NewTestServiceClient(cc) + if _, err = c.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testLegacyPerRPCCredentials{})); !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_CallOptions = %v; want %s", err, test.wantErr) + } + } + }) + } +}