Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grpclb: send custom user-agent #4011

Merged
merged 1 commit into from Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions balancer/balancer.go
Expand Up @@ -174,6 +174,10 @@ type BuildOptions struct {
Dialer func(context.Context, string) (net.Conn, error)
// ChannelzParentID is the entity parent's channelz unique identification number.
ChannelzParentID int64
// CustomUserAgent is the custom user agent set on the parent ClientConn.
// The balancer should set the same custom user agent if it creates a
// ClientConn.
CustomUserAgent string
// Target contains the parsed address info of the dial target. It is the same resolver.Target as
// passed to the resolver.
// See the documentation for the resolver.Target type for details about what it contains.
Expand Down
3 changes: 3 additions & 0 deletions balancer/grpclb/grpclb_remote_balancer.go
Expand Up @@ -224,6 +224,9 @@ func (lb *lbBalancer) newRemoteBalancerCCWrapper() {
if lb.opt.Dialer != nil {
dopts = append(dopts, grpc.WithContextDialer(lb.opt.Dialer))
}
if lb.opt.CustomUserAgent != "" {
dopts = append(dopts, grpc.WithUserAgent(lb.opt.CustomUserAgent))
}
// Explicitly set pickfirst as the balancer.
dopts = append(dopts, grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"pick_first"}`))
dopts = append(dopts, grpc.WithResolvers(lb.manualResolver))
Expand Down
52 changes: 34 additions & 18 deletions balancer/grpclb/grpclb_test.go
Expand Up @@ -194,15 +194,18 @@ type remoteBalancer struct {
stats *rpcStats
statsChan chan *lbpb.ClientStats
fbChan chan struct{}

customUserAgent string
}

func newRemoteBalancer(intervals []time.Duration, statsChan chan *lbpb.ClientStats) *remoteBalancer {
func newRemoteBalancer(customUserAgent string, statsChan chan *lbpb.ClientStats) *remoteBalancer {
return &remoteBalancer{
sls: make(chan *lbpb.ServerList, 1),
done: make(chan struct{}),
stats: newRPCStats(),
statsChan: statsChan,
fbChan: make(chan struct{}),
sls: make(chan *lbpb.ServerList, 1),
done: make(chan struct{}),
stats: newRPCStats(),
statsChan: statsChan,
fbChan: make(chan struct{}),
customUserAgent: customUserAgent,
}
}

Expand All @@ -216,6 +219,17 @@ func (b *remoteBalancer) fallbackNow() {
}

func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServer) error {
md, ok := metadata.FromIncomingContext(stream.Context())
if !ok {
return status.Error(codes.Internal, "failed to receive metadata")
}
if b.customUserAgent != "" {
ua := md["user-agent"]
if len(ua) == 0 || !strings.HasPrefix(ua[0], b.customUserAgent) {
return status.Errorf(codes.InvalidArgument, "received unexpected user-agent: %v, want prefix %q", ua, b.customUserAgent)
}
}

req, err := stream.Recv()
if err != nil {
return err
Expand Down Expand Up @@ -333,7 +347,7 @@ type testServers struct {
beListeners []net.Listener
}

func newLoadBalancer(numberOfBackends int, statsChan chan *lbpb.ClientStats) (tss *testServers, cleanup func(), err error) {
func newLoadBalancer(numberOfBackends int, customUserAgent string, statsChan chan *lbpb.ClientStats) (tss *testServers, cleanup func(), err error) {
var (
beListeners []net.Listener
ls *remoteBalancer
Expand Down Expand Up @@ -366,7 +380,7 @@ func newLoadBalancer(numberOfBackends int, statsChan chan *lbpb.ClientStats) (ts
sn: lbServerName,
}
lb = grpc.NewServer(grpc.Creds(lbCreds))
ls = newRemoteBalancer(nil, statsChan)
ls = newRemoteBalancer(customUserAgent, statsChan)
lbgrpc.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
Expand Down Expand Up @@ -398,7 +412,8 @@ var grpclbConfig = `{"loadBalancingConfig": [{"grpclb": {}}]}`
func (s) TestGRPCLB(t *testing.T) {
r := manual.NewBuilderWithScheme("whatever")

tss, cleanup, err := newLoadBalancer(1, nil)
const testUserAgent = "test-user-agent"
tss, cleanup, err := newLoadBalancer(1, testUserAgent, nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand All @@ -419,7 +434,8 @@ func (s) TestGRPCLB(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r),
grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer),
grpc.WithUserAgent(testUserAgent))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
Expand All @@ -445,7 +461,7 @@ func (s) TestGRPCLB(t *testing.T) {
func (s) TestGRPCLBWeighted(t *testing.T) {
r := manual.NewBuilderWithScheme("whatever")

tss, cleanup, err := newLoadBalancer(2, nil)
tss, cleanup, err := newLoadBalancer(2, "", nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down Expand Up @@ -510,7 +526,7 @@ func (s) TestGRPCLBWeighted(t *testing.T) {
func (s) TestDropRequest(t *testing.T) {
r := manual.NewBuilderWithScheme("whatever")

tss, cleanup, err := newLoadBalancer(2, nil)
tss, cleanup, err := newLoadBalancer(2, "", nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down Expand Up @@ -665,7 +681,7 @@ func (s) TestBalancerDisconnects(t *testing.T) {
lbs []*grpc.Server
)
for i := 0; i < 2; i++ {
tss, cleanup, err := newLoadBalancer(1, nil)
tss, cleanup, err := newLoadBalancer(1, "", nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down Expand Up @@ -737,7 +753,7 @@ func (s) TestFallback(t *testing.T) {

r := manual.NewBuilderWithScheme("whatever")

tss, cleanup, err := newLoadBalancer(1, nil)
tss, cleanup, err := newLoadBalancer(1, "", nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down Expand Up @@ -864,7 +880,7 @@ func (s) TestFallback(t *testing.T) {
func (s) TestExplicitFallback(t *testing.T) {
r := manual.NewBuilderWithScheme("whatever")

tss, cleanup, err := newLoadBalancer(1, nil)
tss, cleanup, err := newLoadBalancer(1, "", nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down Expand Up @@ -974,7 +990,7 @@ func (s) TestFallBackWithNoServerAddress(t *testing.T) {
resolveNowCh <- struct{}{}
}

tss, cleanup, err := newLoadBalancer(1, nil)
tss, cleanup, err := newLoadBalancer(1, "", nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down Expand Up @@ -1085,7 +1101,7 @@ func (s) TestFallBackWithNoServerAddress(t *testing.T) {
func (s) TestGRPCLBPickFirst(t *testing.T) {
r := manual.NewBuilderWithScheme("whatever")

tss, cleanup, err := newLoadBalancer(3, nil)
tss, cleanup, err := newLoadBalancer(3, "", nil)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down Expand Up @@ -1239,7 +1255,7 @@ func checkStats(stats, expected *rpcStats) error {
func runAndCheckStats(t *testing.T, drop bool, statsChan chan *lbpb.ClientStats, runRPCs func(*grpc.ClientConn), statsWant *rpcStats) error {
r := manual.NewBuilderWithScheme("whatever")

tss, cleanup, err := newLoadBalancer(1, statsChan)
tss, cleanup, err := newLoadBalancer(1, "", statsChan)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
Expand Down
1 change: 1 addition & 0 deletions clientconn.go
Expand Up @@ -288,6 +288,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
DialCreds: credsClone,
CredsBundle: cc.dopts.copts.CredsBundle,
Dialer: cc.dopts.copts.Dialer,
CustomUserAgent: cc.dopts.copts.UserAgent,
ChannelzParentID: cc.channelzID,
Target: cc.parsedTarget,
}
Expand Down