Skip to content

Commit

Permalink
Allow validation against max_digits and decimals to pass if norma…
Browse files Browse the repository at this point in the history
…lized or non-normalized input is valid (#1049)
  • Loading branch information
sydney-runkle committed Oct 31, 2023
1 parent dd75669 commit ef3e813
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 53 deletions.
129 changes: 76 additions & 53 deletions src/validators/decimal.rs
Expand Up @@ -83,6 +83,41 @@ impl_py_gc_traverse!(DecimalValidator {
gt
});

fn extract_decimal_digits_info<'data>(
decimal: &PyAny,
normalized: bool,
py: Python<'data>,
) -> ValResult<'data, (u64, u64)> {
let mut normalized_decimal: Option<&PyAny> = None;
if normalized {
normalized_decimal = Some(decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal));
}
let (_, digit_tuple, exponent): (&PyAny, &PyTuple, &PyAny) = normalized_decimal
.unwrap_or(decimal)
.call_method0(intern!(py, "as_tuple"))?
.extract()?;

// finite values have numeric exponent, we checked is_finite above
let exponent: i64 = exponent.extract()?;
let mut digits: u64 = u64::try_from(digit_tuple.len()).map_err(|e| ValError::InternalErr(e.into()))?;
let decimals;
if exponent >= 0 {
// A positive exponent adds that many trailing zeros.
digits += exponent as u64;
decimals = 0;
} else {
// If the absolute value of the negative exponent is larger than the
// number of digits, then it's the same as the number of digits,
// because it'll consume all the digits in digit_tuple and then
// add abs(exponent) - len(digit_tuple) leading zeros after the
// decimal point.
decimals = exponent.unsigned_abs();
digits = digits.max(decimals);
}

Ok((decimals, digits))
}

impl Validator for DecimalValidator {
fn validate<'data>(
&self,
Expand All @@ -98,65 +133,53 @@ impl Validator for DecimalValidator {
}

if self.check_digits {
let normalized_value = decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal);
let (_, digit_tuple, exponent): (&PyAny, &PyTuple, &PyAny) =
normalized_value.call_method0(intern!(py, "as_tuple"))?.extract()?;
if let Ok((normalized_decimals, normalized_digits)) = extract_decimal_digits_info(decimal, true, py) {
if let Ok((decimals, digits)) = extract_decimal_digits_info(decimal, false, py) {
if let Some(max_digits) = self.max_digits {
if (digits > max_digits) & (normalized_digits > max_digits) {
return Err(ValError::new(
ErrorType::DecimalMaxDigits {
max_digits,
context: None,
},
input,
));
}
}

// finite values have numeric exponent, we checked is_finite above
let exponent: i64 = exponent.extract()?;
let mut digits: u64 = u64::try_from(digit_tuple.len()).map_err(|e| ValError::InternalErr(e.into()))?;
let decimals;
if exponent >= 0 {
// A positive exponent adds that many trailing zeros.
digits += exponent as u64;
decimals = 0;
} else {
// If the absolute value of the negative exponent is larger than the
// number of digits, then it's the same as the number of digits,
// because it'll consume all the digits in digit_tuple and then
// add abs(exponent) - len(digit_tuple) leading zeros after the
// decimal point.
decimals = exponent.unsigned_abs();
digits = digits.max(decimals);
}
if let Some(decimal_places) = self.decimal_places {
if (decimals > decimal_places) & (normalized_decimals > decimal_places) {
return Err(ValError::new(
ErrorType::DecimalMaxPlaces {
decimal_places,
context: None,
},
input,
));
}

if let Some(max_digits) = self.max_digits {
if digits > max_digits {
return Err(ValError::new(
ErrorType::DecimalMaxDigits {
max_digits,
context: None,
},
input,
));
}
}
if let Some(max_digits) = self.max_digits {
let whole_digits = digits.saturating_sub(decimals);
let max_whole_digits = max_digits.saturating_sub(decimal_places);

if let Some(decimal_places) = self.decimal_places {
if decimals > decimal_places {
return Err(ValError::new(
ErrorType::DecimalMaxPlaces {
decimal_places,
context: None,
},
input,
));
}
let normalized_whole_digits = normalized_digits.saturating_sub(normalized_decimals);
let normalized_max_whole_digits = max_digits.saturating_sub(decimal_places);

if let Some(max_digits) = self.max_digits {
let whole_digits = digits.saturating_sub(decimals);
let max_whole_digits = max_digits.saturating_sub(decimal_places);
if whole_digits > max_whole_digits {
return Err(ValError::new(
ErrorType::DecimalWholeDigits {
whole_digits: max_whole_digits,
context: None,
},
input,
));
if (whole_digits > max_whole_digits)
& (normalized_whole_digits > normalized_max_whole_digits)
{
return Err(ValError::new(
ErrorType::DecimalWholeDigits {
whole_digits: max_whole_digits,
context: None,
},
input,
));
}
}
}
}
}
};
}
}

Expand Down
28 changes: 28 additions & 0 deletions tests/validators/test_decimal.py
Expand Up @@ -437,3 +437,31 @@ def test_non_finite_constrained_decimal_values(input_value, allow_inf_nan, expec
def test_validate_scientific_notation_from_json(input_value, expected):
v = SchemaValidator({'type': 'decimal'})
assert v.validate_json(input_value) == expected


def test_validate_max_digits_and_decimal_places() -> None:
v = SchemaValidator({'type': 'decimal', 'max_digits': 5, 'decimal_places': 2})

# valid inputs
assert v.validate_json('1.23') == Decimal('1.23')
assert v.validate_json('123.45') == Decimal('123.45')
assert v.validate_json('-123.45') == Decimal('-123.45')

# invalid inputs
with pytest.raises(ValidationError):
v.validate_json('1234.56') # too many digits
with pytest.raises(ValidationError):
v.validate_json('123.456') # too many decimal places
with pytest.raises(ValidationError):
v.validate_json('123456') # too many digits
with pytest.raises(ValidationError):
v.validate_json('abc') # not a valid decimal


def test_validate_max_digits_and_decimal_places_edge_case() -> None:
v = SchemaValidator({'type': 'decimal', 'max_digits': 34, 'decimal_places': 18})

# valid inputs
assert v.validate_python(Decimal('9999999999999999.999999999999999999')) == Decimal(
'9999999999999999.999999999999999999'
)

0 comments on commit ef3e813

Please sign in to comment.