Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed #63: Support HTTPX params argument #81

Merged
merged 2 commits into from
Oct 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 32 additions & 1 deletion respx/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import Callable, Optional, Pattern, Union, overload

from .mocks import MockTransport
from .models import CallList, ContentDataTypes, DefaultType, HeaderTypes, RequestPattern
from .models import (
CallList,
ContentDataTypes,
DefaultType,
HeaderTypes,
QueryParamTypes,
RequestPattern,
)

mock = MockTransport(assert_all_called=False)

Expand Down Expand Up @@ -48,6 +55,8 @@ def pop(alias, default=...):
def add(
method: Union[str, Callable],
url: Optional[Union[str, Pattern]] = None,
*,
lundberg marked this conversation as resolved.
Show resolved Hide resolved
params: Optional[QueryParamTypes] = None,
jocke-l marked this conversation as resolved.
Show resolved Hide resolved
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -59,6 +68,7 @@ def add(
return mock.add(
method,
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand All @@ -70,6 +80,8 @@ def add(

def get(
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -80,6 +92,7 @@ def get(
global mock
return mock.get(
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand All @@ -91,6 +104,8 @@ def get(

def post(
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -101,6 +116,7 @@ def post(
global mock
return mock.post(
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand All @@ -112,6 +128,8 @@ def post(

def put(
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -122,6 +140,7 @@ def put(
global mock
return mock.put(
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand All @@ -133,6 +152,8 @@ def put(

def patch(
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -143,6 +164,7 @@ def patch(
global mock
return mock.patch(
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand All @@ -154,6 +176,8 @@ def patch(

def delete(
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -164,6 +188,7 @@ def delete(
global mock
return mock.delete(
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand All @@ -175,6 +200,8 @@ def delete(

def head(
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -185,6 +212,7 @@ def head(
global mock
return mock.head(
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand All @@ -196,6 +224,8 @@ def head(

def options(
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -206,6 +236,7 @@ def options(
global mock
return mock.options(
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand Down
95 changes: 45 additions & 50 deletions respx/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import inspect
import re
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -19,7 +18,7 @@
Union,
)
from unittest import mock
from urllib.parse import urljoin, urlparse
from urllib.parse import urljoin

import httpx
from httpcore import AsyncByteStream, SyncByteStream
Expand Down Expand Up @@ -65,14 +64,47 @@

DefaultType = TypeVar("DefaultType", bound=Any)

Regex = type(re.compile(""))
Kwargs = Dict[str, Any]
URLPatternTypes = Union[str, Pattern[str], URL]
URLPatternTypes = Union[str, Pattern[str], URL, httpx.URL]
JSONTypes = Union[str, List, Dict]
ContentDataTypes = Union[bytes, str, JSONTypes, Callable, Exception]

istype = lambda t, o: isinstance(o, t)
isregex = partial(istype, Regex)
QueryParamTypes = Union[bytes, str, List[Tuple[str, Any]], Dict[str, Any]]


def build_url(
url: Optional[URLPatternTypes] = None,
base: Optional[str] = None,
params: Optional[QueryParamTypes] = None,
) -> Optional[Union[httpx.URL, Pattern[str]]]:
url = url or ""
if not base:
base_url = httpx.URL("")
elif base.endswith("/"):
base_url = httpx.URL(base)
else:
base_url = httpx.URL(base + "/")

if not url and not base:
return None
elif isinstance(url, (str, tuple, httpx.URL)):
return base_url.join(httpx.URL(url, params=params))
elif isinstance(url, Pattern):
if params is not None:
if r"\?" in url.pattern and params is not None:
raise ValueError(
"Request url pattern contains a query string, which is not "
"supported in conjuction with params argument."
)
query_params = str(httpx.QueryParams(params))
url = re.compile(url.pattern + re.escape(fr"?{query_params}"))

return re.compile(urljoin(str(base_url), url.pattern))
else:
raise ValueError(
"Request url pattern must be str or compiled regex, got {}.".format(
type(url).__name__
)
)


def decode_request(request: Request) -> httpx.Request:
Expand Down Expand Up @@ -287,6 +319,7 @@ def __init__(
self,
method: Union[str, Callable],
url: Optional[URLPatternTypes],
params: Optional[QueryParamTypes] = None,
response: Optional[ResponseTemplate] = None,
pass_through: bool = False,
alias: Optional[str] = None,
Expand All @@ -301,7 +334,7 @@ def __init__(
self._match_func = method
else:
self.method = method.upper()
self.set_url(url, base=base_url)
self.url = build_url(url, base=base_url, params=params)
self.pass_through = pass_through

self.response = response or ResponseTemplate()
Expand All @@ -320,43 +353,6 @@ def call_count(self) -> int:
def calls(self) -> CallList:
return CallList.from_unittest_call_list(self.stats.call_args_list)

def get_url(self) -> Optional[URLPatternTypes]:
return self._url

def set_url(
self, url: Optional[URLPatternTypes], base: Optional[str] = None
) -> None:
url = url or None
if url is None:
url = base
elif isinstance(url, str):
url = url if base is None else urljoin(base, url)
parsed_url = urlparse(url)
if not parsed_url.path:
url = parsed_url._replace(path="/").geturl()
elif isinstance(url, tuple):
url = self.build_url(url)
elif isregex(url):
url = url if base is None else re.compile(urljoin(base, url.pattern))
else:
raise ValueError(
"Request url pattern must be str or compiled regex, got {}.".format(
type(url).__name__
)
)
self._url = url

url = property(get_url, set_url)

def build_url(self, parts: URL) -> str:
scheme, host, port, full_path = parts
port_str = (
""
if not port or port == {b"https": 443, b"http": 80}[scheme]
else f":{port}"
)
return f"{scheme.decode()}://{host.decode()}{port_str}{full_path.decode()}"

def match(self, request: Request) -> Optional[Union[Request, ResponseTemplate]]:
"""
Matches request with configured pattern;
Expand All @@ -383,13 +379,12 @@ def match(self, request: Request) -> Optional[Union[Request, ResponseTemplate]]:
if self.method != request_method.decode():
return None

request_url = self.build_url(_request_url)
if not self._url:
if not self.url:
matches = True
elif isinstance(self._url, str):
matches = self._url == request_url
elif isinstance(self.url, httpx.URL):
matches = self.url.raw == _request_url
else:
match = self._url.match(request_url)
match = self.url.match(str(httpx.URL(_request_url)))
if match:
matches = True
url_params = match.groupdict()
Expand Down