Skip to content

Commit

Permalink
add per-connection and per-rpc check
Browse files Browse the repository at this point in the history
  • Loading branch information
yihuazhang committed Oct 30, 2020
1 parent f4d9cca commit 386b266
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 4 deletions.
4 changes: 2 additions & 2 deletions credentials/credentials.go
Expand Up @@ -92,7 +92,7 @@ type CommonAuthInfo struct {
}

// GetCommonAuthInfo returns the pointer to CommonAuthInfo struct.
func (c *CommonAuthInfo) GetCommonAuthInfo() *CommonAuthInfo {
func (c CommonAuthInfo) GetCommonAuthInfo() CommonAuthInfo {
return c
}

Expand Down Expand Up @@ -231,7 +231,7 @@ func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo {
// This API is experimental.
func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error {
type internalInfo interface {
GetCommonAuthInfo() *CommonAuthInfo
GetCommonAuthInfo() CommonAuthInfo
}
ri, _ := RequestInfoFromContext(ctx)
if ri.AuthInfo == nil {
Expand Down
19 changes: 17 additions & 2 deletions internal/transport/http2_client.go
Expand Up @@ -234,6 +234,19 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
if err != nil {
return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err)
}
type internalInfo interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}
if ci, ok := authInfo.(internalInfo); ok {
level := ci.GetCommonAuthInfo().SecurityLevel
for _, cd := range perRPCCreds {
if cd.RequireTransportSecurity() {
if level != credentials.Invalid && level < credentials.PrivacyAndIntegrity {
return nil, connectionErrorf(true, nil, "transport: cannot send secure credentials on an insecure connection")
}
}
}
}
isSecure = true
if transportCreds.Info().SecurityProtocol == "tls" {
scheme = "https"
Expand Down Expand Up @@ -557,8 +570,10 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call
// Note: if these credentials are provided both via dial options and call
// options, then both sets of credentials will be applied.
if callCreds := callHdr.Creds; callCreds != nil {
if !t.isSecure && callCreds.RequireTransportSecurity() {
return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection")
if callCreds.RequireTransportSecurity() {
if !t.isSecure || credentials.CheckSecurityLevel(ctx, credentials.PrivacyAndIntegrity) != nil {
return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection")
}
}
data, err := callCreds.GetRequestMetadata(ctx, audience)
if err != nil {
Expand Down
85 changes: 85 additions & 0 deletions test/insecure_creds_test.go
Expand Up @@ -21,6 +21,7 @@ package test
import (
"context"
"net"
"strings"
"testing"
"time"

Expand All @@ -30,11 +31,27 @@ import (
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"

testpb "google.golang.org/grpc/test/grpc_testing"
)

const defaultTestTimeout = 5 * time.Second

// testLegacyPerRPCCredentials is a PerRPCCredentials that has yet incorporated security level.
type testLegacyPerRPCCredentials struct{}

func (cr testLegacyPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return nil, nil
}

func (cr testLegacyPerRPCCredentials) RequireTransportSecurity() bool {
return true
}

func isExpectedError(got, want error) bool {
return status.Code(got) == status.Code(want) && strings.Contains(status.Convert(got).Message(), status.Convert(want).Message())
}

// TestInsecureCreds tests the use of insecure creds on the server and client
// side, and verifies that expect security level and auth info are returned.
// Also verifies that this credential can interop with existing `WithInsecure`
Expand Down Expand Up @@ -122,3 +139,71 @@ func (s) TestInsecureCreds(t *testing.T) {
})
}
}

func (s) TestInsecureCredsWithPerRPCCredentials(t *testing.T) {
tests := []struct {
desc string
perRPCCredsViaDialOptions bool
perRPCCredsViaCallOptions bool
wantErr string
}{
{
desc: "send PerRPCCredentials via DialOptions",
perRPCCredsViaDialOptions: true,
perRPCCredsViaCallOptions: false,
wantErr: "context deadline exceeded",
},
{
desc: "send PerRPCCredentials via CallOptions",
perRPCCredsViaDialOptions: false,
perRPCCredsViaCallOptions: true,
wantErr: "transport: cannot send secure credentials on an insecure connection",
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
ss := &stubServer{
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
},
}

sOpts := []grpc.ServerOption{}
sOpts = append(sOpts, grpc.Creds(insecure.NewCredentials()))
s := grpc.NewServer(sOpts...)
defer s.Stop()

testpb.RegisterTestServiceServer(s, ss)

lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("net.Listen(tcp, localhost:0) failed: %v", err)
}

go s.Serve(lis)

addr := lis.Addr().String()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
cOpts := []grpc.DialOption{grpc.WithBlock()}
cOpts = append(cOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))

if test.perRPCCredsViaDialOptions {
cOpts = append(cOpts, grpc.WithPerRPCCredentials(testLegacyPerRPCCredentials{}))
if _, err := grpc.DialContext(ctx, addr, cOpts...); !strings.Contains(err.Error(), test.wantErr) {
t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_DialOptions = %v; want %s", err, test.wantErr)
}
}

if test.perRPCCredsViaCallOptions {
cc, err := grpc.DialContext(ctx, addr, cOpts...)
defer cc.Close()

c := testpb.NewTestServiceClient(cc)
if _, err = c.EmptyCall(ctx, &testpb.Empty{}, grpc.PerRPCCredentials(testLegacyPerRPCCredentials{})); !strings.Contains(err.Error(), test.wantErr) {
t.Fatalf("InsecureCredsWithPerRPCCredentials/send_PerRPCCredentials_via_CallOptions = %v; want %s", err, test.wantErr)
}
}
})
}
}

0 comments on commit 386b266

Please sign in to comment.