diff --git a/internal/wrr/random.go b/internal/wrr/random.go index ccf5113e9f3..4df9bd8fb9c 100644 --- a/internal/wrr/random.go +++ b/internal/wrr/random.go @@ -19,6 +19,7 @@ package wrr import ( "fmt" + "sort" "sync" "google.golang.org/grpc/internal/grpcrand" @@ -26,8 +27,13 @@ import ( // weightedItem is a wrapped weighted item that is used to implement weighted random algorithm. type weightedItem struct { - Item interface{} - Weight int64 + Item interface{} + // TODO Delete Weight? This field is not necessary for randomWRR to work. + // But without this field, if we want to know an item's weight in randomWRR.Add , we have to + // calculate it (i.e. weight = items.AccumulatedWeight - previousItem.AccumulatedWeight) + // which is a bit less concise than items.Weight + Weight int64 + AccumulatedWeight int64 } func (w *weightedItem) String() string { @@ -38,7 +44,7 @@ func (w *weightedItem) String() string { type randomWRR struct { mu sync.RWMutex items []*weightedItem - sumOfWeights int64 + equalWeights bool } // NewRandom creates a new WRR with random. @@ -47,31 +53,40 @@ func NewRandom() WRR { } var grpcrandInt63n = grpcrand.Int63n +var grpcrandIntn = grpcrand.Intn func (rw *randomWRR) Next() (item interface{}) { rw.mu.RLock() defer rw.mu.RUnlock() - if rw.sumOfWeights == 0 { + sumOfWeights := rw.items[len(rw.items)-1].AccumulatedWeight + if sumOfWeights == 0 { return nil } - // Random number in [0, sum). - randomWeight := grpcrandInt63n(rw.sumOfWeights) - for _, item := range rw.items { - randomWeight = randomWeight - item.Weight - if randomWeight < 0 { - return item.Item - } + if rw.equalWeights { + return rw.items[grpcrandIntn(len(rw.items))].Item } - - return rw.items[len(rw.items)-1].Item + // Random number in [0, sumOfWeights). + randomWeight := grpcrandInt63n(sumOfWeights) + // Item's accumulated weights are in ascending order, because item's weight >= 0. + // Binary search rw.items to find first item whose AccumulatedWeight > randomWeight + // The return i is guaranteed to be in range [0, len(rw.items)) because randomWeight < last item's AccumulatedWeight + i := sort.Search(len(rw.items), func(i int) bool { return rw.items[i].AccumulatedWeight > randomWeight }) + return rw.items[i].Item } func (rw *randomWRR) Add(item interface{}, weight int64) { rw.mu.Lock() defer rw.mu.Unlock() - rItem := &weightedItem{Item: item, Weight: weight} + accumulatedWeight := weight + equalWeights := true + if len(rw.items) > 0 { + lastItem := rw.items[len(rw.items)-1] + accumulatedWeight = lastItem.AccumulatedWeight + weight + equalWeights = rw.equalWeights && weight == lastItem.Weight + } + rw.equalWeights = equalWeights + rItem := &weightedItem{Item: item, Weight: weight, AccumulatedWeight: accumulatedWeight} rw.items = append(rw.items, rItem) - rw.sumOfWeights += weight } func (rw *randomWRR) String() string { diff --git a/internal/wrr/wrr_test.go b/internal/wrr/wrr_test.go index 4565e34ffb9..159e7091c54 100644 --- a/internal/wrr/wrr_test.go +++ b/internal/wrr/wrr_test.go @@ -115,4 +115,5 @@ func (s) TestEdfWrrNext(t *testing.T) { func init() { r := rand.New(rand.NewSource(0)) grpcrandInt63n = r.Int63n + grpcrandIntn = r.Intn }