Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy error on untyped ClassVar #8138

Merged
merged 2 commits into from Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 52 additions & 13 deletions pydantic/mypy.py
Expand Up @@ -400,6 +400,29 @@ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
self.type = map_type_from_supertype(self.type, sub_type, self.info)


class PydanticModelClassVar:
"""Class vars are stored to be ignored by subclasses.

Attributes:
name: the class var name
"""

def __init__(self, name):
self.name = name

@classmethod
def deserialize(cls, data: JsonDict) -> PydanticModelClassVar:
"""Based on mypy.plugins.dataclasses.DataclassAttribute.deserialize."""
data = data.copy()
return cls(**data)

def serialize(self) -> JsonDict:
"""Based on mypy.plugins.dataclasses.DataclassAttribute.serialize."""
return {
'name': self.name,
}


class PydanticModelTransformer:
"""Transform the BaseModel subclass according to the plugin settings.

Expand Down Expand Up @@ -441,8 +464,8 @@ def transform(self) -> bool:
info = self._cls.info
is_root_model = any(ROOT_MODEL_FULLNAME in base.fullname for base in info.mro[:-1])
config = self.collect_config()
fields = self.collect_fields(config, is_root_model)
if fields is None:
fields, classvars = self.collect_fields_and_classvar(config, is_root_model)
if fields is None or classvars is None:
# Some definitions are not ready. We need another pass.
return False
for field in fields:
Expand All @@ -462,6 +485,7 @@ def transform(self) -> bool:

info.metadata[METADATA_KEY] = {
'fields': {field.name: field.serialize() for field in fields},
'classvars': {classvar.name: classvar.serialize() for classvar in classvars},
'config': config.get_values_dict(),
}

Expand Down Expand Up @@ -570,11 +594,13 @@ def collect_config(self) -> ModelConfigData: # noqa: C901 (ignore complexity)
config.setdefault(name, value)
return config

def collect_fields(self, model_config: ModelConfigData, is_root_model: bool) -> list[PydanticModelField] | None:
def collect_fields_and_classvar(
self, model_config: ModelConfigData, is_root_model: bool
) -> tuple[list[PydanticModelField] | None, list[PydanticModelClassVar] | None]:
"""Collects the fields for the model, accounting for parent classes."""
cls = self._cls

# First, collect fields belonging to any class in the MRO, ignoring duplicates.
# First, collect fields and classvars belonging to any class in the MRO, ignoring duplicates.
#
# We iterate through the MRO in reverse because attrs defined in the parent must appear
# earlier in the attributes list than attrs defined in the child. See:
Expand All @@ -584,10 +610,11 @@ def collect_fields(self, model_config: ModelConfigData, is_root_model: bool) ->
# in the parent. We can implement this via a dict without disrupting the attr order
# because dicts preserve insertion order in Python 3.7+.
found_fields: dict[str, PydanticModelField] = {}
found_classvars: dict[str, PydanticModelClassVar] = {}
for info in reversed(cls.info.mro[1:-1]): # 0 is the current class, -2 is BaseModel, -1 is object
# if BASEMODEL_METADATA_TAG_KEY in info.metadata and BASEMODEL_METADATA_KEY not in info.metadata:
# # We haven't processed the base class yet. Need another pass.
# return None
# return None, None
if METADATA_KEY not in info.metadata:
continue

Expand All @@ -610,20 +637,28 @@ def collect_fields(self, model_config: ModelConfigData, is_root_model: bool) ->
'BaseModel field may only be overridden by another field',
sym_node.node,
)
# Collect classvars
for name, data in info.metadata[METADATA_KEY]['classvars'].items():
found_classvars[name] = PydanticModelClassVar.deserialize(data)

# Second, collect fields belonging to the current class.
# Second, collect fields and classvars belonging to the current class.
current_field_names: set[str] = set()
current_classvars_names: set[str] = set()
for stmt in self._get_assignment_statements_from_block(cls.defs):
maybe_field = self.collect_field_from_stmt(stmt, model_config)
if maybe_field is not None:
maybe_field = self.collect_field_and_classvars_from_stmt(stmt, model_config, found_classvars)
if isinstance(maybe_field, PydanticModelField):
lhs = stmt.lvalues[0]
if is_root_model and lhs.name != 'root':
error_extra_fields_on_root_model(self._api, stmt)
else:
current_field_names.add(lhs.name)
found_fields[lhs.name] = maybe_field
elif isinstance(maybe_field, PydanticModelClassVar):
lhs = stmt.lvalues[0]
current_classvars_names.add(lhs.name)
found_classvars[lhs.name] = maybe_field

return list(found_fields.values())
return list(found_fields.values()), list(found_classvars.values())

def _get_assignment_statements_from_if_statement(self, stmt: IfStmt) -> Iterator[AssignmentStmt]:
for body in stmt.body:
Expand All @@ -639,9 +674,9 @@ def _get_assignment_statements_from_block(self, block: Block) -> Iterator[Assign
elif isinstance(stmt, IfStmt):
yield from self._get_assignment_statements_from_if_statement(stmt)

def collect_field_from_stmt( # noqa C901
self, stmt: AssignmentStmt, model_config: ModelConfigData
) -> PydanticModelField | None:
def collect_field_and_classvars_from_stmt( # noqa C901
self, stmt: AssignmentStmt, model_config: ModelConfigData, classvars: dict[str, PydanticModelClassVar]
) -> PydanticModelField | PydanticModelClassVar | None:
"""Get pydantic model field from statement.

