Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

credentials: fix PerRPCCredentials w/RequireTransportSecurity and security levels #3995

Merged
merged 5 commits into from Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions credentials/credentials.go
Expand Up @@ -92,7 +92,7 @@ type CommonAuthInfo struct {
}

// GetCommonAuthInfo returns the pointer to CommonAuthInfo struct.
func (c *CommonAuthInfo) GetCommonAuthInfo() *CommonAuthInfo {
func (c CommonAuthInfo) GetCommonAuthInfo() CommonAuthInfo {
easwars marked this conversation as resolved.
Show resolved Hide resolved
return c
}

Expand Down Expand Up @@ -231,7 +231,7 @@ func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo {
// This API is experimental.
func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error {
type internalInfo interface {
GetCommonAuthInfo() *CommonAuthInfo
GetCommonAuthInfo() CommonAuthInfo
}
ri, _ := RequestInfoFromContext(ctx)
if ri.AuthInfo == nil {
Expand Down
12 changes: 6 additions & 6 deletions credentials/insecure/insecure.go
Expand Up @@ -43,11 +43,11 @@ func NewCredentials() credentials.TransportCredentials {
type insecureTC struct{}

func (insecureTC) ClientHandshake(ctx context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
}

func (insecureTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
}

func (insecureTC) Info() credentials.ProtocolInfo {
Expand All @@ -62,13 +62,13 @@ func (insecureTC) OverrideServerName(string) error {
return nil
}

// Info contains the auth information for an insecure connection.
// info contains the auth information for an insecure connection.
// It implements the AuthInfo interface.
type Info struct {
type info struct {
credentials.CommonAuthInfo
}

// AuthType returns the type of Info as a string.
func (Info) AuthType() string {
// AuthType returns the type of info as a string.
func (info) AuthType() string {
return "insecure"
}
12 changes: 6 additions & 6 deletions credentials/local/local.go
Expand Up @@ -38,14 +38,14 @@ import (
"google.golang.org/grpc/credentials"
)

// Info contains the auth information for a local connection.
// info contains the auth information for a local connection.
// It implements the AuthInfo interface.
type Info struct {
type info struct {
credentials.CommonAuthInfo
}

// AuthType returns the type of Info as a string.
func (Info) AuthType() string {
// AuthType returns the type of info as a string.
func (info) AuthType() string {
return "local"
}

Expand Down Expand Up @@ -79,15 +79,15 @@ func (*localTC) ClientHandshake(ctx context.Context, authority string, conn net.
if err != nil {
return nil, nil, err
}
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
}

func (*localTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
if err != nil {
return nil, nil, err
}
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
return conn, info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
}

// NewCredentials returns a local credential implementing credentials.TransportCredentials.
Expand Down
18 changes: 14 additions & 4 deletions credentials/local/local_test.go
Expand Up @@ -144,10 +144,20 @@ func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, erro
if serverHandleResult.err != nil {
return credentials.Invalid, fmt.Errorf("Error at server-side: %v", serverHandleResult.err)
}
clientLocal, _ := clientAuthInfo.(Info)
serverLocal, _ := serverHandleResult.authInfo.(Info)
clientSecLevel := clientLocal.CommonAuthInfo.SecurityLevel
serverSecLevel := serverLocal.CommonAuthInfo.SecurityLevel
var clientSecLevel, serverSecLevel credentials.SecurityLevel
type internalInfo interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}
if info, ok := clientAuthInfo.(internalInfo); ok {
clientSecLevel = info.GetCommonAuthInfo().SecurityLevel
} else {
return credentials.Invalid, fmt.Errorf("Error at client-side: client's AuthInfo does not implement GetCommonAuthInfo()")
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say make a helper:

func getSecurityLevel(ai AuthInfo) credentials.SecurityLevel {
	if c, ok := ai.(interface { GetCommonAuthInfo() credentials.CommonAuthInfo }); ok {
		return c.GetCommonAuthInfo().SecurityLevel
	}
	return credentials.InvalidSecurityLevel // Let's rename this enum value; it's not specific enough to be exported
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that's much more concise. Done!

if info, ok := (serverHandleResult.authInfo).(internalInfo); ok {
serverSecLevel = info.GetCommonAuthInfo().SecurityLevel
} else {
return credentials.Invalid, fmt.Errorf("Error at server-side: server's AuthInfo does not implement GetCommonAuthInfo()")
}
if clientSecLevel != serverSecLevel {
return credentials.Invalid, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String())
}
Expand Down
19 changes: 17 additions & 2 deletions internal/transport/http2_client.go
Expand Up @@ -231,9 +231,22 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context)
connectCtx = contextWithHandshakeInfo(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn)
type internalInfo interface {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmmm, maybe CheckSecurityLevel should accept an AuthInfo instead of a Context. Does this cause too many changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, consider it done!

GetCommonAuthInfo() credentials.CommonAuthInfo
}
if err != nil {
return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err)
}
for _, cd := range perRPCCreds {
if cd.RequireTransportSecurity() {
if ci, ok := authInfo.(internalInfo); ok {
secLevel := ci.GetCommonAuthInfo().SecurityLevel
if secLevel != credentials.Invalid && secLevel < credentials.PrivacyAndIntegrity {
return nil, connectionErrorf(true, nil, "transport: cannot send secure credentials on an insecure connection")
}
}
}
}
isSecure = true
if transportCreds.Info().SecurityProtocol == "tls" {
scheme = "https"
Expand Down Expand Up @@ -557,8 +570,10 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call
// Note: if these credentials are provided both via dial options and call
// options, then both sets of credentials will be applied.
if callCreds := callHdr.Creds; callCreds != nil {
if !t.isSecure && callCreds.RequireTransportSecurity() {
return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection")
if callCreds.RequireTransportSecurity() {
if !t.isSecure || credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity) != nil {
return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection")
}
}
data, err := callCreds.GetRequestMetadata(ctx, audience)
if err != nil {
Expand Down
97 changes: 93 additions & 4 deletions test/insecure_creds_test.go
Expand Up @@ -21,6 +21,7 @@ package test
import (
"context"
"net"
"strings"
"testing"
"time"

Expand All @@ -30,11 +31,23 @@ import (
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"

testpb "google.golang.org/grpc/test/grpc_testing"
)

const defaultTestTimeout = 5 * time.Second

// testLegacyPerRPCCredentials is a PerRPCCredentials that has yet incorporated security level.
type testLegacyPerRPCCredentials struct{}

func (cr testLegacyPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return nil, nil
}

func (cr testLegacyPerRPCCredentials) RequireTransportSecurity() bool {
return true
}

// TestInsecureCreds tests the use of insecure creds on the server and client
// side, and verifies that expect security level and auth info are returned.
// Also verifies that this credential can interop with existing `WithInsecure`
Expand Down Expand Up @@ -73,11 +86,16 @@ func (s) TestInsecureCreds(t *testing.T) {
return nil, status.Error(codes.DataLoss, "Failed to get peer from ctx")
}
// Check security level.
info := pr.AuthInfo.(insecure.Info)
if at := info.AuthType(); at != "insecure" {
return nil, status.Errorf(codes.Unauthenticated, "Wrong AuthType: got %q, want insecure", at)
var secLevel credentials.SecurityLevel
type internalInfo interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}
if secLevel := info.CommonAuthInfo.SecurityLevel; secLevel != credentials.NoSecurity {
if info, ok := pr.AuthInfo.(internalInfo); ok {
secLevel = info.GetCommonAuthInfo().SecurityLevel
} else {
return nil, status.Errorf(codes.Unauthenticated, "peer.AuthInfo does not implement GetCommonAuthInfo()")
}
if secLevel != credentials.NoSecurity {
return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel, credentials.NoSecurity)
}
return &testpb.Empty{}, nil
Expand Down Expand Up @@ -122,3 +140,74 @@ func (s) TestInsecureCreds(t *testing.T) {
})
}
}

func (s) TestInsecureCredsWithPerRPCCredentials(t *testing.T) {
tests := []struct {
desc string
perRPCCredsViaDialOptions bool
perRPCCredsViaCallOptions bool
wantErr string
}{
{
desc: "send PerRPCCredentials via DialOptions",
perRPCCredsViaDialOptions: true,
perRPCCredsViaCallOptions: false,
wantErr: "context deadline exceeded",
},
{
desc: "send PerRPCCredentials via CallOptions",
perRPCCredsViaDialOptions: false,
perRPCCredsViaCallOptions: true,
wantErr: "transport: cannot send secure credentials on an insecure connection",
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
ss := &stubServer{
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}

sOpts := []grpc.ServerOption{}
sOpts = append(sOpts, grpc.Creds(insecure.NewCredentials()))
s := grpc.NewServer(sOpts...)
defer s.Stop()

testpb.RegisterTestServiceServer(s, ss)

lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("net.Listen(tcp, localhost:0) failed: %v", err)
}

go s.Serve(lis)

addr := lis.Addr().String()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
cOpts := []grpc.DialOption{grpc.WithBlock()}
cOpts = append(cOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))

if test.perRPCCredsViaDialOptions {
cOpts = append(cOpts, grpc.WithPerRPCCredentials(testLegacyPerRPCCredentials{}))
if _, err := grpc.DialContext(ctx, addr, cOpts...); !strings.Contains(err.Error(), test.wantErr) {
t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_DialOptions = %v; want %s", err, test.wantErr)
}
}

if test.perRPCCredsViaCallOptions {
cc, err := grpc.DialContext(ctx, addr, cOpts...)
if err != nil {
t.Fatalf("grpc.Dial(%q) failed: %v", addr, err)
}
defer cc.Close()

c := testpb.NewTestServiceClient(cc)
if _, err = c.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testLegacyPerRPCCredentials{})); !strings.Contains(err.Error(), test.wantErr) {
t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_CallOptions = %v; want %s", err, test.wantErr)
}
}
})
}
}
11 changes: 9 additions & 2 deletions test/local_creds_test.go
Expand Up @@ -43,9 +43,16 @@ func testLocalCredsE2ESucceed(network, address string) error {
if !ok {
return nil, status.Error(codes.DataLoss, "Failed to get peer from ctx")
}
type internalInfo interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}
var secLevel credentials.SecurityLevel
if info, ok := (pr.AuthInfo).(internalInfo); ok {
secLevel = info.GetCommonAuthInfo().SecurityLevel
} else {
return nil, status.Errorf(codes.Unauthenticated, "peer.AuthInfo does not implement GetCommonAuthInfo()")
}
// Check security level
info := pr.AuthInfo.(local.Info)
secLevel := info.CommonAuthInfo.SecurityLevel
switch network {
case "unix":
if secLevel != credentials.PrivacyAndIntegrity {
Expand Down