Skip to content

Commit

Permalink
Add eval_type_backport to handle union operator and builtin generic s…
Browse files Browse the repository at this point in the history
…ubscripting in older Pythons (#8209)
  • Loading branch information
alexmojaki committed Jan 16, 2024
1 parent 5de67f7 commit 8060fa1
Show file tree
Hide file tree
Showing 15 changed files with 311 additions and 95 deletions.
17 changes: 13 additions & 4 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -674,7 +674,7 @@ def _resolve_forward_ref(self, obj: Any) -> Any:
# class Model(BaseModel):
# x: SomeImportedTypeAliasWithAForwardReference
try:
obj = _typing_extra.evaluate_fwd_ref(obj, globalns=self._types_namespace)
obj = _typing_extra.eval_type_backport(obj, globalns=self._types_namespace)
except NameError as e:
raise PydanticUndefinedAnnotation.from_name_error(e) from e

Expand Down Expand Up @@ -1014,7 +1014,7 @@ def _common_field_schema( # C901
# Ensure that typevars get mapped to their concrete types:
types_namespace.update({k.__name__: v for k, v in self._typevars_map.items()})

evaluated = _typing_extra.eval_type_lenient(field_info.annotation, types_namespace, None)
evaluated = _typing_extra.eval_type_lenient(field_info.annotation, types_namespace)
if evaluated is not field_info.annotation and not has_instance_in_type(evaluated, PydanticRecursiveRef):
new_field_info = FieldInfo.from_annotation(evaluated)
field_info.annotation = new_field_info.annotation
Expand Down Expand Up @@ -1143,8 +1143,9 @@ def _type_alias_type_schema(

annotation = origin.__value__
typevars_map = get_standard_typevars_map(obj)

with self._types_namespace_stack.push(origin):
annotation = _typing_extra.eval_type_lenient(annotation, self._types_namespace, None)
annotation = _typing_extra.eval_type_lenient(annotation, self._types_namespace)
annotation = replace_types(annotation, typevars_map)
schema = self.generate_schema(annotation)
assert schema['type'] != 'definitions'
Expand Down
71 changes: 47 additions & 24 deletions pydantic/_internal/_typing_extra.py
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Callable
from functools import partial
from types import GetSetDescriptorType
from typing import TYPE_CHECKING, Any, Final, ForwardRef
from typing import TYPE_CHECKING, Any, Final

from typing_extensions import Annotated, Literal, TypeAliasType, TypeGuard, get_args, get_origin

Expand Down Expand Up @@ -213,20 +213,54 @@ def get_cls_type_hints_lenient(obj: Any, globalns: dict[str, Any] | None = None)
return hints


def eval_type_lenient(value: Any, globalns: dict[str, Any] | None, localns: dict[str, Any] | None) -> Any:
def eval_type_lenient(value: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None) -> Any:
"""Behaves like typing._eval_type, except it won't raise an error if a forward reference can't be resolved."""
if value is None:
value = NoneType
elif isinstance(value, str):
value = _make_forward_ref(value, is_argument=False, is_class=True)

try:
return typing._eval_type(value, globalns, localns) # type: ignore
return eval_type_backport(value, globalns, localns)
except NameError:
# the point of this function is to be tolerant to this case
return value


def eval_type_backport(
value: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
) -> Any:
"""Like `typing._eval_type`, but falls back to the `eval_type_backport` package if it's
installed to let older Python versions use newer typing features.
Specifically, this transforms `X | Y` into `typing.Union[X, Y]`
and `list[X]` into `typing.List[X]` etc. (for all the types made generic in PEP 585)
if the original syntax is not supported in the current Python version.
"""
try:
return typing._eval_type( # type: ignore
value, globalns, localns
)
except TypeError as e:
if not (isinstance(value, typing.ForwardRef) and is_backport_fixable_error(e)):
raise
try:
from eval_type_backport import eval_type_backport
except ImportError:
raise TypeError(
f'You have a type annotation {value.__forward_arg__!r} '
f'which makes use of newer typing features than are supported in your version of Python. '
f'To handle this error, you should either remove the use of new syntax '
f'or install the `eval_type_backport` package.'
) from e

return eval_type_backport(value, globalns, localns, try_default=False)


def is_backport_fixable_error(e: TypeError) -> bool:
msg = str(e)
return msg.startswith('unsupported operand type(s) for |: ') or "' object is not subscriptable" in msg


def get_function_type_hints(
function: Callable[..., Any], *, include_keys: set[str] | None = None, types_namespace: dict[str, Any] | None = None
) -> dict[str, Any]:
Expand All @@ -248,7 +282,7 @@ def get_function_type_hints(
elif isinstance(value, str):
value = _make_forward_ref(value)

type_hints[name] = typing._eval_type(value, globalns, types_namespace) # type: ignore
type_hints[name] = eval_type_backport(value, globalns, types_namespace)

return type_hints

Expand Down Expand Up @@ -363,11 +397,15 @@ def get_type_hints( # noqa: C901
if isinstance(value, str):
value = _make_forward_ref(value, is_argument=False, is_class=True)

value = typing._eval_type(value, base_globals, base_locals) # type: ignore
value = eval_type_backport(value, base_globals, base_locals)
hints[name] = value
return (
hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore
)
if not include_extras and hasattr(typing, '_strip_annotations'):
return {
k: typing._strip_annotations(t) # type: ignore
for k, t in hints.items()
}
else:
return hints

if globalns is None:
if isinstance(obj, types.ModuleType):
Expand Down Expand Up @@ -403,28 +441,13 @@ def get_type_hints( # noqa: C901
is_argument=not isinstance(obj, types.ModuleType),
is_class=False,
)
value = typing._eval_type(value, globalns, localns) # type: ignore
value = eval_type_backport(value, globalns, localns)
if name in defaults and defaults[name] is None:
value = typing.Optional[value]
hints[name] = value
return hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore


if sys.version_info < (3, 9):

def evaluate_fwd_ref(
ref: ForwardRef, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
) -> Any:
return ref._evaluate(globalns=globalns, localns=localns)

else:

def evaluate_fwd_ref(
ref: ForwardRef, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
) -> Any:
return ref._evaluate(globalns=globalns, localns=localns, recursive_guard=frozenset())


def is_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
# The dataclasses.is_dataclass function doesn't seem to provide TypeGuard functionality,
# so I created this convenience function
Expand Down
2 changes: 1 addition & 1 deletion pydantic/fields.py
Expand Up @@ -546,7 +546,7 @@ def apply_typevars_map(self, typevars_map: dict[Any, Any] | None, types_namespac
pydantic._internal._generics.replace_types is used for replacing the typevars with
their concrete types.
"""
annotation = _typing_extra.eval_type_lenient(self.annotation, types_namespace, None)
annotation = _typing_extra.eval_type_lenient(self.annotation, types_namespace)
self.annotation = _generics.replace_types(annotation, typevars_map)

def __repr_args__(self) -> ReprArgs:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Expand Up @@ -91,13 +91,15 @@ docs = [
"pydantic-extra-types @ git+https://github.com/pydantic/pydantic-extra-types.git@main"
]
linting = [
"eval-type-backport>=0.1.3",
"ruff==0.1.3",
"mypy~=1.1.1",
]
testing = [
"cloudpickle",
"coverage[toml]",
"dirty-equals",
"eval-type-backport",
"pytest",
"pytest-mock",
"pytest-pretty",
Expand Down
5 changes: 2 additions & 3 deletions tests/test_config.py
Expand Up @@ -4,7 +4,7 @@
from contextlib import nullcontext as does_not_raise
from decimal import Decimal
from inspect import signature
from typing import Any, ContextManager, Iterable, NamedTuple, Optional, Type, Union, get_type_hints
from typing import Any, ContextManager, Iterable, NamedTuple, Optional, Type, Union

from dirty_equals import HasRepr, IsPartialDict
from pydantic_core import SchemaError, SchemaSerializer, SchemaValidator
Expand All @@ -24,6 +24,7 @@
)
from pydantic._internal._config import ConfigWrapper, config_defaults
from pydantic._internal._mock_val_ser import MockValSer
from pydantic._internal._typing_extra import get_type_hints
from pydantic.config import ConfigDict, JsonValue
from pydantic.dataclasses import dataclass as pydantic_dataclass
from pydantic.errors import PydanticUserError
Expand Down Expand Up @@ -523,7 +524,6 @@ class Child(Mixin, Parent):
assert Child.model_config.get('use_enum_values') is True


@pytest.mark.skipif(sys.version_info < (3, 10), reason='different on older versions')
def test_config_wrapper_match():
localns = {'_GenerateSchema': GenerateSchema, 'GenerateSchema': GenerateSchema, 'JsonValue': JsonValue}
config_dict_annotations = [(k, str(v)) for k, v in get_type_hints(ConfigDict, localns=localns).items()]
Expand Down Expand Up @@ -567,7 +567,6 @@ def check_foo(cls, v):
assert src_exc.__notes__[0] == '\nPydantic: cause of loc: foo'


@pytest.mark.skipif(sys.version_info < (3, 10), reason='different on older versions')
def test_config_defaults_match():
localns = {'_GenerateSchema': GenerateSchema, 'GenerateSchema': GenerateSchema}
config_dict_keys = sorted(list(get_type_hints(ConfigDict, localns=localns).keys()))
Expand Down
27 changes: 12 additions & 15 deletions tests/test_edge_cases.py
Expand Up @@ -303,7 +303,6 @@ class Model(BaseModel):
(dict, frozenset, list, set, tuple, type),
],
)
@pytest.mark.skipif(sys.version_info < (3, 9), reason='PEP585 generics only supported for python 3.9 and above')
def test_pep585_generic_types(dict_cls, frozenset_cls, list_cls, set_cls, tuple_cls, type_cls):
class Type1:
pass
Expand All @@ -313,19 +312,19 @@ class Type2:

class Model(BaseModel, arbitrary_types_allowed=True):
a: dict_cls
a1: dict_cls[str, int]
a1: 'dict_cls[str, int]'
b: frozenset_cls
b1: frozenset_cls[int]
b1: 'frozenset_cls[int]'
c: list_cls
c1: list_cls[int]
c1: 'list_cls[int]'
d: set_cls
d1: set_cls[int]
d1: 'set_cls[int]'
e: tuple_cls
e1: tuple_cls[int]
e2: tuple_cls[int, ...]
e3: tuple_cls[()]
e1: 'tuple_cls[int]'
e2: 'tuple_cls[int, ...]'
e3: 'tuple_cls[()]'
f: type_cls
f1: type_cls[Type1]
f1: 'type_cls[Type1]'

default_model_kwargs = dict(
a={},
Expand Down Expand Up @@ -361,7 +360,7 @@ class Model(BaseModel, arbitrary_types_allowed=True):
assert m.f1 == Type1

with pytest.raises(ValidationError) as exc_info:
Model(**(default_model_kwargs | {'e3': (1,)}))
Model(**{**default_model_kwargs, 'e3': (1,)})
# insert_assert(exc_info.value.errors(include_url=False))
assert exc_info.value.errors(include_url=False) == [
{
Expand All @@ -373,10 +372,10 @@ class Model(BaseModel, arbitrary_types_allowed=True):
}
]

Model(**(default_model_kwargs | {'f': Type2}))
Model(**{**default_model_kwargs, 'f': Type2})

with pytest.raises(ValidationError) as exc_info:
Model(**(default_model_kwargs | {'f1': Type2}))
Model(**{**default_model_kwargs, 'f1': Type2})
# insert_assert(exc_info.value.errors(include_url=False))
assert exc_info.value.errors(include_url=False) == [
{
Expand Down Expand Up @@ -2380,10 +2379,9 @@ class Square(AbstractSquare):
Square(side=1.0)


@pytest.mark.skipif(sys.version_info < (3, 9), reason='cannot use list.__class_getitem__ before 3.9')
def test_generic_wrapped_forwardref():
class Operation(BaseModel):
callbacks: list['PathItem']
callbacks: 'list[PathItem]'

class PathItem(BaseModel):
pass
Expand Down Expand Up @@ -2477,7 +2475,6 @@ class C(BaseModel):
]


@pytest.mark.skipif(sys.version_info < (3, 9), reason='cannot parametrize types before 3.9')
@pytest.mark.parametrize(
('sequence_type', 'input_data', 'expected_error_type', 'expected_error_msg', 'expected_error_ctx'),
[
Expand Down

0 comments on commit 8060fa1

Please sign in to comment.