diff --git a/clientconn.go b/clientconn.go index 1e551a7041a..020cdae57e4 100644 --- a/clientconn.go +++ b/clientconn.go @@ -259,6 +259,8 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * cc.authority = cc.dopts.authority } else if strings.HasPrefix(cc.target, "unix:") { cc.authority = "localhost" + } else if strings.HasPrefix(cc.parsedTarget.Endpoint, ":") { + cc.authority = "localhost" + cc.parsedTarget.Endpoint } else { // Use endpoint from "scheme://authority/endpoint" as the default // authority for ClientConn. diff --git a/test/authority_test.go b/test/authority_test.go index 11c3c4c1af8..af289bc69ed 100644 --- a/test/authority_test.go +++ b/test/authority_test.go @@ -23,6 +23,7 @@ import ( "fmt" "net" "os" + "sync" "testing" "time" @@ -33,27 +34,31 @@ import ( testpb "google.golang.org/grpc/test/grpc_testing" ) +func authorityChecker(ctx context.Context, expectedAuthority string) (*testpb.Empty, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Error(codes.InvalidArgument, "failed to parse metadata") + } + auths, ok := md[":authority"] + if !ok { + return nil, status.Error(codes.InvalidArgument, "no authority header") + } + if len(auths) != 1 { + return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("no authority header, auths = %v", auths)) + } + if auths[0] != expectedAuthority { + return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid authority header %v, expected %v", auths[0], expectedAuthority)) + } + return &testpb.Empty{}, nil +} + func runUnixTest(t *testing.T, address, target, expectedAuthority string, dialer func(context.Context, string) (net.Conn, error)) { if err := os.RemoveAll(address); err != nil { t.Fatalf("Error removing socket file %v: %v\n", address, err) } - us := &stubServer{ - emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return nil, status.Error(codes.InvalidArgument, "failed to parse metadata") - } - auths, ok := md[":authority"] - if !ok { - return nil, status.Error(codes.InvalidArgument, "no authority header") - } - if len(auths) < 1 { - return nil, status.Error(codes.InvalidArgument, "no authority header") - } - if auths[0] != expectedAuthority { - return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid authority header %v, expected %v", auths[0], expectedAuthority)) - } - return &testpb.Empty{}, nil + ss := &stubServer{ + emptyCall: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + return authorityChecker(ctx, expectedAuthority) }, network: "unix", address: address, @@ -63,14 +68,13 @@ func runUnixTest(t *testing.T, address, target, expectedAuthority string, dialer if dialer != nil { opts = append(opts, grpc.WithContextDialer(dialer)) } - if err := us.Start(nil, opts...); err != nil { + if err := ss.Start(nil, opts...); err != nil { t.Fatalf("Error starting endpoint server: %v", err) - return } - defer us.Stop() - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer ss.Stop() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - _, err := us.client.EmptyCall(ctx, &testpb.Empty{}) + _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}) if err != nil { t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err) } @@ -147,3 +151,40 @@ func (s) TestUnixCustomDialer(t *testing.T) { }) } } + +func (s) TestColonPortAuthority(t *testing.T) { + expectedAuthority := "" + var authorityMu sync.Mutex + ss := &stubServer{ + emptyCall: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + authorityMu.Lock() + defer authorityMu.Unlock() + return authorityChecker(ctx, expectedAuthority) + }, + network: "tcp", + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + _, port, err := net.SplitHostPort(ss.address) + if err != nil { + t.Fatalf("Failed splitting host from post: %v", err) + } + authorityMu.Lock() + expectedAuthority = "localhost:" + port + authorityMu.Unlock() + // ss.Start dials, but not the ":[port]" target that is being tested here. + // Dial again, with ":[port]" as the target. + cc, err := grpc.Dial(":"+port, grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%q) = %v", ss.target, err) + } + defer cc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _, err = testpb.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}) + if err != nil { + t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err) + } +}