Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy request and response decode in CallList #92

Merged
merged 7 commits into from
Oct 15, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
59 changes: 48 additions & 11 deletions respx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from httpcore import AsyncByteStream, SyncByteStream

if TYPE_CHECKING:
from unittest.mock import _CallList # pragma: nocover
from unittest.mock import _Call, _CallList # pragma: nocover


URL = Tuple[bytes, bytes, Optional[int], bytes]
Expand Down Expand Up @@ -103,17 +103,53 @@ class Call(NamedTuple):
response: Optional[httpx.Response]


class CallList(list):
class CallList:
def __init__(self, call_list: "_CallList"):
self._call_list = call_list

def __iter__(self) -> Generator[Call, None, None]:
yield from super().__iter__()
for raw_call in self._call_list:
yield self._extract_call(raw_call)

@classmethod
def from_unittest_call_list(cls, call_list: "_CallList") -> "CallList":
return cls(Call(request, response) for (request, response), _ in call_list)
def __getitem__(self, item: int) -> Call:
raw_call = self._call_list[item]
return self._extract_call(raw_call)

def __len__(self) -> int:
return len(self._call_list)

@property
def last(self) -> Optional[Call]:
return self[-1] if self else None
return self[-1] if self._call_list else None

def set_new_call_list(self, call_list: "_CallList") -> None:
self._call_list = call_list

@classmethod
def _extract_call(cls, raw_call: "_Call") -> Call:
if isinstance(raw_call.decoded_call, Call):
return raw_call.decoded_call

request, response = raw_call[0]

decoded_call = cls._decode_call(request, response)
raw_call.decoded_call = decoded_call # type: ignore
return decoded_call

@staticmethod
def _decode_call(raw_request: Request, raw_response: Response) -> Call:
# Decode raw request/response as HTTPX models
request = decode_request(raw_request)
response = decode_response(raw_response, request=request)

# Pre-read request/response, but only if mocked, not for pass-through streams
if response and not isinstance(
response.stream, (SyncByteStream, AsyncByteStream)
):
request.read()
response.read()

return Call(request=request, response=response)


class ResponseTemplate:
Expand Down Expand Up @@ -307,6 +343,7 @@ def __init__(
self.response = response or ResponseTemplate()
self.alias = alias
self.stats = mock.MagicMock()
self.calls = CallList(self.stats.call_args_list)

@property
def called(self) -> bool:
Expand All @@ -316,10 +353,6 @@ def called(self) -> bool:
def call_count(self) -> int:
return self.stats.call_count

@property
def calls(self) -> CallList:
return CallList.from_unittest_call_list(self.stats.call_args_list)

def get_url(self) -> Optional[URLPatternTypes]:
return self._url

Expand Down Expand Up @@ -398,3 +431,7 @@ def match(self, request: Request) -> Optional[Union[Request, ResponseTemplate]]:
return self.response.clone(context={"request": _request, **url_params})

return None

def reset(self) -> None:
self.stats.reset_mock()
self.calls.set_new_call_list(self.stats.call_args_list)
36 changes: 12 additions & 24 deletions respx/transports.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union, overload
from typing import Callable, Dict, List, Optional, Pattern, Tuple, Union, overload
from unittest import mock

from httpcore import (
Expand All @@ -18,10 +18,9 @@
HeaderTypes,
Request,
RequestPattern,
Response,
ResponseTemplate,
SyncResponse,
decode_request,
decode_response,
)


Expand All @@ -40,7 +39,7 @@ def __init__(
self.aliases: Dict[str, RequestPattern] = {}

self.stats = mock.MagicMock()
self.calls: CallList = CallList()
self.calls: CallList = CallList(self.stats.call_args_list)

def clear(self):
"""
Expand All @@ -53,10 +52,11 @@ def reset(self) -> None:
"""
Resets call stats.
"""
self.calls.clear()
self.stats.reset_mock()
self.calls.set_new_call_list(self.stats.call_args_list)

for pattern in self.patterns:
pattern.stats.reset_mock()
pattern.reset()

def __getitem__(self, alias: str) -> Optional[RequestPattern]:
return self.aliases.get(alias)
Expand Down Expand Up @@ -271,29 +271,17 @@ def options(

def record(
self,
request: Any,
response: Optional[Any],
request: Request,
response: Optional[Response],
pattern: Optional[RequestPattern] = None,
) -> None:
# Decode raw request/response as HTTPX models
request = decode_request(request)
response = decode_response(response, request=request)

# TODO: Skip recording stats for pass_through requests?
# Pre-read request/response, but only if mocked, not for pass-through streams
if response and not isinstance(
response.stream, (SyncByteStream, AsyncByteStream)
):
request.read()
response.read()

self.stats(request, response)
if pattern:
pattern.stats(request, response)

self.stats(request, response)

# Copy stats due to unwanted use of property refs in the high-level api
self.calls[:] = CallList.from_unittest_call_list(self.stats.call_args_list)
# Decoded request will be lazy written into the call object,
# so it should be the same object for both lists
pattern.stats.call_args_list[-1] = self.stats.call_args_list[-1]

def assert_all_called(self) -> None:
assert all(
Expand Down