Skip to content

Commit

Permalink
Add toLowerCase / toUpperCase support in Mongo backend
Browse files Browse the repository at this point in the history
Create special visitor for `$expr` (pipeline / aggregation) type of
expressions
  • Loading branch information
asereda-gs committed May 5, 2020
1 parent b0b0a8a commit 6e4dd31
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 24 deletions.
Expand Up @@ -406,6 +406,10 @@ protected void upperLowerCase() {
ids(string.value.toUpperCase().in("A", "BC")).hasContentInAnyOrder("id2", "id3", "id4");
ids(string.value.toLowerCase().in("", "a")).hasContentInAnyOrder("id1", "id2", "id3");
ids(string.value.toUpperCase().notIn("BC", "A")).hasContentInAnyOrder("id1");

// chain upper/lower
ids(string.value.toUpperCase().toLowerCase().in("A", "BC")).isEmpty();
ids(string.value.toLowerCase().toUpperCase().in("A", "BC")).hasContentInAnyOrder("id2", "id3", "id4");
}

/**
Expand Down
Expand Up @@ -58,7 +58,7 @@ class AggregationQuery {
private final PathNaming pathNaming;
private final ExpressionNaming projectionNaming;
private final BiMap<Expression, String> naming;
private final CodecRegistry registry;
private final CodecRegistry codecRegistry;

AggregationQuery(Query query, PathNaming pathNaming) {
this.query = maybeRewriteDistinctToGroupBy(query);
Expand All @@ -75,7 +75,7 @@ class AggregationQuery {

this.projectionNaming = ExpressionNaming.from(UniqueCachedNaming.of(query.projections()));
this.naming = ImmutableBiMap.copyOf(biMap);
this.registry = MongoClientSettings.getDefaultCodecRegistry();
this.codecRegistry = MongoClientSettings.getDefaultCodecRegistry();
}

/**
Expand Down Expand Up @@ -147,7 +147,7 @@ private class MatchPipeline implements Pipeline {
@Override
public void process(Consumer<Bson> consumer) {
query.filter().ifPresent(expr -> {
Bson filter = expr.accept(new FindVisitor(pathNaming));
Bson filter = expr.accept(new FindVisitor(pathNaming, codecRegistry));
Objects.requireNonNull(filter, "null filter");
consumer.accept(Aggregates.match(filter));
});
Expand Down Expand Up @@ -250,7 +250,7 @@ public void process(Consumer<Bson> consumer) {

BsonDocument sort = new BsonDocument();
for (Collation collation: query.collations()) {
sort.putAll(toSortFn.apply(collation).toBsonDocument(BsonDocument.class, registry));
sort.putAll(toSortFn.apply(collation).toBsonDocument(BsonDocument.class, codecRegistry));
}

if (!sort.isEmpty()) {
Expand Down
143 changes: 133 additions & 10 deletions criteria/mongo/src/org/immutables/criteria/mongo/FindVisitor.java
Expand Up @@ -19,18 +19,26 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.mongodb.client.model.Filters;
import org.bson.BsonArray;
import org.bson.BsonDocument;
import org.bson.BsonNull;
import org.bson.BsonString;
import org.bson.BsonValue;
import org.bson.Document;
import org.bson.codecs.configuration.CodecRegistry;
import org.bson.conversions.Bson;
import org.immutables.criteria.backend.PathNaming;
import org.immutables.criteria.expression.AbstractExpressionVisitor;
import org.immutables.criteria.expression.Call;
import org.immutables.criteria.expression.ComparableOperators;
import org.immutables.criteria.expression.Constant;
import org.immutables.criteria.expression.Expression;
import org.immutables.criteria.expression.Expressions;
import org.immutables.criteria.expression.IterableOperators;
import org.immutables.criteria.expression.Operator;
import org.immutables.criteria.expression.Operators;
import org.immutables.criteria.expression.OptionalOperators;
import org.immutables.criteria.expression.Path;
import org.immutables.criteria.expression.StringOperators;
import org.immutables.criteria.expression.Visitors;

Expand All @@ -49,10 +57,12 @@
class FindVisitor extends AbstractExpressionVisitor<Bson> {

private final PathNaming naming;
private final CodecRegistry codecRegistry;

FindVisitor(PathNaming naming) {
FindVisitor(PathNaming pathNaming, CodecRegistry codecRegistry) {
super(e -> { throw new UnsupportedOperationException(); });
this.naming = Objects.requireNonNull(naming, "pathNaming");
this.naming = Objects.requireNonNull(pathNaming, "pathNaming");
this.codecRegistry = Objects.requireNonNull(codecRegistry, "codecRegistry");
}

@Override
Expand Down Expand Up @@ -98,8 +108,16 @@ public Bson visit(Call call) {
private Bson binaryCall(Call call) {
Preconditions.checkArgument(call.operator().arity() == Operator.Arity.BINARY, "%s is not binary", call.operator());
final Operator op = call.operator();
final String field = naming.name(Visitors.toPath(call.arguments().get(0)));
final Object value = Visitors.toConstant(call.arguments().get(1)).value();
Expression left = call.arguments().get(0);
Expression right = call.arguments().get(1);

if (!(left instanceof Path && right instanceof Constant)) {
// special case when $expr has to be used
return call.accept(new MongoExpr(naming, codecRegistry)).asDocument();
}

final String field = naming.name(Visitors.toPath(left));
final Object value = Visitors.toConstant(right).value();
if (op == Operators.EQUAL || op == Operators.NOT_EQUAL) {
if ("".equals(value) && op == Operators.NOT_EQUAL) {
// special case for empty string. string != "" should not return missing strings
Expand All @@ -123,12 +141,12 @@ private Bson binaryCall(Call call) {
}

if (op == Operators.IN || op == Operators.NOT_IN) {
final Collection<Object> values = ImmutableSet.copyOf(Visitors.toConstant(call.arguments().get(1)).values());
final Collection<Object> values = ImmutableSet.copyOf(Visitors.toConstant(right).values());
Preconditions.checkNotNull(values, "not expected to be null for %s", op);
if (values.size() == 1) {
// optimization: convert IN, NIN (where argument is a list with single element) into EQ / NE
Operators newOperator = op == Operators.IN ? Operators.EQUAL : Operators.NOT_EQUAL;
Call newCall = Expressions.call(newOperator, call.arguments().get(0), Expressions.constant(values.iterator().next()));
Call newCall = Expressions.call(newOperator, left, Expressions.constant(values.iterator().next()));
return binaryCall(newCall);
}
return op == Operators.IN ? Filters.in(field, values) : Filters.nin(field, values);
Expand Down Expand Up @@ -173,10 +191,6 @@ private Bson binaryCall(Call call) {
return Filters.regex(field, Pattern.compile(pattern));
}

if (op == StringOperators.TO_LOWER_CASE || op == StringOperators.TO_UPPER_CASE) {
// $expr:{$eq:[ $toUpper: "$field", "value"]}
// probably needs special check for wrapping expression
}

throw new UnsupportedOperationException(String.format("Unsupported binary call %s", call));
}
Expand Down Expand Up @@ -215,4 +229,113 @@ private Bson negate(Expression expression) {
return Filters.not(notCall.accept(this));
}

/**
* Visitor used when special {@code $expr} needs to be generated like {@code field1 == field2}
* in mongo it would look like:
*
* <pre>
* {@code
* $expr: {
* $eq: [
* "$field1",
* "$field2"
* ]
* }
* }
* </pre>
* @see <a href="https://docs.mongodb.com/manual/reference/operator/query/expr/">$expr</a>
*/
private static class MongoExpr extends AbstractExpressionVisitor<BsonValue> {
private final PathNaming pathNaming;
private final CodecRegistry codecRegistry;

private MongoExpr(PathNaming pathNaming, CodecRegistry codecRegistry) {
super(e -> { throw new UnsupportedOperationException(); });
this.pathNaming = pathNaming;
this.codecRegistry = codecRegistry;
}

@Override
public BsonValue visit(Call call) {
Operator op = call.operator();
if (op.arity() == Operator.Arity.BINARY) {
return visitBinary(call, call.arguments().get(0), call.arguments().get(1));
}

if (op.arity() == Operator.Arity.UNARY) {
return visitUnary(call, call.arguments().get(0));
}


throw new UnsupportedOperationException("Don't know how to handle " + call);
}

private BsonValue visitBinary(Call call, Expression left, Expression right) {
Operator op = call.operator();

String mongoOp;
if (op == Operators.EQUAL) {
mongoOp = "$eq";
} else if (op == Operators.NOT_EQUAL) {
mongoOp = "$ne";
} else if (op == Operators.IN) {
mongoOp = "$in";
} else if (op == Operators.NOT_IN) {
mongoOp = "$in"; // will be wrapped in $not: {$not: {$in: ... }}
} else {
throw new UnsupportedOperationException(String.format("Unknown operator %s for call %s", op, call));
}

BsonArray args = new BsonArray();
args.add(left.accept(this));
args.add(right.accept(this));

BsonDocument expr = new BsonDocument(mongoOp, args);
if (op == Operators.NOT_IN) {
// for aggregations $nin does not work
// use {$not: {$in: ... }} instead
expr = new BsonDocument("$not", expr);
}

return Filters.expr(expr).toBsonDocument(BsonDocument.class, codecRegistry);
}

private BsonValue visitUnary(Call call, Expression arg) {
Operator op = call.operator();

if (op == StringOperators.TO_LOWER_CASE || op == StringOperators.TO_UPPER_CASE) {
String key = op == StringOperators.TO_LOWER_CASE ? "$toLower" : "$toUpper";
BsonValue value = arg.accept(this);
return new BsonDocument(key, value);
}

throw new UnsupportedOperationException("Unknown unary call " + call);
}

@Override
public BsonValue visit(Path path) {
// in mongo expressions fields are referenced as $field
return new BsonString('$' + pathNaming.name(path));
}

@Override
public BsonValue visit(Constant constant) {
Object value = constant.value();
if (value == null) {
return BsonNull.VALUE;
}

if (value instanceof Iterable) {
return Filters.in("ignore", (Iterable<?>) value)
.toBsonDocument(BsonDocument.class, codecRegistry)
.get("ignore").asDocument()
.get("$in").asArray();
}

return Filters.eq("ignore", value)
.toBsonDocument(BsonDocument.class, codecRegistry)
.get("ignore");
}
}

}
Expand Up @@ -84,7 +84,7 @@ class MongoSession implements Backend.Session {
pathNaming = new MongoPathNaming(idProperty, pathNaming);
}
this.pathNaming = pathNaming;
this.converter = Mongos.converter(this.pathNaming);
this.converter = Mongos.converter(this.pathNaming, collection.getCodecRegistry());
}

