/
_decorators.py
117 lines (99 loc) · 3.73 KB
/
_decorators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright 2020 Alethea Katherine Flowers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import inspect
import types
from typing import Any, Callable, Dict, Iterable, List, Optional
from . import _typing
if _typing.TYPE_CHECKING:
from ._parametrize import Param
class FunctionDecorator:
def __new__(
cls, func: Callable[..., Any], *args: Any, **kwargs: Any
) -> "FunctionDecorator":
obj = super().__new__(cls)
return functools.wraps(func)(obj)
def _copy_func(src: Callable, name: Optional[str] = None) -> Callable:
dst = types.FunctionType(
src.__code__,
src.__globals__, # type: ignore
name=name or src.__name__,
argdefs=src.__defaults__, # type: ignore
closure=src.__closure__, # type: ignore
)
dst.__dict__.update(copy.deepcopy(src.__dict__))
dst = functools.update_wrapper(dst, src)
dst.__kwdefaults__ = src.__kwdefaults__ # type: ignore
return dst
class Func(FunctionDecorator):
def __init__(
self,
func: Callable,
python: _typing.Python = None,
reuse_venv: Optional[bool] = None,
name: Optional[str] = None,
venv_backend: Any = None,
venv_params: Any = None,
should_warn: Optional[Dict[str, Any]] = None,
):
self.func = func
self.python = python
self.reuse_venv = reuse_venv
self.venv_backend = venv_backend
self.venv_params = venv_params
self.should_warn = should_warn or dict()
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*args, **kwargs)
def copy(self, name: Optional[str] = None) -> "Func":
return Func(
_copy_func(self.func, name),
self.python,
self.reuse_venv,
name,
self.venv_backend,
self.venv_params,
self.should_warn,
)
class Call(Func):
def __init__(self, func: Func, param_spec: "Param") -> None:
call_spec = param_spec.call_spec
session_signature = f"({param_spec})"
# Determine the Python interpreter for the session using either @session
# or @parametrize. For backwards compatibility, we only use a "python"
# parameter in @parametrize if the session function does not expect it
# as a normal argument, and if the @session decorator does not already
# specify `python`.
python = func.python
if python is None and "python" in call_spec:
signature = inspect.signature(func.func)
if "python" not in signature.parameters:
python = call_spec.pop("python")
super().__init__(
func,
python,
func.reuse_venv,
None,
func.venv_backend,
func.venv_params,
func.should_warn,
)
self.call_spec = call_spec
self.session_signature = session_signature
def __call__(self, *args: Any, **kwargs: Any) -> Any:
kwargs.update(self.call_spec)
return super().__call__(*args, **kwargs)
@classmethod
def generate_calls(cls, func: Func, param_specs: "Iterable[Param]") -> "List[Call]":
return [cls(func, param_spec) for param_spec in param_specs]