Args:
Expand Down Expand Up @@ -669,6 +704,10 @@ def collect_field_from_stmt( # noqa C901
# Eventually, we may want to attempt to respect model_config['ignored_types']
return None

if lhs.name in classvars:
# Class vars are not fields and are not required to be annotated
return None

# The assignment does not have an annotation, and it's not anything else we recognize
error_untyped_fields(self._api, stmt)
return None
Expand Down Expand Up @@ -713,7 +752,7 @@ def collect_field_from_stmt( # noqa C901

# x: ClassVar[int] is not a field
if node.is_classvar:
return None
return PydanticModelClassVar(lhs.name)

# x: InitVar[int] is not supported in BaseModel
node_type = get_proper_type(node.type)
Expand Down
10 changes: 9 additions & 1 deletion tests/mypy/modules/success.py
Expand Up @@ -6,7 +6,7 @@
import os
from datetime import date, datetime, timedelta, timezone
from pathlib import Path, PurePath
from typing import Any, Dict, ForwardRef, Generic, List, Optional, Type, TypeVar
from typing import Any, ClassVar, Dict, ForwardRef, Generic, List, Optional, Type, TypeVar
from uuid import UUID

from typing_extensions import Annotated, TypedDict
Expand Down Expand Up @@ -300,3 +300,11 @@ def double(value: Any, handler: Any) -> int:

class WrapValidatorModel(BaseModel):
x: Annotated[int, WrapValidator(double)]


class Abstract(BaseModel):
class_id: ClassVar


class Concrete(Abstract):
class_id = 1
10 changes: 9 additions & 1 deletion tests/mypy/outputs/1.0.1/mypy-default_ini/success.py
Expand Up @@ -6,7 +6,7 @@
import os
from datetime import date, datetime, timedelta, timezone
from pathlib import Path, PurePath
from typing import Any, Dict, ForwardRef, Generic, List, Optional, Type, TypeVar
from typing import Any, ClassVar, Dict, ForwardRef, Generic, List, Optional, Type, TypeVar
from uuid import UUID

from typing_extensions import Annotated, TypedDict
Expand Down Expand Up @@ -306,3 +306,11 @@ def double(value: Any, handler: Any) -> int:

class WrapValidatorModel(BaseModel):
x: Annotated[int, WrapValidator(double)]


class Abstract(BaseModel):
class_id: ClassVar


class Concrete(Abstract):
class_id = 1
10 changes: 9 additions & 1 deletion tests/mypy/outputs/1.0.1/pyproject-default_toml/success.py
Expand Up @@ -6,7 +6,7 @@
import os
from datetime import date, datetime, timedelta, timezone
from pathlib import Path, PurePath
from typing import Any, Dict, ForwardRef, Generic, List, Optional, Type, TypeVar
from typing import Any, ClassVar, Dict, ForwardRef, Generic, List, Optional, Type, TypeVar
from uuid import UUID

from typing_extensions import Annotated, TypedDict
Expand Down Expand Up @@ -306,3 +306,11 @@ def double(value: Any, handler: Any) -> int:

class WrapValidatorModel(BaseModel):
x: Annotated[int, WrapValidator(double)]


class Abstract(BaseModel):
class_id: ClassVar


class Concrete(Abstract):
class_id = 1