Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix import mechanism for task modules. #373

Merged
merged 16 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 2 additions & 8 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
import sys
import time
from importlib import util as importlib_util
from pathlib import Path
from typing import Any
from typing import Generator
Expand All @@ -27,6 +26,7 @@
from _pytask.outcomes import CollectionOutcome
from _pytask.outcomes import count_outcomes
from _pytask.path import find_case_sensitive_path
from _pytask.path import import_path
from _pytask.report import CollectionReport
from _pytask.session import Session
from _pytask.shared import find_duplicates
Expand Down Expand Up @@ -111,13 +111,7 @@ def pytask_collect_file(
) -> list[CollectionReport] | None:
"""Collect a file."""
if any(path.match(pattern) for pattern in session.config["task_files"]):
spec = importlib_util.spec_from_file_location(path.stem, str(path))

if spec is None:
raise ImportError(f"Can't find module {path.stem!r} at location {path}.")

mod = importlib_util.module_from_spec(spec)
spec.loader.exec_module(mod)
mod = import_path(path, session.config["root"])

collected_reports = []
for name, obj in inspect.getmembers(mod):
Expand Down
76 changes: 76 additions & 0 deletions src/_pytask/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from __future__ import annotations

import functools
import importlib.util
import os
import sys
from pathlib import Path
from types import ModuleType
from typing import Sequence


Expand Down Expand Up @@ -120,3 +123,76 @@ def find_case_sensitive_path(path: Path, platform: str) -> Path:
"""
out = path.resolve() if platform == "win32" else path
return out


def import_path(path: Path, root: Path) -> ModuleType:
"""Import and return a module from the given path.

The function is taken from pytest when the import mode is set to ``importlib``. It
pytest's recommended import mode for new projects although the default is set to
``prepend``. More discussion and information can be found in :gh:`373`.

"""
module_name = _module_name_from_path(path, root)

spec = importlib.util.spec_from_file_location(module_name, str(path))

if spec is None:
raise ImportError(f"Can't find module {module_name!r} at location {path}.")

mod = importlib.util.module_from_spec(spec)
sys.modules[module_name] = mod
spec.loader.exec_module(mod)
_insert_missing_modules(sys.modules, module_name)
return mod


def _module_name_from_path(path: Path, root: Path) -> str:
"""Return a dotted module name based on the given path, anchored on root.

For example: path="projects/src/tests/test_foo.py" and root="/projects", the
tobiasraabe marked this conversation as resolved.
Show resolved Hide resolved
resulting module name will be "src.tests.test_foo".

"""
path = path.with_suffix("")
try:
relative_path = path.relative_to(root)
except ValueError:
# If we can't get a relative path to root, use the full path, except for the
# first part ("d:\\" or "/" depending on the platform, for example).
path_parts = path.parts[1:]
else:
# Use the parts for the relative path to the root path.
path_parts = relative_path.parts

return ".".join(path_parts)


def _insert_missing_modules(modules: dict[str, ModuleType], module_name: str) -> None:
"""Insert missing modules when importing modules with :func:`import_path`.

When we want to import a module as ``src.tests.test_foo`` for example, we need to
tobiasraabe marked this conversation as resolved.
Show resolved Hide resolved
create empty modules ``src`` and ``src.tests`` after inserting
``src.tests.test_foo``, otherwise ``src.tests.test_foo`` is not importable by
``__import__``.

