Skip to content

Commit

Permalink
fix(set): fix random in SRANDMEMBER and SPOP commands
Browse files Browse the repository at this point in the history
fixes dragonflydb#3018

Signed-off-by: Stepan Bagritsevich <sbagritsevich@quantumbrains.com>
  • Loading branch information
Stepan Bagritsevich committed May 8, 2024
1 parent 3dd6c49 commit 74b498d
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 154 deletions.
31 changes: 31 additions & 0 deletions src/server/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,4 +427,35 @@ std::ostream& operator<<(std::ostream& os, const GlobalState& state) {
return os << GlobalStateName(state);
}

NonUniquePicksGenerator::NonUniquePicksGenerator(RandomPick max_range) : max_range_(max_range) {
CHECK_GT(max_range, RandomPick(0));
}

RandomPick NonUniquePicksGenerator::Generate() {
return absl::Uniform(bitgen_, 0u, max_range_);
}

UniquePicksGenerator::UniquePicksGenerator(std::uint32_t picks_count, RandomPick max_range)
: remaining_picks_count_(picks_count), picked_indexes_(picks_count) {
CHECK_GE(max_range, picks_count);
current_random_limit_ = max_range - picks_count;
}

RandomPick UniquePicksGenerator::Generate() {
DCHECK_GT(remaining_picks_count_, 0u);

remaining_picks_count_--;

const RandomPick max_index = current_random_limit_++;
const RandomPick random_index = absl::Uniform(bitgen_, 0u, max_index + 1u);

const bool random_index_is_picked = picked_indexes_.emplace(random_index).second;
if (random_index_is_picked) {
return random_index;
}

picked_indexes_.insert(max_index);
return max_index;
}

} // namespace dfly
44 changes: 44 additions & 0 deletions src/server/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#pragma once

#include <absl/random/random.h>
#include <absl/strings/ascii.h>
#include <absl/strings/str_cat.h>
#include <absl/types/span.h>
Expand Down Expand Up @@ -309,4 +310,47 @@ struct MemoryBytesFlag {
bool AbslParseFlag(std::string_view in, dfly::MemoryBytesFlag* flag, std::string* err);
std::string AbslUnparseFlag(const dfly::MemoryBytesFlag& flag);

using RandomPick = std::uint32_t;

class PicksGenerator {
public:
virtual RandomPick Generate() = 0;
virtual ~PicksGenerator() = default;
};

class NonUniquePicksGenerator : public PicksGenerator {
public:
/* The generated value will be within the closed-open interval [0, max_range) */
NonUniquePicksGenerator(RandomPick max_range);

RandomPick Generate() override;

private:
const RandomPick max_range_;
absl::BitGen bitgen_{};
};

/*
* Generates unique index in O(1).
*
* picks_count specifies the number of random indexes to be generated.
* In other words, this is the number of times the Generate() function is called.
*
* The class uses Robert Floyd's sampling algorithm
* https://dl.acm.org/doi/pdf/10.1145/30401.315746
* */
class UniquePicksGenerator : public PicksGenerator {
public:
/* The generated value will be within the closed-open interval [0, max_range) */
UniquePicksGenerator(std::uint32_t picks_count, RandomPick max_range);

RandomPick Generate() override;

private:
RandomPick current_random_limit_;
std::uint32_t remaining_picks_count_;
std::unordered_set<RandomPick> picked_indexes_;
absl::BitGen bitgen_{};
};

} // namespace dfly
180 changes: 108 additions & 72 deletions src/server/set_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,22 +284,64 @@ void InterStrSet(const DbContext& db_context, const vector<SetType>& vec, String
}
}

