Skip to content

Commit

Permalink
Warn if a class inherits from Generic before BaseModel (#7891)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmojaki committed Oct 29, 2023
1 parent 6ed90da commit 46f24ce
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 10 deletions.
12 changes: 11 additions & 1 deletion pydantic/_internal/_model_construction.py
Expand Up @@ -14,7 +14,7 @@

from ..errors import PydanticUndefinedAnnotation, PydanticUserError
from ..plugin._schema_validator import create_schema_validator
from ..warnings import PydanticDeprecatedSince20
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
from ._config import ConfigWrapper
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
Expand Down Expand Up @@ -115,6 +115,16 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:

from ..main import BaseModel

mro = cls.__mro__
if Generic in mro and mro.index(Generic) < mro.index(BaseModel):
warnings.warn(
GenericBeforeBaseModelWarning(
'Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) '
'for pydantic generics to work properly.'
),
stacklevel=2,
)

cls.__pydantic_custom_init__ = not getattr(cls.__init__, '__pydantic_base_init__', False)
cls.__pydantic_post_init__ = None if cls.model_post_init is BaseModel.model_post_init else 'model_post_init'

Expand Down
4 changes: 4 additions & 0 deletions pydantic/warnings.py
Expand Up @@ -45,3 +45,7 @@ class PydanticDeprecatedSince20(PydanticDeprecationWarning):

def __init__(self, message: str, *args: object) -> None:
super().__init__(message, *args, since=(2, 0), expected_removal=(3, 0))


class GenericBeforeBaseModelWarning(Warning):
pass
2 changes: 1 addition & 1 deletion tests/test_annotated.py
Expand Up @@ -247,7 +247,7 @@ class _(BaseModel):

T = TypeVar('T')

class GenericModel(Generic[T], BaseModel):
class GenericModel(BaseModel, Generic[T]):
y: T

class _(BaseModel):
Expand Down
31 changes: 23 additions & 8 deletions tests/test_generics.py
Expand Up @@ -72,6 +72,7 @@
recursively_defined_type_refs,
replace_types,
)
from pydantic.warnings import GenericBeforeBaseModelWarning


@pytest.fixture()
Expand Down Expand Up @@ -2053,9 +2054,14 @@ def test_generic_subclass_with_extra_type_with_hint_message():
E = TypeVar('E', bound=BaseModel)
D = TypeVar('D')

class BaseGenericClass(Generic[E, D], BaseModel):
uid: str
name: str
with pytest.warns(
GenericBeforeBaseModelWarning,
match='Classes should inherit from `BaseModel` before generic classes',
):

class BaseGenericClass(Generic[E, D], BaseModel):
uid: str
name: str

with pytest.raises(
TypeError,
Expand All @@ -2065,9 +2071,13 @@ class BaseGenericClass(Generic[E, D], BaseModel):
' `class ChildGenericClass(BaseGenericClass, typing.Generic[~E, ~D]): ...`'
),
):
with pytest.warns(
GenericBeforeBaseModelWarning,
match='Classes should inherit from `BaseModel` before generic classes',
):

class ChildGenericClass(BaseGenericClass[E, Dict[str, Any]]):
...
class ChildGenericClass(BaseGenericClass[E, Dict[str, Any]]):
...


def test_multi_inheritance_generic_binding():
Expand Down Expand Up @@ -2645,9 +2655,14 @@ class C(B[int]):
def test_reverse_order_generic_hashability():
T = TypeVar('T')

class Model(Generic[T], BaseModel):
x: T
model_config = dict(frozen=True)
with pytest.warns(
GenericBeforeBaseModelWarning,
match='Classes should inherit from `BaseModel` before generic classes',
):

class Model(Generic[T], BaseModel):
x: T
model_config = dict(frozen=True)

m1 = Model[int](x=1)
m2 = Model[int](x=1)
Expand Down

0 comments on commit 46f24ce

Please sign in to comment.