Skip to content

Commit

Permalink
Fix mishandling of unions while freezing types in the mypy plugin (#7411
Browse files Browse the repository at this point in the history
)
  • Loading branch information
dmontagu committed Nov 13, 2023
1 parent 7df25da commit 4a2a5bb
Show file tree
Hide file tree
Showing 12 changed files with 141 additions and 2 deletions.
15 changes: 13 additions & 2 deletions pydantic/mypy.py
Expand Up @@ -126,6 +126,10 @@ def plugin(version: str) -> type[Plugin]:
return PydanticPlugin


class _DeferAnalysis(Exception):
pass


class PydanticPlugin(Plugin):
"""The Pydantic mypy plugin."""

Expand Down Expand Up @@ -353,7 +357,10 @@ def expand_type(self, current_info: TypeInfo) -> Type | None:
# however this plugin is called very late, so all types should be fully ready.
# Also, it is tricky to avoid eager expansion of Self types here (e.g. because
# we serialize attributes).
return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
expanded_type = expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
if isinstance(self.type, UnionType) and not isinstance(expanded_type, UnionType):
raise _DeferAnalysis()
return expanded_type
return self.type

def to_var(self, current_info: TypeInfo, use_alias: bool) -> Var:
Expand Down Expand Up @@ -445,7 +452,11 @@ def transform(self) -> bool:
is_settings = any(base.fullname == BASESETTINGS_FULLNAME for base in info.mro[:-1])
self.add_initializer(fields, config, is_settings, is_root_model)
self.add_model_construct_method(fields, config, is_settings)
self.set_frozen(fields, frozen=config.frozen is True)
try:
self.set_frozen(fields, frozen=config.frozen is True)
except _DeferAnalysis:
if not self._api.final_iteration:
self._api.defer()

self.adjust_decorator_signatures()

Expand Down
1 change: 1 addition & 0 deletions tests/mypy/configs/mypy-plugin-strict-no-any.ini
@@ -1,6 +1,7 @@
[mypy]
plugins = pydantic.mypy

warn_unreachable = true
follow_imports = silent
strict_optional = True
warn_redundant_casts = True
Expand Down
14 changes: 14 additions & 0 deletions tests/mypy/modules/plugin_success.py
Expand Up @@ -300,3 +300,17 @@ def validate_after(self) -> Self:
return self

MyModel(number=2)


class InnerModel(BaseModel):
my_var: Union[str, None] = Field(default=None)


class OuterModel(InnerModel):
pass


m = OuterModel()
if m.my_var is None:
# In https://github.com/pydantic/pydantic/issues/7399, this was unreachable
print('not unreachable')
14 changes: 14 additions & 0 deletions tests/mypy/outputs/1.0.1/mypy-default_ini/plugin_success.py
Expand Up @@ -306,3 +306,17 @@ def validate_after(self) -> Self:
return self

MyModel(number=2)


class InnerModel(BaseModel):
my_var: Union[str, None] = Field(default=None)


class OuterModel(InnerModel):
pass


m = OuterModel()
if m.my_var is None:
# In https://github.com/pydantic/pydantic/issues/7399, this was unreachable
print('not unreachable')
14 changes: 14 additions & 0 deletions tests/mypy/outputs/1.0.1/mypy-plugin-strict_ini/plugin_success.py
Expand Up @@ -302,3 +302,17 @@ def validate_after(self) -> Self:
return self

MyModel(number=2)


class InnerModel(BaseModel):
my_var: Union[str, None] = Field(default=None)


class OuterModel(InnerModel):
pass


m = OuterModel()
if m.my_var is None:
# In https://github.com/pydantic/pydantic/issues/7399, this was unreachable
print('not unreachable')
14 changes: 14 additions & 0 deletions tests/mypy/outputs/1.0.1/mypy-plugin_ini/plugin_success.py
Expand Up @@ -300,3 +300,17 @@ def validate_after(self) -> Self:
return self

MyModel(number=2)


class InnerModel(BaseModel):
my_var: Union[str, None] = Field(default=None)


class OuterModel(InnerModel):
pass


m = OuterModel()
if m.my_var is None:
# In https://github.com/pydantic/pydantic/issues/7399, this was unreachable
print('not unreachable')
Expand Up @@ -302,3 +302,17 @@ def validate_after(self) -> Self:
return self

MyModel(number=2)


class InnerModel(BaseModel):
my_var: Union[str, None] = Field(default=None)


class OuterModel(InnerModel):
pass


m = OuterModel()
if m.my_var is None:
# In https://github.com/pydantic/pydantic/issues/7399, this was unreachable
print('not unreachable')
14 changes: 14 additions & 0 deletions tests/mypy/outputs/1.0.1/pyproject-plugin_toml/plugin_success.py
Expand Up @@ -300,3 +300,17 @@ def validate_after(self) -> Self:
return self

MyModel(number=2)


class InnerModel(BaseModel):
my_var: Union[str, None] = Field(default=None)


class OuterModel(InnerModel):
pass


m = OuterModel()
if m.my_var is None:
# In https://github.com/pydantic/pydantic/issues/7399, this was unreachable
print('not unreachable')
14 changes: 14 additions & 0 deletions tests/mypy/outputs/1.1.1/mypy-default_ini/plugin_success.py
Expand Up @@ -307,3 +307,17 @@ def validate_after(self) -> Self:
return self

MyModel(number=2)


class InnerModel(BaseModel):
my_var: Union[str, None] = Field(default=None)


class OuterModel(InnerModel):
pass


m = OuterModel()
if m.my_var is None:
# In https://github.com/pydantic/pydantic/issues/7399, this was unreachable
print('not unreachable')
14 changes: 14 additions & 0 deletions tests/mypy/outputs/1.2.0/mypy-default_ini/plugin_success.py
Expand Up @@ -305,3 +305,17 @@ def validate_after(self) -> Self:
return self

MyModel(number=2)


class InnerModel(BaseModel):
my_var: Union[str, None] = Field(default=None)


class OuterModel(InnerModel):
pass


m = OuterModel()
if m.my_var is None:
# In https://github.com/pydantic/pydantic/issues/7399, this was unreachable
print('not unreachable')
14 changes: 14 additions & 0 deletions tests/mypy/outputs/1.4.1/mypy-default_ini/plugin_success.py
Expand Up @@ -305,3 +305,17 @@ def validate_after(self) -> Self:
return self

MyModel(number=2)


class InnerModel(BaseModel):
my_var: Union[str, None] = Field(default=None)


class OuterModel(InnerModel):
pass


m = OuterModel()
if m.my_var is None:
# In https://github.com/pydantic/pydantic/issues/7399, this was unreachable
print('not unreachable')
1 change: 1 addition & 0 deletions tests/mypy/test_mypy.py
Expand Up @@ -155,6 +155,7 @@ def _convert_to_output_path(v: str) -> Path:
return MypyTestConfig(existing, current)


@pytest.mark.filterwarnings('ignore:ast.:DeprecationWarning') # these are produced by mypy in python 3.12
@pytest.mark.parametrize('config_filename,python_filename', cases)
def test_mypy_results(config_filename: str, python_filename: str, request: pytest.FixtureRequest) -> None:
input_path = PYDANTIC_ROOT / 'tests/mypy/modules' / python_filename
Expand Down

0 comments on commit 4a2a5bb

Please sign in to comment.