Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of optionals in mypy plugin #9008

Merged
merged 8 commits into from Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
78 changes: 33 additions & 45 deletions pydantic/mypy.py
Expand Up @@ -126,10 +126,6 @@ def plugin(version: str) -> type[Plugin]:
return PydanticPlugin


class _DeferAnalysis(Exception):
pass


class PydanticPlugin(Plugin):
"""The Pydantic mypy plugin."""

Expand Down Expand Up @@ -366,11 +362,8 @@ def expand_type(self, current_info: TypeInfo, api: SemanticAnalyzerPluginInterfa
# however this plugin is called very late, so all types should be fully ready.
# Also, it is tricky to avoid eager expansion of Self types here (e.g. because
# we serialize attributes).
expanded_type = expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
if isinstance(self.type, UnionType) and not isinstance(expanded_type, UnionType):
if not api.final_iteration:
raise _DeferAnalysis()
return expanded_type
with state.strict_optional_set(api.options.strict_optional):
return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
return self.type

def to_var(self, current_info: TypeInfo, api: SemanticAnalyzerPluginInterface, use_alias: bool) -> Var:
Expand Down Expand Up @@ -402,12 +395,13 @@ def deserialize(cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPlugin
typ = deserialize_and_fixup_type(data.pop('type'), api)
return cls(type=typ, info=info, **data)

def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
def expand_typevar_from_subtype(self, sub_type: TypeInfo, api: SemanticAnalyzerPluginInterface) -> None:
"""Expands type vars in the context of a subtype when an attribute is inherited
from a generic super type.
"""
if self.type is not None:
self.type = map_type_from_supertype(self.type, sub_type, self.info)
with state.strict_optional_set(api.options.strict_optional):
self.type = map_type_from_supertype(self.type, sub_type, self.info)


class PydanticModelClassVar:
Expand Down Expand Up @@ -485,13 +479,9 @@ def transform(self) -> bool:
return False

is_settings = any(base.fullname == BASESETTINGS_FULLNAME for base in info.mro[:-1])
try:
self.add_initializer(fields, config, is_settings, is_root_model)
self.add_model_construct_method(fields, config, is_settings)
self.set_frozen(fields, self._api, frozen=config.frozen is True)
except _DeferAnalysis:
if not self._api.final_iteration:
self._api.defer()
self.add_initializer(fields, config, is_settings, is_root_model)
self.add_model_construct_method(fields, config, is_settings)
self.set_frozen(fields, self._api, frozen=config.frozen is True)

self.adjust_decorator_signatures()

Expand Down Expand Up @@ -639,8 +629,7 @@ def collect_fields_and_class_vars(
# TODO: We shouldn't be performing type operations during the main
# semantic analysis pass, since some TypeInfo attributes might
# still be in flux. This should be performed in a later phase.
with state.strict_optional_set(self._api.options.strict_optional):
field.expand_typevar_from_subtype(cls.info)
field.expand_typevar_from_subtype(cls.info, self._api)
found_fields[name] = field

sym_node = cls.info.names.get(name)
Expand Down Expand Up @@ -862,32 +851,31 @@ def add_initializer(
typed = self.plugin_config.init_typed
use_alias = config.populate_by_name is not True
requires_dynamic_aliases = bool(config.has_alias_generator and not config.populate_by_name)
with state.strict_optional_set(self._api.options.strict_optional):
args = self.get_field_arguments(
fields,
typed=typed,
requires_dynamic_aliases=requires_dynamic_aliases,
use_alias=use_alias,
is_settings=is_settings,
)
args = self.get_field_arguments(
fields,
typed=typed,
requires_dynamic_aliases=requires_dynamic_aliases,
use_alias=use_alias,
is_settings=is_settings,
)

if is_root_model and MYPY_VERSION_TUPLE <= (1, 0, 1):
# convert root argument to positional argument
# This is needed because mypy support for `dataclass_transform` isn't complete on 1.0.1
args[0].kind = ARG_POS if args[0].kind == ARG_NAMED else ARG_OPT

if is_settings:
base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node
if '__init__' in base_settings_node.names:
base_settings_init_node = base_settings_node.names['__init__'].node
if base_settings_init_node is not None and base_settings_init_node.type is not None:
func_type = base_settings_init_node.type
for arg_idx, arg_name in enumerate(func_type.arg_names):
if arg_name.startswith('__') or not arg_name.startswith('_'):
continue
analyzed_variable_type = self._api.anal_type(func_type.arg_types[arg_idx])
variable = Var(arg_name, analyzed_variable_type)
args.append(Argument(variable, analyzed_variable_type, None, ARG_OPT))
if is_root_model and MYPY_VERSION_TUPLE <= (1, 0, 1):
# convert root argument to positional argument
# This is needed because mypy support for `dataclass_transform` isn't complete on 1.0.1
args[0].kind = ARG_POS if args[0].kind == ARG_NAMED else ARG_OPT

if is_settings:
base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node
if '__init__' in base_settings_node.names:
base_settings_init_node = base_settings_node.names['__init__'].node
if base_settings_init_node is not None and base_settings_init_node.type is not None:
func_type = base_settings_init_node.type
for arg_idx, arg_name in enumerate(func_type.arg_names):
if arg_name.startswith('__') or not arg_name.startswith('_'):
continue
analyzed_variable_type = self._api.anal_type(func_type.arg_types[arg_idx])
variable = Var(arg_name, analyzed_variable_type)
args.append(Argument(variable, analyzed_variable_type, None, ARG_OPT))

if not self.should_init_forbid_extra(fields, config):
var = Var('kwargs')
Expand Down
22 changes: 22 additions & 0 deletions tests/mypy/modules/plugin_optional_inheritance.py
@@ -0,0 +1,22 @@
from typing import Optional

from pydantic import BaseModel


class Foo(BaseModel):
id: Optional[int]


class Bar(BaseModel):
foo: Optional[Foo]


class Baz(Bar):
name: str


b = Bar(foo={'id': 1})
assert b.foo.id == 1

z = Baz(foo={'id': 1}, name='test')
assert z.foo.id == 1
@@ -0,0 +1,24 @@
from typing import Optional

from pydantic import BaseModel


class Foo(BaseModel):
id: Optional[int]


class Bar(BaseModel):
foo: Optional[Foo]


class Baz(Bar):
name: str


b = Bar(foo={'id': 1})
assert b.foo.id == 1
# MYPY: error: Item "None" of "Optional[Foo]" has no attribute "id" [union-attr]

z = Baz(foo={'id': 1}, name='test')
assert z.foo.id == 1
# MYPY: error: Item "None" of "Optional[Foo]" has no attribute "id" [union-attr]
@@ -0,0 +1,24 @@
from typing import Optional

from pydantic import BaseModel


class Foo(BaseModel):
id: Optional[int]


class Bar(BaseModel):
foo: Optional[Foo]


class Baz(Bar):
name: str


b = Bar(foo={'id': 1})
assert b.foo.id == 1
# MYPY: error: Item "None" of "Optional[Foo]" has no attribute "id" [union-attr]

z = Baz(foo={'id': 1}, name='test')
assert z.foo.id == 1
# MYPY: error: Item "None" of "Optional[Foo]" has no attribute "id" [union-attr]
@@ -0,0 +1,24 @@
from typing import Optional

from pydantic import BaseModel


class Foo(BaseModel):
id: Optional[int]


class Bar(BaseModel):
foo: Optional[Foo]


class Baz(Bar):
name: str


b = Bar(foo={'id': 1})
assert b.foo.id == 1
# MYPY: error: Item "None" of "Foo | None" has no attribute "id" [union-attr]

z = Baz(foo={'id': 1}, name='test')
assert z.foo.id == 1
# MYPY: error: Item "None" of "Foo | None" has no attribute "id" [union-attr]
@@ -0,0 +1,24 @@
from typing import Optional

from pydantic import BaseModel


class Foo(BaseModel):
id: Optional[int]


class Bar(BaseModel):
foo: Optional[Foo]


class Baz(Bar):
name: str


b = Bar(foo={'id': 1})
assert b.foo.id == 1
# MYPY: error: Item "None" of "Foo | None" has no attribute "id" [union-attr]

z = Baz(foo={'id': 1}, name='test')
assert z.foo.id == 1
# MYPY: error: Item "None" of "Foo | None" has no attribute "id" [union-attr]
1 change: 1 addition & 0 deletions tests/mypy/test_mypy.py
Expand Up @@ -87,6 +87,7 @@ def build(self) -> List[Union[Tuple[str, str], Any]]:
'plugin_fail.py',
'plugin_success_baseConfig.py',
'plugin_fail_baseConfig.py',
'plugin_optional_inheritance.py',
'pydantic_settings.py',
],
).build()
Expand Down