Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSHARP-4985: Verify that operands to numeric operators in LINQ expressions are represented as numbers on the server. #1294

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using MongoDB.Bson.Serialization.Options;
using MongoDB.Bson.Serialization.Serializers;
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators;

namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
{
Expand All @@ -35,22 +36,18 @@ public static void EnsureRepresentationIsArray(Expression expression, IBsonSeria
}
}

public static void EnsureRepresentationIsNumeric(Expression expression, AggregationExpression translation)
{
EnsureRepresentationIsNumeric(expression, translation);
}

public static void EnsureRepresentationIsNumeric(Expression expression, IBsonSerializer serializer)
{
var representation = GetRepresentation(serializer);
if (!IsNumericRepresentation(representation))
{
throw new ExpressionNotSupportedException(expression, because: $"serializer for type {serializer.ValueType} uses a non-numeric representation: {representation}");
}

static bool IsNumericRepresentation(BsonType representation)
{
return representation switch
{
BsonType.Decimal128 or BsonType.Double or BsonType.Int32 or BsonType.Int64 => true,
_ => false
};
}
}

public static BsonType GetRepresentation(IBsonSerializer serializer)
Expand All @@ -65,6 +62,11 @@ public static BsonType GetRepresentation(IBsonSerializer serializer)
return GetRepresentation(downcastingSerializer.DerivedSerializer);
}

if (serializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer)
{
return GetRepresentation(enumUnderlyingTypeSerializer.EnumSerializer);
}

if (serializer is IImpliedImplementationInterfaceSerializer impliedImplementationSerializer)
{
return GetRepresentation(impliedImplementationSerializer.ImplementationSerializer);
Expand All @@ -91,6 +93,11 @@ public static BsonType GetRepresentation(IBsonSerializer serializer)
return keyValuePairSerializer.Representation;
}

if (serializer is INullableSerializer nullableSerializer)
{
return GetRepresentation(nullableSerializer.ValueSerializer);
}

// for backward compatibility assume that any remaining implementers of IBsonDocumentSerializer are represented as documents
if (serializer is IBsonDocumentSerializer)
{
Expand All @@ -106,11 +113,52 @@ public static BsonType GetRepresentation(IBsonSerializer serializer)
return BsonType.Undefined;
}

public static bool IsIntegerRepresentation(BsonType representation)
{
return representation switch
{
BsonType.Int32 or BsonType.Int64 => true,
_ => false
};
}

public static bool IsNumericRepresentation(BsonType representation)
{
return representation switch
{
BsonType.Decimal128 or BsonType.Double or BsonType.Int32 or BsonType.Int64 => true,
_ => false
};
}

public static bool IsRepresentedAsDocument(IBsonSerializer serializer)
{
return SerializationHelper.GetRepresentation(serializer) == BsonType.Document;
}

public static bool IsRepresentedAsInteger(IBsonSerializer serializer)
{
var representation = GetRepresentation(serializer);
return IsIntegerRepresentation(representation);
}

public static bool IsRepresentedAsIntegerOrNullableInteger(AggregationExpression translation)
{
return IsRepresentedAsIntegerOrNullableInteger(translation.Serializer);
}

public static bool IsRepresentedAsIntegerOrNullableInteger(IBsonSerializer serializer)
{
if (serializer is INullableSerializer nullableSerializer)
{
return IsRepresentedAsInteger(nullableSerializer.ValueSerializer);
}
else
{
return IsRepresentedAsInteger(serializer);
}
}

public static BsonValue SerializeValue(IBsonSerializer serializer, ConstantExpression constantExpression, Expression containingExpression)
{
var value = constantExpression.Value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@

using System;
using System.Linq.Expressions;
using MongoDB.Bson;
using MongoDB.Bson.Serialization;
using MongoDB.Bson.Serialization.Serializers;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
using MongoDB.Driver.Support;

namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
Expand Down Expand Up @@ -76,6 +74,12 @@ public static AggregationExpression Translate(TranslationContext context, Binary
rightTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, rightExpression);
}

if (IsArithmeticExpression(expression))
{
SerializationHelper.EnsureRepresentationIsNumeric(leftExpression, leftTranslation);
SerializationHelper.EnsureRepresentationIsNumeric(rightExpression, rightTranslation);
}

