Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error message for improper RootModel subclasses #8857

Merged
merged 5 commits into from Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 33 additions & 15 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -137,22 +137,40 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
parent_parameters = getattr(cls, '__pydantic_generic_metadata__', {}).get('parameters', ())
parameters = getattr(cls, '__parameters__', None) or parent_parameters
if parameters and parent_parameters and not all(x in parameters for x in parent_parameters):
combined_parameters = parent_parameters + tuple(x for x in parameters if x not in parent_parameters)
parameters_str = ', '.join([str(x) for x in combined_parameters])
generic_type_label = f'typing.Generic[{parameters_str}]'
error_message = (
f'All parameters must be present on typing.Generic;'
f' you should inherit from {generic_type_label}.'
)
if Generic not in bases: # pragma: no cover
# We raise an error here not because it is desirable, but because some cases are mishandled.
# It would be nice to remove this error and still have things behave as expected, it's just
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
# and not returning a typing._GenericAlias from it.
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
error_message += (
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
from ..root_model import RootModelRootType

missing_parameters = tuple(x for x in parameters if x not in parent_parameters)
if RootModelRootType in parent_parameters and RootModelRootType not in parameters:
# This is a special case where the user has subclassed `RootModel`, but has not parametrized
# RootModel with the generic type identifiers being used. Ex:
# class MyModel(RootModel, Generic[T]):
# root: T
# Should instead just be:
# class MyModel(RootModel[T]):
# root: T
parameters_str = ', '.join([x.__name__ for x in missing_parameters])
error_message = (
f'{cls.__name__} is a subclass of `RootModel`, but does not include the generic type identifier(s) '
f'{parameters_str} in its parameters. '
f'You should parametrize RootModel directly, e.g., `class {cls.__name__}(RootModel[{parameters_str}]): ...`.'
)
else:
combined_parameters = parent_parameters + missing_parameters
parameters_str = ', '.join([str(x) for x in combined_parameters])
generic_type_label = f'typing.Generic[{parameters_str}]'
error_message = (
f'All parameters must be present on typing.Generic;'
f' you should inherit from {generic_type_label}.'
)
if Generic not in bases: # pragma: no cover
# We raise an error here not because it is desirable, but because some cases are mishandled.
# It would be nice to remove this error and still have things behave as expected, it's just
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
# and not returning a typing._GenericAlias from it.
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
error_message += (
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
)
raise TypeError(error_message)

cls.__pydantic_generic_metadata__ = {
Expand Down
13 changes: 11 additions & 2 deletions tests/test_root_model.py
@@ -1,11 +1,11 @@
import pickle
from datetime import date, datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Generic, List, Optional, Union

import pytest
from pydantic_core import CoreSchema
from pydantic_core.core_schema import SerializerFunctionWrapHandler
from typing_extensions import Annotated, Literal
from typing_extensions import Annotated, Literal, TypeVar

from pydantic import (
Base64Str,
Expand Down Expand Up @@ -648,3 +648,12 @@ def test_model_validate_strings(root_type, input_value, expected, raises_match,
Model.model_validate_strings(input_value, strict=strict)
else:
assert Model.model_validate_strings(input_value, strict=strict).root == expected


def test_model_construction_with_invalid_generic_specification() -> None:
T_ = TypeVar('T_', bound=BaseModel)

with pytest.raises(TypeError, match='You should parametrize RootModel directly'):

class GenericRootModel(RootModel, Generic[T_]):
root: Union[T_, int]