Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

credentials: fix PerRPCCredentials w/RequireTransportSecurity and security levels #3995

Merged
merged 5 commits into from Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 9 additions & 10 deletions credentials/credentials.go
Expand Up @@ -58,9 +58,9 @@ type PerRPCCredentials interface {
type SecurityLevel int

const (
// Invalid indicates an invalid security level.
// InvalidSecurityLevel indicates an invalid security level.
// The zero SecurityLevel value is invalid for backward compatibility.
Invalid SecurityLevel = iota
InvalidSecurityLevel SecurityLevel = iota
// NoSecurity indicates a connection is insecure.
NoSecurity
// IntegrityOnly indicates a connection only provides integrity protection.
Expand Down Expand Up @@ -92,7 +92,7 @@ type CommonAuthInfo struct {
}

// GetCommonAuthInfo returns the pointer to CommonAuthInfo struct.
func (c *CommonAuthInfo) GetCommonAuthInfo() *CommonAuthInfo {
func (c CommonAuthInfo) GetCommonAuthInfo() CommonAuthInfo {
easwars marked this conversation as resolved.
Show resolved Hide resolved
return c
}

Expand Down Expand Up @@ -229,17 +229,16 @@ func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo {
// or 3) CommonAuthInfo.SecurityLevel has an invalid zero value. For 2) and 3), it is for the purpose of backward-compatibility.
//
// This API is experimental.
func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error {
func CheckSecurityLevel(ai AuthInfo, level SecurityLevel) error {
type internalInfo interface {
GetCommonAuthInfo() *CommonAuthInfo
GetCommonAuthInfo() CommonAuthInfo
}
ri, _ := RequestInfoFromContext(ctx)
if ri.AuthInfo == nil {
return errors.New("unable to obtain SecurityLevel from context")
if ai == nil {
return errors.New("AuthInfo is nil")
}
if ci, ok := ri.AuthInfo.(internalInfo); ok {
if ci, ok := ai.(internalInfo); ok {
// CommonAuthInfo.SecurityLevel has an invalid value.
if ci.GetCommonAuthInfo().SecurityLevel == Invalid {
if ci.GetCommonAuthInfo().SecurityLevel == InvalidSecurityLevel {
return nil
}
if ci.GetCommonAuthInfo().SecurityLevel < level {
Expand Down
28 changes: 4 additions & 24 deletions credentials/credentials_test.go
Expand Up @@ -26,7 +26,6 @@ import (
"testing"
"time"

"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/testdata"
)
Expand Down Expand Up @@ -57,17 +56,6 @@ func (ta testAuthInfo) AuthType() string {
return "testAuthInfo"
}

func createTestContext(s SecurityLevel) context.Context {
auth := &testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: s}}
ri := RequestInfo{
Method: "testInfo",
AuthInfo: auth,
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
return internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(ctx, ri)
}

func (s) TestCheckSecurityLevel(t *testing.T) {
testCases := []struct {
authLevel SecurityLevel
Expand All @@ -90,18 +78,18 @@ func (s) TestCheckSecurityLevel(t *testing.T) {
want: true,
},
{
authLevel: Invalid,
authLevel: InvalidSecurityLevel,
testLevel: IntegrityOnly,
want: true,
},
{
authLevel: Invalid,
authLevel: InvalidSecurityLevel,
testLevel: PrivacyAndIntegrity,
want: true,
},
}
for _, tc := range testCases {
err := CheckSecurityLevel(createTestContext(tc.authLevel), tc.testLevel)
err := CheckSecurityLevel(testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: tc.authLevel}}, tc.testLevel)
if tc.want && (err != nil) {
t.Fatalf("CheckSeurityLevel(%s, %s) returned failure but want success", tc.authLevel.String(), tc.testLevel.String())
} else if !tc.want && (err == nil) {
Expand All @@ -112,15 +100,7 @@ func (s) TestCheckSecurityLevel(t *testing.T) {
}

func (s) TestCheckSecurityLevelNoGetCommonAuthInfoMethod(t *testing.T) {
auth := &testAuthInfoNoGetCommonAuthInfoMethod{}
ri := RequestInfo{
Method: "testInfo",
AuthInfo: auth,
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(ctx, ri)
if err := CheckSecurityLevel(ctxWithRequestInfo, PrivacyAndIntegrity); err != nil {
if err := CheckSecurityLevel(testAuthInfoNoGetCommonAuthInfoMethod{}, PrivacyAndIntegrity); err != nil {
t.Fatalf("CheckSeurityLevel() returned failure but want success")
}
}
Expand Down
12 changes: 6 additions & 6 deletions credentials/insecure/insecure.go
Expand Up @@ -43,11 +43,11 @@ func NewCredentials() credentials.TransportCredentials {
type insecureTC struct{}

func (insecureTC) ClientHandshake(ctx context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
}

func (insecureTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
}

func (insecureTC) Info() credentials.ProtocolInfo {
Expand All @@ -62,13 +62,13 @@ func (insecureTC) OverrideServerName(string) error {
return nil
}

// Info contains the auth information for an insecure connection.
// info contains the auth information for an insecure connection.
// It implements the AuthInfo interface.
type Info struct {
type info struct {
credentials.CommonAuthInfo
}

// AuthType returns the type of Info as a string.
func (Info) AuthType() string {
// AuthType returns the type of info as a string.
func (info) AuthType() string {
return "insecure"
}
14 changes: 7 additions & 7 deletions credentials/local/local.go
Expand Up @@ -38,14 +38,14 @@ import (
"google.golang.org/grpc/credentials"
)

// Info contains the auth information for a local connection.
// info contains the auth information for a local connection.
// It implements the AuthInfo interface.
type Info struct {
type info struct {
credentials.CommonAuthInfo
}

// AuthType returns the type of Info as a string.
func (Info) AuthType() string {
// AuthType returns the type of info as a string.
func (info) AuthType() string {
return "local"
}

Expand All @@ -70,7 +70,7 @@ func getSecurityLevel(network, addr string) (credentials.SecurityLevel, error) {
return credentials.PrivacyAndIntegrity, nil
// Not a local connection and should fail
default:
return credentials.Invalid, fmt.Errorf("local credentials rejected connection to non-local address %q", addr)
return credentials.InvalidSecurityLevel, fmt.Errorf("local credentials rejected connection to non-local address %q", addr)
}
}

Expand All @@ -79,15 +79,15 @@ func (*localTC) ClientHandshake(ctx context.Context, authority string, conn net.
if err != nil {
return nil, nil, err
}
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
}

func (*localTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
if err != nil {
return nil, nil, err
}
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
}

// NewCredentials returns a local credential implementing credentials.TransportCredentials.
Expand Down
32 changes: 23 additions & 9 deletions credentials/local/local_test.go
Expand Up @@ -65,7 +65,7 @@ func (s) TestGetSecurityLevel(t *testing.T) {
{
testNetwork: "tcp",
testAddr: "192.168.0.1:10000",
want: credentials.Invalid,
want: credentials.InvalidSecurityLevel,
},
}
for _, tc := range testCases {
Expand All @@ -78,6 +78,15 @@ func (s) TestGetSecurityLevel(t *testing.T) {

type serverHandshake func(net.Conn) (credentials.AuthInfo, error)

func getSecurityLevelFromAuthInfo(ai credentials.AuthInfo) credentials.SecurityLevel {
if c, ok := ai.(interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}); ok {
return c.GetCommonAuthInfo().SecurityLevel
}
return credentials.InvalidSecurityLevel
}

// Server local handshake implementation.
func serverLocalHandshake(conn net.Conn) (credentials.AuthInfo, error) {
cred := NewCredentials()
Expand Down Expand Up @@ -140,21 +149,26 @@ func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, erro
defer lis.Close()
clientAuthInfo, err := clientHandle(clientLocalHandshake, lis.Addr().Network(), lis.Addr().String())
if err != nil {
return credentials.Invalid, fmt.Errorf("Error at client-side: %v", err)
return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: %v", err)
}
select {
case <-timer.C:
return credentials.Invalid, fmt.Errorf("Test didn't finish in time")
return credentials.InvalidSecurityLevel, fmt.Errorf("Test didn't finish in time")
case serverHandleResult := <-done:
if serverHandleResult.err != nil {
return credentials.Invalid, fmt.Errorf("Error at server-side: %v", serverHandleResult.err)
return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: %v", serverHandleResult.err)
}
clientSecLevel := getSecurityLevelFromAuthInfo(clientAuthInfo)
serverSecLevel := getSecurityLevelFromAuthInfo(serverHandleResult.authInfo)

if clientSecLevel == credentials.InvalidSecurityLevel {
return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: client's AuthInfo does not implement GetCommonAuthInfo()")
}
if serverSecLevel == credentials.InvalidSecurityLevel {
return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: server's AuthInfo does not implement GetCommonAuthInfo()")
}
clientLocal, _ := clientAuthInfo.(Info)
serverLocal, _ := serverHandleResult.authInfo.(Info)
clientSecLevel := clientLocal.CommonAuthInfo.SecurityLevel
serverSecLevel := serverLocal.CommonAuthInfo.SecurityLevel
if clientSecLevel != serverSecLevel {
return credentials.Invalid, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String())
return credentials.InvalidSecurityLevel, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String())
}
return clientSecLevel, nil
}
Expand Down
12 changes: 8 additions & 4 deletions credentials/oauth/oauth.go
Expand Up @@ -42,7 +42,8 @@ func (ts TokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (ma
if err != nil {
return nil, err
}
if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err = credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer TokenSource PerRPCCredentials: %v", err)
}
return map[string]string{
Expand Down Expand Up @@ -84,7 +85,8 @@ func (j jwtAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[s
if err != nil {
return nil, err
}
if err = credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err = credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer jwtAccess PerRPCCredentials: %v", err)
}
return map[string]string{
Expand All @@ -107,7 +109,8 @@ func NewOauthAccess(token *oauth2.Token) credentials.PerRPCCredentials {
}

func (oa oauthAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer oauthAccess PerRPCCredentials: %v", err)
}
return map[string]string{
Expand Down Expand Up @@ -144,7 +147,8 @@ func (s *serviceAccount) GetRequestMetadata(ctx context.Context, uri ...string)
return nil, err
}
}
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer serviceAccount PerRPCCredentials: %v", err)
}
return map[string]string{
Expand Down
3 changes: 2 additions & 1 deletion credentials/sts/sts.go
Expand Up @@ -151,7 +151,8 @@ type callCreds struct {
// GetRequestMetadata returns the cached accessToken, if available and valid, or
// fetches a new one by performing an STS token exchange.
func (c *callCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) {
if err := credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity); err != nil {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer STS PerRPCCredentials: %v", err)
}

Expand Down
19 changes: 17 additions & 2 deletions internal/transport/http2_client.go
Expand Up @@ -234,6 +234,18 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
if err != nil {
return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err)
}
for _, cd := range perRPCCreds {
if cd.RequireTransportSecurity() {
if ci, ok := authInfo.(interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}); ok {
secLevel := ci.GetCommonAuthInfo().SecurityLevel
if secLevel != credentials.InvalidSecurityLevel && secLevel < credentials.PrivacyAndIntegrity {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking we could use CheckSecurityeLevel here, but I see this is different from CheckSecurityLevel because Invalid is not okay here. LGTM, but please move internalInfo so it's defined immediately above to limit its scope (or use it inline).

return nil, connectionErrorf(true, nil, "transport: cannot send secure credentials on an insecure connection")
}
}
}
}
isSecure = true
if transportCreds.Info().SecurityProtocol == "tls" {
scheme = "https"
Expand Down Expand Up @@ -557,8 +569,11 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call
// Note: if these credentials are provided both via dial options and call
// options, then both sets of credentials will be applied.
if callCreds := callHdr.Creds; callCreds != nil {
if !t.isSecure && callCreds.RequireTransportSecurity() {
return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection")
if callCreds.RequireTransportSecurity() {
ri, _ := credentials.RequestInfoFromContext(ctx)
if !t.isSecure || credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity) != nil {
return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection")
}
}
data, err := callCreds.GetRequestMetadata(ctx, audience)
if err != nil {
Expand Down