Skip to content

Commit

Permalink
Fix mypy issue with subclasses of RootModel (#7677)
Browse files Browse the repository at this point in the history
Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
  • Loading branch information
sydney-runkle and dmontagu committed Sep 28, 2023
1 parent b4d2eac commit 8f3ac4e
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 6 deletions.
29 changes: 23 additions & 6 deletions pydantic/mypy.py
Expand Up @@ -89,6 +89,7 @@
METADATA_KEY = 'pydantic-mypy-metadata'
BASEMODEL_FULLNAME = 'pydantic.main.BaseModel'
BASESETTINGS_FULLNAME = 'pydantic_settings.main.BaseSettings'
ROOT_MODEL_FULLNAME = 'pydantic.root_model.RootModel'
MODEL_METACLASS_FULLNAME = 'pydantic._internal._model_construction.ModelMetaclass'
FIELD_FULLNAME = 'pydantic.fields.Field'
DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass'
Expand Down Expand Up @@ -430,8 +431,9 @@ def transform(self) -> bool:
* stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses
"""
info = self._cls.info
is_root_model = any(ROOT_MODEL_FULLNAME in base.fullname for base in info.mro[:-1])
config = self.collect_config()
fields = self.collect_fields(config)
fields = self.collect_fields(config, is_root_model)
if fields is None:
# Some definitions are not ready. We need another pass.
return False
Expand All @@ -440,7 +442,7 @@ def transform(self) -> bool:
return False

is_settings = any(base.fullname == BASESETTINGS_FULLNAME for base in info.mro[:-1])
self.add_initializer(fields, config, is_settings)
self.add_initializer(fields, config, is_settings, is_root_model)
self.add_model_construct_method(fields, config, is_settings)
self.set_frozen(fields, frozen=config.frozen is True)

Expand Down Expand Up @@ -556,7 +558,7 @@ def collect_config(self) -> ModelConfigData: # noqa: C901 (ignore complexity)
config.setdefault(name, value)
return config

def collect_fields(self, model_config: ModelConfigData) -> list[PydanticModelField] | None:
def collect_fields(self, model_config: ModelConfigData, is_root_model: bool) -> list[PydanticModelField] | None:
"""Collects the fields for the model, accounting for parent classes."""
cls = self._cls

Expand Down Expand Up @@ -603,8 +605,11 @@ def collect_fields(self, model_config: ModelConfigData) -> list[PydanticModelFie
maybe_field = self.collect_field_from_stmt(stmt, model_config)
if maybe_field is not None:
lhs = stmt.lvalues[0]
current_field_names.add(lhs.name)
found_fields[lhs.name] = maybe_field
if is_root_model and lhs.name != 'root':
error_extra_fields_on_root_model(self._api, stmt)
else:
current_field_names.add(lhs.name)
found_fields[lhs.name] = maybe_field

return list(found_fields.values())

Expand Down Expand Up @@ -780,7 +785,9 @@ def _infer_dataclass_attr_init_type(self, sym: SymbolTableNode, name: str, conte

return default

def add_initializer(self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool) -> None:
def add_initializer(
self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool, is_root_model: bool
) -> None:
"""Adds a fields-aware `__init__` method to the class.
The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings.
Expand All @@ -799,6 +806,10 @@ def add_initializer(self, fields: list[PydanticModelField], config: ModelConfigD
use_alias=use_alias,
is_settings=is_settings,
)
if is_root_model:
# convert root argument to positional argument
args[0].kind = ARG_POS if args[0].kind == ARG_NAMED else ARG_OPT

if is_settings:
base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node
if '__init__' in base_settings_node.names:
Expand Down Expand Up @@ -1048,6 +1059,7 @@ def setdefault(self, key: str, value: Any) -> None:
ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic')
ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic')
ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic')
ERROR_EXTRA_FIELD_ROOT_MODEL = ErrorCode('pydantic-field', 'Extra field on RootModel subclass', 'Pydantic')


def error_from_attributes(model_name: str, api: CheckerPluginInterface, context: Context) -> None:
Expand Down Expand Up @@ -1084,6 +1096,11 @@ def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context)
api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED)


def error_extra_fields_on_root_model(api: CheckerPluginInterface, context: Context) -> None:
"""Emits an error when there is more than just a root field defined for a subclass of RootModel."""
api.fail('Only `root` is allowed as a field of a `RootModel`', context, code=ERROR_EXTRA_FIELD_ROOT_MODEL)


def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None:
"""Emits an error when `Field` has both `default` and `default_factory` together."""
api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS)
Expand Down
23 changes: 23 additions & 0 deletions tests/mypy/modules/root_models.py
@@ -0,0 +1,23 @@
from typing import List

from pydantic import RootModel


class Pets1(RootModel[List[str]]):
pass


Pets2 = RootModel[List[str]]


class Pets3(RootModel):
root: List[str]


pets1 = Pets1(['dog', 'cat'])
pets2 = Pets2(['dog', 'cat'])
pets3 = Pets3(['dog', 'cat'])


class Pets4(RootModel[List[str]]):
pets: List[str]
25 changes: 25 additions & 0 deletions tests/mypy/outputs/1.0.1/mypy-plugin_ini/root_models.py
@@ -0,0 +1,25 @@
from typing import List

from pydantic import RootModel


class Pets1(RootModel[List[str]]):
pass


Pets2 = RootModel[List[str]]


class Pets3(RootModel):
# MYPY: error: Missing type parameters for generic type "RootModel" [type-arg]
root: List[str]


pets1 = Pets1(['dog', 'cat'])
pets2 = Pets2(['dog', 'cat'])
pets3 = Pets3(['dog', 'cat'])


class Pets4(RootModel[List[str]]):
pets: List[str]
# MYPY: error: Only `root` is allowed as a field of a `RootModel` [pydantic-field]
1 change: 1 addition & 0 deletions tests/mypy/test_mypy.py
Expand Up @@ -98,6 +98,7 @@ def build(self) -> List[Union[Tuple[str, str], Any]]:
+ [
('mypy-plugin.ini', 'custom_constructor.py'),
('mypy-plugin.ini', 'generics.py'),
('mypy-plugin.ini', 'root_models.py'),
('mypy-plugin-strict.ini', 'plugin_default_factory.py'),
('mypy-plugin-strict-no-any.ini', 'dataclass_no_any.py'),
('mypy-plugin-very-strict.ini', 'metaclass_args.py'),
Expand Down

0 comments on commit 8f3ac4e

Please sign in to comment.