Skip to content

Commit

Permalink
Batch big IN/NOT IN selections
Browse files Browse the repository at this point in the history
  • Loading branch information
Julius de Bruijn committed Mar 11, 2020
1 parent 7f34f22 commit 54f27e3
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 27 deletions.
20 changes: 6 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Expand Up @@ -22,7 +22,7 @@ user-facing-errors = { path = "../../../libs/user-facing-errors", features = ["s
tracing = "0.1.10"
tracing-futures = "0.2.0"
tokio = { version = "0.2", features = ["rt-threaded", "time"] }
once_cell = "1.3.1"
once_cell = "1.3"

[dependencies.quaint]
git = "https://github.com/prisma/quaint"
Expand All @@ -32,5 +32,4 @@ features = ["single"]
barrel = { version = "0.6.5-alpha.0", features = ["sqlite3", "mysql", "pg"] }
test-macros = { path = "../../../libs/test-macros" }
test-setup = { path = "../../../libs/test-setup" }
once_cell = "1.2.0"
pretty_assertions = "0.6.1"
2 changes: 1 addition & 1 deletion libs/prisma-models/Cargo.toml
Expand Up @@ -9,7 +9,7 @@ default = []
sql-ext = ["quaint"]

[dependencies]
once_cell = "1.2"
once_cell = "1.3"
serde_derive = "1.0"
serde_json = "1.0"
serde = "1.0"
Expand Down
2 changes: 1 addition & 1 deletion libs/user-facing-error-macros/Cargo.toml
Expand Up @@ -14,5 +14,5 @@ darling = "0.10.1"
syn = "1.0.5"
quote = "1.0.2"
proc-macro2 = "1.0.6"
once_cell = "1.2.0"
once_cell = "1.3"
regex = "1.3.1"
2 changes: 1 addition & 1 deletion query-engine/connectors/query-connector/Cargo.toml
Expand Up @@ -7,7 +7,7 @@ edition = "2018"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
once_cell = "1.2"
once_cell = "1.3"
prisma-models = { path = "../../../libs/prisma-models" }
prisma-value = { path = "../../../libs/prisma-value" }
failure = { version = "0.1", features = ["derive"] }
Expand Down
69 changes: 69 additions & 0 deletions query-engine/connectors/query-connector/src/filter/mod.rs
Expand Up @@ -58,6 +58,75 @@ impl Filter {
_ => 1,
}
}

pub fn can_batch(&self) -> bool {
match self {
Self::Scalar(sf) => sf.can_batch(),
Self::And(filters) => {
let batchable_count = filters
.iter()
.map(|f| if f.can_batch() { 1 } else { 0 })
.sum::<usize>();

batchable_count == 1
}
Self::Or(filters) => {
let batchable_count = filters
.iter()
.map(|f| if f.can_batch() { 1 } else { 0 })
.sum::<usize>();

batchable_count == 1
}
_ => false,
}
}

pub fn batched(self) -> Vec<Filter> {
fn inner<F>(mut filters: Vec<Filter>, f: F) -> Vec<Filter>
where
F: Fn(Vec<Filter>) -> Filter
{
let mut other = Vec::new();
let mut batched = Vec::new();

while let Some(filter) = filters.pop() {
if filter.can_batch() {
for filter in filter.batched() {
batched.push(filter);
}
} else {
other.push(filter);
}
}

if batched.len() > 0 {
batched.into_iter().map(|batch| {
let mut filters = other.clone();
filters.push(batch);

f(filters)
}).collect()
} else {
vec![f(other)]
}
}

match self {
Self::Scalar(sf) => sf
.batched()
.into_iter()
.map(|sf| Self::Scalar(sf))
.collect(),
Self::And(filters) => {
inner(filters, |filters| Filter::And(filters))
}
Self::Or(filters) => {
inner(filters, |filters| Filter::Or(filters))
}
_ => vec![self],
}
}
}

