Skip to content

Commit

Permalink
credentials: fix PerRPCCredentials w/RequireTransportSecurity and sec…
Browse files Browse the repository at this point in the history
…urity levels (grpc#3995)
  • Loading branch information
yihuazhang authored and davidkhala committed Dec 7, 2020
1 parent 511e52d commit 14db5e4
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 69 deletions.
19 changes: 9 additions & 10 deletions credentials/credentials.go
Expand Up @@ -58,9 +58,9 @@ type PerRPCCredentials interface {
type SecurityLevel int

const (
// Invalid indicates an invalid security level.
// InvalidSecurityLevel indicates an invalid security level.
// The zero SecurityLevel value is invalid for backward compatibility.
Invalid SecurityLevel = iota
InvalidSecurityLevel SecurityLevel = iota
// NoSecurity indicates a connection is insecure.
NoSecurity
// IntegrityOnly indicates a connection only provides integrity protection.
Expand Down Expand Up @@ -92,7 +92,7 @@ type CommonAuthInfo struct {
}

// GetCommonAuthInfo returns the pointer to CommonAuthInfo struct.
func (c *CommonAuthInfo) GetCommonAuthInfo() *CommonAuthInfo {
func (c CommonAuthInfo) GetCommonAuthInfo() CommonAuthInfo {
return c
}

Expand Down Expand Up @@ -229,17 +229,16 @@ func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo {
// or 3) CommonAuthInfo.SecurityLevel has an invalid zero value. For 2) and 3), it is for the purpose of backward-compatibility.
//
// This API is experimental.
func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error {
func CheckSecurityLevel(ai AuthInfo, level SecurityLevel) error {
type internalInfo interface {
GetCommonAuthInfo() *CommonAuthInfo
GetCommonAuthInfo() CommonAuthInfo
}
ri, _ := RequestInfoFromContext(ctx)
if ri.AuthInfo == nil {
return errors.New("unable to obtain SecurityLevel from context")
if ai == nil {
return errors.New("AuthInfo is nil")
}
if ci, ok := ri.AuthInfo.(internalInfo); ok {
if ci, ok := ai.(internalInfo); ok {
// CommonAuthInfo.SecurityLevel has an invalid value.
if ci.GetCommonAuthInfo().SecurityLevel == Invalid {
if ci.GetCommonAuthInfo().SecurityLevel == InvalidSecurityLevel {
return nil
}
if ci.GetCommonAuthInfo().SecurityLevel < level {
Expand Down
28 changes: 4 additions & 24 deletions credentials/credentials_test.go
Expand Up @@ -26,7 +26,6 @@ import (
"testing"
"time"

"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/testdata"
)
Expand Down Expand Up @@ -57,17 +56,6 @@ func (ta testAuthInfo) AuthType() string {
return "testAuthInfo"
}

func createTestContext(s SecurityLevel) context.Context {
auth := &testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: s}}
ri := RequestInfo{
Method: "testInfo",
AuthInfo: auth,
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
return internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(ctx, ri)
}

func (s) TestCheckSecurityLevel(t *testing.T) {
testCases := []struct {
authLevel SecurityLevel
Expand All @@ -90,18 +78,18 @@ func (s) TestCheckSecurityLevel(t *testing.T) {
want: true,
},
{
authLevel: Invalid,
authLevel: InvalidSecurityLevel,
testLevel: IntegrityOnly,
want: true,
},
{
authLevel: Invalid,
authLevel: InvalidSecurityLevel,
testLevel: PrivacyAndIntegrity,
want: true,
},
}
for _, tc := range testCases {
err := CheckSecurityLevel(createTestContext(tc.authLevel), tc.testLevel)
err := CheckSecurityLevel(testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: tc.authLevel}}, tc.testLevel)
if tc.want && (err != nil) {
t.Fatalf("CheckSeurityLevel(%s, %s) returned failure but want success", tc.authLevel.String(), tc.testLevel.String())
} else if !tc.want && (err == nil) {
Expand All @@ -112,15 +100,7 @@ func (s) TestCheckSecurityLevel(t *testing.T) {
}

func (s) TestCheckSecurityLevelNoGetCommonAuthInfoMethod(t *testing.T) {
auth := &testAuthInfoNoGetCommonAuthInfoMethod{}
ri := RequestInfo{
Method: "testInfo",
AuthInfo: auth,
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(ctx, ri)
if err := CheckSecurityLevel(ctxWithRequestInfo, PrivacyAndIntegrity); err != nil {
if err := CheckSecurityLevel(testAuthInfoNoGetCommonAuthInfoMethod{}, PrivacyAndIntegrity); err != nil {
t.Fatalf("CheckSeurityLevel() returned failure but want success")
}
}
Expand Down
12 changes: 6 additions & 6 deletions credentials/insecure/insecure.go
Expand Up @@ -43,11 +43,11 @@ func NewCredentials() credentials.TransportCredentials {
type insecureTC struct{}

func (insecureTC) ClientHandshake(ctx context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
}

func (insecureTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
}

func (insecureTC) Info() credentials.ProtocolInfo {
Expand All @@ -62,13 +62,13 @@ func (insecureTC) OverrideServerName(string) error {
return nil
}

// Info contains the auth information for an insecure connection.
// info contains the auth information for an insecure connection.
// It implements the AuthInfo interface.
type Info struct {
type info struct {
credentials.CommonAuthInfo
}

// AuthType returns the type of Info as a string.
func (Info) AuthType() string {
// AuthType returns the type of info as a string.
func (info) AuthType() string {
return "insecure"
}
14 changes: 7 additions & 7 deletions credentials/local/local.go
Expand Up @@ -38,14 +38,14 @@ import (
"google.golang.org/grpc/credentials"
)

// Info contains the auth information for a local connection.
// info contains the auth information for a local connection.
// It implements the AuthInfo interface.
type Info struct {
type info struct {
credentials.CommonAuthInfo
}

// AuthType returns the type of Info as a string.
func (Info) AuthType() string {
// AuthType returns the type of info as a string.
func (info) AuthType() string {
return "local"
}

Expand All @@ -70,7 +70,7 @@ func getSecurityLevel(network, addr string) (credentials.SecurityLevel, error) {
return credentials.PrivacyAndIntegrity, nil
// Not a local connection and should fail
default:
return credentials.Invalid, fmt.Errorf("local credentials rejected connection to non-local address %q", addr)
return credentials.InvalidSecurityLevel, fmt.Errorf("local credentials rejected connection to non-local address %q", addr)
}
}

Expand All @@ -79,15 +79,15 @@ func (*localTC) ClientHandshake(ctx context.Context, authority string, conn net.
if err != nil {
return nil, nil, err
}
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
}

func (*localTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
if err != nil {
return nil, nil, err
}
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
}

// NewCredentials returns a local credential implementing credentials.TransportCredentials.
Expand Down
32 changes: 23 additions & 9 deletions credentials/local/local_test.go
Expand Up @@ -65,7 +65,7 @@ func (s) TestGetSecurityLevel(t *testing.T) {
{
testNetwork: "tcp",
testAddr: "192.168.0.1:10000",
want: credentials.Invalid,
want: credentials.InvalidSecurityLevel,
},
}
for _, tc := range testCases {
Expand All @@ -78,6 +78,15 @@ func (s) TestGetSecurityLevel(t *testing.T) {

type serverHandshake func(net.Conn) (credentials.AuthInfo, error)

func getSecurityLevelFromAuthInfo(ai credentials.AuthInfo) credentials.SecurityLevel {
if c, ok := ai.(interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}); ok {
return c.GetCommonAuthInfo().SecurityLevel
}
return credentials.InvalidSecurityLevel
}

// Server local handshake implementation.
func serverLocalHandshake(conn net.Conn) (credentials.AuthInfo, error) {
cred := NewCredentials()
Expand Down Expand Up @@ -140,21 +149,26 @@ func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, erro
defer lis.Close()
clientAuthInfo, err := clientHandle(clientLocalHandshake, lis.Addr().Network(), lis.Addr().String())
if err != nil {
return credentials.Invalid, fmt.Errorf("Error at client-side: %v", err)
return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: %v", err)
}
select {
case <-timer.C:
return credentials.Invalid, fmt.Errorf("Test didn't finish in time")
return credentials.InvalidSecurityLevel, fmt.Errorf("Test didn't finish in time")
case serverHandleResult := <-done:
if serverHandleResult.err != nil {
return credentials.Invalid, fmt.Errorf("Error at server-side: %v", serverHandleResult.err)
return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: %v", serverHandleResult.err)
}
clientSecLevel := getSecurityLevelFromAuthInfo(clientAuthInfo)
serverSecLevel := getSecurityLevelFromAuthInfo(serverHandleResult.authInfo)

if clientSecLevel == credentials.InvalidSecurityLevel {
return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: client's AuthInfo does not implement GetCommonAuthInfo()")
}
if serverSecLevel == credentials.InvalidSecurityLevel {
return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: server's AuthInfo does not implement GetCommonAuthInfo()")
}
clientLocal, _ := clientAuthInfo.(Info)
serverLocal, _ := serverHandleResult.authInfo.(Info)
clientSecLevel := clientLocal.CommonAuthInfo.SecurityLevel
serverSecLevel := serverLocal.CommonAuthInfo.SecurityLevel
if clientSecLevel != serverSecLevel {
return credentials.Invalid, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String())
return credentials.InvalidSecurityLevel, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String())
}
return clientSecLevel, nil
}
Expand Down
12 changes: 8 additions & 4 deletions credentials/oauth/oauth.go
Expand Up @@ -42,7 +42,8 @@ func (ts TokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (ma
if err != nil {
return nil, err
}
if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err = credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer TokenSource PerRPCCredentials: %v", err)
}
return map[string]string{
Expand Down Expand Up @@ -84,7 +85,8 @@ func (j jwtAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[s
if err != nil {
return nil, err
}
if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err = credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer jwtAccess PerRPCCredentials: %v", err)
}
return map[string]string{
Expand All @@ -107,7 +109,8 @@ func NewOauthAccess(token *oauth2.Token) credentials.PerRPCCredentials {
}

func (oa oauthAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer oauthAccess PerRPCCredentials: %v", err)
}
return map[string]string{
Expand Down Expand Up @@ -144,7 +147,8 @@ func (s *serviceAccount) GetRequestMetadata(ctx context.Context, uri ...string)
return nil, err
}
}
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer serviceAccount PerRPCCredentials: %v", err)
}
return map[string]string{
Expand Down
3 changes: 2 additions & 1 deletion credentials/sts/sts.go
Expand Up @@ -151,7 +151,8 @@ type callCreds struct {
// GetRequestMetadata returns the cached accessToken, if available and valid, or
// fetches a new one by performing an STS token exchange.
func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) {
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer STS PerRPCCredentials: %v", err)
}

Expand Down
19 changes: 17 additions & 2 deletions internal/transport/http2_client.go
Expand Up @@ -234,6 +234,18 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
if err != nil {
return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err)
}
for _, cd := range perRPCCreds {
if cd.RequireTransportSecurity() {
if ci, ok := authInfo.(interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}); ok {
secLevel := ci.GetCommonAuthInfo().SecurityLevel
if secLevel != credentials.InvalidSecurityLevel && secLevel < credentials.PrivacyAndIntegrity {
return nil, connectionErrorf(true, nil, "transport: cannot send secure credentials on an insecure connection")
}
}
}
}
isSecure = true
if transportCreds.Info().SecurityProtocol == "tls" {
scheme = "https"
Expand Down Expand Up @@ -557,8 +569,11 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call
// Note: if these credentials are provided both via dial options and call
// options, then both sets of credentials will be applied.
if callCreds := callHdr.Creds; callCreds != nil {
if !t.isSecure && callCreds.RequireTransportSecurity() {
return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection")
if callCreds.RequireTransportSecurity() {
ri, _ := credentials.RequestInfoFromContext(ctx)
if !t.isSecure || credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity) != nil {
return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection")
}
}
data, err := callCreds.GetRequestMetadata(ctx, audience)
if err != nil {
Expand Down

0 comments on commit 14db5e4

Please sign in to comment.