Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn if a class inherits from Generic before BaseModel #7891

Merged
merged 2 commits into from Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 11 additions & 1 deletion pydantic/_internal/_model_construction.py
Expand Up @@ -15,7 +15,7 @@
from ..errors import PydanticUndefinedAnnotation, PydanticUserError
from ..fields import Field, FieldInfo, ModelPrivateAttr, PrivateAttr
from ..plugin._schema_validator import create_schema_validator
from ..warnings import PydanticDeprecatedSince20
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
from ._config import ConfigWrapper
from ._decorators import (
ComputedFieldInfo,
Expand Down Expand Up @@ -128,6 +128,16 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:

from ..main import BaseModel

mro = cls.__mro__
if Generic in mro and mro.index(Generic) < mro.index(BaseModel):
warnings.warn(
GenericBeforeBaseModelWarning(
'Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) '
'for pydantic generics to work properly.'
),
stacklevel=2,
)

cls.__pydantic_custom_init__ = not getattr(cls.__init__, '__pydantic_base_init__', False)
cls.__pydantic_post_init__ = None if cls.model_post_init is BaseModel.model_post_init else 'model_post_init'

Expand Down
4 changes: 4 additions & 0 deletions pydantic/warnings.py
Expand Up @@ -45,3 +45,7 @@ class PydanticDeprecatedSince20(PydanticDeprecationWarning):

def __init__(self, message: str, *args: object) -> None:
super().__init__(message, *args, since=(2, 0), expected_removal=(3, 0))


class GenericBeforeBaseModelWarning(Warning):
pass
2 changes: 1 addition & 1 deletion tests/test_annotated.py
Expand Up @@ -247,7 +247,7 @@ class _(BaseModel):

T = TypeVar('T')

class GenericModel(Generic[T], BaseModel):
class GenericModel(BaseModel, Generic[T]):
y: T

class _(BaseModel):
Expand Down
31 changes: 23 additions & 8 deletions tests/test_generics.py
Expand Up @@ -69,6 +69,7 @@
recursively_defined_type_refs,
replace_types,
)
from pydantic.warnings import GenericBeforeBaseModelWarning


@pytest.fixture()
Expand Down Expand Up @@ -2034,9 +2035,14 @@ def test_generic_subclass_with_extra_type_with_hint_message():
E = TypeVar('E', bound=BaseModel)
D = TypeVar('D')

class BaseGenericClass(Generic[E, D], BaseModel):
uid: str
name: str
with pytest.warns(
GenericBeforeBaseModelWarning,
match='Classes should inherit from `BaseModel` before generic classes',
):

class BaseGenericClass(Generic[E, D], BaseModel):
uid: str
name: str

with pytest.raises(
TypeError,
Expand All @@ -2046,9 +2052,13 @@ class BaseGenericClass(Generic[E, D], BaseModel):
' `class ChildGenericClass(BaseGenericClass, typing.Generic[~E, ~D]): ...`'
),
):
with pytest.warns(
GenericBeforeBaseModelWarning,
match='Classes should inherit from `BaseModel` before generic classes',
):

class ChildGenericClass(BaseGenericClass[E, Dict[str, Any]]):
...
class ChildGenericClass(BaseGenericClass[E, Dict[str, Any]]):
...


def test_multi_inheritance_generic_binding():
Expand Down Expand Up @@ -2626,9 +2636,14 @@ class C(B[int]):
def test_reverse_order_generic_hashability():
T = TypeVar('T')

class Model(Generic[T], BaseModel):
x: T
model_config = dict(frozen=True)
with pytest.warns(
GenericBeforeBaseModelWarning,
match='Classes should inherit from `BaseModel` before generic classes',
):

class Model(Generic[T], BaseModel):
x: T
model_config = dict(frozen=True)

m1 = Model[int](x=1)
m2 = Model[int](x=1)
Expand Down