Skip to content

Commit

Permalink
Fix handling of optionals in mypy plugin (#9008)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Mar 17, 2024
1 parent 22f86ad commit 325ddf6
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 45 deletions.
78 changes: 33 additions & 45 deletions pydantic/mypy.py
Expand Up @@ -126,10 +126,6 @@ def plugin(version: str) -> type[Plugin]:
return PydanticPlugin


class _DeferAnalysis(Exception):
pass


class PydanticPlugin(Plugin):
"""The Pydantic mypy plugin."""

Expand Down Expand Up @@ -366,11 +362,8 @@ def expand_type(self, current_info: TypeInfo, api: SemanticAnalyzerPluginInterfa
# however this plugin is called very late, so all types should be fully ready.
# Also, it is tricky to avoid eager expansion of Self types here (e.g. because
# we serialize attributes).
expanded_type = expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
if isinstance(self.type, UnionType) and not isinstance(expanded_type, UnionType):
if not api.final_iteration:
raise _DeferAnalysis()
return expanded_type
with state.strict_optional_set(api.options.strict_optional):
return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
return self.type

def to_var(self, current_info: TypeInfo, api: SemanticAnalyzerPluginInterface, use_alias: bool) -> Var:
Expand Down Expand Up @@ -402,12 +395,13 @@ def deserialize(cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPlugin
typ = deserialize_and_fixup_type(data.pop('type'), api)
return cls(type=typ, info=info, **data)

def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
def expand_typevar_from_subtype(self, sub_type: TypeInfo, api: SemanticAnalyzerPluginInterface) -> None:
"""Expands type vars in the context of a subtype when an attribute is inherited
from a generic super type.
"""
if self.type is not None:
self.type = map_type_from_supertype(self.type, sub_type, self.info)
with state.strict_optional_set(api.options.strict_optional):
self.type = map_type_from_supertype(self.type, sub_type, self.info)


class PydanticModelClassVar:
Expand Down Expand Up @@ -485,13 +479,9 @@ def transform(self) -> bool:
return False

is_settings = any(base.fullname == BASESETTINGS_FULLNAME for base in info.mro[:-1])
try:
self.add_initializer(fields, config, is_settings, is_root_model)
self.add_model_construct_method(fields, config, is_settings)
self.set_frozen(fields, self._api, frozen=config.frozen is True)
except _DeferAnalysis:
if not self._api.final_iteration:
self._api.defer()
self.add_initializer(fields, config, is_settings, is_root_model)
self.add_model_construct_method(fields, config, is_settings)
self.set_frozen(fields, self._api, frozen=config.frozen is True)

self.adjust_decorator_signatures()

Expand Down Expand Up @@ -639,8 +629,7 @@ def collect_fields_and_class_vars(
# TODO: We shouldn't be performing type operations during the main
# semantic analysis pass, since some TypeInfo attributes might
# still be in flux. This should be performed in a later phase.
with state.strict_optional_set(self._api.options.strict_optional):
field.expand_typevar_from_subtype(cls.info)
field.expand_typevar_from_subtype(cls.info, self._api)
found_fields[name] = field

sym_node = cls.info.names.get(name)
Expand Down Expand Up @@ -862,32 +851,31 @@ def add_initializer(
typed = self.plugin_config.init_typed
use_alias = config.populate_by_name is not True
requires_dynamic_aliases = bool(config.has_alias_generator and not config.populate_by_name)
with state.strict_optional_set(self._api.options.strict_optional):
args = self.get_field_arguments(
fields,
typed=typed,
requires_dynamic_aliases=requires_dynamic_aliases,
use_alias=use_alias,
is_settings=is_settings,
)
args = self.get_field_arguments(
fields,
typed=typed,
requires_dynamic_aliases=requires_dynamic_aliases,
use_alias=use_alias,
is_settings=is_settings,
)

if is_root_model and MYPY_VERSION_TUPLE <= (1, 0, 1):
# convert root argument to positional argument
# This is needed because mypy support for `dataclass_transform` isn't complete on 1.0.1
args[0].kind = ARG_POS if args[0].kind == ARG_NAMED else ARG_OPT

if is_settings:
base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node
if '__init__' in base_settings_node.names:
base_settings_init_node = base_settings_node.names['__init__'].node
if base_settings_init_node is not None and base_settings_init_node.type is not None:
func_type = base_settings_init_node.type
for arg_idx, arg_name in enumerate(func_type.arg_names):
if arg_name.startswith('__') or not arg_name.startswith('_'):
continue
analyzed_variable_type = self._api.anal_type(func_type.arg_types[arg_idx])
variable = Var(arg_name, analyzed_variable_type)
args.append(Argument(variable, analyzed_variable_type, None, ARG_OPT))
if is_root_model and MYPY_VERSION_TUPLE <= (1, 0, 1):
# convert root argument to positional argument
# This is needed because mypy support for `dataclass_transform` isn't complete on 1.0.1
args[0].kind = ARG_POS if args[0].kind == ARG_NAMED else ARG_OPT

if is_settings:
base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node
if '__init__' in base_settings_node.names:
base_settings_init_node = base_settings_node.names['__init__'].node
if base_settings_init_node is not None and base_settings_init_node.type is not None:
func_type = base_settings_init_node.type
for arg_idx, arg_name in enumerate(func_type.arg_names):
if arg_name.startswith('__') or not arg_name.startswith('_'):
continue
analyzed_variable_type = self._api.anal_type(func_type.arg_types[arg_idx])
variable = Var(arg_name, analyzed_variable_type)
args.append(Argument(variable, analyzed_variable_type, None, ARG_OPT))

if not self.should_init_forbid_extra(fields, config):
var = Var('kwargs')
Expand Down
22 changes: 22 additions & 0 deletions tests/mypy/modules/plugin_optional_inheritance.py
@@ -0,0 +1,22 @@
from typing import Optional

from pydantic import BaseModel


class Foo(BaseModel):
id: Optional[int]


class Bar(BaseModel):
foo: Optional[Foo]


class Baz(Bar):
name: str


b = Bar(foo={'id': 1})
assert b.foo.id == 1

z = Baz(foo={'id': 1}, name='test')
assert z.foo.id == 1
@@ -0,0 +1,24 @@
from typing import Optional

from pydantic import BaseModel


class Foo(BaseModel):
id: Optional[int]


class Bar(BaseModel):
foo: Optional[Foo]


class Baz(Bar):
name: str


b = Bar(foo={'id': 1})
assert b.foo.id == 1
# MYPY: error: Item "None" of "Optional[Foo]" has no attribute "id" [union-attr]

z = Baz(foo={'id': 1}, name='test')
assert z.foo.id == 1
# MYPY: error: Item "None" of "Optional[Foo]" has no attribute "id" [union-attr]
@@ -0,0 +1,24 @@
from typing import Optional

from pydantic import BaseModel


class Foo(BaseModel):
id: Optional[int]


class Bar(BaseModel):
foo: Optional[Foo]


class Baz(Bar):
name: str


b = Bar(foo={'id': 1})
assert b.foo.id == 1
# MYPY: error: Item "None" of "Optional[Foo]" has no attribute "id" [union-attr]

z = Baz(foo={'id': 1}, name='test')
assert z.foo.id == 1
# MYPY: error: Item "None" of "Optional[Foo]" has no attribute "id" [union-attr]
@@ -0,0 +1,24 @@
from typing import Optional

from pydantic import BaseModel


class Foo(BaseModel):
id: Optional[int]


class Bar(BaseModel):
foo: Optional[Foo]


class Baz(Bar):
name: str


b = Bar(foo={'id': 1})
assert b.foo.id == 1
# MYPY: error: Item "None" of "Foo | None" has no attribute "id" [union-attr]

z = Baz(foo={'id': 1}, name='test')
assert z.foo.id == 1
# MYPY: error: Item "None" of "Foo | None" has no attribute "id" [union-attr]
@@ -0,0 +1,24 @@
from typing import Optional

from pydantic import BaseModel


class Foo(BaseModel):
id: Optional[int]


class Bar(BaseModel):
foo: Optional[Foo]


class Baz(Bar):
name: str


b = Bar(foo={'id': 1})
assert b.foo.id == 1
# MYPY: error: Item "None" of "Foo | None" has no attribute "id" [union-attr]

z = Baz(foo={'id': 1}, name='test')
assert z.foo.id == 1
# MYPY: error: Item "None" of "Foo | None" has no attribute "id" [union-attr]
1 change: 1 addition & 0 deletions tests/mypy/test_mypy.py
Expand Up @@ -87,6 +87,7 @@ def build(self) -> List[Union[Tuple[str, str], Any]]:
'plugin_fail.py',
'plugin_success_baseConfig.py',
'plugin_fail_baseConfig.py',
'plugin_optional_inheritance.py',
'pydantic_settings.py',
],
).build()
Expand Down

0 comments on commit 325ddf6

Please sign in to comment.