Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

capture: add type annotations to CaptureFixture #7631

Merged
merged 2 commits into from
Aug 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
100 changes: 79 additions & 21 deletions src/_pytest/capture.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Per-test stdout/stderr capturing mechanism."""
import collections
import contextlib
import functools
import io
import os
import sys
from io import UnsupportedOperation
from tempfile import TemporaryFile
from typing import Any
from typing import AnyStr
from typing import Generator
from typing import Generic
from typing import Iterator
from typing import Optional
from typing import TextIO
from typing import Tuple
Expand Down Expand Up @@ -488,10 +492,64 @@ def writeorg(self, data):

# MultiCapture

CaptureResult = collections.namedtuple("CaptureResult", ["out", "err"])

# This class was a namedtuple, but due to mypy limitation[0] it could not be
# made generic, so was replaced by a regular class which tries to emulate the
# pertinent parts of a namedtuple. If the mypy limitation is ever lifted, can
# make it a namedtuple again.
# [0]: https://github.com/python/mypy/issues/685
@functools.total_ordering
class CaptureResult(Generic[AnyStr]):
"""The result of :method:`CaptureFixture.readouterr`."""

class MultiCapture:
# Can't use slots in Python<3.5.3 due to https://bugs.python.org/issue31272
if sys.version_info >= (3, 5, 3):
__slots__ = ("out", "err")

def __init__(self, out: AnyStr, err: AnyStr) -> None:
self.out = out # type: AnyStr
self.err = err # type: AnyStr

def __len__(self) -> int:
return 2

def __iter__(self) -> Iterator[AnyStr]:
return iter((self.out, self.err))

def __getitem__(self, item: int) -> AnyStr:
return tuple(self)[item]

def _replace(
self, *, out: Optional[AnyStr] = None, err: Optional[AnyStr] = None
) -> "CaptureResult[AnyStr]":
return CaptureResult(
out=self.out if out is None else out, err=self.err if err is None else err
)

def count(self, value: AnyStr) -> int:
return tuple(self).count(value)

def index(self, value) -> int:
return tuple(self).index(value)

def __eq__(self, other: object) -> bool:
if not isinstance(other, (CaptureResult, tuple)):
return NotImplemented
return tuple(self) == tuple(other)

def __hash__(self) -> int:
return hash(tuple(self))

def __lt__(self, other: object) -> bool:
if not isinstance(other, (CaptureResult, tuple)):
return NotImplemented
return tuple(self) < tuple(other)

def __repr__(self) -> str:
return "CaptureResult(out={!r}, err={!r})".format(self.out, self.err)


class MultiCapture(Generic[AnyStr]):
_state = None
_in_suspended = False

Expand All @@ -514,7 +572,7 @@ def start_capturing(self) -> None:
if self.err:
self.err.start()

def pop_outerr_to_orig(self):
def pop_outerr_to_orig(self) -> Tuple[AnyStr, AnyStr]:
"""Pop current snapshot out/err capture and flush to orig streams."""
out, err = self.readouterr()
if out:
Expand Down Expand Up @@ -555,7 +613,7 @@ def stop_capturing(self) -> None:
if self.in_:
self.in_.done()

def readouterr(self) -> CaptureResult:
def readouterr(self) -> CaptureResult[AnyStr]:
if self.out:
out = self.out.snap()
else:
Expand All @@ -567,7 +625,7 @@ def readouterr(self) -> CaptureResult:
return CaptureResult(out, err)


def _get_multicapture(method: "_CaptureMethod") -> MultiCapture:
def _get_multicapture(method: "_CaptureMethod") -> MultiCapture[str]:
if method == "fd":
return MultiCapture(in_=FDCapture(0), out=FDCapture(1), err=FDCapture(2))
elif method == "sys":
Expand Down Expand Up @@ -605,8 +663,8 @@ class CaptureManager:

def __init__(self, method: "_CaptureMethod") -> None:
self._method = method
self._global_capturing = None # type: Optional[MultiCapture]
self._capture_fixture = None # type: Optional[CaptureFixture]
self._global_capturing = None # type: Optional[MultiCapture[str]]
self._capture_fixture = None # type: Optional[CaptureFixture[Any]]

def __repr__(self) -> str:
return "<CaptureManager _method={!r} _global_capturing={!r} _capture_fixture={!r}>".format(
Expand Down Expand Up @@ -655,13 +713,13 @@ def resume(self) -> None:
self.resume_global_capture()
self.resume_fixture()

def read_global_capture(self):
def read_global_capture(self) -> CaptureResult[str]:
assert self._global_capturing is not None
return self._global_capturing.readouterr()

# Fixture Control

def set_fixture(self, capture_fixture: "CaptureFixture") -> None:
def set_fixture(self, capture_fixture: "CaptureFixture[Any]") -> None:
if self._capture_fixture:
current_fixture = self._capture_fixture.request.fixturename
requested_fixture = capture_fixture.request.fixturename
Expand Down Expand Up @@ -760,14 +818,14 @@ def pytest_internalerror(self) -> None:
self.stop_global_capturing()


class CaptureFixture:
class CaptureFixture(Generic[AnyStr]):
"""Object returned by the :py:func:`capsys`, :py:func:`capsysbinary`,
:py:func:`capfd` and :py:func:`capfdbinary` fixtures."""

def __init__(self, captureclass, request: SubRequest) -> None:
self.captureclass = captureclass
self.request = request
self._capture = None # type: Optional[MultiCapture]
self._capture = None # type: Optional[MultiCapture[AnyStr]]
self._captured_out = self.captureclass.EMPTY_BUFFER
self._captured_err = self.captureclass.EMPTY_BUFFER

Expand All @@ -786,7 +844,7 @@ def close(self) -> None:
self._capture.stop_capturing()
self._capture = None

def readouterr(self):
def readouterr(self) -> CaptureResult[AnyStr]:
"""Read and return the captured output so far, resetting the internal
buffer.

