Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy issue with subclasses of RootModel #7677

Merged
merged 5 commits into from Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 19 additions & 2 deletions pydantic/mypy.py
Expand Up @@ -440,7 +440,8 @@ def transform(self) -> bool:
return False

is_settings = any(base.fullname == BASESETTINGS_FULLNAME for base in info.mro[:-1])
self.add_initializer(fields, config, is_settings)
is_root_model = any('pydantic.root_model.RootModel' in base.fullname for base in info.mro[:-1])
self.add_initializer(fields, config, is_settings, is_root_model)
self.add_model_construct_method(fields, config, is_settings)
self.set_frozen(fields, frozen=config.frozen is True)

Expand Down Expand Up @@ -780,7 +781,9 @@ def _infer_dataclass_attr_init_type(self, sym: SymbolTableNode, name: str, conte

return default

def add_initializer(self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool) -> None:
def add_initializer(
self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool, is_root_model: bool
) -> None:
"""Adds a fields-aware `__init__` method to the class.

The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings.
Expand All @@ -799,6 +802,14 @@ def add_initializer(self, fields: list[PydanticModelField], config: ModelConfigD
use_alias=use_alias,
is_settings=is_settings,
)
if is_root_model:
if len(fields) > 1:
pass
# error_extra_fields_on_root_model(self._api)
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved

# convert root argument to positional argument
args[0].kind = ARG_POS
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved

if is_settings:
base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node
if '__init__' in base_settings_node.names:
Expand Down Expand Up @@ -1048,6 +1059,7 @@ def setdefault(self, key: str, value: Any) -> None:
ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic')
ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic')
ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic')
ERROR_EXTRA_FIELD_ROOT_MODEL = ErrorCode('pydantic-field', 'Extra field on RootModel subclass', 'Pydantic')


def error_from_attributes(model_name: str, api: CheckerPluginInterface, context: Context) -> None:
Expand Down Expand Up @@ -1084,6 +1096,11 @@ def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context)
api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED)


def error_extra_fields_on_root_model(api: CheckerPluginInterface, context: Context) -> None:
"""Emits an error when there is more than just a root field defined for a subclass of RootModel."""
api.fail('Only `root` is allowed as a field of a `RootModel`', context, code=ERROR_EXTRA_FIELD_ROOT_MODEL)


def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None:
"""Emits an error when `Field` has both `default` and `default_factory` together."""
api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS)
Expand Down
19 changes: 19 additions & 0 deletions tests/mypy/modules/root_models.py
@@ -0,0 +1,19 @@
from typing import List

from pydantic import RootModel


class Pets1(RootModel[List[str]]):
pass


Pets2 = RootModel[List[str]]


class Pets3(RootModel):
root: List[str]


pets1 = Pets1(['dog', 'cat'])
pets2 = Pets2(['dog', 'cat'])
pets3 = Pets3(['dog', 'cat'])
1 change: 1 addition & 0 deletions tests/mypy/test_mypy.py
Expand Up @@ -98,6 +98,7 @@ def build(self) -> List[Union[Tuple[str, str], Any]]:
+ [
('mypy-plugin.ini', 'custom_constructor.py'),
('mypy-plugin.ini', 'generics.py'),
('mypy-plugin.ini', 'root_models.py'),
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
('mypy-plugin-strict.ini', 'plugin_default_factory.py'),
('mypy-plugin-strict-no-any.ini', 'dataclass_no_any.py'),
('mypy-plugin-very-strict.ini', 'metaclass_args.py'),
Expand Down