var ast = expression.NodeType switch
{
ExpressionType.Add => IsStringConcatenationExpression(expression) ?
Expand Down Expand Up @@ -163,7 +167,7 @@ private static bool IsAddOrSubtractExpression(Expression expression)

private static bool IsArithmeticExpression(BinaryExpression expression)
{
return expression.Type.IsNumeric() && IsArithmeticOperator(expression.NodeType);
return expression.Type.IsNumericOrNullableNumeric() && IsArithmeticOperator(expression.NodeType);
}

private static bool IsArithmeticOperator(ExpressionType nodeType)
Expand Down Expand Up @@ -291,31 +295,29 @@ private static AggregationExpression TranslateEnumExpression(TranslationContext
leftTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, leftExpression);
rightTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, rightExpression);

AggregationExpression enumTranslation, operandTranslation;
if (IsEnumOrConvertEnumToUnderlyingType(leftExpression))
{
serializer = leftTranslation.Serializer;
enumTranslation = leftTranslation;
operandTranslation = rightTranslation;
}
else
{
serializer = rightTranslation.Serializer;
enumTranslation = rightTranslation;
operandTranslation = leftTranslation;
}

var representation = BsonType.Int32; // assume an integer representation unless we can determine otherwise
var valueSerializer = serializer;
if (valueSerializer is INullableSerializer nullableSerializer)
if (!SerializationHelper.IsRepresentedAsIntegerOrNullableInteger(enumTranslation))
{
valueSerializer = nullableSerializer.ValueSerializer;
}
if (valueSerializer is IEnumUnderlyingTypeSerializer enumUnderlyingTypeSerializer &&
enumUnderlyingTypeSerializer.EnumSerializer is IHasRepresentationSerializer withRepresentationSerializer)
{
representation = withRepresentationSerializer.Representation;
throw new ExpressionNotSupportedException(expression, because: "arithmetic on enums is only allowed when the enum is represented as an integer");
}

if (representation != BsonType.Int32 && representation != BsonType.Int64)
if (!SerializationHelper.IsRepresentedAsIntegerOrNullableInteger(operandTranslation))
{
throw new ExpressionNotSupportedException(expression, because: "arithmetic on enums is only allowed when the enum is represented as an integer");
throw new ExpressionNotSupportedException(expression, because: "the value being added to or subtracted from an enum must be represented as an integer");
}

