Skip to content

Commit

Permalink
Merge pull request #108 from lundberg/auto-register-patterns
Browse files Browse the repository at this point in the history
Auto register patterns
  • Loading branch information
lundberg committed Nov 13, 2020
2 parents abb369d + 1587874 commit 117537f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 34 deletions.
67 changes: 33 additions & 34 deletions respx/patterns.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import json as jsonlib
import operator
import re
from abc import ABC
from enum import Enum
from functools import reduce
from http.cookies import SimpleCookie
from types import MappingProxyType
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Expand Down Expand Up @@ -54,14 +57,30 @@ def __repr__(self): # pragma: nocover
return f"<Match {self.matches}>"


class Pattern:
lookups: Tuple[Lookup, ...] = (Lookup.EQUAL,)
key: str
class Pattern(ABC):
key: ClassVar[str]
lookups: ClassVar[Tuple[Lookup, ...]] = (Lookup.EQUAL,)

lookup: Lookup
base: Optional["Pattern"]
value: Any

# Automatically register all the subclasses in this dict
__registry: ClassVar[Dict[str, Type["Pattern"]]] = {}
registry = MappingProxyType(__registry)

def __init_subclass__(cls) -> None:
if not getattr(cls, "key", None) or ABC in cls.__bases__:
return

if cls.key in cls.__registry:
raise TypeError(
"Subclasses of Pattern must define a unique key. "
f"{cls.key!r} is already defined in {cls.__registry[cls.key]!r}"
)

cls.__registry[cls.key] = cls

def __init__(self, value: Any, lookup: Optional[Lookup] = None) -> None:
if lookup and lookup not in self.lookups:
raise NotImplementedError(
Expand Down Expand Up @@ -207,8 +226,8 @@ def match(self, request: RequestTypes) -> Match:


class Method(Pattern):
lookups = (Lookup.EQUAL, Lookup.IN)
key = "method"
lookups = (Lookup.EQUAL, Lookup.IN)
value: Union[str, Sequence[str]]

def clean(self, value: Union[str, Sequence[str]]) -> Union[str, Sequence[str]]:
Expand Down Expand Up @@ -247,8 +266,8 @@ def _contains(self, value: Any) -> Match:


class Headers(MultiItemsMixin, Pattern):
lookups = (Lookup.CONTAINS, Lookup.EQUAL)
key = "headers"
lookups = (Lookup.CONTAINS, Lookup.EQUAL)
value: httpx.Headers

def clean(self, value: HeaderTypes) -> httpx.Headers:
Expand All @@ -265,8 +284,8 @@ def parse(self, request: RequestTypes) -> httpx.Headers:


class Cookies(Pattern):
lookups = (Lookup.CONTAINS, Lookup.EQUAL)
key = "cookies"
lookups = (Lookup.CONTAINS, Lookup.EQUAL)
value: Set[Tuple[str, str]]

def __hash__(self):
Expand Down Expand Up @@ -299,8 +318,8 @@ def _contains(self, value: Set[Tuple[str, str]]) -> Match:


class Scheme(Pattern):
lookups = (Lookup.EQUAL, Lookup.IN)
key = "scheme"
lookups = (Lookup.EQUAL, Lookup.IN)
value: Union[str, Sequence[str]]

def clean(self, value: Union[str, Sequence[str]]) -> Union[str, Sequence[str]]:
Expand All @@ -318,8 +337,8 @@ def parse(self, request: RequestTypes) -> str:


class Host(Pattern):
lookups = (Lookup.EQUAL, Lookup.IN)
key = "host"
lookups = (Lookup.EQUAL, Lookup.IN)
value: Union[str, Sequence[str]]

def parse(self, request: RequestTypes) -> str:
Expand All @@ -332,8 +351,8 @@ def parse(self, request: RequestTypes) -> str:


class Port(Pattern):
lookups = (Lookup.EQUAL, Lookup.IN)
key = "port"
lookups = (Lookup.EQUAL, Lookup.IN)
value: Optional[int]

def parse(self, request: RequestTypes) -> Optional[int]:
Expand All @@ -350,8 +369,8 @@ def parse(self, request: RequestTypes) -> Optional[int]:


class Path(Pattern):
lookups = (Lookup.EQUAL, Lookup.REGEX, Lookup.STARTS_WITH, Lookup.IN)
key = "path"
lookups = (Lookup.EQUAL, Lookup.REGEX, Lookup.STARTS_WITH, Lookup.IN)
value: Union[str, Sequence[str], RegexPattern[str]]

def clean(
Expand Down Expand Up @@ -379,8 +398,8 @@ def strip_base(self, value: str) -> str:


class Params(MultiItemsMixin, Pattern):
lookups = (Lookup.CONTAINS, Lookup.EQUAL)
key = "params"
lookups = (Lookup.CONTAINS, Lookup.EQUAL)
value: httpx.QueryParams

def clean(self, value: QueryParamTypes) -> httpx.QueryParams:
Expand All @@ -397,12 +416,12 @@ def parse(self, request: RequestTypes) -> httpx.QueryParams:


class URL(Pattern):
key = "url"
lookups = (
Lookup.EQUAL,
Lookup.REGEX,
Lookup.STARTS_WITH,
)
key = "url"
value: Union[str, RegexPattern[str]]

def clean(self, value: URLPatternTypes) -> Union[str, RegexPattern[str]]:
Expand Down Expand Up @@ -495,26 +514,6 @@ def clean(self, value: Dict) -> bytes:
return data


# TODO: Refactor to registration when subclassing Pattern
PATTERNS: Dict[str, Type[Union[Pattern, PathPattern]]] = {
P.key: P # type: ignore
for P in (
Method,
Headers,
Cookies,
Scheme,
Host,
Port,
Path,
Params,
URL,
Content,
Data,
JSON,
)
}


def M(*patterns: Pattern, **lookups: Any) -> Pattern:
extras = None

Expand All @@ -530,11 +529,11 @@ def M(*patterns: Pattern, **lookups: Any) -> Pattern:
# Parse pattern key and lookup
pattern_key, __, rest = pattern__lookup.partition("__")
path, __, lookup_name = rest.rpartition("__")
if pattern_key not in PATTERNS:
if pattern_key not in Pattern.registry:
raise KeyError(f"{pattern_key!r} is not a valid Pattern")

# Get pattern class
P = PATTERNS[pattern_key]
P = Pattern.registry[pattern_key]
pattern: Union[Pattern, PathPattern]

if issubclass(P, PathPattern):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Method,
Params,
Path,
Pattern,
Port,
Scheme,
merge_patterns,
Expand Down Expand Up @@ -395,3 +396,10 @@ def test_merge_patterns():
base = Path("/ham/", Lookup.STARTS_WITH)
merged_pattern = merge_patterns(pattern, path=base)
assert any([p.base == base for p in iter(merged_pattern)])


def test_unique_pattern_key():
with pytest.raises(TypeError, match="unique key"):

class Foobar(Pattern):
key = "url"

0 comments on commit 117537f

Please sign in to comment.