Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update create_model() to support typing.Annotated as input #8947

Merged
merged 11 commits into from Mar 18, 2024
21 changes: 20 additions & 1 deletion pydantic/main.py
Expand Up @@ -1433,7 +1433,8 @@ def create_model( # noqa: C901
__cls_kwargs__: A dictionary of keyword arguments for class creation, such as `metaclass`.
__slots__: Deprecated. Should not be passed to `create_model`.
**field_definitions: Attributes of the new model. They should be passed in the format:
`<name>=(<type>, <default value>)` or `<name>=(<type>, <FieldInfo>)`.
`<name>=(<type>, <default value>)`, `<name>=(<type>, <FieldInfo>)`, or `typing.Annotated[<type>, <default value>]`, `typing.Annotated[<type>, <FieldInfo>]`.
Any additional metadata in `typing.Annotated[<type>, <default value | FieldInfo>, ...]` will be ignored.

Returns:
The new [model][pydantic.BaseModel].
Expand Down Expand Up @@ -1473,6 +1474,24 @@ def create_model( # noqa: C901
'Field definitions should be a `(<type>, <default>)`.',
code='create-model-field-definitions',
) from e

elif sys.version_info >= (3, 9) and _typing_extra.is_annotated(
f_def
): # typing.Annotated introduced on python3.9
try:
f_annotation = (
f_def.__origin__
) # __origin__ represents the annotation_type in Annotated[annotate_type, ...]

f_value = f_def.__metadata__[
0
] # Python Annotated requires at least one python variable to represent the annotation
except ValueError as e:
raise PydanticUserError(
'Field definitions should be a `typing.Annotation[type, definition]`.',
code='create-model-field-definitions',
wannieman98 marked this conversation as resolved.
Show resolved Hide resolved
) from e

else:
f_annotation, f_value = None, f_def

Expand Down
49 changes: 49 additions & 0 deletions tests/test_create_model.py
@@ -1,5 +1,7 @@
import platform
import re
import sys
import typing
from typing import Generic, Optional, Tuple, TypeVar

import pytest
Expand Down Expand Up @@ -516,6 +518,53 @@ def test_create_model_non_annotated():
create_model('FooModel', foo=(str, ...), bar=123)


@pytest.mark.skipif(sys.version_info < (3, 9), reason='Annotated is introduced after python3.9')
@pytest.mark.parametrize(
'annotation_type,field_info',
[
(bool, Field(alias='foo_bool_alias', description='foo boolean')),
(str, Field(alias='foo_str_alis', description='foo string')),
],
)
def test_create_model_typing_annotated_field_info(annotation_type, field_info):
annotated_foo = typing.Annotated[annotation_type, field_info]
model = create_model('FooModel', foo=annotated_foo, bar=(int, 123))

assert model.model_fields.keys() == {'foo', 'bar'}

foo = model.model_fields.get('foo')

assert foo is not None
assert foo.annotation == annotation_type
assert foo.alias == field_info.alias
assert foo.description == field_info.description


@pytest.mark.skipif(sys.version_info < (3, 9), reason='Annotated is introduced after python3.9')
def test_create_model_typing_annotated_field_uses_first_type():
annotated_foo = typing.Annotated[int, 10, str]
wannieman98 marked this conversation as resolved.
Show resolved Hide resolved
model = create_model('FooModel', foo=annotated_foo)

foo = model.model_fields.get('foo')

assert foo is not None
assert foo.annotation == int
assert foo.default == 10


@pytest.mark.skipif(sys.version_info < (3, 9), reason='Annotated is introduced after python3.9')
@pytest.mark.parametrize('annotation_type,default_value', [(bool, False), (str, 'default_value')])
def test_create_model_typing_annotated_field_default(annotation_type, default_value):
annotated_foo = typing.Annotated[annotation_type, default_value]
model = create_model('FooModel', foo=annotated_foo)

foo = model.model_fields.get('foo')

assert foo is not None
assert foo.annotation == annotation_type
assert foo.default == default_value


def test_create_model_tuple():
model = create_model('FooModel', foo=(Tuple[int, int], (1, 2)))
assert model().foo == (1, 2)
Expand Down