Skip to content

Commit

Permalink
support ZINTER
Browse files Browse the repository at this point in the history
  • Loading branch information
alicebob committed Oct 11, 2023
1 parent dda3a52 commit feeb189
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 119 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@

### v2.16.1

- fix ZINTERSTORE with wets (thanks @lingjl2010 and @okhowang)
- fix ZINTERSTORE with sets (thanks @lingjl2010 and @okhowang)
- fix exclusive ranges in XRANGE (thanks @joseotoro)


Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ Implemented commands:
- ZCARD
- ZCOUNT
- ZINCRBY
- ZINTER
- ZINTERSTORE
- ZLEXCOUNT
- ZPOPMIN
Expand Down
291 changes: 174 additions & 117 deletions cmd_sorted_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ func commandsSortedSet(m *Miniredis) {
m.srv.Register("ZCARD", m.cmdZcard)
m.srv.Register("ZCOUNT", m.cmdZcount)
m.srv.Register("ZINCRBY", m.cmdZincrby)
m.srv.Register("ZINTERSTORE", m.cmdZinterstore)
m.srv.Register("ZINTER", m.makeCmdZinter(false))
m.srv.Register("ZINTERSTORE", m.makeCmdZinter(true))
m.srv.Register("ZLEXCOUNT", m.cmdZlexcount)
m.srv.Register("ZRANGE", m.cmdZrange)
m.srv.Register("ZRANGEBYLEX", m.makeCmdZrangebylex(false))
Expand Down Expand Up @@ -324,145 +325,192 @@ func (m *Miniredis) cmdZincrby(c *server.Peer, cmd string, args []string) {
})
}

