From 54382d46ccf49b165cfbac699a6b569f56cc207e Mon Sep 17 00:00:00 2001 From: Antoine Tollenaere Date: Fri, 26 Apr 2024 16:01:10 +0200 Subject: [PATCH] ringhash: allow setting request hash key explicitly Implement part 1 of A76: the ability to set the request hash key, to extract the hash from a header. This allows using ring hash without xDS. --- xds/internal/balancer/ringhash/config.go | 8 +++ xds/internal/balancer/ringhash/config_test.go | 15 ++++ xds/internal/balancer/ringhash/picker.go | 52 ++++++++++---- xds/internal/balancer/ringhash/picker_test.go | 69 ++++++++++++++++--- xds/internal/balancer/ringhash/ringhash.go | 2 +- .../balancer/ringhash/ringhash_test.go | 48 ++++++++++++- xds/internal/balancer/ringhash/util.go | 46 +++++++++---- xds/internal/balancer/ringhash/util_test.go | 65 +++++++++++++++++ xds/internal/resolver/serviceconfig.go | 2 +- xds/internal/resolver/xds_resolver_test.go | 2 +- 10 files changed, 272 insertions(+), 37 deletions(-) create mode 100644 xds/internal/balancer/ringhash/util_test.go diff --git a/xds/internal/balancer/ringhash/config.go b/xds/internal/balancer/ringhash/config.go index b4afcf10013..23368fe97f6 100644 --- a/xds/internal/balancer/ringhash/config.go +++ b/xds/internal/balancer/ringhash/config.go @@ -23,6 +23,7 @@ import ( "fmt" "google.golang.org/grpc/internal/envconfig" + "google.golang.org/grpc/internal/metadata" "google.golang.org/grpc/serviceconfig" ) @@ -32,6 +33,8 @@ type LBConfig struct { MinRingSize uint64 `json:"minRingSize,omitempty"` MaxRingSize uint64 `json:"maxRingSize,omitempty"` + + RequestMetadataKey string `json:"request_metadata_key,omitempty"` } const ( @@ -66,5 +69,10 @@ func parseConfig(c json.RawMessage) (*LBConfig, error) { if cfg.MaxRingSize > envconfig.RingHashCap { cfg.MaxRingSize = envconfig.RingHashCap } + if cfg.RequestMetadataKey != "" { + if err := metadata.ValidatePair(cfg.RequestMetadataKey, ""); err != nil { + return nil, fmt.Errorf("invalid request_metadata_key %q: %w", cfg.RequestMetadataKey, err) + } + } return &cfg, nil } diff --git a/xds/internal/balancer/ringhash/config_test.go b/xds/internal/balancer/ringhash/config_test.go index 1077d3e7daf..107163b827f 100644 --- a/xds/internal/balancer/ringhash/config_test.go +++ b/xds/internal/balancer/ringhash/config_test.go @@ -94,6 +94,21 @@ func (s) TestParseConfig(t *testing.T) { want: nil, wantErr: true, }, + { + name: "request metadata key set", + js: `{"request_metadata_key": "x-foo"}`, + want: &LBConfig{ + MinRingSize: defaultMinSize, + MaxRingSize: defaultMaxSize, + RequestMetadataKey: "x-foo", + }, + }, + { + name: "invalid request metadata keys", + js: `{"request_metadata_key": "!invalid"}`, + want: nil, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/xds/internal/balancer/ringhash/picker.go b/xds/internal/balancer/ringhash/picker.go index b450716fa0f..7446721333b 100644 --- a/xds/internal/balancer/ringhash/picker.go +++ b/xds/internal/balancer/ringhash/picker.go @@ -25,21 +25,24 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/status" ) type picker struct { - ring *ring - logger *grpclog.PrefixLogger - subConnStates map[*subConn]connectivity.State + ring *ring + logger *grpclog.PrefixLogger + subConnStates map[*subConn]connectivity.State + requestHashKey string + randuint64 func() uint64 // overridable for testing } -func newPicker(ring *ring, logger *grpclog.PrefixLogger) *picker { +func newPicker(ring *ring, requestHashKey string, logger *grpclog.PrefixLogger) *picker { states := make(map[*subConn]connectivity.State) for _, e := range ring.items { states[e.sc] = e.sc.effectiveState() } - return &picker{ring: ring, logger: logger, subConnStates: states} + return &picker{ring: ring, logger: logger, subConnStates: states, requestHashKey: requestHashKey, randuint64: grpcrand.Uint64} } // handleRICSResult is the return type of handleRICS. It's needed to wrap the @@ -55,16 +58,24 @@ type handleRICSResult struct { // or Shutdown. TransientFailure will be handled specifically after this // function returns. // -// The first return value indicates if the state is in Ready, Idle, Connecting +// The second return value indicates if the state is in Ready, Idle, Connecting // or Shutdown. If it's true, the PickResult and error should be returned from // Pick() as is. -func (p *picker) handleRICS(e *ringEntry) (handleRICSResult, bool) { +func (p *picker) handleRICS(e *ringEntry, usingRandomHash bool) (handleRICSResult, bool) { switch state := p.subConnStates[e.sc]; state { case connectivity.Ready: return handleRICSResult{pr: balancer.PickResult{SubConn: e.sc.sc}}, true case connectivity.Idle: // Trigger Connect() and queue the pick. e.sc.queueConnect() + if usingRandomHash { + // "If the use of this random hash triggers a connection attempt + // (...), then before queuing the pick, the picker will scan forward + // searching for a subchannel in `READY` state. If it finds a + // subchannel in `READY` state, the picker returns it." - A76 + p, err := p.returnNextReadySubConn(e) + return handleRICSResult{pr: p, err: err}, true + } return handleRICSResult{err: balancer.ErrNoSubConnAvailable}, true case connectivity.Connecting: return handleRICSResult{err: balancer.ErrNoSubConnAvailable}, true @@ -84,15 +95,19 @@ func (p *picker) handleRICS(e *ringEntry) (handleRICSResult, bool) { } func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { - e := p.ring.pick(getRequestHash(info.Ctx)) - if hr, ok := p.handleRICS(e); ok { + h, usingRandomHash := getRequestHash(info.Ctx, p.requestHashKey) + if usingRandomHash { + h = p.randuint64() + } + e := p.ring.pick(h) + if hr, ok := p.handleRICS(e, usingRandomHash); ok { return hr.pr, hr.err } // ok was false, the entry is in transient failure. - return p.handleTransientFailure(e) + return p.handleTransientFailure(e, usingRandomHash) } -func (p *picker) handleTransientFailure(e *ringEntry) (balancer.PickResult, error) { +func (p *picker) handleTransientFailure(e *ringEntry, usingRandomHash bool) (balancer.PickResult, error) { // Queue a connect on the first picked SubConn. e.sc.queueConnect() @@ -105,7 +120,7 @@ func (p *picker) handleTransientFailure(e *ringEntry) (balancer.PickResult, erro // For the second SubConn, also check Ready/Idle/Connecting as if it's the // first entry. - if hr, ok := p.handleRICS(e2); ok { + if hr, ok := p.handleRICS(e2, usingRandomHash); ok { return hr.pr, hr.err } @@ -148,6 +163,19 @@ func (p *picker) handleTransientFailure(e *ringEntry) (balancer.PickResult, erro return balancer.PickResult{}, fmt.Errorf("no connection is Ready") } +// returnNextReadySubConn returns the first entry after e that has its +// subconn in READY state. If no such entry is found, it returns +// balancer.ErrNoSubConnAvailable. +func (p *picker) returnNextReadySubConn(e *ringEntry) (balancer.PickResult, error) { + for i := range p.ring.items { + e := p.ring.items[(e.idx+i)%len(p.ring.items)] + if e.sc.state == connectivity.Ready { + return balancer.PickResult{SubConn: e.sc.sc}, nil + } + } + return balancer.PickResult{}, balancer.ErrNoSubConnAvailable +} + // nextSkippingDuplicates finds the next entry in the ring, with a different // subconn from the given entry. func nextSkippingDuplicates(ring *ring, entry *ringEntry) *ringEntry { diff --git a/xds/internal/balancer/ringhash/picker_test.go b/xds/internal/balancer/ringhash/picker_test.go index f1dbaf2e5ed..a60b5a06bd4 100644 --- a/xds/internal/balancer/ringhash/picker_test.go +++ b/xds/internal/balancer/ringhash/picker_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc/grpclog" igrpclog "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/metadata" ) var testSubConns []*testutils.TestSubConn @@ -57,7 +58,7 @@ func newTestRing(cStats []connectivity.State) *ring { return &ring{items: items} } -func (s) TestPickerPickFirstTwo(t *testing.T) { +func (s) TestXdsPickerPickFirstTwo(t *testing.T) { tests := []struct { name string ring *ring @@ -107,9 +108,9 @@ func (s) TestPickerPickFirstTwo(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p := newPicker(tt.ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test")) + p := newPicker(tt.ring, "", igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test")) got, err := p.Pick(balancer.PickInfo{ - Ctx: SetRequestHash(context.Background(), tt.hash), + Ctx: SetXDSRequestHash(context.Background(), tt.hash), }) if err != tt.wantErr { t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr) @@ -129,6 +130,56 @@ func (s) TestPickerPickFirstTwo(t *testing.T) { } } +// TestPickerWithRequestHashKey tests that if an explicit request hash key is +// set, it will be used to pick a SubConn. +func (s) TestPickerWithRequestHashKey(t *testing.T) { + tests := []struct { + name string + values []string + ring *ring + wantSC balancer.SubConn + wantErr error + }{ + { + name: "hash key is not set, pick the first ready SubConn", + ring: newTestRing([]connectivity.State{connectivity.Idle, connectivity.TransientFailure, connectivity.Connecting, connectivity.Ready}), + wantSC: testSubConns[3], + }, + { + name: "hash key is not set, no subchannel ready", + ring: newTestRing([]connectivity.State{connectivity.Idle, connectivity.TransientFailure, connectivity.Connecting, connectivity.Shutdown}), + wantErr: balancer.ErrNoSubConnAvailable, + }, + { + name: "hash key is set to a single value, connect and queue the pick", + values: []string{"test-value"}, // this hashes to the end of the test ring => endpoint 1 expected. + ring: newTestRing([]connectivity.State{connectivity.Idle, connectivity.TransientFailure, connectivity.Connecting, connectivity.Ready}), + wantErr: balancer.ErrNoSubConnAvailable, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + requestHashKey := "test-key" + ring := tt.ring + p := newPicker(ring, requestHashKey, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test")) + p.randuint64 = func() uint64 { return 5 } + + md := metadata.New(nil) + md.Set("test-key", tt.values...) + got, err := p.Pick(balancer.PickInfo{ + Ctx: metadata.NewOutgoingContext(context.Background(), md), + }) + if err != tt.wantErr { + t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got.SubConn != tt.wantSC { + t.Errorf("Pick() got = %v, want picked SubConn: %v", got, tt.wantSC) + } + }) + } +} + // TestPickerPickTriggerTFConnect covers that if the picked SubConn is // TransientFailures, all SubConns until a non-TransientFailure are queued for // Connect(). @@ -137,8 +188,8 @@ func (s) TestPickerPickTriggerTFConnect(t *testing.T) { connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.Idle, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, }) - p := newPicker(ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test")) - _, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)}) + p := newPicker(ring, "", igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test")) + _, err := p.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(context.Background(), 5)}) if err == nil { t.Fatalf("Pick() error = %v, want non-nil", err) } @@ -167,8 +218,8 @@ func (s) TestPickerPickTriggerTFReturnReady(t *testing.T) { ring := newTestRing([]connectivity.State{ connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.Ready, }) - p := newPicker(ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test")) - pr, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)}) + p := newPicker(ring, "", igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test")) + pr, err := p.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(context.Background(), 5)}) if err != nil { t.Fatalf("Pick() error = %v, want nil", err) } @@ -193,8 +244,8 @@ func (s) TestPickerPickTriggerTFWithIdle(t *testing.T) { ring := newTestRing([]connectivity.State{ connectivity.TransientFailure, connectivity.TransientFailure, connectivity.Idle, connectivity.TransientFailure, connectivity.TransientFailure, }) - p := newPicker(ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test")) - _, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)}) + p := newPicker(ring, "", igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test")) + _, err := p.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(context.Background(), 5)}) if err == balancer.ErrNoSubConnAvailable { t.Fatalf("Pick() error = %v, want %v", err, balancer.ErrNoSubConnAvailable) } diff --git a/xds/internal/balancer/ringhash/ringhash.go b/xds/internal/balancer/ringhash/ringhash.go index e63c6f65390..d465b3f7ed3 100644 --- a/xds/internal/balancer/ringhash/ringhash.go +++ b/xds/internal/balancer/ringhash/ringhash.go @@ -434,7 +434,7 @@ func (b *ringhashBalancer) regeneratePicker() { b.picker = base.NewErrPicker(b.mergeErrors()) return } - b.picker = newPicker(b.ring, b.logger) + b.picker = newPicker(b.ring, b.config.RequestMetadataKey, b.logger) } func (b *ringhashBalancer) Close() { diff --git a/xds/internal/balancer/ringhash/ringhash_test.go b/xds/internal/balancer/ringhash/ringhash_test.go index a1edfe5d228..633690bd556 100644 --- a/xds/internal/balancer/ringhash/ringhash_test.go +++ b/xds/internal/balancer/ringhash/ringhash_test.go @@ -63,7 +63,7 @@ func init() { } func ctxWithHash(h uint64) context.Context { - return SetRequestHash(context.Background(), h) + return SetXDSRequestHash(context.Background(), h) } // setupTest creates the balancer, and does an initial sanity check. @@ -480,6 +480,52 @@ func (s) TestSubConnToConnectWhenOverallTransientFailure(t *testing.T) { } } +// TestRequestHashKey tests the case where the ringhash balancer receives a +// new picker when the request hash key changes. +func (s) TestRequestHashKeyChanged(t *testing.T) { + wantAddrs := []resolver.Address{ + {Addr: testBackendAddrStrs[0]}, + {Addr: testBackendAddrStrs[1]}, + {Addr: testBackendAddrStrs[2]}, + } + cc, b, p0 := setupTest(t, wantAddrs) + ring0 := p0.(*picker).ring + + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: wantAddrs}, + BalancerConfig: &LBConfig{ + RequestMetadataKey: "test-key", + }, + }); err != nil { + t.Fatalf("UpdateClientConnState returned err: %v", err) + } + var p1 balancer.Picker + select { + case p1 = <-cc.NewPickerCh: + case <-time.After(defaultTestTimeout): + t.Fatalf("timeout waiting for picker after UpdateClientConn with different addresses") + } + ring1 := p1.(*picker).ring + if ring1 == ring0 { + t.Fatalf("new picker after changing request hash key has the same ring as before, want different") + } + + // Same config, there be no new picker. + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: wantAddrs}, + BalancerConfig: &LBConfig{ + RequestMetadataKey: "test-key", + }, + }); err != nil { + t.Fatalf("UpdateClientConnState returned err: %v", err) + } + select { + case <-cc.NewPickerCh: + t.Fatalf("unexpected picker after UpdateClientConn with the same addresses") + case <-time.After(defaultTestShortTimeout): + } +} + func (s) TestConnectivityStateEvaluatorRecordTransition(t *testing.T) { tests := []struct { name string diff --git a/xds/internal/balancer/ringhash/util.go b/xds/internal/balancer/ringhash/util.go index 92bb3ae5b79..617e935d850 100644 --- a/xds/internal/balancer/ringhash/util.go +++ b/xds/internal/balancer/ringhash/util.go @@ -18,23 +18,45 @@ package ringhash -import "context" +import ( + "context" + "strings" -type clusterKey struct{} + "github.com/cespare/xxhash/v2" + "google.golang.org/grpc/metadata" +) -func getRequestHash(ctx context.Context) uint64 { - requestHash, _ := ctx.Value(clusterKey{}).(uint64) - return requestHash +type xdsHashKey struct{} + +// getRequestHash returns the request hash to use for this pick, and whether +// a random hash was used. +func getRequestHash(ctx context.Context, requestMetadataKey string) (uint64, bool) { + if requestMetadataKey == "" { + // No explicit request metadata key, use the hash set by the xDS + // resolver. + requestHash, _ := ctx.Value(xdsHashKey{}).(uint64) + return requestHash, false + } + md, _ := metadata.FromOutgoingContext(ctx) + values := md.Get(requestMetadataKey) + if len(values) == 0 || len(values) == 1 && values[0] == "" { + // If the header is not present, generate a random hash. + return 0, true + } + joinedValues := strings.Join(values, ",") + return xxhash.Sum64String(joinedValues), false } -// GetRequestHashForTesting returns the request hash in the context; to be used +// GetXDSRequestHashForTesting returns the request hash in the context; to be used // for testing only. -func GetRequestHashForTesting(ctx context.Context) uint64 { - return getRequestHash(ctx) +func GetXDSRequestHashForTesting(ctx context.Context) uint64 { + // for xDS the random hash is never generated in the picker. + h, _ := getRequestHash(ctx, "") + return h } -// SetRequestHash adds the request hash to the context for use in Ring Hash Load -// Balancing. -func SetRequestHash(ctx context.Context, requestHash uint64) context.Context { - return context.WithValue(ctx, clusterKey{}, requestHash) +// SetXDSRequestHash adds the request hash to the context for use in Ring Hash +// Load Balancing using xDS route hash_policy. +func SetXDSRequestHash(ctx context.Context, requestHash uint64) context.Context { + return context.WithValue(ctx, xdsHashKey{}, requestHash) } diff --git a/xds/internal/balancer/ringhash/util_test.go b/xds/internal/balancer/ringhash/util_test.go new file mode 100644 index 00000000000..5202752a828 --- /dev/null +++ b/xds/internal/balancer/ringhash/util_test.go @@ -0,0 +1,65 @@ +package ringhash + +import ( + "context" + "testing" + + "github.com/cespare/xxhash/v2" + "google.golang.org/grpc/metadata" +) + +func (s) TestGetRequestHash(t *testing.T) { + tests := []struct { + name string + requestMetadataKey string + xdsValue uint64 + explicitValue []string + wantHash uint64 + wantRandom bool + }{ + { + name: "xds hash", + xdsValue: 123, + wantHash: 123, + }, + { + name: "explicit key, no value", + requestMetadataKey: "test-key", + wantRandom: true, + }, + { + name: "explicit key, emtpy value", + requestMetadataKey: "test-key", + explicitValue: []string{""}, + wantRandom: true, + }, + { + name: "explicit key, non empty value", + requestMetadataKey: "test-key", + explicitValue: []string{"test-value"}, + wantHash: xxhash.Sum64String("test-value"), + }, + { + name: "explicit key, multiple values", + requestMetadataKey: "test-key", + explicitValue: []string{"test-value", "test-value-2"}, + wantHash: xxhash.Sum64String("test-value,test-value-2"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + if tt.explicitValue != nil { + ctx = metadata.NewOutgoingContext(context.Background(), metadata.MD{"test-key": tt.explicitValue}) + } + if tt.xdsValue != 0 { + ctx = SetXDSRequestHash(context.Background(), tt.xdsValue) + } + gotHash, gotRandom := getRequestHash(ctx, tt.requestMetadataKey) + + if gotHash != tt.wantHash || gotRandom != tt.wantRandom { + t.Errorf("getRequestHash(%v) = (%v, %v), want (%v, %v)", tt.explicitValue, gotRandom, gotHash, tt.wantRandom, tt.wantHash) + } + }) + } +} diff --git a/xds/internal/resolver/serviceconfig.go b/xds/internal/resolver/serviceconfig.go index 88cb1d2a1fd..960e087a15e 100644 --- a/xds/internal/resolver/serviceconfig.go +++ b/xds/internal/resolver/serviceconfig.go @@ -171,7 +171,7 @@ func (cs *configSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RP } lbCtx := clustermanager.SetPickedCluster(rpcInfo.Context, cluster.name) - lbCtx = ringhash.SetRequestHash(lbCtx, cs.generateHash(rpcInfo, rt.hashPolicies)) + lbCtx = ringhash.SetXDSRequestHash(lbCtx, cs.generateHash(rpcInfo, rt.hashPolicies)) config := &iresolver.RPCConfig{ // Communicate to the LB policy the chosen cluster and request hash, if Ring Hash LB policy. diff --git a/xds/internal/resolver/xds_resolver_test.go b/xds/internal/resolver/xds_resolver_test.go index 1a0ca4ed5b4..1b18da6332b 100644 --- a/xds/internal/resolver/xds_resolver_test.go +++ b/xds/internal/resolver/xds_resolver_test.go @@ -488,7 +488,7 @@ func (s) TestResolverRequestHash(t *testing.T) { if err != nil { t.Fatalf("cs.SelectConfig(): %v", err) } - gotHash := ringhash.GetRequestHashForTesting(res.Context) + gotHash := ringhash.GetXDSRequestHashForTesting(res.Context) wantHash := xxhash.Sum64String("/products") if gotHash != wantHash { t.Fatalf("Got request hash: %v, want: %v", gotHash, wantHash)