Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement content, data and json patterns #106

Merged
merged 3 commits into from
Nov 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
123 changes: 112 additions & 11 deletions respx/patterns.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json as jsonlib
import operator
import re
from enum import Enum
Expand All @@ -7,11 +8,13 @@
Any,
Callable,
Dict,
List,
Optional,
Pattern as RegexPattern,
Sequence,
Set,
Tuple,
Type,
Union,
)
from urllib.parse import urljoin
Expand Down Expand Up @@ -53,15 +56,16 @@ def __repr__(self): # pragma: nocover

class Pattern:
lookups: Tuple[Lookup, ...] = (Lookup.EQUAL,)
lookup: Lookup
key: str

lookup: Lookup
base: Optional["Pattern"]
value: Any

def __init__(self, value: Any, lookup: Optional[Lookup] = None) -> None:
if lookup and lookup not in self.lookups:
raise NotImplementedError(
f"{lookup.value!r} is not a valid Lookup for {self.__class__.__name__!r}"
f"{self.key!r} pattern does not support {lookup.value!r} lookup"
)
self.lookup = lookup or self.lookups[0]
self.base = None
Expand Down Expand Up @@ -138,6 +142,16 @@ def _in(self, value: Any) -> Match:
return Match(value in self.value)


class PathPattern(Pattern):
path: Optional[str]

def __init__(
self, value: Any, lookup: Optional[Lookup] = None, *, path: Optional[str] = None
) -> None:
self.path = path
super().__init__(value, lookup)


class _And(Pattern):
value: Tuple[Pattern, Pattern]

Expand Down Expand Up @@ -417,9 +431,73 @@ def parse(self, request: RequestTypes) -> str:
return url


class ContentMixin:
def parse(self, request: RequestTypes) -> Any:
if not isinstance(request, httpx.Request):
method, url, headers, stream = request
request = httpx.Request(
method, httpx.URL(url), headers=headers, stream=stream
)
content = request.read()
return content


class Content(ContentMixin, Pattern):
lookups = (Lookup.EQUAL,)
key = "content"
value: bytes

def clean(self, value: Union[bytes, str]) -> bytes:
if isinstance(value, str):
return value.encode()
return value


class JSON(ContentMixin, PathPattern):
lookups = (Lookup.EQUAL,)
key = "json"
value: str

def clean(self, value: Union[str, List, Dict]) -> str:
return self.hash(value)

def parse(self, request: RequestTypes) -> str:
content = super().parse(request)
json = jsonlib.loads(content.decode("utf-8"))

if self.path:
value = json
for bit in self.path.split("__"):
key = int(bit) if bit.isdigit() else bit
try:
value = value[key]
except KeyError as e:
raise KeyError(f"{self.path!r} not in {json!r}") from e
except IndexError as e:
raise IndexError(f"{self.path!r} not in {json!r}") from e
else:
value = json

return self.hash(value)

def hash(self, value: Union[str, List, Dict]) -> str:
return jsonlib.dumps(value, sort_keys=True)


class Data(ContentMixin, Pattern):
lookups = (Lookup.EQUAL,)
key = "data"
value: bytes

def clean(self, value: Dict) -> bytes:
request = httpx.Request("POST", "/", data=value)
data = request.read()
return data


