Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle constraints being applied to schemas that don't accept it #6951

Merged
merged 14 commits into from Aug 10, 2023
5 changes: 0 additions & 5 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -1516,11 +1516,6 @@ def _apply_annotations(
# expand annotations before we start processing them so that `__prepare_pydantic_annotations` can consume
# individual items from GroupedMetadata
annotations = list(_known_annotated_metadata.expand_grouped_metadata(annotations))
non_field_infos, field_infos = [a for a in annotations if not isinstance(a, FieldInfo)], [
a for a in annotations if isinstance(a, FieldInfo)
]
if field_infos:
annotations = [*non_field_infos, FieldInfo.merge_field_infos(*field_infos)]
idx = -1
prepare = getattr(source_type, '__prepare_pydantic_annotations__', None)
if prepare:
Expand Down
211 changes: 162 additions & 49 deletions pydantic/_internal/_known_annotated_metadata.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import dataclasses
from collections import defaultdict
from copy import copy
from functools import partial
from typing import Any, Iterable
Expand All @@ -17,7 +18,7 @@
INEQUALITY = {'le', 'ge', 'lt', 'gt'}
NUMERIC_CONSTRAINTS = {'multiple_of', 'allow_inf_nan', *INEQUALITY}

STR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT, 'strip_whitespace', 'to_lower', 'to_upper'}
STR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT, 'strip_whitespace', 'to_lower', 'to_upper', 'pattern'}
BYTES_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}

LIST_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
Expand All @@ -28,15 +29,57 @@

FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
BOOL_CONSTRAINTS = STRICT

DATE_TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}

URL_CONSTRAINTS = {
'max_length',
'allowed_schemes',
'host_required',
'default_host',
'default_port',
'default_path',
}

TEXT_SCHEMA_TYPES = ('str', 'bytes', 'url', 'multi-host-url')
SEQUENCE_SCHEMA_TYPES = ('list', 'tuple', 'set', 'frozenset', 'generator', *TEXT_SCHEMA_TYPES)
NUMERIC_SCHEMA_TYPES = ('float', 'int', 'date', 'time', 'timedelta', 'datetime')

CONSTRAINTS_TO_ALLOWED_SCHEMAS: dict[str, set[str]] = defaultdict(set)
for constraint in STR_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(TEXT_SCHEMA_TYPES)
for constraint in BYTES_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bytes',))
for constraint in LIST_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('list',))
for constraint in TUPLE_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('tuple',))
for constraint in SET_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('set', 'frozenset'))
for constraint in DICT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('dict',))
for constraint in GENERATOR_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('generator',))
for constraint in FLOAT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('float',))
for constraint in INT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('int',))
for constraint in DATE_TIME_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('date', 'time', 'datetime'))
for constraint in TIMEDELTA_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('timedelta',))
for constraint in TIME_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('time',))
for schema_type in (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model'):
CONSTRAINTS_TO_ALLOWED_SCHEMAS['strict'].add(schema_type)
for constraint in URL_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('url', 'multi-host-url'))
for constraint in BOOL_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bool',))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this get called on from pydantic import X? If so we might need to worry about import time, if it's later/when require it shouldn't be a problem.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it is imported immediately. It's not all that complex though, just assigning times to a dict. I could hide it behind an lru_cache function or something, but that would probably make schema generation slower (one more dict lookup), and that seems to be a bigger problem right now than import time



def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
"""Expand the annotations.
Expand Down Expand Up @@ -96,82 +139,150 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No
PydanticCustomError: If `Predicate` fails.
"""
schema = schema.copy()
schema_update, _ = collect_known_metadata([annotation])
if isinstance(annotation, at.Gt):
if schema['type'] in NUMERIC_SCHEMA_TYPES:
schema.update(schema_update) # type: ignore
schema_update, other_metadata = collect_known_metadata([annotation])
schema_type = schema['type']
for constraint, value in schema_update.items():
if constraint not in CONSTRAINTS_TO_ALLOWED_SCHEMAS:
raise ValueError(f'Unknown constraint {constraint}')
allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint]

