diff --git a/nox/sessions.py b/nox/sessions.py index 043a9e28..56237783 100644 --- a/nox/sessions.py +++ b/nox/sessions.py @@ -21,6 +21,7 @@ import sys import unicodedata import warnings +from types import TracebackType from typing import ( Any, Callable, @@ -31,6 +32,7 @@ Optional, Sequence, Tuple, + Type, Union, ) @@ -210,10 +212,37 @@ def invoked_from(self) -> str: """ return self._runner.global_config.invoked_from - def chdir(self, dir: Union[str, os.PathLike]) -> None: - """Change the current working directory.""" + class _WorkingDirContext: + def __init__(self, dir: Union[str, os.PathLike]) -> None: + self._prev_working_dir = os.getcwd() + os.chdir(dir) + + def __enter__(self) -> "Session._WorkingDirContext": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + os.chdir(self._prev_working_dir) + + def chdir(self, dir: Union[str, os.PathLike]) -> "Session._WorkingDirContext": + """Change the current working directory. + + Can be used as a context manager to automatically restore the working directory:: + + with session.chdir("somewhere/deep/in/monorepo"): + # Runs in "/somewhere/deep/in/monorepo" + session.run("pytest") + + # Runs in original working directory + session.run("flake8") + + """ self.log(f"cd {dir}") - os.chdir(dir) + return Session._WorkingDirContext(dir) cd = chdir """An alias for :meth:`chdir`.""" diff --git a/tests/test_sessions.py b/tests/test_sessions.py index 4f31a826..262d9f32 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -164,6 +164,19 @@ def test_chdir(self, tmpdir): assert os.getcwd() == cdto os.chdir(current_cwd) + def test_chdir_ctx(self, tmpdir): + cdto = str(tmpdir.join("cdbby").ensure(dir=True)) + current_cwd = os.getcwd() + + session, _ = self.make_session_and_runner() + + with session.chdir(cdto): + assert os.getcwd() == cdto + + assert os.getcwd() == current_cwd + + os.chdir(current_cwd) + def test_invoked_from(self, tmpdir): cdto = str(tmpdir.join("cdbby").ensure(dir=True)) current_cwd = os.getcwd()