diff --git a/balancer/balancer.go b/balancer/balancer.go index 8bf359dbfda..788759bde4b 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -174,6 +174,10 @@ type BuildOptions struct { Dialer func(context.Context, string) (net.Conn, error) // ChannelzParentID is the entity parent's channelz unique identification number. ChannelzParentID int64 + // CustomUserAgent is the custom user agent set on the parent ClientConn. + // The balancer should set the same custom user agent if it creates a + // ClientConn. + CustomUserAgent string // Target contains the parsed address info of the dial target. It is the same resolver.Target as // passed to the resolver. // See the documentation for the resolver.Target type for details about what it contains. diff --git a/balancer/grpclb/grpclb_remote_balancer.go b/balancer/grpclb/grpclb_remote_balancer.go index 8eb45be28e3..8fdda09034f 100644 --- a/balancer/grpclb/grpclb_remote_balancer.go +++ b/balancer/grpclb/grpclb_remote_balancer.go @@ -224,6 +224,9 @@ func (lb *lbBalancer) newRemoteBalancerCCWrapper() { if lb.opt.Dialer != nil { dopts = append(dopts, grpc.WithContextDialer(lb.opt.Dialer)) } + if lb.opt.CustomUserAgent != "" { + dopts = append(dopts, grpc.WithUserAgent(lb.opt.CustomUserAgent)) + } // Explicitly set pickfirst as the balancer. dopts = append(dopts, grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"pick_first"}`)) dopts = append(dopts, grpc.WithResolvers(lb.manualResolver)) diff --git a/balancer/grpclb/grpclb_test.go b/balancer/grpclb/grpclb_test.go index 5413613375a..dcc5235703a 100644 --- a/balancer/grpclb/grpclb_test.go +++ b/balancer/grpclb/grpclb_test.go @@ -194,15 +194,18 @@ type remoteBalancer struct { stats *rpcStats statsChan chan *lbpb.ClientStats fbChan chan struct{} + + customUserAgent string } -func newRemoteBalancer(intervals []time.Duration, statsChan chan *lbpb.ClientStats) *remoteBalancer { +func newRemoteBalancer(customUserAgent string, statsChan chan *lbpb.ClientStats) *remoteBalancer { return &remoteBalancer{ - sls: make(chan *lbpb.ServerList, 1), - done: make(chan struct{}), - stats: newRPCStats(), - statsChan: statsChan, - fbChan: make(chan struct{}), + sls: make(chan *lbpb.ServerList, 1), + done: make(chan struct{}), + stats: newRPCStats(), + statsChan: statsChan, + fbChan: make(chan struct{}), + customUserAgent: customUserAgent, } } @@ -216,6 +219,17 @@ func (b *remoteBalancer) fallbackNow() { } func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServer) error { + md, ok := metadata.FromIncomingContext(stream.Context()) + if !ok { + return status.Error(codes.Internal, "failed to receive metadata") + } + if b.customUserAgent != "" { + ua := md["user-agent"] + if len(ua) == 0 || !strings.HasPrefix(ua[0], b.customUserAgent) { + return status.Errorf(codes.InvalidArgument, "received unexpected user-agent: %v, want prefix %q", ua, b.customUserAgent) + } + } + req, err := stream.Recv() if err != nil { return err @@ -333,7 +347,7 @@ type testServers struct { beListeners []net.Listener } -func newLoadBalancer(numberOfBackends int, statsChan chan *lbpb.ClientStats) (tss *testServers, cleanup func(), err error) { +func newLoadBalancer(numberOfBackends int, customUserAgent string, statsChan chan *lbpb.ClientStats) (tss *testServers, cleanup func(), err error) { var ( beListeners []net.Listener ls *remoteBalancer @@ -366,7 +380,7 @@ func newLoadBalancer(numberOfBackends int, statsChan chan *lbpb.ClientStats) (ts sn: lbServerName, } lb = grpc.NewServer(grpc.Creds(lbCreds)) - ls = newRemoteBalancer(nil, statsChan) + ls = newRemoteBalancer(customUserAgent, statsChan) lbgrpc.RegisterLoadBalancerServer(lb, ls) go func() { lb.Serve(lbLis) @@ -398,7 +412,8 @@ var grpclbConfig = `{"loadBalancingConfig": [{"grpclb": {}}]}` func (s) TestGRPCLB(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(1, nil) + const testUserAgent = "test-user-agent" + tss, cleanup, err := newLoadBalancer(1, testUserAgent, nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -419,7 +434,8 @@ func (s) TestGRPCLB(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer), + grpc.WithUserAgent(testUserAgent)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } @@ -445,7 +461,7 @@ func (s) TestGRPCLB(t *testing.T) { func (s) TestGRPCLBWeighted(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(2, nil) + tss, cleanup, err := newLoadBalancer(2, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -510,7 +526,7 @@ func (s) TestGRPCLBWeighted(t *testing.T) { func (s) TestDropRequest(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(2, nil) + tss, cleanup, err := newLoadBalancer(2, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -665,7 +681,7 @@ func (s) TestBalancerDisconnects(t *testing.T) { lbs []*grpc.Server ) for i := 0; i < 2; i++ { - tss, cleanup, err := newLoadBalancer(1, nil) + tss, cleanup, err := newLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -737,7 +753,7 @@ func (s) TestFallback(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(1, nil) + tss, cleanup, err := newLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -864,7 +880,7 @@ func (s) TestFallback(t *testing.T) { func (s) TestExplicitFallback(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(1, nil) + tss, cleanup, err := newLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -974,7 +990,7 @@ func (s) TestFallBackWithNoServerAddress(t *testing.T) { resolveNowCh <- struct{}{} } - tss, cleanup, err := newLoadBalancer(1, nil) + tss, cleanup, err := newLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -1085,7 +1101,7 @@ func (s) TestFallBackWithNoServerAddress(t *testing.T) { func (s) TestGRPCLBPickFirst(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(3, nil) + tss, cleanup, err := newLoadBalancer(3, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -1239,7 +1255,7 @@ func checkStats(stats, expected *rpcStats) error { func runAndCheckStats(t *testing.T, drop bool, statsChan chan *lbpb.ClientStats, runRPCs func(*grpc.ClientConn), statsWant *rpcStats) error { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(1, statsChan) + tss, cleanup, err := newLoadBalancer(1, "", statsChan) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } diff --git a/clientconn.go b/clientconn.go index 11ed4af607a..1e551a7041a 100644 --- a/clientconn.go +++ b/clientconn.go @@ -288,6 +288,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * DialCreds: credsClone, CredsBundle: cc.dopts.copts.CredsBundle, Dialer: cc.dopts.copts.Dialer, + CustomUserAgent: cc.dopts.copts.UserAgent, ChannelzParentID: cc.channelzID, Target: cc.parsedTarget, }