Skip to content

Commit

Permalink
Support HTTPX params argument
Browse files Browse the repository at this point in the history
  • Loading branch information
jocke-l committed Oct 13, 2020
1 parent 4cbb4f7 commit d63f338
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 69 deletions.
25 changes: 24 additions & 1 deletion respx/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import Callable, Optional, Pattern, Union, overload

from .mocks import MockTransport
from .models import CallList, ContentDataTypes, DefaultType, HeaderTypes, RequestPattern
from .models import (
CallList,
ContentDataTypes,
DefaultType,
HeaderTypes,
QueryParamTypes,
RequestPattern,
)

mock = MockTransport(assert_all_called=False)

Expand Down Expand Up @@ -49,6 +56,7 @@ def add(
method: Union[str, Callable],
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -60,6 +68,7 @@ def add(
return mock.add(
method,
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand All @@ -72,6 +81,7 @@ def add(
def get(
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -82,6 +92,7 @@ def get(
global mock
return mock.get(
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand All @@ -94,6 +105,7 @@ def get(
def post(
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -104,6 +116,7 @@ def post(
global mock
return mock.post(
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand All @@ -116,6 +129,7 @@ def post(
def put(
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -126,6 +140,7 @@ def put(
global mock
return mock.put(
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand All @@ -138,6 +153,7 @@ def put(
def patch(
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -148,6 +164,7 @@ def patch(
global mock
return mock.patch(
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand All @@ -160,6 +177,7 @@ def patch(
def delete(
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -170,6 +188,7 @@ def delete(
global mock
return mock.delete(
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand All @@ -182,6 +201,7 @@ def delete(
def head(
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -192,6 +212,7 @@ def head(
global mock
return mock.head(
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand All @@ -204,6 +225,7 @@ def head(
def options(
url: Optional[Union[str, Pattern]] = None,
*,
params: Optional[QueryParamTypes] = None,
status_code: Optional[int] = None,
content: Optional[ContentDataTypes] = None,
content_type: Optional[str] = None,
Expand All @@ -214,6 +236,7 @@ def options(
global mock
return mock.options(
url=url,
params=params,
status_code=status_code,
content=content,
content_type=content_type,
Expand Down
95 changes: 45 additions & 50 deletions respx/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import inspect
import re
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -19,7 +18,7 @@
Union,
)
from unittest import mock
from urllib.parse import urljoin, urlparse
from urllib.parse import urljoin

import httpx
from httpcore import AsyncByteStream, SyncByteStream
Expand Down Expand Up @@ -65,14 +64,47 @@

DefaultType = TypeVar("DefaultType", bound=Any)

Regex = type(re.compile(""))
Kwargs = Dict[str, Any]
URLPatternTypes = Union[str, Pattern[str], URL]
URLPatternTypes = Union[str, Pattern[str], URL, httpx.URL]
JSONTypes = Union[str, List, Dict]
ContentDataTypes = Union[bytes, str, JSONTypes, Callable, Exception]

istype = lambda t, o: isinstance(o, t)
isregex = partial(istype, Regex)
QueryParamTypes = Union[bytes, str, List[Tuple[str, Any]], Dict[str, Any]]


def build_url(
url: Optional[URLPatternTypes] = None,
base: Optional[str] = None,
params: Optional[QueryParamTypes] = None,
) -> Optional[Union[httpx.URL, Pattern[str]]]:
url = url or ""
if not base:
base_url = httpx.URL("")
elif base.endswith("/"):
base_url = httpx.URL(base)
else:
base_url = httpx.URL(base + "/")

if not url and not base:
return None
elif isinstance(url, (str, tuple, httpx.URL)):
return base_url.join(httpx.URL(url, params=params))
elif isinstance(url, Pattern):
if params is not None:
if r"\?" in url.pattern and params is not None:
raise ValueError(
"Request url pattern contains a query string, which is not "
"supported in conjuction with params argument."
)
query_params = str(httpx.QueryParams(params))
url = re.compile(url.pattern + re.escape(fr"?{query_params}"))

return re.compile(urljoin(str(base_url), url.pattern))
else:
raise ValueError(
"Request url pattern must be str or compiled regex, got {}.".format(
type(url).__name__
)
)


def decode_request(request: Request) -> httpx.Request:
Expand Down Expand Up @@ -287,6 +319,7 @@ def __init__(
self,
method: Union[str, Callable],
url: Optional[URLPatternTypes],
params: Optional[QueryParamTypes] = None,
response: Optional[ResponseTemplate] = None,
pass_through: bool = False,
alias: Optional[str] = None,
Expand All @@ -301,7 +334,7 @@ def __init__(
self._match_func = method
else:
self.method = method.upper()
self.set_url(url, base=base_url)
self.url = build_url(url, base=base_url, params=params)
self.pass_through = pass_through

self.response = response or ResponseTemplate()
Expand All @@ -320,43 +353,6 @@ def call_count(self) -> int:
def calls(self) -> CallList:
return CallList.from_unittest_call_list(self.stats.call_args_list)

def get_url(self) -> Optional[URLPatternTypes]:
return self._url

def set_url(
self, url: Optional[URLPatternTypes], base: Optional[str] = None
) -> None:
url = url or None
if url is None:
url = base
elif isinstance(url, str):
url = url if base is None else urljoin(base, url)
parsed_url = urlparse(url)
if not parsed_url.path:
url = parsed_url._replace(path="/").geturl()
elif isinstance(url, tuple):
url = self.build_url(url)
elif isregex(url):
url = url if base is None else re.compile(urljoin(base, url.pattern))
else:
raise ValueError(
"Request url pattern must be str or compiled regex, got {}.".format(
type(url).__name__
)
)
self._url = url

url = property(get_url, set_url)

def build_url(self, parts: URL) -> str:
scheme, host, port, full_path = parts
port_str = (
""
if not port or port == {b"https": 443, b"http": 80}[scheme]
else f":{port}"
)
return f"{scheme.decode()}://{host.decode()}{port_str}{full_path.decode()}"

def match(self, request: Request) -> Optional[Union[Request, ResponseTemplate]]:
"""
Matches request with configured pattern;
Expand All @@ -383,13 +379,12 @@ def match(self, request: Request) -> Optional[Union[Request, ResponseTemplate]]:
if self.method != request_method.decode():
return None

request_url = self.build_url(_request_url)
if not self._url:
if not self.url:
matches = True
elif isinstance(self._url, str):
matches = self._url == request_url
elif isinstance(self.url, httpx.URL):
matches = self.url.raw == _request_url
else:
match = self._url.match(request_url)
match = self.url.match(str(httpx.URL(_request_url)))
if match:
matches = True
url_params = match.groupdict()
Expand Down

0 comments on commit d63f338

Please sign in to comment.