Skip to content

Commit

Permalink
Fix import mechanism for task modules. (#373)
Browse files Browse the repository at this point in the history
Co-authored-by: Nick Crews <nicholas.b.crews@gmail.com>
Co-authored-by: Tobias Raabe <raabe@posteo.de>
  • Loading branch information
NickCrews and tobiasraabe committed May 23, 2023
1 parent 836518a commit 9e6be35
Show file tree
Hide file tree
Showing 5 changed files with 306 additions and 9 deletions.
10 changes: 2 additions & 8 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sys
import time
import warnings
from importlib import util as importlib_util
from pathlib import Path
from typing import Any
from typing import Generator
Expand All @@ -28,6 +27,7 @@
from _pytask.outcomes import CollectionOutcome
from _pytask.outcomes import count_outcomes
from _pytask.path import find_case_sensitive_path
from _pytask.path import import_path
from _pytask.report import CollectionReport
from _pytask.session import Session
from _pytask.shared import find_duplicates
Expand Down Expand Up @@ -121,13 +121,7 @@ def pytask_collect_file(
) -> list[CollectionReport] | None:
"""Collect a file."""
if any(path.match(pattern) for pattern in session.config["task_files"]):
spec = importlib_util.spec_from_file_location(path.stem, str(path))

if spec is None:
raise ImportError(f"Can't find module {path.stem!r} at location {path}.")

mod = importlib_util.module_from_spec(spec)
spec.loader.exec_module(mod)
mod = import_path(path, session.config["root"])

collected_reports = []
for name, obj in inspect.getmembers(mod):
Expand Down
76 changes: 76 additions & 0 deletions src/_pytask/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from __future__ import annotations

import functools
import importlib.util
import os
import sys
from pathlib import Path
from types import ModuleType
from typing import Sequence


Expand Down Expand Up @@ -120,3 +123,76 @@ def find_case_sensitive_path(path: Path, platform: str) -> Path:
"""
out = path.resolve() if platform == "win32" else path
return out


def import_path(path: Path, root: Path) -> ModuleType:
"""Import and return a module from the given path.
The function is taken from pytest when the import mode is set to ``importlib``. It
pytest's recommended import mode for new projects although the default is set to
``prepend``. More discussion and information can be found in :gh:`373`.
"""
module_name = _module_name_from_path(path, root)

spec = importlib.util.spec_from_file_location(module_name, str(path))

if spec is None:
raise ImportError(f"Can't find module {module_name!r} at location {path}.")

mod = importlib.util.module_from_spec(spec)
sys.modules[module_name] = mod
spec.loader.exec_module(mod)
_insert_missing_modules(sys.modules, module_name)
return mod


def _module_name_from_path(path: Path, root: Path) -> str:
"""Return a dotted module name based on the given path, anchored on root.
For example: path="projects/src/project/task_foo.py" and root="/projects", the
resulting module name will be "src.project.task_foo".
"""
path = path.with_suffix("")
try:
relative_path = path.relative_to(root)
except ValueError:
# If we can't get a relative path to root, use the full path, except for the
# first part ("d:\\" or "/" depending on the platform, for example).
path_parts = path.parts[1:]
else:
# Use the parts for the relative path to the root path.
path_parts = relative_path.parts

return ".".join(path_parts)


def _insert_missing_modules(modules: dict[str, ModuleType], module_name: str) -> None:
"""Insert missing modules when importing modules with :func:`import_path`.
When we want to import a module as ``src.project.task_foo`` for example, we need to
create empty modules ``src`` and ``src.project`` after inserting
``src.project.task_foo``, otherwise ``src.project.task_foo`` is not importable by
``__import__``.
"""
module_parts = module_name.split(".")
while module_name:
if module_name not in modules:
try:
# If sys.meta_path is empty, calling import_module will issue a warning
# and raise ModuleNotFoundError. To avoid the warning, we check
# sys.meta_path explicitly and raise the error ourselves to fall back to
# creating a dummy module.
if not sys.meta_path:
raise ModuleNotFoundError
importlib.import_module(module_name)
except ModuleNotFoundError:
module = ModuleType(
module_name,
doc="Empty module created by pytask.",
)
modules[module_name] = module
module_parts.pop(-1)
module_name = ".".join(module_parts)
40 changes: 40 additions & 0 deletions tests/test_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,46 @@ def task_write_text(depends_on, produces):
assert tmp_path.joinpath("out.txt").read_text() == "Relative paths work."


@pytest.mark.end_to_end()
def test_collect_tasks_from_modules_with_the_same_name(tmp_path):
"""We need to check that task modules can have the same name. See #373 and #374."""
tmp_path.joinpath("a").mkdir()
tmp_path.joinpath("b").mkdir()
tmp_path.joinpath("a", "task_module.py").write_text("def task_a(): pass")
tmp_path.joinpath("b", "task_module.py").write_text("def task_a(): pass")
session = main({"paths": tmp_path})
assert len(session.collection_reports) == 2
assert all(
report.outcome == CollectionOutcome.SUCCESS
for report in session.collection_reports
)
assert {
report.node.function.__module__ for report in session.collection_reports
} == {"a.task_module", "b.task_module"}


@pytest.mark.end_to_end()
def test_collect_module_name(tmp_path):
"""We need to add a task module to the sys.modules. See #373 and #374."""
source = """
# without this import, everything works fine
from __future__ import annotations
import dataclasses
@dataclasses.dataclass
class Data:
x: int
def task_my_task():
pass
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
session = main({"paths": tmp_path})
outcome = session.collection_reports[0].outcome
assert outcome == CollectionOutcome.SUCCESS


@pytest.mark.end_to_end()
def test_collect_filepathnode_with_unknown_type(tmp_path):
"""If a node cannot be parsed because unknown type, raise an error."""
Expand Down
4 changes: 3 additions & 1 deletion tests/test_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def task_1():
x = 3
print("hello18")
assert count_continue == 2, "unexpected_failure: %d != 2" % count_continue
raise Exception("expected_failure")
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))

Expand Down Expand Up @@ -399,7 +400,8 @@ def task_1():
assert "1" in rest
assert "failed" in rest
assert "Failed" in rest
assert "AssertionError: unexpected_failure" in rest
assert "AssertionError: unexpected_failure" not in rest
assert "expected_failure" in rest
_flush(child)


Expand Down
185 changes: 185 additions & 0 deletions tests/test_path.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from __future__ import annotations

import sys
import textwrap
from contextlib import ExitStack as does_not_raise # noqa: N813
from pathlib import Path
from pathlib import PurePosixPath
from pathlib import PureWindowsPath
from types import ModuleType
from typing import Any

import pytest
from _pytask.path import _insert_missing_modules
from _pytask.path import _module_name_from_path
from _pytask.path import find_case_sensitive_path
from _pytask.path import find_closest_ancestor
from _pytask.path import find_common_ancestor
from _pytask.path import import_path
from _pytask.path import relative_to


Expand Down Expand Up @@ -117,3 +123,182 @@ def test_find_case_sensitive_path(tmp_path, path, existing_paths, expected):

result = find_case_sensitive_path(tmp_path / path, sys.platform)
assert result == tmp_path / expected


@pytest.fixture()
def simple_module(tmp_path: Path) -> Path:
fn = tmp_path / "_src/project/mymod.py"
fn.parent.mkdir(parents=True)
fn.write_text("def foo(x): return 40 + x")
return fn


def test_importmode_importlib(simple_module: Path, tmp_path: Path) -> None:
"""`importlib` mode does not change sys.path."""
module = import_path(simple_module, root=tmp_path)
assert module.foo(2) == 42 # type: ignore[attr-defined]
assert str(simple_module.parent) not in sys.path
assert module.__name__ in sys.modules
assert module.__name__ == "_src.project.mymod"
assert "_src" in sys.modules
assert "_src.project" in sys.modules


def test_importmode_twice_is_different_module(
simple_module: Path, tmp_path: Path
) -> None:
"""`importlib` mode always returns a new module."""
module1 = import_path(simple_module, root=tmp_path)
module2 = import_path(simple_module, root=tmp_path)
assert module1 is not module2


def test_no_meta_path_found(
simple_module: Path, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
"""Even without any meta_path should still import module."""
monkeypatch.setattr(sys, "meta_path", [])
module = import_path(simple_module, root=tmp_path)
assert module.foo(2) == 42 # type: ignore[attr-defined]

# mode='importlib' fails if no spec is found to load the module
import importlib.util

monkeypatch.setattr(
importlib.util, "spec_from_file_location", lambda *args: None # noqa: ARG005
)
with pytest.raises(ImportError):
import_path(simple_module, root=tmp_path)


def test_importmode_importlib_with_dataclass(tmp_path: Path) -> None:
"""
Ensure that importlib mode works with a module containing dataclasses (#373,
pytest#7856).
"""
fn = tmp_path.joinpath("_src/project/task_dataclass.py")
fn.parent.mkdir(parents=True)
fn.write_text(
textwrap.dedent(
"""
from dataclasses import dataclass
@dataclass
class Data:
value: str
"""
)
)

module = import_path(fn, root=tmp_path)
Data: Any = module.Data # noqa: N806
data = Data(value="foo")
assert data.value == "foo"
assert data.__module__ == "_src.project.task_dataclass"


def test_importmode_importlib_with_pickle(tmp_path: Path) -> None:
"""Ensure that importlib mode works with pickle (#373, pytest#7859)."""
fn = tmp_path.joinpath("_src/project/task_pickle.py")
fn.parent.mkdir(parents=True)
fn.write_text(
textwrap.dedent(
"""
import pickle
def _action():
return 42
def round_trip():
s = pickle.dumps(_action)
return pickle.loads(s)
"""
)
)

module = import_path(fn, root=tmp_path)
round_trip = module.round_trip
action = round_trip()
assert action() == 42


def test_importmode_importlib_with_pickle_separate_modules(tmp_path: Path) -> None:
"""
Ensure that importlib mode works can load pickles that look similar but are
defined in separate modules.
"""
fn1 = tmp_path.joinpath("_src/m1/project/task.py")
fn1.parent.mkdir(parents=True)
fn1.write_text(
textwrap.dedent(
"""
import dataclasses
import pickle
@dataclasses.dataclass
class Data:
x: int = 42
"""
)
)

fn2 = tmp_path.joinpath("_src/m2/project/task.py")
fn2.parent.mkdir(parents=True)
fn2.write_text(
textwrap.dedent(
"""
import dataclasses
import pickle
@dataclasses.dataclass
class Data:
x: str = ""
"""
)
)

import pickle

def round_trip(obj):
s = pickle.dumps(obj)
return pickle.loads(s) # noqa: S301

module = import_path(fn1, root=tmp_path)
Data1 = module.Data # noqa: N806

module = import_path(fn2, root=tmp_path)
Data2 = module.Data # noqa: N806

assert round_trip(Data1(20)) == Data1(20)
assert round_trip(Data2("hello")) == Data2("hello")
assert Data1.__module__ == "_src.m1.project.task"
assert Data2.__module__ == "_src.m2.project.task"


def test_module_name_from_path(tmp_path: Path) -> None:
result = _module_name_from_path(tmp_path / "src/project/task_foo.py", tmp_path)
assert result == "src.project.task_foo"

# Path is not relative to root dir: use the full path to obtain the module name.
result = _module_name_from_path(Path("/home/foo/task_foo.py"), Path("/bar"))
assert result == "home.foo.task_foo"


def test_insert_missing_modules(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
monkeypatch.chdir(tmp_path)
# Use 'xxx' and 'xxy' as parent names as they are unlikely to exist and
# don't end up being imported.
modules = {"xxx.project.foo": ModuleType("xxx.project.foo")}
_insert_missing_modules(modules, "xxx.project.foo")
assert sorted(modules) == ["xxx", "xxx.project", "xxx.project.foo"]

mod = ModuleType("mod", doc="My Module")
modules = {"xxy": mod}
_insert_missing_modules(modules, "xxy")
assert modules == {"xxy": mod}

modules = {}
_insert_missing_modules(modules, "")
assert not modules

0 comments on commit 9e6be35

Please sign in to comment.