Skip to content

Commit

Permalink
fix(set): fix random in SRANDMEMBER and SPOP commands
Browse files Browse the repository at this point in the history
fixes dragonflydb#3018

Signed-off-by: Stepan Bagritsevich <sbagritsevich@quantumbrains.com>
  • Loading branch information
Stepan Bagritsevich committed May 7, 2024
1 parent f27506e commit 0f19aa7
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 170 deletions.
31 changes: 31 additions & 0 deletions src/server/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,4 +427,35 @@ std::ostream& operator<<(std::ostream& os, const GlobalState& state) {
return os << GlobalStateName(state);
}

NonUniquePicksGenerator::NonUniquePicksGenerator(RandomPick max_range) : max_range_(max_range) {
CHECK_GT(max_range, RandomPick(0));
}

RandomPick NonUniquePicksGenerator::Generate() {
return absl::Uniform(bitgen_, 0u, max_range_);
}

UniquePicksGenerator::UniquePicksGenerator(std::uint32_t picks_count, RandomPick max_range)
: remaining_picks_count_(picks_count), picked_indexes_(picks_count) {
CHECK_GE(max_range, picks_count);
current_random_limit_ = max_range - picks_count;
}

RandomPick UniquePicksGenerator::Generate() {
DCHECK_GT(remaining_picks_count_, 0u);

remaining_picks_count_--;

const RandomPick max_index = current_random_limit_++;
const RandomPick random_index = absl::Uniform(bitgen_, 0u, max_index + 1u);

const bool random_index_is_picked = picked_indexes_.emplace(random_index).second;
if (random_index_is_picked) {
return random_index;
}

picked_indexes_.insert(max_index);
return max_index;
}

} // namespace dfly
42 changes: 42 additions & 0 deletions src/server/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#pragma once

#include <absl/random/random.h>
#include <absl/strings/ascii.h>
#include <absl/strings/str_cat.h>
#include <absl/types/span.h>
Expand Down Expand Up @@ -309,4 +310,45 @@ struct MemoryBytesFlag {
bool AbslParseFlag(std::string_view in, dfly::MemoryBytesFlag* flag, std::string* err);
std::string AbslUnparseFlag(const dfly::MemoryBytesFlag& flag);

using RandomPick = std::uint32_t;

class PicksGenerator {
public:
virtual RandomPick Generate() = 0;
virtual ~PicksGenerator() = default;
};

class NonUniquePicksGenerator : public PicksGenerator {
public:
NonUniquePicksGenerator(RandomPick max_range);

RandomPick Generate() override;

private:
const RandomPick max_range_;
absl::BitGen bitgen_{};
};

/*
* Generates unique index in O(1).
*
* picks_count specifies the number of random indexes to be generated.
* In other words, this is the number of times the Generate() function is called.
*
* The class uses Robert Floyd's sampling algorithm
* https://dl.acm.org/doi/pdf/10.1145/30401.315746
* */
class UniquePicksGenerator : public PicksGenerator {
public:
UniquePicksGenerator(std::uint32_t picks_count, RandomPick max_range);

RandomPick Generate() override;

private:
RandomPick current_random_limit_;
std::uint32_t remaining_picks_count_;
std::unordered_set<RandomPick> picked_indexes_;
absl::BitGen bitgen_{};
};

} // namespace dfly
154 changes: 66 additions & 88 deletions src/server/set_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,19 +284,32 @@ void InterStrSet(const DbContext& db_context, const vector<SetType>& vec, String
}
}