Expand Down Expand Up @@ -825,15 +883,15 @@ def disabled(self) -> Generator[None, None, None]:


@pytest.fixture
def capsys(request: SubRequest) -> Generator[CaptureFixture, None, None]:
def capsys(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
"""Enable text capturing of writes to ``sys.stdout`` and ``sys.stderr``.

The captured output is made available via ``capsys.readouterr()`` method
calls, which return a ``(out, err)`` namedtuple.
``out`` and ``err`` will be ``text`` objects.
"""
capman = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture(SysCapture, request)
capture_fixture = CaptureFixture[str](SysCapture, request)
capman.set_fixture(capture_fixture)
capture_fixture._start()
yield capture_fixture
Expand All @@ -842,15 +900,15 @@ def capsys(request: SubRequest) -> Generator[CaptureFixture, None, None]:


@pytest.fixture
def capsysbinary(request: SubRequest) -> Generator[CaptureFixture, None, None]:
def capsysbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None, None]:
"""Enable bytes capturing of writes to ``sys.stdout`` and ``sys.stderr``.

The captured output is made available via ``capsysbinary.readouterr()``
method calls, which return a ``(out, err)`` namedtuple.
``out`` and ``err`` will be ``bytes`` objects.
"""
capman = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture(SysCaptureBinary, request)
capture_fixture = CaptureFixture[bytes](SysCaptureBinary, request)
capman.set_fixture(capture_fixture)
capture_fixture._start()
yield capture_fixture
Expand All @@ -859,15 +917,15 @@ def capsysbinary(request: SubRequest) -> Generator[CaptureFixture, None, None]:


