Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow validation against max_digits and decimals to pass if normalized or non-normalized input is valid #1049

Merged
merged 2 commits into from Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
129 changes: 76 additions & 53 deletions src/validators/decimal.rs
Expand Up @@ -83,6 +83,41 @@ impl_py_gc_traverse!(DecimalValidator {
gt
});

fn extract_decimal_digits_info<'data>(
decimal: &PyAny,
normalized: bool,
py: Python<'data>,
) -> ValResult<'data, (u64, u64)> {
let mut normalized_decimal: Option<&PyAny> = None;
if normalized {
normalized_decimal = Some(decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal));
}
let (_, digit_tuple, exponent): (&PyAny, &PyTuple, &PyAny) = normalized_decimal
.unwrap_or(decimal)
.call_method0(intern!(py, "as_tuple"))?
.extract()?;

// finite values have numeric exponent, we checked is_finite above
let exponent: i64 = exponent.extract()?;
let mut digits: u64 = u64::try_from(digit_tuple.len()).map_err(|e| ValError::InternalErr(e.into()))?;
let decimals;
if exponent >= 0 {
// A positive exponent adds that many trailing zeros.
digits += exponent as u64;
decimals = 0;
} else {
// If the absolute value of the negative exponent is larger than the
// number of digits, then it's the same as the number of digits,
// because it'll consume all the digits in digit_tuple and then
// add abs(exponent) - len(digit_tuple) leading zeros after the
// decimal point.
decimals = exponent.unsigned_abs();
digits = digits.max(decimals);
}

Ok((decimals, digits))
}

impl Validator for DecimalValidator {
fn validate<'data>(
&self,
Expand All @@ -98,65 +133,53 @@ impl Validator for DecimalValidator {
}

if self.check_digits {
let normalized_value = decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal);
let (_, digit_tuple, exponent): (&PyAny, &PyTuple, &PyAny) =
normalized_value.call_method0(intern!(py, "as_tuple"))?.extract()?;
if let Ok((normalized_decimals, normalized_digits)) = extract_decimal_digits_info(decimal, true, py) {
if let Ok((decimals, digits)) = extract_decimal_digits_info(decimal, false, py) {
if let Some(max_digits) = self.max_digits {
if (digits > max_digits) & (normalized_digits > max_digits) {
return Err(ValError::new(
ErrorType::DecimalMaxDigits {
max_digits,
context: None,
},
input,
));
}
}

// finite values have numeric exponent, we checked is_finite above
let exponent: i64 = exponent.extract()?;
let mut digits: u64 = u64::try_from(digit_tuple.len()).map_err(|e| ValError::InternalErr(e.into()))?;
let decimals;
if exponent >= 0 {
// A positive exponent adds that many trailing zeros.
digits += exponent as u64;
decimals = 0;
} else {
// If the absolute value of the negative exponent is larger than the
// number of digits, then it's the same as the number of digits,
// because it'll consume all the digits in digit_tuple and then
// add abs(exponent) - len(digit_tuple) leading zeros after the
// decimal point.
decimals = exponent.unsigned_abs();
digits = digits.max(decimals);
}
if let Some(decimal_places) = self.decimal_places {
if (decimals > decimal_places) & (normalized_decimals > decimal_places) {
return Err(ValError::new(
ErrorType::DecimalMaxPlaces {
decimal_places,
context: None,
},
input,
));
}

if let Some(max_digits) = self.max_digits {
if digits > max_digits {
return Err(ValError::new(
ErrorType::DecimalMaxDigits {
max_digits,
context: None,
},
input,
));
}
}
if let Some(max_digits) = self.max_digits {
let whole_digits = digits.saturating_sub(decimals);
let max_whole_digits = max_digits.saturating_sub(decimal_places);

if let Some(decimal_places) = self.decimal_places {
if decimals > decimal_places {
return Err(ValError::new(
ErrorType::DecimalMaxPlaces {
decimal_places,
context: None,
},
input,
));
}
let normalized_whole_digits = normalized_digits.saturating_sub(normalized_decimals);
let normalized_max_whole_digits = max_digits.saturating_sub(decimal_places);

if let Some(max_digits) = self.max_digits {
let whole_digits = digits.saturating_sub(decimals);
let max_whole_digits = max_digits.saturating_sub(decimal_places);
if whole_digits > max_whole_digits {
return Err(ValError::new(
ErrorType::DecimalWholeDigits {
whole_digits: max_whole_digits,
context: None,
},
input,
));
if (whole_digits > max_whole_digits)
& (normalized_whole_digits > normalized_max_whole_digits)
{
return Err(ValError::new(
ErrorType::DecimalWholeDigits {
whole_digits: max_whole_digits,
context: None,
},
input,
));
}
}
}
}
}
};
}
}

Expand Down
28 changes: 28 additions & 0 deletions tests/validators/test_decimal.py
Expand Up @@ -437,3 +437,31 @@ def test_non_finite_constrained_decimal_values(input_value, allow_inf_nan, expec
def test_validate_scientific_notation_from_json(input_value, expected):
v = SchemaValidator({'type': 'decimal'})
assert v.validate_json(input_value) == expected


def test_validate_max_digits_and_decimal_places() -> None:
v = SchemaValidator({'type': 'decimal', 'max_digits': 5, 'decimal_places': 2})

# valid inputs
assert v.validate_json('1.23') == Decimal('1.23')
assert v.validate_json('123.45') == Decimal('123.45')
assert v.validate_json('-123.45') == Decimal('-123.45')

# invalid inputs
with pytest.raises(ValidationError):
v.validate_json('1234.56') # too many digits
with pytest.raises(ValidationError):
v.validate_json('123.456') # too many decimal places
with pytest.raises(ValidationError):
v.validate_json('123456') # too many digits
with pytest.raises(ValidationError):
v.validate_json('abc') # not a valid decimal


def test_validate_max_digits_and_decimal_places_edge_case() -> None:
v = SchemaValidator({'type': 'decimal', 'max_digits': 34, 'decimal_places': 18})

# valid inputs
assert v.validate_python(Decimal('9999999999999999.999999999999999999')) == Decimal(
'9999999999999999.999999999999999999'
)