serializer = enumTranslation.Serializer;
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
{
var valueExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression);
SerializationHelper.EnsureRepresentationIsNumeric(valueExpression, valueTranslation);
var ast = AstExpression.Abs(valueTranslation.Ast);
return new AggregationExpression(expression, ast, valueTranslation.Serializer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
{
var argumentExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
SerializationHelper.EnsureRepresentationIsNumeric(argumentExpression, argumentTranslation);
var ast = AstExpression.Ceil(argumentTranslation.Ast);
var serializer = BsonSerializer.LookupSerializer(expression.Type);
return new AggregationExpression(expression, ast, serializer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,9 @@ public static AggregationExpression Translate(TranslationContext context, Method
{
throw new ExpressionNotSupportedException(valueExpression, expression);
}
var representation = timeSpanSerializer.Representation;
SerializationHelper.EnsureRepresentationIsNumeric(valueExpression, timeSpanSerializer);

var serializerUnits = timeSpanSerializer.Units;
if (representation != BsonType.Int32 && representation != BsonType.Int64 && representation != BsonType.Double)
{
throw new ExpressionNotSupportedException(valueExpression, expression);
}
(unit, amount) = serializerUnits switch
{
TimeSpanUnits.Ticks => ("millisecond", AstExpression.Divide(valueTranslation.Ast, (double)TimeSpan.TicksPerMillisecond)),
Expand All @@ -174,6 +171,8 @@ public static AggregationExpression Translate(TranslationContext context, Method
else
{
var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression);
SerializationHelper.EnsureRepresentationIsNumeric(valueExpression, valueTranslation);

(unit, amount) = method.Name switch
{
"AddTicks" => ("millisecond", AstExpression.Divide(valueTranslation.Ast, (double)TimeSpan.TicksPerMillisecond)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
{
var argumentExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
SerializationHelper.EnsureRepresentationIsNumeric(argumentExpression, argumentTranslation);
var ast = AstExpression.Exp(argumentTranslation.Ast);
return new AggregationExpression(expression, ast, new DoubleSerializer());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
{
var argumentExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
SerializationHelper.EnsureRepresentationIsNumeric(argumentExpression, argumentTranslation);
var ast = AstExpression.Floor(argumentTranslation.Ast);
var serializer = BsonSerializer.LookupSerializer(expression.Type);
return new AggregationExpression(expression, ast, serializer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ string TranslateAnyOf(ReadOnlyCollection<Expression> arguments)

var startIndexExpression = arguments[1];
var startIndexTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, startIndexExpression);
SerializationHelper.EnsureRepresentationIsNumeric(startIndexExpression, startIndexTranslation);
return AstExpression.UseVarIfNotSimple("startIndex", startIndexTranslation.Ast);
}

Expand All @@ -127,6 +128,7 @@ string TranslateAnyOf(ReadOnlyCollection<Expression> arguments)

var countExpression = arguments[2];
var countTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, countExpression);
SerializationHelper.EnsureRepresentationIsNumeric(countExpression, countTranslation);
return AstExpression.UseVarIfNotSimple("count", countTranslation.Ast);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,18 @@ public static AggregationExpression Translate(TranslationContext context, Method
{
var objectTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, objectExpression);
var valueTranslation = TranslateValue();
var startIndexTranslation = startIndexExpression == null ? null : ExpressionToAggregationExpressionTranslator.Translate(context, startIndexExpression);
var countTranslation = countExpression == null ? null : ExpressionToAggregationExpressionTranslator.Translate(context, countExpression);
AggregationExpression startIndexTranslation = null;
if (startIndexExpression != null)
{
startIndexTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, startIndexExpression);
SerializationHelper.EnsureRepresentationIsNumeric(startIndexExpression, startIndexTranslation);
}
AggregationExpression countTranslation = null;
if (countExpression != null)
{
countTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, countExpression);
SerializationHelper.EnsureRepresentationIsNumeric(countExpression, countTranslation);
}
var ordinal = GetOrdinalFromComparisonType();

var endAst = CreateEndAst(startIndexTranslation?.Ast, countTranslation?.Ast);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
{
var argumentExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
SerializationHelper.EnsureRepresentationIsNumeric(argumentExpression, argumentTranslation);
AstExpression ast;
if (method.Is(MathMethod.LogWithNewBase))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ public static AggregationExpression Translate(TranslationContext context, Method
{
var xExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
var xTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, xExpression);
SerializationHelper.EnsureRepresentationIsNumeric(xExpression, xTranslation);
var yExpression = ConvertHelper.RemoveWideningConvert(arguments[1]);
var yTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, yExpression);
SerializationHelper.EnsureRepresentationIsNumeric(yExpression, yTranslation);
var ast = AstExpression.Pow(xTranslation.Ast, yTranslation.Ast);
return new AggregationExpression(expression, ast, new DoubleSerializer());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ public static AggregationExpression Translate(TranslationContext context, Method
{
var startExpression = arguments[0];
var startTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, startExpression);
var (startVar, startAst) = AstExpression.UseVarIfNotSimple("start", startTranslation.Ast);
SerializationHelper.EnsureRepresentationIsNumeric(startExpression, startTranslation);
var countExpression = arguments[1];
var countTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, countExpression);
SerializationHelper.EnsureRepresentationIsNumeric(countExpression, countTranslation);

var (startVar, startAst) = AstExpression.UseVarIfNotSimple("start", startTranslation.Ast);
var (countVar, countAst) = AstExpression.UseVarIfNotSimple("count", countTranslation.Ast);

var ast = AstExpression.Let(
startVar,
countVar,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ public static AggregationExpression Translate(TranslationContext context, Method
{
var argumentExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
SerializationHelper.EnsureRepresentationIsNumeric(argumentExpression, argumentTranslation);

AstExpression ast;
if (method.IsOneOf(__roundWithPlaceMethods))
{
var placeExpression = arguments[1];
var placeTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, placeExpression);
SerializationHelper.EnsureRepresentationIsNumeric(placeExpression, placeTranslation);
ast = AstExpression.Round(argumentTranslation.Ast, placeTranslation.Ast);
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
{
var argumentExpression = ConvertHelper.RemoveWideningConvert(arguments[0]);
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression);
SerializationHelper.EnsureRepresentationIsNumeric(argumentExpression, argumentTranslation);
var ast = AstExpression.Sqrt(argumentTranslation.Ast);
return new AggregationExpression(expression, ast, new DoubleSerializer());
}
Expand Down