Skip to content

Commit

Permalink
grpclb: add target_field to service config (#4847)
Browse files Browse the repository at this point in the history
  • Loading branch information
easwars committed Oct 11, 2021
1 parent 49f6388 commit 6c56e21
Show file tree
Hide file tree
Showing 6 changed files with 450 additions and 309 deletions.
32 changes: 29 additions & 3 deletions balancer/grpclb/grpclb.go
Expand Up @@ -135,6 +135,7 @@ func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) bal

lb := &lbBalancer{
cc: newLBCacheClientConn(cc),
dialTarget: opt.Target.Endpoint,
target: opt.Target.Endpoint,
opt: opt,
fallbackTimeout: b.fallbackTimeout,
Expand Down Expand Up @@ -164,9 +165,10 @@ func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) bal
}

type lbBalancer struct {
cc *lbCacheClientConn
target string
opt balancer.BuildOptions
cc *lbCacheClientConn
dialTarget string // user's dial target
target string // same as dialTarget unless overridden in service config
opt balancer.BuildOptions

usePickFirst bool

Expand Down Expand Up @@ -398,6 +400,30 @@ func (lb *lbBalancer) handleServiceConfig(gc *grpclbServiceConfig) {
lb.mu.Lock()
defer lb.mu.Unlock()

// grpclb uses the user's dial target to populate the `Name` field of the
// `InitialLoadBalanceRequest` message sent to the remote balancer. But when
// grpclb is used a child policy in the context of RLS, we want the `Name`
// field to be populated with the value received from the RLS server. To
// support this use case, an optional "target_name" field has been added to
// the grpclb LB policy's config. If specified, it overrides the name of
// the target to be sent to the remote balancer; if not, the target to be
// sent to the balancer will continue to be obtained from the target URI
// passed to the gRPC client channel. Whenever that target to be sent to the
// balancer is updated, we need to restart the stream to the balancer as
// this target is sent in the first message on the stream.
if gc != nil {
target := lb.dialTarget
if gc.TargetName != "" {
target = gc.TargetName
}
if target != lb.target {
lb.target = target
if lb.ccRemoteLB != nil {
lb.ccRemoteLB.cancelRemoteBalancerCall()
}
}
}

newUsePickFirst := childIsPickFirst(gc)
if lb.usePickFirst == newUsePickFirst {
return
Expand Down
1 change: 1 addition & 0 deletions balancer/grpclb/grpclb_config.go
Expand Up @@ -34,6 +34,7 @@ const (
type grpclbServiceConfig struct {
serviceconfig.LoadBalancingConfig
ChildPolicy *[]map[string]json.RawMessage
TargetName string
}

func (b *lbBuilder) ParseConfig(lbConfig json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
Expand Down
40 changes: 28 additions & 12 deletions balancer/grpclb/grpclb_config_test.go
Expand Up @@ -20,52 +20,68 @@ package grpclb

import (
"encoding/json"
"errors"
"fmt"
"reflect"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/serviceconfig"
)

func (s) TestParse(t *testing.T) {
tests := []struct {
name string
s string
sc string
want serviceconfig.LoadBalancingConfig
wantErr error
wantErr bool
}{
{
name: "empty",
s: "",
sc: "",
want: nil,
wantErr: errors.New("unexpected end of JSON input"),
wantErr: true,
},
{
name: "success1",
s: `{"childPolicy":[{"pick_first":{}}]}`,
sc: `
{
"childPolicy": [
{"pick_first":{}}
],
"targetName": "foo-service"
}`,
want: &grpclbServiceConfig{
ChildPolicy: &[]map[string]json.RawMessage{
{"pick_first": json.RawMessage("{}")},
},
TargetName: "foo-service",
},
},
{
name: "success2",
s: `{"childPolicy":[{"round_robin":{}},{"pick_first":{}}]}`,
sc: `
{
"childPolicy": [
{"round_robin":{}},
{"pick_first":{}}
],
"targetName": "foo-service"
}`,
want: &grpclbServiceConfig{
ChildPolicy: &[]map[string]json.RawMessage{
{"round_robin": json.RawMessage("{}")},
{"pick_first": json.RawMessage("{}")},
},
TargetName: "foo-service",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got, err := (&lbBuilder{}).ParseConfig(json.RawMessage(tt.s)); !reflect.DeepEqual(got, tt.want) || !strings.Contains(fmt.Sprint(err), fmt.Sprint(tt.wantErr)) {
t.Errorf("parseFullServiceConfig() = %+v, %+v, want %+v, <contains %q>", got, err, tt.want, tt.wantErr)
got, err := (&lbBuilder{}).ParseConfig(json.RawMessage(tt.sc))
if (err != nil) != (tt.wantErr) {
t.Fatalf("ParseConfig(%q) returned error: %v, wantErr: %v", tt.sc, err, tt.wantErr)
}
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Fatalf("ParseConfig(%q) returned unexpected difference (-want +got):\n%s", tt.sc, diff)
}
})
}
Expand Down
43 changes: 38 additions & 5 deletions balancer/grpclb/grpclb_remote_balancer.go
Expand Up @@ -206,6 +206,9 @@ type remoteBalancerCCWrapper struct {
backoff backoff.Strategy
done chan struct{}

streamMu sync.Mutex
streamCancel func()

// waitgroup to wait for all goroutines to exit.
wg sync.WaitGroup
}
Expand Down Expand Up @@ -319,10 +322,8 @@ func (ccw *remoteBalancerCCWrapper) sendLoadReport(s *balanceLoadClientStream, i
}
}

func (ccw *remoteBalancerCCWrapper) callRemoteBalancer() (backoff bool, _ error) {
func (ccw *remoteBalancerCCWrapper) callRemoteBalancer(ctx context.Context) (backoff bool, _ error) {
lbClient := &loadBalancerClient{cc: ccw.cc}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := lbClient.BalanceLoad(ctx, grpc.WaitForReady(true))
if err != nil {
return true, fmt.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err)
Expand Down Expand Up @@ -362,11 +363,43 @@ func (ccw *remoteBalancerCCWrapper) callRemoteBalancer() (backoff bool, _ error)
return false, ccw.readServerList(stream)
}

// cancelRemoteBalancerCall cancels the context used by the stream to the remote
// balancer. watchRemoteBalancer() takes care of restarting this call after the
// stream fails.
func (ccw *remoteBalancerCCWrapper) cancelRemoteBalancerCall() {
ccw.streamMu.Lock()
if ccw.streamCancel != nil {
ccw.streamCancel()
ccw.streamCancel = nil
}
ccw.streamMu.Unlock()
}

func (ccw *remoteBalancerCCWrapper) watchRemoteBalancer() {
defer ccw.wg.Done()
defer func() {
ccw.wg.Done()
ccw.streamMu.Lock()
if ccw.streamCancel != nil {
// This is to make sure that we don't leak the context when we are
// directly returning from inside of the below `for` loop.
ccw.streamCancel()
ccw.streamCancel = nil
}
ccw.streamMu.Unlock()
}()

var retryCount int
var ctx context.Context
for {
doBackoff, err := ccw.callRemoteBalancer()
ccw.streamMu.Lock()
if ccw.streamCancel != nil {
ccw.streamCancel()
ccw.streamCancel = nil
}
ctx, ccw.streamCancel = context.WithCancel(context.Background())
ccw.streamMu.Unlock()

doBackoff, err := ccw.callRemoteBalancer(ctx)
select {
case <-ccw.done:
return
Expand Down

0 comments on commit 6c56e21

Please sign in to comment.