Skip to content

Commit

Permalink
Handle constraints being applied to schemas that don't accept it (#6951)
Browse files Browse the repository at this point in the history
Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
Co-authored-by: Samuel Colvin <s@muelcolvin.com>
  • Loading branch information
3 people committed Aug 10, 2023
1 parent efc8a69 commit 998a6ba
Show file tree
Hide file tree
Showing 9 changed files with 348 additions and 81 deletions.
5 changes: 0 additions & 5 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -1543,11 +1543,6 @@ def _apply_annotations(
# expand annotations before we start processing them so that `__prepare_pydantic_annotations` can consume
# individual items from GroupedMetadata
annotations = list(_known_annotated_metadata.expand_grouped_metadata(annotations))
non_field_infos, field_infos = [a for a in annotations if not isinstance(a, FieldInfo)], [
a for a in annotations if isinstance(a, FieldInfo)
]
if field_infos:
annotations = [*non_field_infos, FieldInfo.merge_field_infos(*field_infos)]
idx = -1
prepare = getattr(source_type, '__prepare_pydantic_annotations__', None)
if prepare:
Expand Down
210 changes: 161 additions & 49 deletions pydantic/_internal/_known_annotated_metadata.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import dataclasses
from collections import defaultdict
from copy import copy
from functools import partial
from typing import Any, Iterable
Expand All @@ -17,7 +18,7 @@
INEQUALITY = {'le', 'ge', 'lt', 'gt'}
NUMERIC_CONSTRAINTS = {'multiple_of', 'allow_inf_nan', *INEQUALITY}

STR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT, 'strip_whitespace', 'to_lower', 'to_upper'}
STR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT, 'strip_whitespace', 'to_lower', 'to_upper', 'pattern'}
BYTES_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}

LIST_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
Expand All @@ -28,15 +29,57 @@

FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
BOOL_CONSTRAINTS = STRICT

DATE_TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}

URL_CONSTRAINTS = {
'max_length',
'allowed_schemes',
'host_required',
'default_host',
'default_port',
'default_path',
}

TEXT_SCHEMA_TYPES = ('str', 'bytes', 'url', 'multi-host-url')
SEQUENCE_SCHEMA_TYPES = ('list', 'tuple', 'set', 'frozenset', 'generator', *TEXT_SCHEMA_TYPES)
NUMERIC_SCHEMA_TYPES = ('float', 'int', 'date', 'time', 'timedelta', 'datetime')

CONSTRAINTS_TO_ALLOWED_SCHEMAS: dict[str, set[str]] = defaultdict(set)
for constraint in STR_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(TEXT_SCHEMA_TYPES)
for constraint in BYTES_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bytes',))
for constraint in LIST_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('list',))
for constraint in TUPLE_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('tuple',))
for constraint in SET_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('set', 'frozenset'))
for constraint in DICT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('dict',))
for constraint in GENERATOR_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('generator',))
for constraint in FLOAT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('float',))
for constraint in INT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('int',))
for constraint in DATE_TIME_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('date', 'time', 'datetime'))
for constraint in TIMEDELTA_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('timedelta',))
for constraint in TIME_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('time',))
for schema_type in (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model'):
CONSTRAINTS_TO_ALLOWED_SCHEMAS['strict'].add(schema_type)
for constraint in URL_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('url', 'multi-host-url'))
for constraint in BOOL_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bool',))


def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
"""Expand the annotations.
Expand Down Expand Up @@ -96,82 +139,149 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No
PydanticCustomError: If `Predicate` fails.
"""
schema = schema.copy()
schema_update, _ = collect_known_metadata([annotation])
if isinstance(annotation, at.Gt):
if schema['type'] in NUMERIC_SCHEMA_TYPES:
schema.update(schema_update) # type: ignore
schema_update, other_metadata = collect_known_metadata([annotation])
schema_type = schema['type']
for constraint, value in schema_update.items():
if constraint not in CONSTRAINTS_TO_ALLOWED_SCHEMAS:
raise ValueError(f'Unknown constraint {constraint}')
allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint]

