Skip to content

Commit

Permalink
ringhash: allow setting request hash key explicitly
Browse files Browse the repository at this point in the history
Implement part 1 of A76: the ability to set the request hash key, to extract the
hash from a header. This allows using ring hash without xDS.
  • Loading branch information
atollena committed Apr 26, 2024
1 parent 1e8b9b7 commit 54382d4
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 37 deletions.
8 changes: 8 additions & 0 deletions xds/internal/balancer/ringhash/config.go
Expand Up @@ -23,6 +23,7 @@ import (
"fmt"

"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/metadata"
"google.golang.org/grpc/serviceconfig"
)

Expand All @@ -32,6 +33,8 @@ type LBConfig struct {

MinRingSize uint64 `json:"minRingSize,omitempty"`
MaxRingSize uint64 `json:"maxRingSize,omitempty"`

RequestMetadataKey string `json:"request_metadata_key,omitempty"`
}

const (
Expand Down Expand Up @@ -66,5 +69,10 @@ func parseConfig(c json.RawMessage) (*LBConfig, error) {
if cfg.MaxRingSize > envconfig.RingHashCap {
cfg.MaxRingSize = envconfig.RingHashCap
}
if cfg.RequestMetadataKey != "" {
if err := metadata.ValidatePair(cfg.RequestMetadataKey, ""); err != nil {
return nil, fmt.Errorf("invalid request_metadata_key %q: %w", cfg.RequestMetadataKey, err)
}
}
return &cfg, nil
}
15 changes: 15 additions & 0 deletions xds/internal/balancer/ringhash/config_test.go
Expand Up @@ -94,6 +94,21 @@ func (s) TestParseConfig(t *testing.T) {
want: nil,
wantErr: true,
},
{
name: "request metadata key set",
js: `{"request_metadata_key": "x-foo"}`,
want: &LBConfig{
MinRingSize: defaultMinSize,
MaxRingSize: defaultMaxSize,
RequestMetadataKey: "x-foo",
},
},
{
name: "invalid request metadata keys",
js: `{"request_metadata_key": "!invalid"}`,
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
52 changes: 40 additions & 12 deletions xds/internal/balancer/ringhash/picker.go
Expand Up @@ -25,21 +25,24 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/status"
)

type picker struct {
ring *ring
logger *grpclog.PrefixLogger
subConnStates map[*subConn]connectivity.State
ring *ring
logger *grpclog.PrefixLogger
subConnStates map[*subConn]connectivity.State
requestHashKey string
randuint64 func() uint64 // overridable for testing
}

func newPicker(ring *ring, logger *grpclog.PrefixLogger) *picker {
func newPicker(ring *ring, requestHashKey string, logger *grpclog.PrefixLogger) *picker {
states := make(map[*subConn]connectivity.State)
for _, e := range ring.items {
states[e.sc] = e.sc.effectiveState()
}
return &picker{ring: ring, logger: logger, subConnStates: states}
return &picker{ring: ring, logger: logger, subConnStates: states, requestHashKey: requestHashKey, randuint64: grpcrand.Uint64}
}

// handleRICSResult is the return type of handleRICS. It's needed to wrap the
Expand All @@ -55,16 +58,24 @@ type handleRICSResult struct {
// or Shutdown. TransientFailure will be handled specifically after this
// function returns.
//
// The first return value indicates if the state is in Ready, Idle, Connecting
// The second return value indicates if the state is in Ready, Idle, Connecting
// or Shutdown. If it's true, the PickResult and error should be returned from
// Pick() as is.
func (p *picker) handleRICS(e *ringEntry) (handleRICSResult, bool) {
func (p *picker) handleRICS(e *ringEntry, usingRandomHash bool) (handleRICSResult, bool) {
switch state := p.subConnStates[e.sc]; state {
case connectivity.Ready:
return handleRICSResult{pr: balancer.PickResult{SubConn: e.sc.sc}}, true
case connectivity.Idle:
// Trigger Connect() and queue the pick.
e.sc.queueConnect()
if usingRandomHash {
// "If the use of this random hash triggers a connection attempt
// (...), then before queuing the pick, the picker will scan forward
// searching for a subchannel in `READY` state. If it finds a
// subchannel in `READY` state, the picker returns it." - A76
p, err := p.returnNextReadySubConn(e)
return handleRICSResult{pr: p, err: err}, true
}
return handleRICSResult{err: balancer.ErrNoSubConnAvailable}, true
case connectivity.Connecting:
return handleRICSResult{err: balancer.ErrNoSubConnAvailable}, true
Expand All @@ -84,15 +95,19 @@ func (p *picker) handleRICS(e *ringEntry) (handleRICSResult, bool) {
}

func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
e := p.ring.pick(getRequestHash(info.Ctx))
if hr, ok := p.handleRICS(e); ok {
h, usingRandomHash := getRequestHash(info.Ctx, p.requestHashKey)
if usingRandomHash {
h = p.randuint64()
}
e := p.ring.pick(h)
if hr, ok := p.handleRICS(e, usingRandomHash); ok {
return hr.pr, hr.err
}
// ok was false, the entry is in transient failure.
return p.handleTransientFailure(e)
return p.handleTransientFailure(e, usingRandomHash)
}

func (p *picker) handleTransientFailure(e *ringEntry) (balancer.PickResult, error) {
func (p *picker) handleTransientFailure(e *ringEntry, usingRandomHash bool) (balancer.PickResult, error) {
// Queue a connect on the first picked SubConn.
e.sc.queueConnect()

Expand All @@ -105,7 +120,7 @@ func (p *picker) handleTransientFailure(e *ringEntry) (balancer.PickResult, erro

// For the second SubConn, also check Ready/Idle/Connecting as if it's the
// first entry.
if hr, ok := p.handleRICS(e2); ok {
if hr, ok := p.handleRICS(e2, usingRandomHash); ok {
return hr.pr, hr.err
}

Expand Down Expand Up @@ -148,6 +163,19 @@ func (p *picker) handleTransientFailure(e *ringEntry) (balancer.PickResult, erro
return balancer.PickResult{}, fmt.Errorf("no connection is Ready")
}

// returnNextReadySubConn returns the first entry after e that has its
// subconn in READY state. If no such entry is found, it returns
// balancer.ErrNoSubConnAvailable.
func (p *picker) returnNextReadySubConn(e *ringEntry) (balancer.PickResult, error) {
for i := range p.ring.items {
e := p.ring.items[(e.idx+i)%len(p.ring.items)]
if e.sc.state == connectivity.Ready {
return balancer.PickResult{SubConn: e.sc.sc}, nil
}
}
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}

// nextSkippingDuplicates finds the next entry in the ring, with a different
// subconn from the given entry.
func nextSkippingDuplicates(ring *ring, entry *ringEntry) *ringEntry {
Expand Down
69 changes: 60 additions & 9 deletions xds/internal/balancer/ringhash/picker_test.go
Expand Up @@ -30,6 +30,7 @@ import (
"google.golang.org/grpc/grpclog"
igrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/metadata"
)

var testSubConns []*testutils.TestSubConn
Expand Down Expand Up @@ -57,7 +58,7 @@ func newTestRing(cStats []connectivity.State) *ring {
return &ring{items: items}
}

func (s) TestPickerPickFirstTwo(t *testing.T) {
func (s) TestXdsPickerPickFirstTwo(t *testing.T) {
tests := []struct {
name string
ring *ring
Expand Down Expand Up @@ -107,9 +108,9 @@ func (s) TestPickerPickFirstTwo(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := newPicker(tt.ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test"))
p := newPicker(tt.ring, "", igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test"))
got, err := p.Pick(balancer.PickInfo{
Ctx: SetRequestHash(context.Background(), tt.hash),
Ctx: SetXDSRequestHash(context.Background(), tt.hash),
})
if err != tt.wantErr {
t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr)
Expand All @@ -129,6 +130,56 @@ func (s) TestPickerPickFirstTwo(t *testing.T) {
}
}

// TestPickerWithRequestHashKey tests that if an explicit request hash key is
// set, it will be used to pick a SubConn.
func (s) TestPickerWithRequestHashKey(t *testing.T) {
tests := []struct {
name string
values []string
ring *ring
wantSC balancer.SubConn
wantErr error
}{
{
name: "hash key is not set, pick the first ready SubConn",
ring: newTestRing([]connectivity.State{connectivity.Idle, connectivity.TransientFailure, connectivity.Connecting, connectivity.Ready}),
wantSC: testSubConns[3],
},
{
name: "hash key is not set, no subchannel ready",
ring: newTestRing([]connectivity.State{connectivity.Idle, connectivity.TransientFailure, connectivity.Connecting, connectivity.Shutdown}),
wantErr: balancer.ErrNoSubConnAvailable,
},
{
name: "hash key is set to a single value, connect and queue the pick",
values: []string{"test-value"}, // this hashes to the end of the test ring => endpoint 1 expected.
ring: newTestRing([]connectivity.State{connectivity.Idle, connectivity.TransientFailure, connectivity.Connecting, connectivity.Ready}),
wantErr: balancer.ErrNoSubConnAvailable,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
requestHashKey := "test-key"
ring := tt.ring
p := newPicker(ring, requestHashKey, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test"))
p.randuint64 = func() uint64 { return 5 }

md := metadata.New(nil)
md.Set("test-key", tt.values...)
got, err := p.Pick(balancer.PickInfo{
Ctx: metadata.NewOutgoingContext(context.Background(), md),
})
if err != tt.wantErr {
t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got.SubConn != tt.wantSC {
t.Errorf("Pick() got = %v, want picked SubConn: %v", got, tt.wantSC)
}
})
}
}

// TestPickerPickTriggerTFConnect covers that if the picked SubConn is
// TransientFailures, all SubConns until a non-TransientFailure are queued for
// Connect().
Expand All @@ -137,8 +188,8 @@ func (s) TestPickerPickTriggerTFConnect(t *testing.T) {
connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure,
connectivity.Idle, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure,
})
p := newPicker(ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test"))
_, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)})
p := newPicker(ring, "", igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test"))
_, err := p.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(context.Background(), 5)})
if err == nil {
t.Fatalf("Pick() error = %v, want non-nil", err)
}
Expand Down Expand Up @@ -167,8 +218,8 @@ func (s) TestPickerPickTriggerTFReturnReady(t *testing.T) {
ring := newTestRing([]connectivity.State{
connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.Ready,
})
p := newPicker(ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test"))
pr, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)})
p := newPicker(ring, "", igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test"))
pr, err := p.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(context.Background(), 5)})
if err != nil {
t.Fatalf("Pick() error = %v, want nil", err)
}
Expand All @@ -193,8 +244,8 @@ func (s) TestPickerPickTriggerTFWithIdle(t *testing.T) {
ring := newTestRing([]connectivity.State{
connectivity.TransientFailure, connectivity.TransientFailure, connectivity.Idle, connectivity.TransientFailure, connectivity.TransientFailure,
})
p := newPicker(ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test"))
_, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)})
p := newPicker(ring, "", igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test"))
_, err := p.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(context.Background(), 5)})
if err == balancer.ErrNoSubConnAvailable {
t.Fatalf("Pick() error = %v, want %v", err, balancer.ErrNoSubConnAvailable)
}
Expand Down
2 changes: 1 addition & 1 deletion xds/internal/balancer/ringhash/ringhash.go
Expand Up @@ -434,7 +434,7 @@ func (b *ringhashBalancer) regeneratePicker() {
b.picker = base.NewErrPicker(b.mergeErrors())
return
}
b.picker = newPicker(b.ring, b.logger)
b.picker = newPicker(b.ring, b.config.RequestMetadataKey, b.logger)
}

func (b *ringhashBalancer) Close() {
Expand Down
48 changes: 47 additions & 1 deletion xds/internal/balancer/ringhash/ringhash_test.go
Expand Up @@ -63,7 +63,7 @@ func init() {
}

func ctxWithHash(h uint64) context.Context {
return SetRequestHash(context.Background(), h)
return SetXDSRequestHash(context.Background(), h)
}

// setupTest creates the balancer, and does an initial sanity check.
Expand Down Expand Up @@ -480,6 +480,52 @@ func (s) TestSubConnToConnectWhenOverallTransientFailure(t *testing.T) {
}
}

// TestRequestHashKey tests the case where the ringhash balancer receives a
// new picker when the request hash key changes.
func (s) TestRequestHashKeyChanged(t *testing.T) {
wantAddrs := []resolver.Address{
{Addr: testBackendAddrStrs[0]},
{Addr: testBackendAddrStrs[1]},
{Addr: testBackendAddrStrs[2]},
}
cc, b, p0 := setupTest(t, wantAddrs)
ring0 := p0.(*picker).ring

if err := b.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{Addresses: wantAddrs},
BalancerConfig: &LBConfig{
RequestMetadataKey: "test-key",
},
}); err != nil {
t.Fatalf("UpdateClientConnState returned err: %v", err)
}
var p1 balancer.Picker
select {
case p1 = <-cc.NewPickerCh:
case <-time.After(defaultTestTimeout):
t.Fatalf("timeout waiting for picker after UpdateClientConn with different addresses")
}
ring1 := p1.(*picker).ring
if ring1 == ring0 {
t.Fatalf("new picker after changing request hash key has the same ring as before, want different")
}

// Same config, there be no new picker.
if err := b.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{Addresses: wantAddrs},
BalancerConfig: &LBConfig{
RequestMetadataKey: "test-key",
},
}); err != nil {
t.Fatalf("UpdateClientConnState returned err: %v", err)
}
select {
case <-cc.NewPickerCh:
t.Fatalf("unexpected picker after UpdateClientConn with the same addresses")
case <-time.After(defaultTestShortTimeout):
}
}

func (s) TestConnectivityStateEvaluatorRecordTransition(t *testing.T) {
tests := []struct {
name string
Expand Down

0 comments on commit 54382d4

Please sign in to comment.