Skip to content

Commit

Permalink
grpclb: send custom user-agent (grpc#4011)
Browse files Browse the repository at this point in the history
  • Loading branch information
menghanl authored and davidkhala committed Dec 7, 2020
1 parent 532eb39 commit 511e52d
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 18 deletions.
4 changes: 4 additions & 0 deletions balancer/balancer.go
Expand Up @@ -174,6 +174,10 @@ type BuildOptions struct {
Dialer func(context.Context, string) (net.Conn, error)
// ChannelzParentID is the entity parent's channelz unique identification number.
ChannelzParentID int64
// CustomUserAgent is the custom user agent set on the parent ClientConn.
// The balancer should set the same custom user agent if it creates a
// ClientConn.
CustomUserAgent string
// Target contains the parsed address info of the dial target. It is the same resolver.Target as
// passed to the resolver.
// See the documentation for the resolver.Target type for details about what it contains.
Expand Down
3 changes: 3 additions & 0 deletions balancer/grpclb/grpclb_remote_balancer.go
Expand Up @@ -224,6 +224,9 @@ func (lb *lbBalancer) newRemoteBalancerCCWrapper() {
if lb.opt.Dialer != nil {
dopts = append(dopts, grpc.WithContextDialer(lb.opt.Dialer))
}
if lb.opt.CustomUserAgent != "" {
dopts = append(dopts, grpc.WithUserAgent(lb.opt.CustomUserAgent))
}
// Explicitly set pickfirst as the balancer.
dopts = append(dopts, grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"pick_first"}`))
dopts = append(dopts, grpc.WithResolvers(lb.manualResolver))
Expand Down
52 changes: 34 additions & 18 deletions balancer/grpclb/grpclb_test.go
Expand Up @@ -194,15 +194,18 @@ type remoteBalancer struct {
stats *rpcStats
statsChan chan *lbpb.ClientStats
fbChan chan struct{}

customUserAgent string
}

func newRemoteBalancer(intervals []time.Duration, statsChan chan *lbpb.ClientStats) *remoteBalancer {
func newRemoteBalancer(customUserAgent string, statsChan chan *lbpb.ClientStats) *remoteBalancer {
return &remoteBalancer{
sls: make(chan *lbpb.ServerList, 1),
done: make(chan struct{}),
stats: newRPCStats(),
statsChan: statsChan,
fbChan: make(chan struct{}),
sls: make(chan *lbpb.ServerList, 1),
done: make(chan struct{}),
stats: newRPCStats(),
statsChan: statsChan,
fbChan: make(chan struct{}),
customUserAgent: customUserAgent,
}
}

Expand All @@ -216,6 +219,17 @@ func (b *remoteBalancer) fallbackNow() {
}

func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServer) error {
md, ok := metadata.FromIncomingContext(stream.Context())
if !ok {
return status.Error(codes.Internal, "failed to receive metadata")
}
if b.customUserAgent != "" {
ua := md["user-agent"]
if len(ua) == 0 || !strings.HasPrefix(ua[0], b.customUserAgent) {
return status.Errorf(codes.InvalidArgument, "received unexpected user-agent: %v, want prefix %q", ua, b.customUserAgent)
}
}

req, err := stream.Recv()
if err != nil {
return err
Expand Down Expand Up @@ -333,7 +347,7 @@ type testServers struct {
beListeners []net.Listener
}

func newLoadBalancer(numberOfBackends int, statsChan chan *lbpb.ClientStats) (tss *testServers, cleanup func(), err error) {
func newLoadBalancer(numberOfBackends int, customUserAgent string, statsChan chan *lbpb.ClientStats) (tss *testServers, cleanup func(), err error) {
var (
beListeners []net.Listener
ls *remoteBalancer
Expand Down Expand Up @@ -366,7 +380,7 @@ func newLoadBalancer(numberOfBackends int, statsChan chan *lbpb.ClientStats) (ts
sn: lbServerName,
}
lb = grpc.NewServer(grpc.Creds(lbCreds))
ls = newRemoteBalancer(nil, statsChan)
ls = newRemoteBalancer(customUserAgent, statsChan)
lbgrpc.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
Expand Down Expand Up @@ -398,7 +412,8 @@ var grpclbConfig = `{"loadBalancingConfig": [{"grpclb": {}}]}`
func (s) TestGRPCLB(t *testing.T) {
r := manual.NewBuilderWithScheme("whatever")

tss, cleanup, err := newLoadBalancer(1, nil)
const testUserAgent = "test-user-agent"
tss, cleanup, err := newLoadBalancer(1, testUserAgent, nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand All @@ -419,7 +434,8 @@ func (s) TestGRPCLB(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r),
grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer),
grpc.WithUserAgent(testUserAgent))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
Expand All @@ -445,7 +461,7 @@ func (s) TestGRPCLB(t *testing.T) {
func (s) TestGRPCLBWeighted(t *testing.T) {
r := manual.NewBuilderWithScheme("whatever")

tss, cleanup, err := newLoadBalancer(2, nil)
tss, cleanup, err := newLoadBalancer(2, "", nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down Expand Up @@ -510,7 +526,7 @@ func (s) TestGRPCLBWeighted(t *testing.T) {
func (s) TestDropRequest(t *testing.T) {
r := manual.NewBuilderWithScheme("whatever")

tss, cleanup, err := newLoadBalancer(2, nil)
tss, cleanup, err := newLoadBalancer(2, "", nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down Expand Up @@ -665,7 +681,7 @@ func (s) TestBalancerDisconnects(t *testing.T) {
lbs []*grpc.Server
)
for i := 0; i < 2; i++ {
tss, cleanup, err := newLoadBalancer(1, nil)
tss, cleanup, err := newLoadBalancer(1, "", nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down Expand Up @@ -737,7 +753,7 @@ func (s) TestFallback(t *testing.T) {

r := manual.NewBuilderWithScheme("whatever")

tss, cleanup, err := newLoadBalancer(1, nil)
tss, cleanup, err := newLoadBalancer(1, "", nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down Expand Up @@ -864,7 +880,7 @@ func (s) TestFallback(t *testing.T) {
func (s) TestExplicitFallback(t *testing.T) {
r := manual.NewBuilderWithScheme("whatever")

tss, cleanup, err := newLoadBalancer(1, nil)
tss, cleanup, err := newLoadBalancer(1, "", nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down Expand Up @@ -974,7 +990,7 @@ func (s) TestFallBackWithNoServerAddress(t *testing.T) {
resolveNowCh <- struct{}{}
}

tss, cleanup, err := newLoadBalancer(1, nil)
tss, cleanup, err := newLoadBalancer(1, "", nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down Expand Up @@ -1085,7 +1101,7 @@ func (s) TestFallBackWithNoServerAddress(t *testing.T) {
func (s) TestGRPCLBPickFirst(t *testing.T) {
r := manual.NewBuilderWithScheme("whatever")

tss, cleanup, err := newLoadBalancer(3, nil)
tss, cleanup, err := newLoadBalancer(3, "", nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down Expand Up @@ -1239,7 +1255,7 @@ func checkStats(stats, expected *rpcStats) error {
func runAndCheckStats(t *testing.T, drop bool, statsChan chan *lbpb.ClientStats, runRPCs func(*grpc.ClientConn), statsWant *rpcStats) error {
r := manual.NewBuilderWithScheme("whatever")

tss, cleanup, err := newLoadBalancer(1, statsChan)
tss, cleanup, err := newLoadBalancer(1, "", statsChan)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down
1 change: 1 addition & 0 deletions clientconn.go
Expand Up @@ -288,6 +288,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
DialCreds: credsClone,
CredsBundle: cc.dopts.copts.CredsBundle,
Dialer: cc.dopts.copts.Dialer,
CustomUserAgent: cc.dopts.copts.UserAgent,
ChannelzParentID: cc.channelzID,
Target: cc.parsedTarget,
}
Expand Down

0 comments on commit 511e52d

Please sign in to comment.