// ZINTERSTORE
func (m *Miniredis) cmdZinterstore(c *server.Peer, cmd string, args []string) {
if len(args) < 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
// ZINTERSTORE and ZINTER
func (m *Miniredis) makeCmdZinter(store bool) func(c *server.Peer, cmd string, args []string) {
return func(c *server.Peer, cmd string, args []string) {
minArgs := 2
if store {
minArgs++
}
if len(args) < minArgs {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}

destination := args[0]
numKeys, err := strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
args = args[2:]
if len(args) < numKeys {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
if numKeys <= 0 {
setDirty(c)
c.WriteError("ERR at least 1 input key is needed for ZUNIONSTORE/ZINTERSTORE")
return
}
keys := args[:numKeys]
args = args[numKeys:]
var opts = struct {
Store bool // if true this is ZINTERSTORE
Destination string // only relevant if $store is true
Keys []string
Aggregate string
WithWeights bool
Weights []float64
WithScores bool // only for ZINTER
}{
Store: store,
Aggregate: "sum",
}

withWeights := false
weights := []float64{}
aggregate := "sum"
for len(args) > 0 {
switch strings.ToLower(args[0]) {
case "weights":
if len(args) < numKeys+1 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
for i := 0; i < numKeys; i++ {
f, err := strconv.ParseFloat(args[i+1], 64)
if err != nil {
if store {
opts.Destination = args[0]
args = args[1:]
}
numKeys, err := strconv.Atoi(args[0])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
args = args[1:]
if len(args) < numKeys {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
if numKeys <= 0 {
setDirty(c)
c.WriteError("ERR at least 1 input key is needed for ZUNIONSTORE/ZINTERSTORE")
return
}
opts.Keys = args[:numKeys]
args = args[numKeys:]

for len(args) > 0 {
switch strings.ToLower(args[0]) {
case "weights":
if len(args) < numKeys+1 {
setDirty(c)
c.WriteError("ERR weight value is not a float")
c.WriteError(msgSyntaxError)
return
}
weights = append(weights, f)
}
withWeights = true
args = args[numKeys+1:]
case "aggregate":
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
aggregate = strings.ToLower(args[1])
switch aggregate {
case "sum", "min", "max":
for i := 0; i < numKeys; i++ {
f, err := strconv.ParseFloat(args[i+1], 64)
if err != nil {
setDirty(c)
c.WriteError("ERR weight value is not a float")
return
}
opts.Weights = append(opts.Weights, f)
}
opts.WithWeights = true
args = args[numKeys+1:]
case "aggregate":
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
aggregate := strings.ToLower(args[1])
switch aggregate {
case "sum", "min", "max":
opts.Aggregate = aggregate
default:
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
args = args[2:]
case "withscores":
if store {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
opts.WithScores = true
args = args[1:]
default:
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
args = args[2:]
default:
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
}

withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
db.del(destination, true)

// We collect everything and remove all keys which turned out not to be
// present in every set.
sset := map[string]float64{}
counts := map[string]int{}
for i, key := range keys {
if !db.exists(key) {
continue
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if opts.Store {
db.del(opts.Destination, true)
}

var set map[string]float64
switch db.t(key) {
case "set":
set = map[string]float64{}
for elem := range db.setKeys[key] {
set[elem] = 1.0
}
case "zset":
set = db.sortedSet(key)
default:
c.WriteError(msgWrongType)
return
}
for member, score := range set {
if withWeights {
score *= weights[i]
}
counts[member]++
old, ok := sset[member]
if !ok {
sset[member] = score
// We collect everything and remove all keys which turned out not to be
// present in every set.
sset := map[string]float64{}
counts := map[string]int{}
for i, key := range opts.Keys {
if !db.exists(key) {
continue
}
switch aggregate {

var set map[string]float64
switch db.t(key) {
case "set":
set = map[string]float64{}
for elem := range db.setKeys[key] {
set[elem] = 1.0
}
case "zset":
set = db.sortedSet(key)
default:
panic("Invalid aggregate")
case "sum":
sset[member] += score
case "min":
if score < old {
sset[member] = score
c.WriteError(msgWrongType)
return
}
for member, score := range set {
if opts.WithWeights {
score *= opts.Weights[i]
}
case "max":
if score > old {
counts[member]++
old, ok := sset[member]
if !ok {
sset[member] = score
continue
}
switch opts.Aggregate {
default:
panic("Invalid aggregate")
case "sum":
sset[member] += score
case "min":
if score < old {
sset[member] = score
}
case "max":
if score > old {
sset[member] = score
}
}
}
}
}
for key, count := range counts {
if count != numKeys {
delete(sset, key)
for key, count := range counts {
if count != numKeys {
delete(sset, key)
}
}
}
db.ssetSet(destination, sset)
c.WriteInt(len(sset))
})

if opts.Store {
// ZINTERSTORE mode
db.ssetSet(opts.Destination, sset)
c.WriteInt(len(sset))
return
}
// ZINTER mode
size := len(sset)
if opts.WithScores {
size *= 2
}
c.WriteLen(size)
for _, l := range sortedKeys(sset) {
c.WriteBulk(l)
if opts.WithScores {
c.WriteFloat(sset[l])
}
}
})
}
}

// ZLEXCOUNT
Expand Down Expand Up @@ -1947,3 +1995,12 @@ func parseLexrange(s string) (string, bool, error) {
return "", false, errors.New(msgInvalidRangeItem)
}
}

func sortedKeys(m map[string]float64) []string {
var keys []string
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
30 changes: 30 additions & 0 deletions cmd_sorted_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,36 @@ func TestZunion(t *testing.T) {
})
}

func TestZinter(t *testing.T) {
s, err := Run()
ok(t, err)
defer s.Close()
c, err := proto.Dial(s.Addr())
ok(t, err)
defer c.Close()

s.ZAdd("h1", 1.0, "field1")
s.ZAdd("h1", 2.0, "field2")
s.ZAdd("h1", 3.0, "field3")
s.ZAdd("h2", 1.0, "field1")
s.ZAdd("h2", 2.0, "field2")
s.ZAdd("h2", 4.0, "field4")
s.SAdd("s2", "field1")

// Simple case
{
mustDo(t, c,
"ZINTER", "2", "h1", "h2",
proto.Strings("field1", "field2"),
)
mustDo(t, c,
"ZINTER", "2", "h1", "h2", "WITHSCORES",
proto.Strings("field1", "2", "field2", "4"),
)
}
// it's the same code as ZINTERSTORE, so see TestZinterstore()
}

func TestZinterstore(t *testing.T) {
s, err := Run()
ok(t, err)
Expand Down

0 comments on commit feeb189

Please sign in to comment.