if schema_type in allowed_schemas:
schema[constraint] = value
continue

if constraint == 'allow_inf_nan' and value is False:
return cs.no_info_after_validator_function(
_validators.forbid_inf_nan_check,
schema,
)
elif constraint == 'pattern':
# insert a str schema to make sure the regex engine matches
return cs.chain_schema(
[
schema,
cs.str_schema(pattern=value),
]
)
elif constraint == 'gt':
return cs.no_info_after_validator_function(
partial(_validators.greater_than_validator, gt=value),
schema,
)
elif constraint == 'ge':
return cs.no_info_after_validator_function(
partial(_validators.greater_than_or_equal_validator, ge=value),
schema,
)
elif constraint == 'lt':
return cs.no_info_after_validator_function(
partial(_validators.less_than_validator, lt=value),
schema,
)
elif constraint == 'le':
return cs.no_info_after_validator_function(
partial(_validators.less_than_or_equal_validator, le=value),
schema,
)
elif constraint == 'multiple_of':
return cs.no_info_after_validator_function(
partial(_validators.multiple_of_validator, multiple_of=value),
schema,
)
elif constraint == 'min_length':
return cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=value),
schema,
)
elif constraint == 'max_length':
return cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=value),
schema,
)
elif constraint == 'strip_whitespace':
return cs.chain_schema(
[
schema,
cs.str_schema(strip_whitespace=True),
]
)
elif constraint == 'to_lower':
return cs.chain_schema(
[
schema,
cs.str_schema(to_lower=True),
]
)
elif constraint == 'to_upper':
return cs.chain_schema(
[
schema,
cs.str_schema(to_upper=True),
]
)
elif constraint == 'min_length':
return cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=annotation.min_length),
schema,
)
elif constraint == 'max_length':
return cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=annotation.max_length),
schema,
)
else:
raise RuntimeError(f'Unable to apply constraint {constraint} to schema {schema_type}')

for annotation in other_metadata:
if isinstance(annotation, at.Gt):
return cs.no_info_after_validator_function(
partial(_validators.greater_than_validator, gt=annotation.gt),
schema,
)
elif isinstance(annotation, at.Ge):
if schema['type'] in NUMERIC_SCHEMA_TYPES:
schema.update(schema_update) # type: ignore
else:
elif isinstance(annotation, at.Ge):
return cs.no_info_after_validator_function(
partial(_validators.greater_than_or_equal_validator, ge=annotation.ge),
schema,
)
elif isinstance(annotation, at.Lt):
if schema['type'] in NUMERIC_SCHEMA_TYPES:
schema.update(schema_update) # type: ignore
else:
elif isinstance(annotation, at.Lt):
return cs.no_info_after_validator_function(
partial(_validators.less_than_validator, lt=annotation.lt),
schema,
)
elif isinstance(annotation, at.Le):
if schema['type'] in NUMERIC_SCHEMA_TYPES:
schema.update(schema_update) # type: ignore
else:
elif isinstance(annotation, at.Le):
return cs.no_info_after_validator_function(
partial(_validators.less_than_or_equal_validator, le=annotation.le),
schema,
)
elif isinstance(annotation, at.MultipleOf):
if schema['type'] in NUMERIC_SCHEMA_TYPES:
schema.update(schema_update) # type: ignore
else:
elif isinstance(annotation, at.MultipleOf):
return cs.no_info_after_validator_function(
partial(_validators.multiple_of_validator, multiple_of=annotation.multiple_of),
schema,
)
elif isinstance(annotation, at.MinLen):
if schema['type'] in SEQUENCE_SCHEMA_TYPES:
schema.update(schema_update) # type: ignore
else:
elif isinstance(annotation, at.MinLen):
return cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=annotation.min_length),
schema,
)
elif isinstance(annotation, at.MaxLen):
if schema['type'] in SEQUENCE_SCHEMA_TYPES:
schema.update(schema_update) # type: ignore
else:
elif isinstance(annotation, at.MaxLen):
return cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=annotation.max_length),
schema,
)
elif isinstance(annotation, at.Predicate):
predicate_name = f'{annotation.func.__qualname__} ' if hasattr(annotation.func, '__qualname__') else ''

