Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix defer_build behavior with TypeAdapter #7736

Merged
merged 2 commits into from Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions pydantic/_internal/_dataclasses.py
Expand Up @@ -27,7 +27,7 @@
from ._fields import collect_dataclass_fields
from ._generate_schema import GenerateSchema
from ._generics import get_standard_typevars_map
from ._mock_val_ser import set_dataclass_mock_validator
from ._mock_val_ser import set_dataclass_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._utils import is_valid_identifier

Expand Down Expand Up @@ -153,14 +153,14 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -
except PydanticUndefinedAnnotation as e:
if raise_errors:
raise
set_dataclass_mock_validator(cls, cls.__name__, f'`{e.name}`')
set_dataclass_mocks(cls, cls.__name__, f'`{e.name}`')
return False

core_config = config_wrapper.core_config(cls)

schema = gen_schema.collect_definitions(schema)
if collect_invalid_schemas(schema):
set_dataclass_mock_validator(cls, cls.__name__, 'all referenced types')
set_dataclass_mocks(cls, cls.__name__, 'all referenced types')
return False

schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema))
Expand Down
38 changes: 30 additions & 8 deletions pydantic/_internal/_mock_val_ser.py
Expand Up @@ -70,7 +70,7 @@ def set_model_mocks(cls: type[BaseModel], cls_name: str, undefined_name: str = '
)

def attempt_rebuild_validator() -> SchemaValidator | None:
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5):
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
return cls.__pydantic_validator__
else:
return None
Expand All @@ -83,7 +83,7 @@ def attempt_rebuild_validator() -> SchemaValidator | None:
)

def attempt_rebuild_serializer() -> SchemaSerializer | None:
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5):
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
return cls.__pydantic_serializer__
else:
return None
Expand All @@ -96,16 +96,25 @@ def attempt_rebuild_serializer() -> SchemaSerializer | None:
)


def set_dataclass_mock_validator(cls: type[PydanticDataclass], cls_name: str, undefined_name: str) -> None:
def set_dataclass_mocks(
cls: type[PydanticDataclass], cls_name: str, undefined_name: str = 'all referenced types'
) -> None:
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a dataclass.

Args:
cls: The model class to set the mocks on
cls_name: Name of the model class, used in error messages
undefined_name: Name of the undefined thing, used in error messages
"""
from ..dataclasses import rebuild_dataclass

undefined_type_error_message = (
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
f' then call `pydantic.dataclasses.rebuild_dataclass({cls_name})`.'
)

def attempt_rebuild() -> SchemaValidator | None:
from ..dataclasses import rebuild_dataclass

if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5):
def attempt_rebuild_validator() -> SchemaValidator | None:
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
return cls.__pydantic_validator__
else:
return None
Expand All @@ -114,5 +123,18 @@ def attempt_rebuild() -> SchemaValidator | None:
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild,
attempt_rebuild=attempt_rebuild_validator,
)

def attempt_rebuild_serializer() -> SchemaSerializer | None:
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
return cls.__pydantic_serializer__
else:
return None

cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_serializer,
)
19 changes: 19 additions & 0 deletions tests/test_config.py
Expand Up @@ -687,6 +687,25 @@ class MyModel(BaseModel, defer_build=True):
assert isinstance(MyModel.__pydantic_serializer__, SchemaSerializer)


def test_config_type_adapter_defer_build():
class MyModel(BaseModel, defer_build=True):
x: int

ta = TypeAdapter(MyModel)

assert isinstance(ta.validator, MockValSer)
assert isinstance(ta.serializer, MockValSer)

m = ta.validate_python({'x': 1})
assert m.x == 1
m2 = ta.validate_python({'x': 2})
assert m2.x == 2

# in the future, can reassign said validators to the TypeAdapter
assert isinstance(MyModel.__pydantic_validator__, SchemaValidator)
assert isinstance(MyModel.__pydantic_serializer__, SchemaSerializer)


def test_config_model_defer_build_nested():
class MyNestedModel(BaseModel, defer_build=True):
x: int
Expand Down