if schema_type in allowed_schemas:
schema[constraint] = value
continue

if constraint == 'allow_inf_nan' and value is False:
return cs.no_info_after_validator_function(
_validators.forbid_inf_nan_check,
schema,
)
elif constraint == 'pattern':
# insert a str schema to make sure the regex engine matches
return cs.chain_schema(
[
schema,
cs.str_schema(pattern=value),
]
)
adriangb marked this conversation as resolved.
Show resolved Hide resolved
elif constraint == 'gt':
return cs.no_info_after_validator_function(
partial(_validators.greater_than_validator, gt=value),
schema,
)
elif constraint == 'ge':
return cs.no_info_after_validator_function(
partial(_validators.greater_than_or_equal_validator, ge=value),
schema,
)
elif constraint == 'lt':
return cs.no_info_after_validator_function(
partial(_validators.less_than_validator, lt=value),
schema,
)
elif constraint == 'le':
return cs.no_info_after_validator_function(
partial(_validators.less_than_or_equal_validator, le=value),
schema,
)
elif constraint == 'multiple_of':
return cs.no_info_after_validator_function(
partial(_validators.multiple_of_validator, multiple_of=value),
schema,
)
elif constraint == 'min_length':
return cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=value),
schema,
)
elif constraint == 'max_length':
return cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=value),
schema,
)
elif constraint == 'strip_whitespace':
return cs.chain_schema(
[
schema,
cs.str_schema(strip_whitespace=True),
]
)
elif constraint == 'to_lower':
return cs.chain_schema(
[
schema,
cs.str_schema(to_lower=True),
]
)
elif constraint == 'to_upper':
return cs.chain_schema(
[
schema,
cs.str_schema(to_upper=True),
]
)
elif constraint == 'min_length':
return cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=annotation.min_length),
schema,
)
elif constraint == 'max_length':
return cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=annotation.max_length),
schema,
)
else:
raise RuntimeError(f'Unable to apply constraint {constraint} to schema {schema_type}')

for annotation in other_metadata:
if isinstance(annotation, at.Gt):
return cs.no_info_after_validator_function(
partial(_validators.greater_than_validator, gt=annotation.gt),
schema,
)
elif isinstance(annotation, at.Ge):
if schema['type'] in NUMERIC_SCHEMA_TYPES:
schema.update(schema_update) # type: ignore
else:
elif isinstance(annotation, at.Ge):
return cs.no_info_after_validator_function(
partial(_validators.greater_than_or_equal_validator, ge=annotation.ge),
schema,
)
elif isinstance(annotation, at.Lt):
if schema['type'] in NUMERIC_SCHEMA_TYPES:
schema.update(schema_update) # type: ignore
else:
elif isinstance(annotation, at.Lt):
return cs.no_info_after_validator_function(
partial(_validators.less_than_validator, lt=annotation.lt),
schema,
)
elif isinstance(annotation, at.Le):
if schema['type'] in NUMERIC_SCHEMA_TYPES:
schema.update(schema_update) # type: ignore
else:
elif isinstance(annotation, at.Le):
return cs.no_info_after_validator_function(
partial(_validators.less_than_or_equal_validator, le=annotation.le),
schema,
)
elif isinstance(annotation, at.MultipleOf):
if schema['type'] in NUMERIC_SCHEMA_TYPES:
schema.update(schema_update) # type: ignore
else:
elif isinstance(annotation, at.MultipleOf):
return cs.no_info_after_validator_function(
partial(_validators.multiple_of_validator, multiple_of=annotation.multiple_of),
schema,
)
elif isinstance(annotation, at.MinLen):
if schema['type'] in SEQUENCE_SCHEMA_TYPES:
schema.update(schema_update) # type: ignore
else:
elif isinstance(annotation, at.MinLen):
return cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=annotation.min_length),
schema,
)
elif isinstance(annotation, at.MaxLen):
if schema['type'] in SEQUENCE_SCHEMA_TYPES:
schema.update(schema_update) # type: ignore
else:
elif isinstance(annotation, at.MaxLen):
return cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=annotation.max_length),
schema,
)
elif isinstance(annotation, at.Predicate):
predicate_name = f'{annotation.func.__qualname__} ' if hasattr(annotation.func, '__qualname__') else ''