@pytest.fixture
def capfd(request: SubRequest) -> Generator[CaptureFixture, None, None]:
def capfd(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
"""Enable text capturing of writes to file descriptors ``1`` and ``2``.

The captured output is made available via ``capfd.readouterr()`` method
calls, which return a ``(out, err)`` namedtuple.
``out`` and ``err`` will be ``text`` objects.
"""
capman = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture(FDCapture, request)
capture_fixture = CaptureFixture[str](FDCapture, request)
capman.set_fixture(capture_fixture)
capture_fixture._start()
yield capture_fixture
Expand All @@ -876,15 +934,15 @@ def capfd(request: SubRequest) -> Generator[CaptureFixture, None, None]:


@pytest.fixture
def capfdbinary(request: SubRequest) -> Generator[CaptureFixture, None, None]:
def capfdbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None, None]:
"""Enable bytes capturing of writes to file descriptors ``1`` and ``2``.

The captured output is made available via ``capfd.readouterr()`` method
calls, which return a ``(out, err)`` namedtuple.
``out`` and ``err`` will be ``byte`` objects.
"""
capman = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture(FDCaptureBinary, request)
capture_fixture = CaptureFixture[bytes](FDCaptureBinary, request)
capman.set_fixture(capture_fixture)
capture_fixture._start()
yield capture_fixture
Expand Down
43 changes: 40 additions & 3 deletions testing/test_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,37 @@
from _pytest import capture
from _pytest.capture import _get_multicapture
from _pytest.capture import CaptureManager
from _pytest.capture import CaptureResult
from _pytest.capture import MultiCapture
from _pytest.config import ExitCode

# note: py.io capture tests where copied from
# pylib 1.4.20.dev2 (rev 13d9af95547e)


def StdCaptureFD(out: bool = True, err: bool = True, in_: bool = True) -> MultiCapture:
def StdCaptureFD(
out: bool = True, err: bool = True, in_: bool = True
) -> MultiCapture[str]:
return capture.MultiCapture(
in_=capture.FDCapture(0) if in_ else None,
out=capture.FDCapture(1) if out else None,
err=capture.FDCapture(2) if err else None,
)


def StdCapture(out: bool = True, err: bool = True, in_: bool = True) -> MultiCapture:
def StdCapture(
out: bool = True, err: bool = True, in_: bool = True
) -> MultiCapture[str]:
return capture.MultiCapture(
in_=capture.SysCapture(0) if in_ else None,
out=capture.SysCapture(1) if out else None,
err=capture.SysCapture(2) if err else None,
)


def TeeStdCapture(out: bool = True, err: bool = True, in_: bool = True) -> MultiCapture:
def TeeStdCapture(
out: bool = True, err: bool = True, in_: bool = True
) -> MultiCapture[str]:
return capture.MultiCapture(
in_=capture.SysCapture(0, tee=True) if in_ else None,
out=capture.SysCapture(1, tee=True) if out else None,
Expand Down Expand Up @@ -856,6 +863,36 @@ def test_dontreadfrominput():
f.close() # just for completeness


def test_captureresult() -> None:
cr = CaptureResult("out", "err")
assert len(cr) == 2
assert cr.out == "out"
assert cr.err == "err"
out, err = cr
assert out == "out"
assert err == "err"
assert cr[0] == "out"
assert cr[1] == "err"
assert cr == cr
assert cr == CaptureResult("out", "err")
assert cr != CaptureResult("wrong", "err")
assert cr == ("out", "err")
assert cr != ("out", "wrong")
assert hash(cr) == hash(CaptureResult("out", "err"))
assert hash(cr) == hash(("out", "err"))
assert hash(cr) != hash(("out", "wrong"))
assert cr < ("z",)
assert cr < ("z", "b")
assert cr < ("z", "b", "c")
assert cr.count("err") == 1
assert cr.count("wrong") == 0
assert cr.index("err") == 1
with pytest.raises(ValueError):
assert cr.index("wrong") == 0
assert next(iter(cr)) == "out"
assert cr._replace(err="replaced") == ("out", "replaced")


@pytest.fixture
def tmpfile(testdir) -> Generator[BinaryIO, None, None]:
f = testdir.makepyfile("").open("wb+")
Expand Down