def val_func(v: Any) -> Any:
# annotation.func may also raise an exception, let it pass through
if not annotation.func(v):
raise PydanticCustomError(
'predicate_failed',
f'Predicate {predicate_name}failed', # type: ignore
{},
)
return v

return cs.no_info_after_validator_function(val_func, schema)
elif schema_update:
# for all other annotations just update the schema
# this includes things like `strict` which apply to pretty much every schema
schema.update(schema_update) # type: ignore
else:
elif isinstance(annotation, at.Predicate):
predicate_name = f'{annotation.func.__qualname__} ' if hasattr(annotation.func, '__qualname__') else ''

def val_func(v: Any) -> Any:
# annotation.func may also raise an exception, let it pass through
if not annotation.func(v):
raise PydanticCustomError(
'predicate_failed',
f'Predicate {predicate_name}failed', # type: ignore
)
return v

return cs.no_info_after_validator_function(val_func, schema)
# ignore any other unknown metadata
return None

return schema
Expand Down Expand Up @@ -206,7 +316,9 @@ def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any],
# But it seems dangerous!
if isinstance(annotation, PydanticGeneralMetadata):
res.update(annotation.__dict__)
elif isinstance(annotation, (at.BaseMetadata, PydanticMetadata)):
elif isinstance(annotation, PydanticMetadata):
res.update(dataclasses.asdict(annotation)) # type: ignore[call-overload]
elif isinstance(annotation, (at.MinLen, at.MaxLen, at.Gt, at.Ge, at.Lt, at.Le, at.MultipleOf)):
res.update(dataclasses.asdict(annotation)) # type: ignore[call-overload]
elif isinstance(annotation, type) and issubclass(annotation, PydanticMetadata):
# also support PydanticMetadata classes being used without initialisation,
Expand Down
7 changes: 7 additions & 0 deletions pydantic/_internal/_validators.py
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations as _annotations

import math
import re
import typing
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
Expand Down Expand Up @@ -269,3 +270,9 @@ def max_length_validator(x: Any, max_length: Any) -> Any:
{'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
)
return x


def forbid_inf_nan_check(x: Any) -> Any:
if not math.isfinite(x):
raise PydanticKnownError('finite_number')
return x
16 changes: 14 additions & 2 deletions pydantic/fields.py
Expand Up @@ -285,7 +285,13 @@ class MyModel(pydantic.BaseModel):
new_field_info = copy(field_info)
new_field_info.annotation = first_arg
new_field_info.frozen = final or field_info.frozen
new_field_info.metadata += [a for a in extra_args if not isinstance(a, FieldInfo)]
metadata: list[Any] = []
for a in extra_args:
if not isinstance(a, FieldInfo):
metadata.append(a)
else:
metadata.extend(a.metadata)
new_field_info.metadata = metadata
return new_field_info

return cls(annotation=annotation, frozen=final or None)
Expand Down Expand Up @@ -356,7 +362,13 @@ class MyModel(pydantic.BaseModel):
first_arg, *extra_args = typing_extensions.get_args(annotation)
field_infos = [a for a in extra_args if isinstance(a, FieldInfo)]
field_info = cls.merge_field_infos(*field_infos, annotation=first_arg, default=default)
field_info.metadata += [a for a in extra_args if not isinstance(a, FieldInfo)]
metadata: list[Any] = []
for a in extra_args:
if not isinstance(a, FieldInfo):
metadata.append(a)
else:
metadata.extend(a.metadata)
field_info.metadata = metadata
return field_info

return cls(annotation=annotation, default=default, frozen=final or None)
Expand Down

0 comments on commit 998a6ba

Please sign in to comment.