diff --git a/internal/wrr/random.go b/internal/wrr/random.go index ccf5113e9f3..6d5eb7d4620 100644 --- a/internal/wrr/random.go +++ b/internal/wrr/random.go @@ -19,6 +19,7 @@ package wrr import ( "fmt" + "sort" "sync" "google.golang.org/grpc/internal/grpcrand" @@ -26,8 +27,9 @@ import ( // weightedItem is a wrapped weighted item that is used to implement weighted random algorithm. type weightedItem struct { - Item interface{} - Weight int64 + item interface{} + weight int64 + accumulatedWeight int64 } func (w *weightedItem) String() string { @@ -36,9 +38,10 @@ func (w *weightedItem) String() string { // randomWRR is a struct that contains weighted items implement weighted random algorithm. type randomWRR struct { - mu sync.RWMutex - items []*weightedItem - sumOfWeights int64 + mu sync.RWMutex + items []*weightedItem + // Are all item's weights equal + equalWeights bool } // NewRandom creates a new WRR with random. @@ -51,27 +54,36 @@ var grpcrandInt63n = grpcrand.Int63n func (rw *randomWRR) Next() (item interface{}) { rw.mu.RLock() defer rw.mu.RUnlock() - if rw.sumOfWeights == 0 { + if len(rw.items) == 0 { return nil } - // Random number in [0, sum). - randomWeight := grpcrandInt63n(rw.sumOfWeights) - for _, item := range rw.items { - randomWeight = randomWeight - item.Weight - if randomWeight < 0 { - return item.Item - } + if rw.equalWeights { + return rw.items[grpcrandInt63n(int64(len(rw.items)))].item } - return rw.items[len(rw.items)-1].Item + sumOfWeights := rw.items[len(rw.items)-1].accumulatedWeight + // Random number in [0, sumOfWeights). + randomWeight := grpcrandInt63n(sumOfWeights) + // Item's accumulated weights are in ascending order, because item's weight >= 0. + // Binary search rw.items to find first item whose accumulatedWeight > randomWeight + // The return i is guaranteed to be in range [0, len(rw.items)) because randomWeight < last item's accumulatedWeight + i := sort.Search(len(rw.items), func(i int) bool { return rw.items[i].accumulatedWeight > randomWeight }) + return rw.items[i].item } func (rw *randomWRR) Add(item interface{}, weight int64) { rw.mu.Lock() defer rw.mu.Unlock() - rItem := &weightedItem{Item: item, Weight: weight} + accumulatedWeight := weight + equalWeights := true + if len(rw.items) > 0 { + lastItem := rw.items[len(rw.items)-1] + accumulatedWeight = lastItem.accumulatedWeight + weight + equalWeights = rw.equalWeights && weight == lastItem.weight + } + rw.equalWeights = equalWeights + rItem := &weightedItem{item: item, weight: weight, accumulatedWeight: accumulatedWeight} rw.items = append(rw.items, rItem) - rw.sumOfWeights += weight } func (rw *randomWRR) String() string { diff --git a/internal/wrr/wrr_test.go b/internal/wrr/wrr_test.go index 4565e34ffb9..ce4f5e507a2 100644 --- a/internal/wrr/wrr_test.go +++ b/internal/wrr/wrr_test.go @@ -21,6 +21,7 @@ import ( "errors" "math" "math/rand" + "strconv" "testing" "github.com/google/go-cmp/cmp" @@ -70,12 +71,22 @@ func testWRRNext(t *testing.T, newWRR func() WRR) { name: "17-23-37", weights: []int64{17, 23, 37}, }, + { + name: "no items", + weights: []int64{}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var sumOfWeights int64 - w := newWRR() + if len(tt.weights) == 0 { + if next := w.Next(); next != nil { + t.Fatalf("w.Next returns non nil value:%v when there is no item", next) + } + return + } + + var sumOfWeights int64 for i, weight := range tt.weights { w.Add(i, weight) sumOfWeights += weight @@ -112,6 +123,70 @@ func (s) TestEdfWrrNext(t *testing.T) { testWRRNext(t, NewEDF) } +func BenchmarkRandomWRRNext(b *testing.B) { + for _, n := range []int{100, 500, 1000} { + b.Run("equal-weights-"+strconv.Itoa(n)+"-items", func(b *testing.B) { + w := NewRandom() + sumOfWeights := n + for i := 0; i < n; i++ { + w.Add(i, 1) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + for i := 0; i < sumOfWeights; i++ { + w.Next() + } + } + }) + } + + var maxWeight int64 = 1024 + for _, n := range []int{100, 500, 1000} { + b.Run("random-weights-"+strconv.Itoa(n)+"-items", func(b *testing.B) { + w := NewRandom() + var sumOfWeights int64 + for i := 0; i < n; i++ { + weight := rand.Int63n(maxWeight + 1) + w.Add(i, weight) + sumOfWeights += weight + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + for i := 0; i < int(sumOfWeights); i++ { + w.Next() + } + } + }) + } + + itemsNum := 200 + heavyWeight := int64(itemsNum) + lightWeight := int64(1) + heavyIndices := []int{0, itemsNum / 2, itemsNum - 1} + for _, heavyIndex := range heavyIndices { + b.Run("skew-weights-heavy-index-"+strconv.Itoa(heavyIndex), func(b *testing.B) { + w := NewRandom() + var sumOfWeights int64 + for i := 0; i < itemsNum; i++ { + var weight int64 + if i == heavyIndex { + weight = heavyWeight + } else { + weight = lightWeight + } + sumOfWeights += weight + w.Add(i, weight) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + for i := 0; i < int(sumOfWeights); i++ { + w.Next() + } + } + }) + } +} + func init() { r := rand.New(rand.NewSource(0)) grpcrandInt63n = r.Int63n