StringVec PopStrSet(const DbContext& db_context, unsigned count, const SetType& st) {
StringVec RandMemberStrSet(const DbContext& db_context, const SetType& st,
PicksGenerator& generator, std::size_t picks_count) {
std::unordered_map<RandomPick, std::uint32_t> times_index_is_picked;
for (std::size_t i = 0; i < picks_count; i++) {
times_index_is_picked[generator.Generate()]++;
}

StringVec result;
result.reserve(picks_count);

if (true) {
StringSet* ss = (StringSet*)st.first;
ss->set_time(MemberTimeSeconds(db_context.time_now_ms));
StringSet* ss = static_cast<StringSet*>(st.first);
ss->set_time(MemberTimeSeconds(db_context.time_now_ms));

// TODO: this loop is inefficient because Pop searches again and again an occupied bucket.
for (unsigned i = 0; i < count && !ss->Empty(); ++i) {
result.push_back(ss->Pop().value());
std::uint32_t ss_entry_index = 0;
for (const sds ptr : *ss) {
auto it = times_index_is_picked.find(ss_entry_index++);
if (it == times_index_is_picked.end()) {
continue;
}

std::uint32_t t = it->second;
while (t--) {
result.emplace_back(ptr, sdslen(ptr));
}
}

/* Equal elements in the result are always successive. So, it is necessary to shuffle them */
absl::BitGen gen;
std::shuffle(result.begin(), result.end(), gen);

return result;
}

StringVec RandMemberSet(const DbContext& db_context, const CompactObj& co,
PicksGenerator& generator, std::size_t picks_count) {
SetType st{co.RObjPtr(), co.Encoding()};

if (st.second == kEncodingIntSet) {
intset* is = static_cast<intset*>(st.first);

StringVec result;
result.reserve(picks_count);

for (std::size_t i = 0; i < picks_count; i++) {
const std::size_t picked_index = generator.Generate();

int64_t value = 0;
CHECK_GT(intsetGet(is, picked_index, &value), std::uint8_t(0));

result.push_back(absl::StrCat(value));
}
return result;
}

CHECK(IsDenseEncoding(co));
return RandMemberStrSet(db_context, st, generator, picks_count);
}

vector<string> ToVec(absl::flat_hash_set<string>&& set) {
vector<string> result(set.size());
size_t i = 0;
Expand Down Expand Up @@ -819,69 +861,91 @@ OpResult<StringVec> OpInter(const Transaction* t, EngineShard* es, bool remove_f
return result;
}

OpResult<StringVec> OpRandMember(const OpArgs& op_args, std::string_view key, int count) {
auto find_res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_SET);
if (!find_res)
return find_res.status();

const CompactObj& co = find_res.value()->second;

const std::uint32_t size = co.Size();
const bool picks_are_unique = count >= 0;
const std::uint32_t picks_count =
picks_are_unique ? std::min(static_cast<std::uint32_t>(count), size) : std::abs(count);

auto generator = [picks_are_unique, picks_count, size]() -> std::unique_ptr<PicksGenerator> {
if (picks_are_unique) {
return std::make_unique<UniquePicksGenerator>(picks_count, size);
} else {
return std::make_unique<NonUniquePicksGenerator>(size);
}
}();

return RandMemberSet(op_args.db_cntx, co, *generator, picks_count);
}

// count - how many elements to pop.
OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, unsigned count) {
auto& db_cntx = op_args.db_cntx;
auto& db_slice = op_args.shard->db_slice();
auto find_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_SET);
if (!find_res)
auto find_res = db_slice.FindMutable(db_cntx, key, OBJ_SET);
if (!find_res) {
return find_res.status();
}

StringVec result;
if (count == 0)
return result;
CompactObj& co = find_res->it->second;

auto it = find_res->it;
size_t slen = it->second.Size();
const std::uint32_t size = co.Size();
const std::uint32_t picks_count = std::min(count, size);

/* CASE 1:
* The number of requested elements is greater than or equal to
* the number of elements inside the set: simply return the whole set. */
if (count >= slen) {
PrimeValue& pv = it->second;
if (IsDenseEncoding(pv)) {
StringSet* ss = (StringSet*)pv.RObjPtr();
if (count >= size) {
if (IsDenseEncoding(co)) {
StringSet* ss = (StringSet*)co.RObjPtr();
ss->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms));
}

container_utils::IterateSet(it->second, [&result](container_utils::ContainerEntry ce) {
StringVec result;
result.reserve(picks_count);

container_utils::IterateSet(co, [&result](container_utils::ContainerEntry ce) {
result.push_back(ce.ToString());
return true;
});

// Delete the set as it is now empty
find_res->post_updater.Run();
CHECK(db_slice.Del(op_args.db_cntx.db_index, it));
CHECK(db_slice.Del(op_args.db_cntx.db_index, find_res->it));

// Replicate as DEL.
if (op_args.shard->journal()) {
RecordJournal(op_args, "DEL"sv, ArgSlice{key});
}
} else {
SetType st{it->second.RObjPtr(), it->second.Encoding()};
if (st.second == kEncodingIntSet) {
intset* is = (intset*)st.first;
int64_t val = 0;

// copy last count values.
for (uint32_t i = slen - count; i < slen; ++i) {
intsetGet(is, i, &val);
result.push_back(absl::StrCat(val));
}
return result;
}

is = intsetTrimTail(is, count); // now remove last count items
it->second.SetRObjPtr(is);
} else {
result = PopStrSet(op_args.db_cntx, count, st);
}
/* CASE 2:
* The number of requested elements is less than the number of elements inside the set.
* In this case, we need to select random members from the set and then remove them. */
UniquePicksGenerator generator{picks_count, size};

// Replicate as SREM with removed keys, because SPOP is not deterministic.
if (op_args.shard->journal()) {
vector<string_view> mapped(result.size() + 1);
mapped[0] = key;
std::copy(result.begin(), result.end(), mapped.begin() + 1);
RecordJournal(op_args, "SREM"sv, mapped);
}
// Select random members
StringVec result = RandMemberSet(db_cntx, co, generator, picks_count);

// Remove selected members
std::vector<std::string_view> members_to_remove{result.begin(), result.end()};
bool is_empty = RemoveSet(db_cntx, members_to_remove, &co).second;
find_res->post_updater.Run();

CHECK(!is_empty);

if (op_args.shard->journal()) {
members_to_remove.insert(members_to_remove.begin(), key);
RecordJournal(op_args, "SPOP"sv, members_to_remove);
}

return result;
}

Expand Down Expand Up @@ -1204,41 +1268,13 @@ void SRandMember(CmdArgList args, ConnectionContext* cntx) {
if (auto err = parser.Error(); err)
return cntx->SendError(err->MakeReply());

const unsigned ucount = std::abs(count);

const auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<StringVec> {
StringVec result;
auto find_res = shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_SET);
if (!find_res) {
return find_res.status();
}

const PrimeValue& pv = find_res.value()->second;
if (IsDenseEncoding(pv)) {
StringSet* ss = (StringSet*)pv.RObjPtr();
ss->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms));
}

container_utils::IterateSet(find_res.value()->second,
[&result, ucount](container_utils::ContainerEntry ce) {
if (result.size() < ucount) {
result.push_back(ce.ToString());
return true;
}
return false;
});
return result;
return OpRandMember(t->GetOpArgs(shard), key, count);
};

OpResult<StringVec> result = cntx->transaction->ScheduleSingleHopT(cb);
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
if (result) {
if (count < 0 && !result->empty()) {
for (auto i = result->size(); i < ucount; ++i) {
// we can return duplicate elements, so first is OK
result->push_back(result->front());
}
}
rb->SendStringArr(*result, RedisReplyBuilder::SET);
} else if (result.status() == OpStatus::KEY_NOTFOUND) {
if (is_count) {
Expand Down

0 comments on commit 74b498d

Please sign in to comment.