impl From<ScalarFilter> for Filter {
Expand Down
49 changes: 49 additions & 0 deletions query-engine/connectors/query-connector/src/filter/scalar.rs
Expand Up @@ -15,6 +15,55 @@ pub struct ScalarFilter {
pub condition: ScalarCondition,
}

const BATCH_SIZE: usize = 5000;

impl ScalarFilter {
pub fn can_batch(&self) -> bool {
match self.condition {
ScalarCondition::In(ref l) if l.len() > BATCH_SIZE => true,
ScalarCondition::NotIn(ref l) if l.len() > BATCH_SIZE => true,
_ => false,
}
}

pub fn batched(self) -> Vec<ScalarFilter> {
fn inner(list: PrismaListValue) -> Vec<PrismaListValue> {
let mut batches = Vec::with_capacity(list.len() % BATCH_SIZE + 1);
batches.push(Vec::with_capacity(BATCH_SIZE));

for (idx, item) in list.into_iter().enumerate() {
if idx != 0 && idx % BATCH_SIZE == 0 {
batches.push(Vec::with_capacity(BATCH_SIZE));
}

batches.last_mut().unwrap().push(item);
}

batches
}

match self.condition {
ScalarCondition::In(list) => {
let projection = self.projection;

inner(list).into_iter().map(|batch| ScalarFilter {
projection: projection.clone(),
condition: ScalarCondition::In(batch)
}).collect()
}
ScalarCondition::NotIn(list) => {
let projection = self.projection;

inner(list).into_iter().map(|batch| ScalarFilter {
projection: projection.clone(),
condition: ScalarCondition::NotIn(batch)
}).collect()
}
_ => vec![self]
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ScalarCondition {
Equals(PrismaValue),
Expand Down
33 changes: 33 additions & 0 deletions query-engine/connectors/query-connector/src/query_arguments.rs
Expand Up @@ -65,6 +65,39 @@ impl QueryArguments {
},
}
}

pub fn can_batch(&self) -> bool {
self.filter
.as_ref()
.map(|filter| filter.can_batch())
.unwrap_or(false)
}

pub fn batched(self) -> Vec<Self> {
match self.filter {
Some(filter) => {
let after = self.after;
let before = self.before;
let skip = self.skip;
let first = self.first;
let last = self.last;
let order_by = self.order_by;

filter.batched().into_iter().map(|filter| {
QueryArguments {
after: after.clone(),
before: before.clone(),
skip: skip.clone(),
first: first.clone(),
last: last.clone(),
filter: Some(filter),
order_by: order_by.clone(),
}
}).collect()
},
_ => vec![self]
}
}
}

impl<T> From<T> for QueryArguments
Expand Down
1 change: 1 addition & 0 deletions query-engine/connectors/sql-query-connector/Cargo.toml
Expand Up @@ -15,6 +15,7 @@ log = "0.4"
async-trait = "0.1"
futures = "0.3"
rust_decimal = "=1.1.0"
tokio = "0.2"

[dependencies.quaint]
git = "https://github.com/prisma/quaint"
Expand Down
Expand Up @@ -5,6 +5,7 @@ use crate::{
use connector_interface::*;
use prisma_models::*;
use quaint::ast::*;
use futures::future;

pub async fn get_single_record(
conn: &dyn QueryExt,
Expand Down Expand Up @@ -36,14 +37,29 @@ pub async fn get_many_records(
) -> crate::Result<ManyRecords> {
let field_names = selected_fields.db_names().map(String::from).collect();
let idents: Vec<_> = selected_fields.types().collect();
let query = read::get_records(model, selected_fields.columns(), query_arguments);

let records = conn
.filter(query.into(), idents.as_slice())
.await?
.into_iter()
.map(Record::from)
.collect();
let mut records = Vec::new();

if query_arguments.can_batch() {
let batches = query_arguments.batched();
let mut futures = Vec::with_capacity(batches.len());

for args in batches.into_iter() {
let query = read::get_records(model, selected_fields.columns(), args);
futures.push(conn.filter(query.into(), idents.as_slice()));
}

for result in future::join_all(futures).await.into_iter() {
for item in result?.into_iter() {
records.push(Record::from(item))
}
}
} else {
let query = read::get_records(model, selected_fields.columns(), query_arguments);

for item in conn.filter(query.into(), idents.as_slice()).await?.into_iter() {
records.push(Record::from(item))
}
}

Ok(ManyRecords { records, field_names })
}
Expand Down

0 comments on commit 54f27e3

Please sign in to comment.