From 4e0b2131f34ca634581a1ad5ce455d6cdbe7650e Mon Sep 17 00:00:00 2001 From: yihuaz Date: Sun, 1 Nov 2020 18:18:00 -0800 Subject: [PATCH] fix PerRPCCredentials w/RequireTransportSecurity and security levels --- credentials/credentials.go | 39 ++++++++----- credentials/insecure/insecure.go | 12 ++-- credentials/local/local.go | 12 ++-- credentials/local/local_test.go | 6 +- internal/transport/http2_client.go | 14 ++++- test/insecure_creds_test.go | 90 ++++++++++++++++++++++++++++-- test/local_creds_test.go | 3 +- 7 files changed, 138 insertions(+), 38 deletions(-) diff --git a/credentials/credentials.go b/credentials/credentials.go index 02766443ae74..6ada58385a02 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -92,7 +92,7 @@ type CommonAuthInfo struct { } // GetCommonAuthInfo returns the pointer to CommonAuthInfo struct. -func (c *CommonAuthInfo) GetCommonAuthInfo() *CommonAuthInfo { +func (c CommonAuthInfo) GetCommonAuthInfo() CommonAuthInfo { return c } @@ -224,29 +224,42 @@ func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo { return chi } +// GetSecurityLevel returns a connection's security level. It returns 1) an erorr if authInfo is nil or 2) Invalid SecurityLevel +// if authInfo does not implement GetCommonAuthInfo() method. +// +// This API is experimental. +func GetSecurityLevel(authInfo AuthInfo) (SecurityLevel, error) { + type internalInfo interface { + GetCommonAuthInfo() CommonAuthInfo + } + if authInfo == nil { + return -1, errors.New("authInfo is nil") + } + if ci, ok := authInfo.(internalInfo); ok { + return ci.GetCommonAuthInfo().SecurityLevel, nil + } + return Invalid, nil +} + // CheckSecurityLevel checks if a connection's security level is greater than or equal to the specified one. // It returns success if 1) the condition is satisified or 2) AuthInfo struct does not implement GetCommonAuthInfo() method // or 3) CommonAuthInfo.SecurityLevel has an invalid zero value. For 2) and 3), it is for the purpose of backward-compatibility. // // This API is experimental. func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error { - type internalInfo interface { - GetCommonAuthInfo() *CommonAuthInfo - } ri, _ := RequestInfoFromContext(ctx) if ri.AuthInfo == nil { return errors.New("unable to obtain SecurityLevel from context") } - if ci, ok := ri.AuthInfo.(internalInfo); ok { - // CommonAuthInfo.SecurityLevel has an invalid value. - if ci.GetCommonAuthInfo().SecurityLevel == Invalid { - return nil - } - if ci.GetCommonAuthInfo().SecurityLevel < level { - return fmt.Errorf("requires SecurityLevel %v; connection has %v", level, ci.GetCommonAuthInfo().SecurityLevel) - } + secLevel, _ := GetSecurityLevel(ri.AuthInfo) + // CommonAuthInfo.SecurityLevel has an invalid value or AuthInfo struct does not implement GetCommonAuthInfo() method. + if secLevel == Invalid { + return nil + } + if secLevel < level { + return fmt.Errorf("requires SecurityLevel %v; connection has %v", level, secLevel) } - // The condition is satisfied or AuthInfo struct does not implement GetCommonAuthInfo() method. + // condition is satisfied. return nil } diff --git a/credentials/insecure/insecure.go b/credentials/insecure/insecure.go index 7fc11717f765..c4fa27c920da 100644 --- a/credentials/insecure/insecure.go +++ b/credentials/insecure/insecure.go @@ -43,11 +43,11 @@ func NewCredentials() credentials.TransportCredentials { type insecureTC struct{} func (insecureTC) ClientHandshake(ctx context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return conn, Info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil + return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil } func (insecureTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return conn, Info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil + return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil } func (insecureTC) Info() credentials.ProtocolInfo { @@ -62,13 +62,13 @@ func (insecureTC) OverrideServerName(string) error { return nil } -// Info contains the auth information for an insecure connection. +// info contains the auth information for an insecure connection. // It implements the AuthInfo interface. -type Info struct { +type info struct { credentials.CommonAuthInfo } -// AuthType returns the type of Info as a string. -func (Info) AuthType() string { +// AuthType returns the type of info as a string. +func (info) AuthType() string { return "insecure" } diff --git a/credentials/local/local.go b/credentials/local/local.go index a9d446ecaa92..469090142558 100644 --- a/credentials/local/local.go +++ b/credentials/local/local.go @@ -38,14 +38,14 @@ import ( "google.golang.org/grpc/credentials" ) -// Info contains the auth information for a local connection. +// info contains the auth information for a local connection. // It implements the AuthInfo interface. -type Info struct { +type info struct { credentials.CommonAuthInfo } -// AuthType returns the type of Info as a string. -func (Info) AuthType() string { +// AuthType returns the type of info as a string. +func (info) AuthType() string { return "local" } @@ -79,7 +79,7 @@ func (*localTC) ClientHandshake(ctx context.Context, authority string, conn net. if err != nil { return nil, nil, err } - return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil + return conn, info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil } func (*localTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { @@ -87,7 +87,7 @@ func (*localTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, if err != nil { return nil, nil, err } - return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil + return conn, info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil } // NewCredentials returns a local credential implementing credentials.TransportCredentials. diff --git a/credentials/local/local_test.go b/credentials/local/local_test.go index 3c65010e8b2a..60866a51ae9e 100644 --- a/credentials/local/local_test.go +++ b/credentials/local/local_test.go @@ -144,10 +144,8 @@ func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, erro if serverHandleResult.err != nil { return credentials.Invalid, fmt.Errorf("Error at server-side: %v", serverHandleResult.err) } - clientLocal, _ := clientAuthInfo.(Info) - serverLocal, _ := serverHandleResult.authInfo.(Info) - clientSecLevel := clientLocal.CommonAuthInfo.SecurityLevel - serverSecLevel := serverLocal.CommonAuthInfo.SecurityLevel + clientSecLevel, _ := credentials.GetSecurityLevel(clientAuthInfo) + serverSecLevel, _ := credentials.GetSecurityLevel(serverHandleResult.authInfo) if clientSecLevel != serverSecLevel { return credentials.Invalid, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String()) } diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 0364df53f868..db9075204d7c 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -234,6 +234,14 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts if err != nil { return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err) } + for _, cd := range perRPCCreds { + if cd.RequireTransportSecurity() { + secLevel, _ := credentials.GetSecurityLevel(authInfo) + if secLevel != credentials.Invalid && secLevel < credentials.PrivacyAndIntegrity { + return nil, connectionErrorf(true, nil, "transport: cannot send secure credentials on an insecure connection") + } + } + } isSecure = true if transportCreds.Info().SecurityProtocol == "tls" { scheme = "https" @@ -557,8 +565,10 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call // Note: if these credentials are provided both via dial options and call // options, then both sets of credentials will be applied. if callCreds := callHdr.Creds; callCreds != nil { - if !t.isSecure && callCreds.RequireTransportSecurity() { - return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection") + if callCreds.RequireTransportSecurity() { + if !t.isSecure || credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity) != nil { + return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection") + } } data, err := callCreds.GetRequestMetadata(ctx, audience) if err != nil { diff --git a/test/insecure_creds_test.go b/test/insecure_creds_test.go index 81d5a5ba5d0b..a5ac90e8f132 100644 --- a/test/insecure_creds_test.go +++ b/test/insecure_creds_test.go @@ -21,6 +21,7 @@ package test import ( "context" "net" + "strings" "testing" "time" @@ -30,11 +31,23 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" + testpb "google.golang.org/grpc/test/grpc_testing" ) const defaultTestTimeout = 5 * time.Second +// testLegacyPerRPCCredentials is a PerRPCCredentials that has yet incorporated security level. +type testLegacyPerRPCCredentials struct{} + +func (cr testLegacyPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + return nil, nil +} + +func (cr testLegacyPerRPCCredentials) RequireTransportSecurity() bool { + return true +} + // TestInsecureCreds tests the use of insecure creds on the server and client // side, and verifies that expect security level and auth info are returned. // Also verifies that this credential can interop with existing `WithInsecure` @@ -73,11 +86,7 @@ func (s) TestInsecureCreds(t *testing.T) { return nil, status.Error(codes.DataLoss, "Failed to get peer from ctx") } // Check security level. - info := pr.AuthInfo.(insecure.Info) - if at := info.AuthType(); at != "insecure" { - return nil, status.Errorf(codes.Unauthenticated, "Wrong AuthType: got %q, want insecure", at) - } - if secLevel := info.CommonAuthInfo.SecurityLevel; secLevel != credentials.NoSecurity { + if secLevel, _ := credentials.GetSecurityLevel(pr.AuthInfo); secLevel != credentials.NoSecurity { return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel, credentials.NoSecurity) } return &testpb.Empty{}, nil @@ -122,3 +131,74 @@ func (s) TestInsecureCreds(t *testing.T) { }) } } + +func (s) TestInsecureCredsWithPerRPCCredentials(t *testing.T) { + tests := []struct { + desc string + perRPCCredsViaDialOptions bool + perRPCCredsViaCallOptions bool + wantErr string + }{ + { + desc: "send PerRPCCredentials via DialOptions", + perRPCCredsViaDialOptions: true, + perRPCCredsViaCallOptions: false, + wantErr: "context deadline exceeded", + }, + { + desc: "send PerRPCCredentials via CallOptions", + perRPCCredsViaDialOptions: false, + perRPCCredsViaCallOptions: true, + wantErr: "transport: cannot send secure credentials on an insecure connection", + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + ss := &stubServer{ + emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + sOpts := []grpc.ServerOption{} + sOpts = append(sOpts, grpc.Creds(insecure.NewCredentials())) + s := grpc.NewServer(sOpts...) + defer s.Stop() + + testpb.RegisterTestServiceServer(s, ss) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("net.Listen(tcp, localhost:0) failed: %v", err) + } + + go s.Serve(lis) + + addr := lis.Addr().String() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cOpts := []grpc.DialOption{grpc.WithBlock()} + cOpts = append(cOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) + + if test.perRPCCredsViaDialOptions { + cOpts = append(cOpts, grpc.WithPerRPCCredentials(testLegacyPerRPCCredentials{})) + if _, err := grpc.DialContext(ctx, addr, cOpts...); !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_DialOptions = %v; want %s", err, test.wantErr) + } + } + + if test.perRPCCredsViaCallOptions { + cc, err := grpc.DialContext(ctx, addr, cOpts...) + if err != nil { + t.Fatalf("grpc.Dial(%q) failed: %v", addr, err) + } + defer cc.Close() + + c := testpb.NewTestServiceClient(cc) + if _, err = c.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testLegacyPerRPCCredentials{})); !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_CallOptions = %v; want %s", err, test.wantErr) + } + } + }) + } +} diff --git a/test/local_creds_test.go b/test/local_creds_test.go index b55b73bdcbce..1f7e4089f7ab 100644 --- a/test/local_creds_test.go +++ b/test/local_creds_test.go @@ -44,8 +44,7 @@ func testLocalCredsE2ESucceed(network, address string) error { return nil, status.Error(codes.DataLoss, "Failed to get peer from ctx") } // Check security level - info := pr.AuthInfo.(local.Info) - secLevel := info.CommonAuthInfo.SecurityLevel + secLevel, _ := credentials.GetSecurityLevel(pr.AuthInfo) switch network { case "unix": if secLevel != credentials.PrivacyAndIntegrity {