# TODO: Refactor to registration when subclassing Pattern
PATTERNS = {
P.key: P
PATTERNS: Dict[str, Type[Union[Pattern, PathPattern]]] = {
P.key: P # type: ignore
for P in (
Method,
Headers,
Expand All @@ -430,6 +508,9 @@ def parse(self, request: RequestTypes) -> str:
Path,
Params,
URL,
Content,
Data,
JSON,
)
}

Expand All @@ -441,22 +522,42 @@ def M(*patterns: Pattern, **lookups: Any) -> Pattern:
if not value:
continue

# Handle url pattern
if pattern__lookup == "url":
extras = parse_url_patterns(value)
continue

pattern_key, __, lookup_value = pattern__lookup.partition("__")
# Parse pattern key and lookup
pattern_key, __, rest = pattern__lookup.partition("__")
path, __, lookup_name = rest.rpartition("__")
if pattern_key not in PATTERNS:
raise KeyError(f"{pattern_key!r} is not a valid Pattern")

lookup = None if not lookup_value else Lookup(lookup_value)
pattern = PATTERNS[pattern_key](value, lookup=lookup)
# Get pattern class
P = PATTERNS[pattern_key]
pattern: Union[Pattern, PathPattern]

if issubclass(P, PathPattern):
# Make path supported pattern, i.e. JSON
try:
lookup = Lookup(lookup_name) if lookup_name else None
except ValueError:
lookup = None
path = rest
pattern = P(value, lookup=lookup, path=path)
else:
# Make regular pattern
lookup = Lookup(lookup_name) if lookup_name else None
pattern = P(value, lookup=lookup)

patterns += (pattern,)

pattern = combine(patterns)
# Combine and merge patterns
combined_pattern = combine(patterns)
if extras:
pattern = merge_patterns(pattern, **extras)
return pattern
combined_pattern = merge_patterns(combined_pattern, **extras)

return combined_pattern


def get_scheme_port(scheme: Optional[str]) -> Optional[int]:
Expand All @@ -480,7 +581,7 @@ def parse_url_patterns(
return bases

if isinstance(url, RegexPattern):
return {"url": URL(url, Lookup.REGEX)}
return {"url": URL(url, lookup=Lookup.REGEX)}

url = httpx.URL(url)
scheme_port = get_scheme_port(url.scheme)
Expand Down
12 changes: 12 additions & 0 deletions respx/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def request(
) -> SyncResponse:
raw_request = (method, url, headers, stream)
request = decode_request(raw_request)

# Pre-read request
request.read()
stream = request.stream # type: ignore

# Resolve response
response = self.resolve(request)

if response is None:
Expand All @@ -43,6 +49,12 @@ async def arequest(
) -> AsyncResponse:
raw_request = (method, url, headers, stream)
request = decode_request(raw_request)

# Pre-read request
await request.aread()
stream = request.stream # type: ignore

# Resolve response
response = self.resolve(request)

if response is None:
Expand Down
31 changes: 20 additions & 11 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ async def test_callable_content(client):
url_pattern = re.compile(r"https://foo.bar/(?P<slug>\w+)/")

def content_callback(request, slug):
request.read()
content = jsonlib.loads(request.content)
return respx.MockResponse(text=f"hello {slug}{content['x']}")

Expand All @@ -303,7 +302,6 @@ def content_callback(request, slug):
@pytest.mark.asyncio
async def test_request_callback(client):
def callback(request, name):
request.read()
if request.url.host == "foo.bar" and request.content == b'{"foo": "bar"}':
return respx.MockResponse(
202,
Expand Down Expand Up @@ -393,15 +391,7 @@ async def test_external_pass_through(client): # pragma: nocover
with respx.mock:
# Mock pass-through call
url = "https://httpbin.org/post"
route = respx.post(url).respond(content=b"").pass_through()

# Mock a non-matching callback pattern pre-reading request data
def callback(req):
req.read()
assert req.content == b'{"foo": "bar"}'
return None

respx.add(callback)
route = respx.post(url, json__foo="bar").pass_through()

# Make external pass-through call
assert route.call_count == 0
Expand Down Expand Up @@ -560,3 +550,22 @@ def test_respond():

with pytest.raises(ValueError, match="content can only be"):
route.respond(content=Exception())


@pytest.mark.asyncio
@pytest.mark.parametrize(
"kwargs",
[
{"content": b"foobar"},
{"content": "foobar"},
{"json": {"foo": "bar"}},
{"json": [{"foo": "bar", "ham": "spam"}, {"zoo": "apa", "egg": "yolk"}]},
{"data": {"animal": "Räv", "name": "Röda Räven"}},
],
)
async def test_async_post_content(kwargs):
async with respx.mock:
respx.post("https://foo.bar/", **kwargs) % 201
async with httpx.AsyncClient() as client:
response = await client.post("https://foo.bar/", **kwargs)
assert response.status_code == 201