Skip to content

Commit

Permalink
Improve error message for improper RootModel subclasses (#8857)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle committed Feb 20, 2024
1 parent 5d09b47 commit 812516d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 17 deletions.
48 changes: 33 additions & 15 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -137,22 +137,40 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
parent_parameters = getattr(cls, '__pydantic_generic_metadata__', {}).get('parameters', ())
parameters = getattr(cls, '__parameters__', None) or parent_parameters
if parameters and parent_parameters and not all(x in parameters for x in parent_parameters):
combined_parameters = parent_parameters + tuple(x for x in parameters if x not in parent_parameters)
parameters_str = ', '.join([str(x) for x in combined_parameters])
generic_type_label = f'typing.Generic[{parameters_str}]'
error_message = (
f'All parameters must be present on typing.Generic;'
f' you should inherit from {generic_type_label}.'
)
if Generic not in bases: # pragma: no cover
# We raise an error here not because it is desirable, but because some cases are mishandled.
# It would be nice to remove this error and still have things behave as expected, it's just
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
# and not returning a typing._GenericAlias from it.
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
error_message += (
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
from ..root_model import RootModelRootType

missing_parameters = tuple(x for x in parameters if x not in parent_parameters)
if RootModelRootType in parent_parameters and RootModelRootType not in parameters:
# This is a special case where the user has subclassed `RootModel`, but has not parametrized
# RootModel with the generic type identifiers being used. Ex:
# class MyModel(RootModel, Generic[T]):
# root: T
# Should instead just be:
# class MyModel(RootModel[T]):
# root: T
parameters_str = ', '.join([x.__name__ for x in missing_parameters])
error_message = (
f'{cls.__name__} is a subclass of `RootModel`, but does not include the generic type identifier(s) '
f'{parameters_str} in its parameters. '
f'You should parametrize RootModel directly, e.g., `class {cls.__name__}(RootModel[{parameters_str}]): ...`.'
)
else:
combined_parameters = parent_parameters + missing_parameters
parameters_str = ', '.join([str(x) for x in combined_parameters])
generic_type_label = f'typing.Generic[{parameters_str}]'
error_message = (
f'All parameters must be present on typing.Generic;'
f' you should inherit from {generic_type_label}.'
)
if Generic not in bases: # pragma: no cover
# We raise an error here not because it is desirable, but because some cases are mishandled.
# It would be nice to remove this error and still have things behave as expected, it's just
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
# and not returning a typing._GenericAlias from it.
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
error_message += (
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
)
raise TypeError(error_message)

cls.__pydantic_generic_metadata__ = {
Expand Down
13 changes: 11 additions & 2 deletions tests/test_root_model.py
@@ -1,11 +1,11 @@
import pickle
from datetime import date, datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Generic, List, Optional, Union

import pytest
from pydantic_core import CoreSchema
from pydantic_core.core_schema import SerializerFunctionWrapHandler
from typing_extensions import Annotated, Literal
from typing_extensions import Annotated, Literal, TypeVar

from pydantic import (
Base64Str,
Expand Down Expand Up @@ -648,3 +648,12 @@ def test_model_validate_strings(root_type, input_value, expected, raises_match,
Model.model_validate_strings(input_value, strict=strict)
else:
assert Model.model_validate_strings(input_value, strict=strict).root == expected


def test_model_construction_with_invalid_generic_specification() -> None:
T_ = TypeVar('T_', bound=BaseModel)

with pytest.raises(TypeError, match='You should parametrize RootModel directly'):

class GenericRootModel(RootModel, Generic[T_]):
root: Union[T_, int]

0 comments on commit 812516d

Please sign in to comment.