"""
module_parts = module_name.split(".")
while module_name:
if module_name not in modules:
try:
# If sys.meta_path is empty, calling import_module will issue a warning
# and raise ModuleNotFoundError. To avoid the warning, we check
# sys.meta_path explicitly and raise the error ourselves to fall back to
# creating a dummy module.
if not sys.meta_path:
raise ModuleNotFoundError
importlib.import_module(module_name)
except ModuleNotFoundError:
module = ModuleType(
module_name,
doc="Empty module created by pytask.",
)
modules[module_name] = module
module_parts.pop(-1)
module_name = ".".join(module_parts)
22 changes: 22 additions & 0 deletions tests/test_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@ def task_write_text(depends_on, produces):
assert tmp_path.joinpath("out.txt").read_text() == "Relative paths work."


@pytest.mark.end_to_end()
def test_collect_module_name_(tmp_path):
tobiasraabe marked this conversation as resolved.
Show resolved Hide resolved
"""We need to add a task module to the sys.modules. See #373 and #374."""
source = """
# without this import, everything works fine
from __future__ import annotations

import dataclasses

@dataclasses.dataclass
class Data:
x: int

def task_my_task():
pass
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
session = main({"paths": tmp_path})
outcome = session.collection_reports[0].outcome
assert outcome == CollectionOutcome.SUCCESS


@pytest.mark.end_to_end()
def test_collect_filepathnode_with_unknown_type(tmp_path):
"""If a node cannot be parsed because unknown type, raise an error."""
Expand Down
4 changes: 3 additions & 1 deletion tests/test_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def task_1():
x = 3
print("hello18")
assert count_continue == 2, "unexpected_failure: %d != 2" % count_continue
raise Exception("expected_failure")
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))

Expand Down Expand Up @@ -399,7 +400,8 @@ def task_1():
assert "1" in rest
assert "failed" in rest
assert "Failed" in rest
assert "AssertionError: unexpected_failure" in rest
assert "AssertionError: unexpected_failure" not in rest
assert "expected_failure" in rest
_flush(child)


Expand Down
182 changes: 182 additions & 0 deletions tests/test_path.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from __future__ import annotations

import sys
import textwrap
from contextlib import ExitStack as does_not_raise # noqa: N813
from pathlib import Path
from pathlib import PurePosixPath
from pathlib import PureWindowsPath
from types import ModuleType
from typing import Any

import pytest
from _pytask.path import _insert_missing_modules
from _pytask.path import _module_name_from_path
from _pytask.path import find_case_sensitive_path
from _pytask.path import find_closest_ancestor
from _pytask.path import find_common_ancestor
from _pytask.path import import_path
from _pytask.path import relative_to


Expand Down Expand Up @@ -117,3 +123,179 @@ def test_find_case_sensitive_path(tmp_path, path, existing_paths, expected):

result = find_case_sensitive_path(tmp_path / path, sys.platform)
assert result == tmp_path / expected


@pytest.fixture()
def simple_module(tmp_path: Path) -> Path:
fn = tmp_path / "_src/tests/mymod.py"
tobiasraabe marked this conversation as resolved.
Show resolved Hide resolved
fn.parent.mkdir(parents=True)
fn.write_text("def foo(x): return 40 + x")
return fn


def test_importmode_importlib(simple_module: Path, tmp_path: Path) -> None:
"""`importlib` mode does not change sys.path."""
module = import_path(simple_module, root=tmp_path)
assert module.foo(2) == 42 # type: ignore[attr-defined]
assert str(simple_module.parent) not in sys.path
assert module.__name__ in sys.modules
assert module.__name__ == "_src.tests.mymod"
assert "_src" in sys.modules
assert "_src.tests" in sys.modules


def test_importmode_twice_is_different_module(
simple_module: Path, tmp_path: Path
) -> None:
"""`importlib` mode always returns a new module."""
module1 = import_path(simple_module, root=tmp_path)
module2 = import_path(simple_module, root=tmp_path)
assert module1 is not module2


def test_no_meta_path_found(
simple_module: Path, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
"""Even without any meta_path should still import module."""
monkeypatch.setattr(sys, "meta_path", [])
module = import_path(simple_module, root=tmp_path)
assert module.foo(2) == 42 # type: ignore[attr-defined]

# mode='importlib' fails if no spec is found to load the module
import importlib.util

monkeypatch.setattr(
importlib.util, "spec_from_file_location", lambda *args: None # noqa: ARG005
)
with pytest.raises(ImportError):
import_path(simple_module, root=tmp_path)


def test_importmode_importlib_with_dataclass(tmp_path: Path) -> None:
"""Ensure that importlib mode works with a module containing dataclasses (#7856)."""
tobiasraabe marked this conversation as resolved.
Show resolved Hide resolved
fn = tmp_path.joinpath("_src/tests/test_dataclass.py")
fn.parent.mkdir(parents=True)
fn.write_text(
textwrap.dedent(
"""
from dataclasses import dataclass

@dataclass
class Data:
value: str
"""
)
)

module = import_path(fn, root=tmp_path)
Data: Any = module.Data # noqa: N806
data = Data(value="foo")
assert data.value == "foo"
assert data.__module__ == "_src.tests.test_dataclass"


def test_importmode_importlib_with_pickle(tmp_path: Path) -> None:
"""Ensure that importlib mode works with pickle (#7859)."""
tobiasraabe marked this conversation as resolved.
Show resolved Hide resolved
fn = tmp_path.joinpath("_src/tests/test_pickle.py")
tobiasraabe marked this conversation as resolved.
Show resolved Hide resolved
fn.parent.mkdir(parents=True)
fn.write_text(
textwrap.dedent(
"""
import pickle

def _action():
return 42

def round_trip():
s = pickle.dumps(_action)
return pickle.loads(s)
"""
)
)

module = import_path(fn, root=tmp_path)
round_trip = module.round_trip
action = round_trip()
assert action() == 42


def test_importmode_importlib_with_pickle_separate_modules(tmp_path: Path) -> None:
"""
Ensure that importlib mode works can load pickles that look similar but are
defined in separate modules.
"""
fn1 = tmp_path.joinpath("_src/m1/tests/test.py")
fn1.parent.mkdir(parents=True)
fn1.write_text(
textwrap.dedent(
"""
import dataclasses
import pickle

@dataclasses.dataclass
class Data:
x: int = 42
"""
)
)

fn2 = tmp_path.joinpath("_src/m2/tests/test.py")
fn2.parent.mkdir(parents=True)
fn2.write_text(
textwrap.dedent(
"""
import dataclasses
import pickle

@dataclasses.dataclass
class Data:
x: str = ""
"""
)
)

import pickle

def round_trip(obj):
s = pickle.dumps(obj)
return pickle.loads(s) # noqa: S301

module = import_path(fn1, root=tmp_path)
Data1 = module.Data # noqa: N806

module = import_path(fn2, root=tmp_path)
Data2 = module.Data # noqa: N806

assert round_trip(Data1(20)) == Data1(20)
assert round_trip(Data2("hello")) == Data2("hello")
assert Data1.__module__ == "_src.m1.tests.test"
assert Data2.__module__ == "_src.m2.tests.test"


def test_module_name_from_path(tmp_path: Path) -> None:
result = _module_name_from_path(tmp_path / "src/tests/test_foo.py", tmp_path)
assert result == "src.tests.test_foo"

# Path is not relative to root dir: use the full path to obtain the module name.
result = _module_name_from_path(Path("/home/foo/test_foo.py"), Path("/bar"))
assert result == "home.foo.test_foo"


def test_insert_missing_modules(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
monkeypatch.chdir(tmp_path)
# Use 'xxx' and 'xxy' as parent names as they are unlikely to exist and
# don't end up being imported.
modules = {"xxx.tests.foo": ModuleType("xxx.tests.foo")}
_insert_missing_modules(modules, "xxx.tests.foo")
assert sorted(modules) == ["xxx", "xxx.tests", "xxx.tests.foo"]

mod = ModuleType("mod", doc="My Module")
modules = {"xxy": mod}
_insert_missing_modules(modules, "xxy")
assert modules == {"xxy": mod}

modules = {}
_insert_missing_modules(modules, "")
assert not modules