Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

credentials: fix PerRPCCredentials w/RequireTransportSecurity and security levels #3995

Merged
merged 5 commits into from Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 7 additions & 8 deletions credentials/credentials.go
Expand Up @@ -58,9 +58,9 @@ type PerRPCCredentials interface {
type SecurityLevel int

const (
// Invalid indicates an invalid security level.
// InvalidSecurityLevel indicates an invalid security level.
// The zero SecurityLevel value is invalid for backward compatibility.
Invalid SecurityLevel = iota
InvalidSecurityLevel SecurityLevel = iota
// NoSecurity indicates a connection is insecure.
NoSecurity
// IntegrityOnly indicates a connection only provides integrity protection.
Expand Down Expand Up @@ -229,17 +229,16 @@ func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo {
// or 3) CommonAuthInfo.SecurityLevel has an invalid zero value. For 2) and 3), it is for the purpose of backward-compatibility.
//
// This API is experimental.
func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error {
func CheckSecurityLevel(ai AuthInfo, level SecurityLevel) error {
type internalInfo interface {
GetCommonAuthInfo() CommonAuthInfo
}
ri, _ := RequestInfoFromContext(ctx)
if ri.AuthInfo == nil {
return errors.New("unable to obtain SecurityLevel from context")
if ai == nil {
return errors.New("AuthInfo is nil")
}
if ci, ok := ri.AuthInfo.(internalInfo); ok {
if ci, ok := ai.(internalInfo); ok {
// CommonAuthInfo.SecurityLevel has an invalid value.
if ci.GetCommonAuthInfo().SecurityLevel == Invalid {
if ci.GetCommonAuthInfo().SecurityLevel == InvalidSecurityLevel {
return nil
}
if ci.GetCommonAuthInfo().SecurityLevel < level {
Expand Down
28 changes: 4 additions & 24 deletions credentials/credentials_test.go
Expand Up @@ -26,7 +26,6 @@ import (
"testing"
"time"

"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/testdata"
)
Expand Down Expand Up @@ -57,17 +56,6 @@ func (ta testAuthInfo) AuthType() string {
return "testAuthInfo"
}

func createTestContext(s SecurityLevel) context.Context {
auth := &testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: s}}
ri := RequestInfo{
Method: "testInfo",
AuthInfo: auth,
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
return internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(ctx, ri)
}

func (s) TestCheckSecurityLevel(t *testing.T) {
testCases := []struct {
authLevel SecurityLevel
Expand All @@ -90,18 +78,18 @@ func (s) TestCheckSecurityLevel(t *testing.T) {
want: true,
},
{
authLevel: Invalid,
authLevel: InvalidSecurityLevel,
testLevel: IntegrityOnly,
want: true,
},
{
authLevel: Invalid,
authLevel: InvalidSecurityLevel,
testLevel: PrivacyAndIntegrity,
want: true,
},
}
for _, tc := range testCases {
err := CheckSecurityLevel(createTestContext(tc.authLevel), tc.testLevel)
err := CheckSecurityLevel(testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: tc.authLevel}}, tc.testLevel)
if tc.want && (err != nil) {
t.Fatalf("CheckSeurityLevel(%s, %s) returned failure but want success", tc.authLevel.String(), tc.testLevel.String())
} else if !tc.want && (err == nil) {
Expand All @@ -112,15 +100,7 @@ func (s) TestCheckSecurityLevel(t *testing.T) {
}

func (s) TestCheckSecurityLevelNoGetCommonAuthInfoMethod(t *testing.T) {
auth := &testAuthInfoNoGetCommonAuthInfoMethod{}
ri := RequestInfo{
Method: "testInfo",
AuthInfo: auth,
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(ctx, ri)
if err := CheckSecurityLevel(ctxWithRequestInfo, PrivacyAndIntegrity); err != nil {
if err := CheckSecurityLevel(testAuthInfoNoGetCommonAuthInfoMethod{}, PrivacyAndIntegrity); err != nil {
t.Fatalf("CheckSeurityLevel() returned failure but want success")
}
}
Expand Down
2 changes: 1 addition & 1 deletion credentials/local/local.go
Expand Up @@ -70,7 +70,7 @@ func getSecurityLevel(network, addr string) (credentials.SecurityLevel, error) {
return credentials.PrivacyAndIntegrity, nil
// Not a local connection and should fail
default:
return credentials.Invalid, fmt.Errorf("local credentials rejected connection to non-local address %q", addr)
return credentials.InvalidSecurityLevel, fmt.Errorf("local credentials rejected connection to non-local address %q", addr)
}
}

Expand Down
38 changes: 21 additions & 17 deletions credentials/local/local_test.go
Expand Up @@ -65,7 +65,7 @@ func (s) TestGetSecurityLevel(t *testing.T) {
{
testNetwork: "tcp",
testAddr: "192.168.0.1:10000",
want: credentials.Invalid,
want: credentials.InvalidSecurityLevel,
},
}
for _, tc := range testCases {
Expand All @@ -78,6 +78,15 @@ func (s) TestGetSecurityLevel(t *testing.T) {

type serverHandshake func(net.Conn) (credentials.AuthInfo, error)

func getSecurityLevelFromAuthInfo(ai credentials.AuthInfo) credentials.SecurityLevel {
if c, ok := ai.(interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}); ok {
return c.GetCommonAuthInfo().SecurityLevel
}
return credentials.InvalidSecurityLevel
}

// Server local handshake implementation.
func serverLocalHandshake(conn net.Conn) (credentials.AuthInfo, error) {
cred := NewCredentials()
Expand Down Expand Up @@ -140,31 +149,26 @@ func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, erro
defer lis.Close()
clientAuthInfo, err := clientHandle(clientLocalHandshake, lis.Addr().Network(), lis.Addr().String())
if err != nil {
return credentials.Invalid, fmt.Errorf("Error at client-side: %v", err)
return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: %v", err)
}
select {
case <-timer.C:
return credentials.Invalid, fmt.Errorf("Test didn't finish in time")
return credentials.InvalidSecurityLevel, fmt.Errorf("Test didn't finish in time")
case serverHandleResult := <-done:
if serverHandleResult.err != nil {
return credentials.Invalid, fmt.Errorf("Error at server-side: %v", serverHandleResult.err)
return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: %v", serverHandleResult.err)
}
var clientSecLevel, serverSecLevel credentials.SecurityLevel
type internalInfo interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}
if info, ok := clientAuthInfo.(internalInfo); ok {
clientSecLevel = info.GetCommonAuthInfo().SecurityLevel
} else {
return credentials.Invalid, fmt.Errorf("Error at client-side: client's AuthInfo does not implement GetCommonAuthInfo()")
clientSecLevel := getSecurityLevelFromAuthInfo(clientAuthInfo)
serverSecLevel := getSecurityLevelFromAuthInfo(serverHandleResult.authInfo)

if clientSecLevel == credentials.InvalidSecurityLevel {
return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: client's AuthInfo does not implement GetCommonAuthInfo()")
}
if info, ok := (serverHandleResult.authInfo).(internalInfo); ok {
serverSecLevel = info.GetCommonAuthInfo().SecurityLevel
} else {
return credentials.Invalid, fmt.Errorf("Error at server-side: server's AuthInfo does not implement GetCommonAuthInfo()")
if serverSecLevel == credentials.InvalidSecurityLevel {
return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: server's AuthInfo does not implement GetCommonAuthInfo()")
}
if clientSecLevel != serverSecLevel {
return credentials.Invalid, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String())
return credentials.InvalidSecurityLevel, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String())
}
return clientSecLevel, nil
}
Expand Down
12 changes: 8 additions & 4 deletions credentials/oauth/oauth.go
Expand Up @@ -42,7 +42,8 @@ func (ts TokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (ma
if err != nil {
return nil, err
}
if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err = credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer TokenSource PerRPCCredentials: %v", err)
}
return map[string]string{
Expand Down Expand Up @@ -84,7 +85,8 @@ func (j jwtAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[s
if err != nil {
return nil, err
}
if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err = credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer jwtAccess PerRPCCredentials: %v", err)
}
return map[string]string{
Expand All @@ -107,7 +109,8 @@ func NewOauthAccess(token *oauth2.Token) credentials.PerRPCCredentials {
}

func (oa oauthAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer oauthAccess PerRPCCredentials: %v", err)
}
return map[string]string{
Expand Down Expand Up @@ -144,7 +147,8 @@ func (s *serviceAccount) GetRequestMetadata(ctx context.Context, uri ...string)
return nil, err
}
}
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer serviceAccount PerRPCCredentials: %v", err)
}
return map[string]string{
Expand Down
3 changes: 2 additions & 1 deletion credentials/sts/sts.go
Expand Up @@ -151,7 +151,8 @@ type callCreds struct {
// GetRequestMetadata returns the cached accessToken, if available and valid, or
// fetches a new one by performing an STS token exchange.
func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) {
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer STS PerRPCCredentials: %v", err)
}

Expand Down
5 changes: 3 additions & 2 deletions internal/transport/http2_client.go
Expand Up @@ -241,7 +241,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
if cd.RequireTransportSecurity() {
if ci, ok := authInfo.(internalInfo); ok {
secLevel := ci.GetCommonAuthInfo().SecurityLevel
if secLevel != credentials.Invalid && secLevel < credentials.PrivacyAndIntegrity {
if secLevel != credentials.InvalidSecurityLevel && secLevel < credentials.PrivacyAndIntegrity {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking we could use CheckSecurityeLevel here, but I see this is different from CheckSecurityLevel because Invalid is not okay here. LGTM, but please move internalInfo so it's defined immediately above to limit its scope (or use it inline).

return nil, connectionErrorf(true, nil, "transport: cannot send secure credentials on an insecure connection")
}
}
Expand Down Expand Up @@ -571,7 +571,8 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call
// options, then both sets of credentials will be applied.
if callCreds := callHdr.Creds; callCreds != nil {
if callCreds.RequireTransportSecurity() {
if !t.isSecure || credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity) != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if !t.isSecure || credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity) != nil {
return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection")
}
}
Expand Down
18 changes: 11 additions & 7 deletions test/insecure_creds_test.go
Expand Up @@ -48,6 +48,15 @@ func (cr testLegacyPerRPCCredentials) RequireTransportSecurity() bool {
return true
}

func getSecurityLevel(ai credentials.AuthInfo) credentials.SecurityLevel {
if c, ok := ai.(interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}); ok {
return c.GetCommonAuthInfo().SecurityLevel
}
return credentials.InvalidSecurityLevel
}

// TestInsecureCreds tests the use of insecure creds on the server and client
// side, and verifies that expect security level and auth info are returned.
// Also verifies that this credential can interop with existing `WithInsecure`
Expand Down Expand Up @@ -86,13 +95,8 @@ func (s) TestInsecureCreds(t *testing.T) {
return nil, status.Error(codes.DataLoss, "Failed to get peer from ctx")
}
// Check security level.
var secLevel credentials.SecurityLevel
type internalInfo interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}
if info, ok := pr.AuthInfo.(internalInfo); ok {
secLevel = info.GetCommonAuthInfo().SecurityLevel
} else {
secLevel := getSecurityLevel(pr.AuthInfo)
if secLevel == credentials.InvalidSecurityLevel {
return nil, status.Errorf(codes.Unauthenticated, "peer.AuthInfo does not implement GetCommonAuthInfo()")
}
if secLevel != credentials.NoSecurity {
Expand Down