StringVec PopStrSet(const DbContext& db_context, unsigned count, const SetType& st) {
StringVec RandMemberStrSet(const DbContext& db_context, const SetType& st,
const std::unique_ptr<PicksGenerator>& generator,
std::size_t picks_count) {
std::unordered_map<RandomPick, std::uint32_t> times_index_is_picked;
for (std::size_t i = 0; i < picks_count; i++) {
times_index_is_picked[generator->Generate()]++;
}

StringVec result;
result.reserve(picks_count);

if (true) {
StringSet* ss = (StringSet*)st.first;
ss->set_time(MemberTimeSeconds(db_context.time_now_ms));
StringSet* ss = static_cast<StringSet*>(st.first);
ss->set_time(MemberTimeSeconds(db_context.time_now_ms));

// TODO: this loop is inefficient because Pop searches again and again an occupied bucket.
for (unsigned i = 0; i < count && !ss->Empty(); ++i) {
result.push_back(ss->Pop().value());
std::uint32_t ss_entry_index = 0;
for (const sds ptr : *ss) {
std::uint32_t t = times_index_is_picked[ss_entry_index++];
while (t--) {
result.emplace_back(ptr, sdslen(ptr));
}
}

/* Members in result are sorted by scores. So, it is necessary to shuffle them*/
absl::BitGen gen;
std::shuffle(result.begin(), result.end(), gen);

return result;
}

Expand Down Expand Up @@ -819,70 +832,63 @@ OpResult<StringVec> OpInter(const Transaction* t, EngineShard* es, bool remove_f
return result;
}

// count - how many elements to pop.
OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, unsigned count) {
auto& db_slice = op_args.shard->db_slice();
auto find_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_SET);
OpResult<StringVec> OpRandMember(const OpArgs& op_args, std::string_view key, int count) {
auto find_res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_SET);
if (!find_res)
return find_res.status();

StringVec result;
if (count == 0)
return result;
const CompactObj& co = find_res.value()->second;

auto it = find_res->it;
size_t slen = it->second.Size();
const std::size_t size = co.Size();
const bool picks_are_unique = count >= 0;
const std::size_t picks_count =
picks_are_unique ? std::min(static_cast<std::size_t>(count), size) : std::abs(count);

/* CASE 1:
* The number of requested elements is greater than or equal to
* the number of elements inside the set: simply return the whole set. */
if (count >= slen) {
PrimeValue& pv = it->second;
if (IsDenseEncoding(pv)) {
StringSet* ss = (StringSet*)pv.RObjPtr();
ss->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms));
}
std::unique_ptr<PicksGenerator> generator =
picks_are_unique ? static_cast<std::unique_ptr<PicksGenerator>>(
std::make_unique<UniquePicksGenerator>(picks_count, size))
: std::make_unique<NonUniquePicksGenerator>(size);

container_utils::IterateSet(it->second, [&result](container_utils::ContainerEntry ce) {
result.push_back(ce.ToString());
return true;
});
SetType st{co.RObjPtr(), co.Encoding()};
if (st.second == kEncodingIntSet) {
intset* is = static_cast<intset*>(st.first);

// Delete the set as it is now empty
find_res->post_updater.Run();
CHECK(db_slice.Del(op_args.db_cntx.db_index, it));
StringVec result;
result.reserve(picks_count);

// Replicate as DEL.
if (op_args.shard->journal()) {
RecordJournal(op_args, "DEL"sv, ArgSlice{key});
for (std::size_t i = 0; i < picks_count; i++) {
const std::size_t picked_index = generator->Generate();

int64_t value;
DCHECK_GT(intsetGet(is, picked_index, &value), std::uint8_t(0));

result.push_back(absl::StrCat(value));
}
return result;
} else {
SetType st{it->second.RObjPtr(), it->second.Encoding()};
if (st.second == kEncodingIntSet) {
intset* is = (intset*)st.first;
int64_t val = 0;

// copy last count values.
for (uint32_t i = slen - count; i < slen; ++i) {
intsetGet(is, i, &val);
result.push_back(absl::StrCat(val));
}
CHECK(IsDenseEncoding(co));
return RandMemberStrSet(op_args.db_cntx, st, generator, picks_count);
}
}

is = intsetTrimTail(is, count); // now remove last count items
it->second.SetRObjPtr(is);
} else {
result = PopStrSet(op_args.db_cntx, count, st);
}
// count - how many elements to pop.
OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, int count) {
auto rand_members_result = OpRandMember(op_args, key, count);
if (!rand_members_result) {
return rand_members_result.status();
}

// Replicate as SREM with removed keys, because SPOP is not deterministic.
if (op_args.shard->journal()) {
vector<string_view> mapped(result.size() + 1);
mapped[0] = key;
std::copy(result.begin(), result.end(), mapped.begin() + 1);
RecordJournal(op_args, "SREM"sv, mapped);
}
StringVec rand_members = rand_members_result.value();

std::vector<std::string_view> members_to_remove{rand_members.begin(), rand_members.end()};
ArgSlice span{members_to_remove.data(), members_to_remove.size()};

auto rem_members_result = OpRem(op_args, key, span, false);
if (!rem_members_result) {
return rem_members_result.status();
}
return result;

return rand_members;
}

OpResult<StringVec> OpScan(const OpArgs& op_args, string_view key, uint64_t* cursor,
Expand Down Expand Up @@ -1055,7 +1061,7 @@ void SCard(CmdArgList args, ConnectionContext* cntx) {

void SPop(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 0);
unsigned count = 1;
int count = 1;
if (args.size() > 1) {
string_view arg = ArgS(args, 1);
if (!absl::SimpleAtoi(arg, &count)) {
Expand Down Expand Up @@ -1204,41 +1210,13 @@ void SRandMember(CmdArgList args, ConnectionContext* cntx) {
if (auto err = parser.Error(); err)
return cntx->SendError(err->MakeReply());

const unsigned ucount = std::abs(count);

const auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<StringVec> {
StringVec result;
auto find_res = shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_SET);
if (!find_res) {
return find_res.status();
}

const PrimeValue& pv = find_res.value()->second;
if (IsDenseEncoding(pv)) {
StringSet* ss = (StringSet*)pv.RObjPtr();
ss->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms));
}

container_utils::IterateSet(find_res.value()->second,
[&result, ucount](container_utils::ContainerEntry ce) {
if (result.size() < ucount) {
result.push_back(ce.ToString());
return true;
}
return false;
});
return result;
return OpRandMember(t->GetOpArgs(shard), key, count);
};

OpResult<StringVec> result = cntx->transaction->ScheduleSingleHopT(cb);
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
if (result) {
if (count < 0 && !result->empty()) {
for (auto i = result->size(); i < ucount; ++i) {
// we can return duplicate elements, so first is OK
result->push_back(result->front());
}
}
rb->SendStringArr(*result, RedisReplyBuilder::SET);
} else if (result.status() == OpStatus::KEY_NOTFOUND) {
if (is_count) {
Expand Down

0 comments on commit 0f19aa7

Please sign in to comment.