Skip to content

Commit

Permalink
Correct pydantic dataclasses import (#8027)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Nov 6, 2023
1 parent 3a15424 commit e01cad6
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pydantic/__init__.py
Expand Up @@ -365,7 +365,7 @@ def __getattr__(attr_name: str) -> object:
from importlib import import_module

if module_name == '__module__':
return import_module(attr_name, package=package)
return import_module(f'.{attr_name}', package=package)
else:
module = import_module(module_name, package=package)
return getattr(module, attr_name)
Expand Down
4 changes: 2 additions & 2 deletions pydantic/_migration.py
Expand Up @@ -271,7 +271,7 @@ def wrapper(name: str) -> object:
The object.
"""
if name == '__path__':
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')
raise AttributeError(f'module {module!r} has no attribute {name!r}')

import warnings

Expand Down Expand Up @@ -303,6 +303,6 @@ def wrapper(name: str) -> object:
globals: Dict[str, Any] = sys.modules[module].__dict__
if name in globals:
return globals[name]
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')
raise AttributeError(f'module {module!r} has no attribute {name!r}')

return wrapper
18 changes: 18 additions & 0 deletions tests/conftest.py
Expand Up @@ -3,9 +3,11 @@
import os
import re
import secrets
import subprocess
import sys
import textwrap
from dataclasses import dataclass
from pathlib import Path
from types import FunctionType
from typing import Any, Optional

Expand Down Expand Up @@ -83,6 +85,22 @@ def run(source_code_or_function, rewrite_assertions=True, module_name_prefix=Non
return run


@pytest.fixture
def subprocess_run_code(tmp_path: Path):
def run_code(source_code_or_function) -> str:
if isinstance(source_code_or_function, FunctionType):
source_code = _extract_source_code_from_function(source_code_or_function)
else:
source_code = source_code_or_function

py_file = tmp_path / 'test.py'
py_file.write_text(source_code)

return subprocess.check_output([sys.executable, str(py_file)], cwd=tmp_path, encoding='utf8')

return run_code


@dataclass
class Err:
message: str
Expand Down
53 changes: 42 additions & 11 deletions tests/test_exports.py
Expand Up @@ -2,7 +2,6 @@
import importlib.util
import json
import platform
import subprocess
import sys
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -74,11 +73,8 @@ def test_public_internal():
"""


def test_import_pydantic(tmp_path: Path):
py_file = tmp_path / 'test.py'
py_file.write_text(IMPORTED_PYDANTIC_CODE)

output = subprocess.check_output([sys.executable, str(py_file)], cwd=tmp_path)
def test_import_pydantic(subprocess_run_code):
output = subprocess_run_code(IMPORTED_PYDANTIC_CODE)
imported_modules = json.loads(output)
# debug(imported_modules)
assert 'pydantic' in imported_modules
Expand All @@ -97,14 +93,49 @@ def test_import_pydantic(tmp_path: Path):
"""


def test_import_base_model(tmp_path: Path):
py_file = tmp_path / 'test.py'
py_file.write_text(IMPORTED_BASEMODEL_CODE)

output = subprocess.check_output([sys.executable, str(py_file)], cwd=tmp_path)
def test_import_base_model(subprocess_run_code):
output = subprocess_run_code(IMPORTED_BASEMODEL_CODE)
imported_modules = json.loads(output)
# debug(sorted(imported_modules))
assert 'pydantic' in imported_modules
assert 'pydantic.fields' not in imported_modules
assert 'pydantic.types' not in imported_modules
assert 'annotated_types' not in imported_modules


def test_dataclass_import(subprocess_run_code):
@subprocess_run_code
def run_in_subprocess():
import pydantic

assert pydantic.dataclasses.__name__ == 'pydantic.dataclasses'

@pydantic.dataclasses.dataclass
class Foo:
a: int

try:
Foo('not an int')
except ValueError:
pass
else:
raise AssertionError('Should have raised a ValueError')


def test_dataclass_import2(subprocess_run_code):
@subprocess_run_code
def run_in_subprocess():
import pydantic.dataclasses

assert pydantic.dataclasses.__name__ == 'pydantic.dataclasses'

@pydantic.dataclasses.dataclass
class Foo:
a: int

try:
Foo('not an int')
except ValueError:
pass
else:
raise AssertionError('Should have raised a ValueError')
2 changes: 1 addition & 1 deletion tests/test_migration.py
Expand Up @@ -50,5 +50,5 @@ def test_getattr_migration():

assert callable(get_attr('test_getattr_migration')) is True

with pytest.raises(AttributeError, match="module 'pydantic._migration' has no attribute 'foo'"):
with pytest.raises(AttributeError, match="module 'tests.test_migration' has no attribute 'foo'"):
get_attr('foo')

0 comments on commit e01cad6

Please sign in to comment.