private Bson toBsonFilter(Query query) {
Expand Down
5 changes: 3 additions & 2 deletions criteria/mongo/src/org/immutables/criteria/mongo/Mongos.java
Expand Up @@ -16,6 +16,7 @@

package org.immutables.criteria.mongo;

import org.bson.codecs.configuration.CodecRegistry;
import org.bson.conversions.Bson;
import org.immutables.criteria.backend.PathNaming;
import org.immutables.criteria.expression.ExpressionConverter;
Expand All @@ -37,8 +38,8 @@ private Mongos() {}
/**
* Convert existing expression to Bson
*/
static ExpressionConverter<Bson> converter(PathNaming pathNaming) {
return expression -> expression.accept(new FindVisitor(pathNaming));
static ExpressionConverter<Bson> converter(PathNaming pathNaming, CodecRegistry codecRegistry) {
return expression -> expression.accept(new FindVisitor(pathNaming, codecRegistry));
}

/**
Expand Down
Expand Up @@ -71,6 +71,25 @@ void emptyString() {
check(criteria.value.notEmpty()).matches("{value: {$nin: ['', null], $exists: true}}");
}

@Test
void upperLower() {
check(criteria.value.toUpperCase().is("A")).matches("{$expr: {$eq:[{$toUpper: '$value'}, 'A']}}");
check(criteria.value.toLowerCase().is("a")).matches("{$expr: {$eq:[{$toLower: '$value'}, 'a']}}");
check(criteria.value.toLowerCase().isNot("a")).matches("{$expr: {$ne:[{$toLower: '$value'}, 'a']}}");
check(criteria.value.toLowerCase().in("a", "b")).matches("{$expr: {$in:[{$toLower: '$value'}, ['a', 'b']]}}");
// for aggregations / $expr, $nin does not work
// use {$not: {$in: ... }} instead
check(criteria.value.toUpperCase().notIn("a", "b")).matches("{$expr: {$not: {$in:[{$toUpper: '$value'}, ['a', 'b']]}}}");

// chain toUpper.toLower.toUpper
check(criteria.value.toUpperCase().toLowerCase().is("A"))
.matches("{$expr: {$eq:[{$toLower: {$toUpper: '$value'}}, 'A']}}");

check(criteria.value.toLowerCase().toUpperCase().is("A"))
.matches("{$expr: {$eq:[{$toUpper: {$toLower: '$value'}}, 'A']}}");

}

private static QueryAssertion check(StringHolderCriteria criteria) {
return QueryAssertion.ofFilter(Criterias.toQuery(criteria));
}
Expand Down
Expand Up @@ -35,7 +35,6 @@
import org.immutables.criteria.typemodel.StringTemplate;
import org.immutables.criteria.typemodel.UpdateByQueryTemplate;
import org.immutables.criteria.typemodel.WriteTemplate;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.extension.ExtendWith;

Expand All @@ -53,11 +52,6 @@ class String extends StringTemplate {
private String() {
super(backend);
}

@Disabled("Doesn't work in mongo yet")
@Override
protected void upperLowerCase() {
}
}

@Nested
Expand Down
Expand Up @@ -43,8 +43,8 @@ class QueryAssertion {
this.query = Objects.requireNonNull(query, "query");
Path idPath = Visitors.toPath(KeyExtractor.defaultFactory().create(query.entityClass()).metadata().keys().get(0));
PathNaming pathNaming = new MongoPathNaming(idPath, PathNaming.defaultNaming());
FindVisitor visitor = new FindVisitor(pathNaming);
CodecRegistry codecRegistry = MongoClientSettings.getDefaultCodecRegistry();
FindVisitor visitor = new FindVisitor(pathNaming, codecRegistry);
if (query.hasAggregations() || pipeline) {
AggregationQuery agg = new AggregationQuery(query, pathNaming);
this.actual = agg.toPipeline().stream()
Expand Down

0 comments on commit 6e4dd31

Please sign in to comment.