From 14db5e4ce18d99536d4e4931cf1eca2cd309eb7a Mon Sep 17 00:00:00 2001 From: yihuaz Date: Mon, 9 Nov 2020 15:33:53 -0800 Subject: [PATCH] credentials: fix PerRPCCredentials w/RequireTransportSecurity and security levels (#3995) --- credentials/credentials.go | 19 +++--- credentials/credentials_test.go | 28 ++------ credentials/insecure/insecure.go | 12 ++-- credentials/local/local.go | 14 ++-- credentials/local/local_test.go | 32 ++++++--- credentials/oauth/oauth.go | 12 ++-- credentials/sts/sts.go | 3 +- internal/transport/http2_client.go | 19 +++++- test/insecure_creds_test.go | 101 +++++++++++++++++++++++++++-- test/local_creds_test.go | 11 +++- 10 files changed, 182 insertions(+), 69 deletions(-) diff --git a/credentials/credentials.go b/credentials/credentials.go index 02766443ae74..e69562e78786 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -58,9 +58,9 @@ type PerRPCCredentials interface { type SecurityLevel int const ( - // Invalid indicates an invalid security level. + // InvalidSecurityLevel indicates an invalid security level. // The zero SecurityLevel value is invalid for backward compatibility. - Invalid SecurityLevel = iota + InvalidSecurityLevel SecurityLevel = iota // NoSecurity indicates a connection is insecure. NoSecurity // IntegrityOnly indicates a connection only provides integrity protection. @@ -92,7 +92,7 @@ type CommonAuthInfo struct { } // GetCommonAuthInfo returns the pointer to CommonAuthInfo struct. -func (c *CommonAuthInfo) GetCommonAuthInfo() *CommonAuthInfo { +func (c CommonAuthInfo) GetCommonAuthInfo() CommonAuthInfo { return c } @@ -229,17 +229,16 @@ func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo { // or 3) CommonAuthInfo.SecurityLevel has an invalid zero value. For 2) and 3), it is for the purpose of backward-compatibility. // // This API is experimental. -func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error { +func CheckSecurityLevel(ai AuthInfo, level SecurityLevel) error { type internalInfo interface { - GetCommonAuthInfo() *CommonAuthInfo + GetCommonAuthInfo() CommonAuthInfo } - ri, _ := RequestInfoFromContext(ctx) - if ri.AuthInfo == nil { - return errors.New("unable to obtain SecurityLevel from context") + if ai == nil { + return errors.New("AuthInfo is nil") } - if ci, ok := ri.AuthInfo.(internalInfo); ok { + if ci, ok := ai.(internalInfo); ok { // CommonAuthInfo.SecurityLevel has an invalid value. - if ci.GetCommonAuthInfo().SecurityLevel == Invalid { + if ci.GetCommonAuthInfo().SecurityLevel == InvalidSecurityLevel { return nil } if ci.GetCommonAuthInfo().SecurityLevel < level { diff --git a/credentials/credentials_test.go b/credentials/credentials_test.go index dee0f2ca8304..08eb9c430ffc 100644 --- a/credentials/credentials_test.go +++ b/credentials/credentials_test.go @@ -26,7 +26,6 @@ import ( "testing" "time" - "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/testdata" ) @@ -57,17 +56,6 @@ func (ta testAuthInfo) AuthType() string { return "testAuthInfo" } -func createTestContext(s SecurityLevel) context.Context { - auth := &testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: s}} - ri := RequestInfo{ - Method: "testInfo", - AuthInfo: auth, - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - return internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(ctx, ri) -} - func (s) TestCheckSecurityLevel(t *testing.T) { testCases := []struct { authLevel SecurityLevel @@ -90,18 +78,18 @@ func (s) TestCheckSecurityLevel(t *testing.T) { want: true, }, { - authLevel: Invalid, + authLevel: InvalidSecurityLevel, testLevel: IntegrityOnly, want: true, }, { - authLevel: Invalid, + authLevel: InvalidSecurityLevel, testLevel: PrivacyAndIntegrity, want: true, }, } for _, tc := range testCases { - err := CheckSecurityLevel(createTestContext(tc.authLevel), tc.testLevel) + err := CheckSecurityLevel(testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: tc.authLevel}}, tc.testLevel) if tc.want && (err != nil) { t.Fatalf("CheckSeurityLevel(%s, %s) returned failure but want success", tc.authLevel.String(), tc.testLevel.String()) } else if !tc.want && (err == nil) { @@ -112,15 +100,7 @@ func (s) TestCheckSecurityLevel(t *testing.T) { } func (s) TestCheckSecurityLevelNoGetCommonAuthInfoMethod(t *testing.T) { - auth := &testAuthInfoNoGetCommonAuthInfoMethod{} - ri := RequestInfo{ - Method: "testInfo", - AuthInfo: auth, - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(ctx, ri) - if err := CheckSecurityLevel(ctxWithRequestInfo, PrivacyAndIntegrity); err != nil { + if err := CheckSecurityLevel(testAuthInfoNoGetCommonAuthInfoMethod{}, PrivacyAndIntegrity); err != nil { t.Fatalf("CheckSeurityLevel() returned failure but want success") } } diff --git a/credentials/insecure/insecure.go b/credentials/insecure/insecure.go index 7fc11717f765..c4fa27c920da 100644 --- a/credentials/insecure/insecure.go +++ b/credentials/insecure/insecure.go @@ -43,11 +43,11 @@ func NewCredentials() credentials.TransportCredentials { type insecureTC struct{} func (insecureTC) ClientHandshake(ctx context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return conn, Info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil + return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil } func (insecureTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { - return conn, Info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil + return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil } func (insecureTC) Info() credentials.ProtocolInfo { @@ -62,13 +62,13 @@ func (insecureTC) OverrideServerName(string) error { return nil } -// Info contains the auth information for an insecure connection. +// info contains the auth information for an insecure connection. // It implements the AuthInfo interface. -type Info struct { +type info struct { credentials.CommonAuthInfo } -// AuthType returns the type of Info as a string. -func (Info) AuthType() string { +// AuthType returns the type of info as a string. +func (info) AuthType() string { return "insecure" } diff --git a/credentials/local/local.go b/credentials/local/local.go index a9d446ecaa92..f772bc1307b2 100644 --- a/credentials/local/local.go +++ b/credentials/local/local.go @@ -38,14 +38,14 @@ import ( "google.golang.org/grpc/credentials" ) -// Info contains the auth information for a local connection. +// info contains the auth information for a local connection. // It implements the AuthInfo interface. -type Info struct { +type info struct { credentials.CommonAuthInfo } -// AuthType returns the type of Info as a string. -func (Info) AuthType() string { +// AuthType returns the type of info as a string. +func (info) AuthType() string { return "local" } @@ -70,7 +70,7 @@ func getSecurityLevel(network, addr string) (credentials.SecurityLevel, error) { return credentials.PrivacyAndIntegrity, nil // Not a local connection and should fail default: - return credentials.Invalid, fmt.Errorf("local credentials rejected connection to non-local address %q", addr) + return credentials.InvalidSecurityLevel, fmt.Errorf("local credentials rejected connection to non-local address %q", addr) } } @@ -79,7 +79,7 @@ func (*localTC) ClientHandshake(ctx context.Context, authority string, conn net. if err != nil { return nil, nil, err } - return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil + return conn, info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil } func (*localTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { @@ -87,7 +87,7 @@ func (*localTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, if err != nil { return nil, nil, err } - return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil + return conn, info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil } // NewCredentials returns a local credential implementing credentials.TransportCredentials. diff --git a/credentials/local/local_test.go b/credentials/local/local_test.go index 64e2ec3e7fc6..00ae39f07e56 100644 --- a/credentials/local/local_test.go +++ b/credentials/local/local_test.go @@ -65,7 +65,7 @@ func (s) TestGetSecurityLevel(t *testing.T) { { testNetwork: "tcp", testAddr: "192.168.0.1:10000", - want: credentials.Invalid, + want: credentials.InvalidSecurityLevel, }, } for _, tc := range testCases { @@ -78,6 +78,15 @@ func (s) TestGetSecurityLevel(t *testing.T) { type serverHandshake func(net.Conn) (credentials.AuthInfo, error) +func getSecurityLevelFromAuthInfo(ai credentials.AuthInfo) credentials.SecurityLevel { + if c, ok := ai.(interface { + GetCommonAuthInfo() credentials.CommonAuthInfo + }); ok { + return c.GetCommonAuthInfo().SecurityLevel + } + return credentials.InvalidSecurityLevel +} + // Server local handshake implementation. func serverLocalHandshake(conn net.Conn) (credentials.AuthInfo, error) { cred := NewCredentials() @@ -140,21 +149,26 @@ func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, erro defer lis.Close() clientAuthInfo, err := clientHandle(clientLocalHandshake, lis.Addr().Network(), lis.Addr().String()) if err != nil { - return credentials.Invalid, fmt.Errorf("Error at client-side: %v", err) + return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: %v", err) } select { case <-timer.C: - return credentials.Invalid, fmt.Errorf("Test didn't finish in time") + return credentials.InvalidSecurityLevel, fmt.Errorf("Test didn't finish in time") case serverHandleResult := <-done: if serverHandleResult.err != nil { - return credentials.Invalid, fmt.Errorf("Error at server-side: %v", serverHandleResult.err) + return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: %v", serverHandleResult.err) + } + clientSecLevel := getSecurityLevelFromAuthInfo(clientAuthInfo) + serverSecLevel := getSecurityLevelFromAuthInfo(serverHandleResult.authInfo) + + if clientSecLevel == credentials.InvalidSecurityLevel { + return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: client's AuthInfo does not implement GetCommonAuthInfo()") + } + if serverSecLevel == credentials.InvalidSecurityLevel { + return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: server's AuthInfo does not implement GetCommonAuthInfo()") } - clientLocal, _ := clientAuthInfo.(Info) - serverLocal, _ := serverHandleResult.authInfo.(Info) - clientSecLevel := clientLocal.CommonAuthInfo.SecurityLevel - serverSecLevel := serverLocal.CommonAuthInfo.SecurityLevel if clientSecLevel != serverSecLevel { - return credentials.Invalid, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String()) + return credentials.InvalidSecurityLevel, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String()) } return clientSecLevel, nil } diff --git a/credentials/oauth/oauth.go b/credentials/oauth/oauth.go index 6657055d6609..852ae375cfc7 100644 --- a/credentials/oauth/oauth.go +++ b/credentials/oauth/oauth.go @@ -42,7 +42,8 @@ func (ts TokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (ma if err != nil { return nil, err } - if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil { + ri, _ := credentials.RequestInfoFromContext(ctx) + if err = credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil { return nil, fmt.Errorf("unable to transfer TokenSource PerRPCCredentials: %v", err) } return map[string]string{ @@ -84,7 +85,8 @@ func (j jwtAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[s if err != nil { return nil, err } - if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil { + ri, _ := credentials.RequestInfoFromContext(ctx) + if err = credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil { return nil, fmt.Errorf("unable to transfer jwtAccess PerRPCCredentials: %v", err) } return map[string]string{ @@ -107,7 +109,8 @@ func NewOauthAccess(token *oauth2.Token) credentials.PerRPCCredentials { } func (oa oauthAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil { + ri, _ := credentials.RequestInfoFromContext(ctx) + if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil { return nil, fmt.Errorf("unable to transfer oauthAccess PerRPCCredentials: %v", err) } return map[string]string{ @@ -144,7 +147,8 @@ func (s *serviceAccount) GetRequestMetadata(ctx context.Context, uri ...string) return nil, err } } - if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil { + ri, _ := credentials.RequestInfoFromContext(ctx) + if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil { return nil, fmt.Errorf("unable to transfer serviceAccount PerRPCCredentials: %v", err) } return map[string]string{ diff --git a/credentials/sts/sts.go b/credentials/sts/sts.go index f4d58011be86..9285192a8eba 100644 --- a/credentials/sts/sts.go +++ b/credentials/sts/sts.go @@ -151,7 +151,8 @@ type callCreds struct { // GetRequestMetadata returns the cached accessToken, if available and valid, or // fetches a new one by performing an STS token exchange. func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) { - if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil { + ri, _ := credentials.RequestInfoFromContext(ctx) + if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil { return nil, fmt.Errorf("unable to transfer STS PerRPCCredentials: %v", err) } diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 0364df53f868..fef365c0d281 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -234,6 +234,18 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts if err != nil { return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err) } + for _, cd := range perRPCCreds { + if cd.RequireTransportSecurity() { + if ci, ok := authInfo.(interface { + GetCommonAuthInfo() credentials.CommonAuthInfo + }); ok { + secLevel := ci.GetCommonAuthInfo().SecurityLevel + if secLevel != credentials.InvalidSecurityLevel && secLevel < credentials.PrivacyAndIntegrity { + return nil, connectionErrorf(true, nil, "transport: cannot send secure credentials on an insecure connection") + } + } + } + } isSecure = true if transportCreds.Info().SecurityProtocol == "tls" { scheme = "https" @@ -557,8 +569,11 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call // Note: if these credentials are provided both via dial options and call // options, then both sets of credentials will be applied. if callCreds := callHdr.Creds; callCreds != nil { - if !t.isSecure && callCreds.RequireTransportSecurity() { - return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection") + if callCreds.RequireTransportSecurity() { + ri, _ := credentials.RequestInfoFromContext(ctx) + if !t.isSecure || credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity) != nil { + return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection") + } } data, err := callCreds.GetRequestMetadata(ctx, audience) if err != nil { diff --git a/test/insecure_creds_test.go b/test/insecure_creds_test.go index 81d5a5ba5d0b..dd56a8c46ee5 100644 --- a/test/insecure_creds_test.go +++ b/test/insecure_creds_test.go @@ -21,6 +21,7 @@ package test import ( "context" "net" + "strings" "testing" "time" @@ -30,11 +31,32 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" + testpb "google.golang.org/grpc/test/grpc_testing" ) const defaultTestTimeout = 5 * time.Second +// testLegacyPerRPCCredentials is a PerRPCCredentials that has yet incorporated security level. +type testLegacyPerRPCCredentials struct{} + +func (cr testLegacyPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + return nil, nil +} + +func (cr testLegacyPerRPCCredentials) RequireTransportSecurity() bool { + return true +} + +func getSecurityLevel(ai credentials.AuthInfo) credentials.SecurityLevel { + if c, ok := ai.(interface { + GetCommonAuthInfo() credentials.CommonAuthInfo + }); ok { + return c.GetCommonAuthInfo().SecurityLevel + } + return credentials.InvalidSecurityLevel +} + // TestInsecureCreds tests the use of insecure creds on the server and client // side, and verifies that expect security level and auth info are returned. // Also verifies that this credential can interop with existing `WithInsecure` @@ -73,11 +95,11 @@ func (s) TestInsecureCreds(t *testing.T) { return nil, status.Error(codes.DataLoss, "Failed to get peer from ctx") } // Check security level. - info := pr.AuthInfo.(insecure.Info) - if at := info.AuthType(); at != "insecure" { - return nil, status.Errorf(codes.Unauthenticated, "Wrong AuthType: got %q, want insecure", at) + secLevel := getSecurityLevel(pr.AuthInfo) + if secLevel == credentials.InvalidSecurityLevel { + return nil, status.Errorf(codes.Unauthenticated, "peer.AuthInfo does not implement GetCommonAuthInfo()") } - if secLevel := info.CommonAuthInfo.SecurityLevel; secLevel != credentials.NoSecurity { + if secLevel != credentials.NoSecurity { return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel, credentials.NoSecurity) } return &testpb.Empty{}, nil @@ -122,3 +144,74 @@ func (s) TestInsecureCreds(t *testing.T) { }) } } + +func (s) TestInsecureCredsWithPerRPCCredentials(t *testing.T) { + tests := []struct { + desc string + perRPCCredsViaDialOptions bool + perRPCCredsViaCallOptions bool + wantErr string + }{ + { + desc: "send PerRPCCredentials via DialOptions", + perRPCCredsViaDialOptions: true, + perRPCCredsViaCallOptions: false, + wantErr: "context deadline exceeded", + }, + { + desc: "send PerRPCCredentials via CallOptions", + perRPCCredsViaDialOptions: false, + perRPCCredsViaCallOptions: true, + wantErr: "transport: cannot send secure credentials on an insecure connection", + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + ss := &stubServer{ + emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + sOpts := []grpc.ServerOption{} + sOpts = append(sOpts, grpc.Creds(insecure.NewCredentials())) + s := grpc.NewServer(sOpts...) + defer s.Stop() + + testpb.RegisterTestServiceServer(s, ss) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("net.Listen(tcp, localhost:0) failed: %v", err) + } + + go s.Serve(lis) + + addr := lis.Addr().String() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cOpts := []grpc.DialOption{grpc.WithBlock()} + cOpts = append(cOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) + + if test.perRPCCredsViaDialOptions { + cOpts = append(cOpts, grpc.WithPerRPCCredentials(testLegacyPerRPCCredentials{})) + if _, err := grpc.DialContext(ctx, addr, cOpts...); !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_DialOptions = %v; want %s", err, test.wantErr) + } + } + + if test.perRPCCredsViaCallOptions { + cc, err := grpc.DialContext(ctx, addr, cOpts...) + if err != nil { + t.Fatalf("grpc.Dial(%q) failed: %v", addr, err) + } + defer cc.Close() + + c := testpb.NewTestServiceClient(cc) + if _, err = c.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testLegacyPerRPCCredentials{})); !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_CallOptions = %v; want %s", err, test.wantErr) + } + } + }) + } +} diff --git a/test/local_creds_test.go b/test/local_creds_test.go index b55b73bdcbce..b9115f0d5ac8 100644 --- a/test/local_creds_test.go +++ b/test/local_creds_test.go @@ -43,9 +43,16 @@ func testLocalCredsE2ESucceed(network, address string) error { if !ok { return nil, status.Error(codes.DataLoss, "Failed to get peer from ctx") } + type internalInfo interface { + GetCommonAuthInfo() credentials.CommonAuthInfo + } + var secLevel credentials.SecurityLevel + if info, ok := (pr.AuthInfo).(internalInfo); ok { + secLevel = info.GetCommonAuthInfo().SecurityLevel + } else { + return nil, status.Errorf(codes.Unauthenticated, "peer.AuthInfo does not implement GetCommonAuthInfo()") + } // Check security level - info := pr.AuthInfo.(local.Info) - secLevel := info.CommonAuthInfo.SecurityLevel switch network { case "unix": if secLevel != credentials.PrivacyAndIntegrity {