def val_func(v: Any) -> Any:
# annotation.func may also raise an exception, let it pass through
if not annotation.func(v):
raise PydanticCustomError(
'predicate_failed',
f'Predicate {predicate_name}failed', # type: ignore
{},
)
return v

return cs.no_info_after_validator_function(val_func, schema)
elif schema_update:
# for all other annotations just update the schema
# this includes things like `strict` which apply to pretty much every schema
schema.update(schema_update) # type: ignore
else:
elif isinstance(annotation, at.Predicate):
predicate_name = f'{annotation.func.__qualname__} ' if hasattr(annotation.func, '__qualname__') else ''

def val_func(v: Any) -> Any:
# annotation.func may also raise an exception, let it pass through
if not annotation.func(v):
raise PydanticCustomError(
'predicate_failed',
f'Predicate {predicate_name} failed', # type: ignore
{},
)
adriangb marked this conversation as resolved.
Show resolved Hide resolved
return v

return cs.no_info_after_validator_function(val_func, schema)
# ignore any other unknown metadata
return None

return schema
Expand Down Expand Up @@ -206,7 +317,9 @@ def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any],
# But it seems dangerous!
if isinstance(annotation, PydanticGeneralMetadata):
res.update(annotation.__dict__)
elif isinstance(annotation, (at.BaseMetadata, PydanticMetadata)):
elif isinstance(annotation, PydanticMetadata):
res.update(dataclasses.asdict(annotation)) # type: ignore[call-overload]
elif isinstance(annotation, (at.MinLen, at.MaxLen, at.Gt, at.Ge, at.Lt, at.Le, at.MultipleOf)):
res.update(dataclasses.asdict(annotation)) # type: ignore[call-overload]
elif isinstance(annotation, type) and issubclass(annotation, PydanticMetadata):
# also support PydanticMetadata classes being used without initialisation,
Expand Down
7 changes: 7 additions & 0 deletions pydantic/_internal/_validators.py
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations as _annotations

import math
import re
import typing
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
Expand Down Expand Up @@ -269,3 +270,9 @@ def max_length_validator(x: Any, max_length: Any) -> Any:
{'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
)
return x


def forbid_inf_nan_check(x: Any) -> Any:
if not math.isfinite(x):
raise PydanticKnownError('finite_number')
return x
Comment on lines +275 to +278
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test for the error case here

16 changes: 14 additions & 2 deletions pydantic/fields.py
Expand Up @@ -285,7 +285,13 @@ class MyModel(pydantic.BaseModel):
new_field_info = copy(field_info)
new_field_info.annotation = first_arg
new_field_info.frozen = final or field_info.frozen
new_field_info.metadata += [a for a in extra_args if not isinstance(a, FieldInfo)]
metadata: list[Any] = []
for a in extra_args:
if not isinstance(a, FieldInfo):
metadata.append(a)
else:
metadata.extend(a.metadata)
new_field_info.metadata = metadata
return new_field_info

return cls(annotation=annotation, frozen=final or None)
Expand Down Expand Up @@ -356,7 +362,13 @@ class MyModel(pydantic.BaseModel):
first_arg, *extra_args = typing_extensions.get_args(annotation)
field_infos = [a for a in extra_args if isinstance(a, FieldInfo)]
field_info = cls.merge_field_infos(*field_infos, annotation=first_arg, default=default)
field_info.metadata += [a for a in extra_args if not isinstance(a, FieldInfo)]
metadata: list[Any] = []
for a in extra_args:
if not isinstance(a, FieldInfo):
metadata.append(a)
else:
metadata.extend(a.metadata)
field_info.metadata = metadata
return field_info

return cls(annotation=annotation, default=default, frozen=final or None)
Expand Down