Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy issue with subclasses of RootModel #7677

Merged
merged 5 commits into from Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 23 additions & 6 deletions pydantic/mypy.py
Expand Up @@ -89,6 +89,7 @@
METADATA_KEY = 'pydantic-mypy-metadata'
BASEMODEL_FULLNAME = 'pydantic.main.BaseModel'
BASESETTINGS_FULLNAME = 'pydantic_settings.main.BaseSettings'
ROOT_MODEL_FULLNAME = 'pydantic.root_model.RootModel'
MODEL_METACLASS_FULLNAME = 'pydantic._internal._model_construction.ModelMetaclass'
FIELD_FULLNAME = 'pydantic.fields.Field'
DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass'
Expand Down Expand Up @@ -430,8 +431,9 @@ def transform(self) -> bool:
* stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses
"""
info = self._cls.info
is_root_model = any(ROOT_MODEL_FULLNAME in base.fullname for base in info.mro[:-1])
config = self.collect_config()
fields = self.collect_fields(config)
fields = self.collect_fields(config, is_root_model)
if fields is None:
# Some definitions are not ready. We need another pass.
return False
Expand All @@ -440,7 +442,7 @@ def transform(self) -> bool:
return False

is_settings = any(base.fullname == BASESETTINGS_FULLNAME for base in info.mro[:-1])
self.add_initializer(fields, config, is_settings)
self.add_initializer(fields, config, is_settings, is_root_model)
self.add_model_construct_method(fields, config, is_settings)
self.set_frozen(fields, frozen=config.frozen is True)

Expand Down Expand Up @@ -556,7 +558,7 @@ def collect_config(self) -> ModelConfigData: # noqa: C901 (ignore complexity)
config.setdefault(name, value)
return config

def collect_fields(self, model_config: ModelConfigData) -> list[PydanticModelField] | None:
def collect_fields(self, model_config: ModelConfigData, is_root_model: bool) -> list[PydanticModelField] | None:
"""Collects the fields for the model, accounting for parent classes."""
cls = self._cls

Expand Down Expand Up @@ -603,8 +605,11 @@ def collect_fields(self, model_config: ModelConfigData) -> list[PydanticModelFie
maybe_field = self.collect_field_from_stmt(stmt, model_config)
if maybe_field is not None:
lhs = stmt.lvalues[0]
current_field_names.add(lhs.name)
found_fields[lhs.name] = maybe_field
if is_root_model and lhs.name != 'root':
error_extra_fields_on_root_model(self._api, stmt)
else:
current_field_names.add(lhs.name)
found_fields[lhs.name] = maybe_field

return list(found_fields.values())

Expand Down Expand Up @@ -780,7 +785,9 @@ def _infer_dataclass_attr_init_type(self, sym: SymbolTableNode, name: str, conte

return default

def add_initializer(self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool) -> None:
def add_initializer(
self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool, is_root_model: bool
) -> None:
"""Adds a fields-aware `__init__` method to the class.

The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings.
Expand All @@ -799,6 +806,10 @@ def add_initializer(self, fields: list[PydanticModelField], config: ModelConfigD
use_alias=use_alias,
is_settings=is_settings,
)
if is_root_model:
# convert root argument to positional argument
args[0].kind = ARG_POS if args[0].kind == ARG_NAMED else ARG_OPT

if is_settings:
base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node
if '__init__' in base_settings_node.names:
Expand Down Expand Up @@ -1048,6 +1059,7 @@ def setdefault(self, key: str, value: Any) -> None:
ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic')
ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic')
ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic')
ERROR_EXTRA_FIELD_ROOT_MODEL = ErrorCode('pydantic-field', 'Extra field on RootModel subclass', 'Pydantic')


def error_from_attributes(model_name: str, api: CheckerPluginInterface, context: Context) -> None:
Expand Down Expand Up @@ -1084,6 +1096,11 @@ def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context)
api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED)


def error_extra_fields_on_root_model(api: CheckerPluginInterface, context: Context) -> None:
"""Emits an error when there is more than just a root field defined for a subclass of RootModel."""
api.fail('Only `root` is allowed as a field of a `RootModel`', context, code=ERROR_EXTRA_FIELD_ROOT_MODEL)


def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None:
"""Emits an error when `Field` has both `default` and `default_factory` together."""
api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS)
Expand Down
23 changes: 23 additions & 0 deletions tests/mypy/modules/root_models.py
@@ -0,0 +1,23 @@
from typing import List

from pydantic import RootModel


class Pets1(RootModel[List[str]]):
pass


Pets2 = RootModel[List[str]]


class Pets3(RootModel):
root: List[str]


pets1 = Pets1(['dog', 'cat'])
pets2 = Pets2(['dog', 'cat'])
pets3 = Pets3(['dog', 'cat'])


class Pets4(RootModel[List[str]]):
pets: List[str]
25 changes: 25 additions & 0 deletions tests/mypy/outputs/1.0.1/mypy-plugin_ini/root_models.py
@@ -0,0 +1,25 @@
from typing import List

from pydantic import RootModel


class Pets1(RootModel[List[str]]):
pass


Pets2 = RootModel[List[str]]


class Pets3(RootModel):
# MYPY: error: Missing type parameters for generic type "RootModel" [type-arg]
root: List[str]


pets1 = Pets1(['dog', 'cat'])
pets2 = Pets2(['dog', 'cat'])
pets3 = Pets3(['dog', 'cat'])


class Pets4(RootModel[List[str]]):
pets: List[str]
# MYPY: error: Only `root` is allowed as a field of a `RootModel` [pydantic-field]
1 change: 1 addition & 0 deletions tests/mypy/test_mypy.py
Expand Up @@ -98,6 +98,7 @@ def build(self) -> List[Union[Tuple[str, str], Any]]:
+ [
('mypy-plugin.ini', 'custom_constructor.py'),
('mypy-plugin.ini', 'generics.py'),
('mypy-plugin.ini', 'root_models.py'),
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
('mypy-plugin-strict.ini', 'plugin_default_factory.py'),
('mypy-plugin-strict-no-any.ini', 'dataclass_no_any.py'),
('mypy-plugin-very-strict.ini', 'metaclass_args.py'),
Expand Down