Skip to content

Commit

Permalink
Make @validate_call return a function instead of a custom descripto…
Browse files Browse the repository at this point in the history
…r - fixes binding issue with inheritance and adds `self/cls` argument to validation errors (#8268)
  • Loading branch information
alexmojaki committed Dec 11, 2023
1 parent fd0dfff commit 68df4af
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 89 deletions.
87 changes: 5 additions & 82 deletions pydantic/_internal/_validate_call.py
@@ -1,7 +1,6 @@
from __future__ import annotations as _annotations

import inspect
from dataclasses import dataclass
from functools import partial
from typing import Any, Awaitable, Callable

Expand All @@ -13,58 +12,34 @@
from ._config import ConfigWrapper


@dataclass
class CallMarker:
function: Callable[..., Any]
validate_return: bool


class ValidateCallWrapper:
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value.
It's partially inspired by `wraps` which in turn uses `partial`, but extended to be a descriptor so
these functions can be applied to instance methods, class methods, static methods, as well as normal functions.
"""
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""

__slots__ = (
'raw_function',
'_config',
'_validate_return',
'__pydantic_core_schema__',
'__pydantic_validator__',
'__signature__',
'__name__',
'__qualname__',
'__annotations__',
'__dict__', # required for __module__
)

def __init__(self, function: Callable[..., Any], config: ConfigDict | None, validate_return: bool):
self.raw_function = function
self._config = config
self._validate_return = validate_return
self.__signature__ = inspect.signature(function)
if isinstance(function, partial):
func = function.func
schema_type = func
self.__name__ = f'partial({func.__name__})'
self.__qualname__ = f'partial({func.__qualname__})'
self.__annotations__ = func.__annotations__
self.__module__ = func.__module__
self.__doc__ = func.__doc__
else:
schema_type = function
self.__name__ = function.__name__
self.__qualname__ = function.__qualname__
self.__annotations__ = function.__annotations__
self.__module__ = function.__module__
self.__doc__ = function.__doc__

namespace = _typing_extra.add_module_globals(function, None)
config_wrapper = ConfigWrapper(config)
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
self.__pydantic_core_schema__ = schema
core_config = config_wrapper.core_config(self)

self.__pydantic_validator__ = create_schema_validator(
Expand All @@ -77,15 +52,11 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
config_wrapper.plugin_settings,
)

if self._validate_return:
return_type = (
self.__signature__.return_annotation
if self.__signature__.return_annotation is not self.__signature__.empty
else Any
)
if validate_return:
signature = inspect.signature(function)
return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
self.__return_pydantic_core_schema__ = schema
validator = create_schema_validator(
schema,
schema_type,
Expand All @@ -95,7 +66,7 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
core_config,
config_wrapper.plugin_settings,
)
if inspect.iscoroutinefunction(self.raw_function):
if inspect.iscoroutinefunction(function):

async def return_val_wrapper(aw: Awaitable[Any]) -> None:
return validator.validate_python(await aw)
Expand All @@ -104,58 +75,10 @@ async def return_val_wrapper(aw: Awaitable[Any]) -> None:
else:
self.__return_pydantic_validator__ = validator.validate_python
else:
self.__return_pydantic_core_schema__ = None
self.__return_pydantic_validator__ = None

self._name: str | None = None # set by __get__, used to set the instance attribute when decorating methods

def __call__(self, *args: Any, **kwargs: Any) -> Any:
res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
if self.__return_pydantic_validator__:
return self.__return_pydantic_validator__(res)
return res

def __get__(self, obj: Any, objtype: type[Any] | None = None) -> ValidateCallWrapper:
"""Bind the raw function and return another ValidateCallWrapper wrapping that."""
if obj is None:
# It's possible this wrapper is dynamically applied to a class attribute not allowing
# name to be populated by __set_name__. In this case, we'll manually acquire the name
# from the function reference.
if self._name is None:
self._name = self.raw_function.__name__
try:
# Handle the case where a method is accessed as a class attribute
return objtype.__getattribute__(objtype, self._name) # type: ignore
except AttributeError:
# This will happen the first time the attribute is accessed
pass

bound_function = self.raw_function.__get__(obj, objtype)
result = self.__class__(bound_function, self._config, self._validate_return)

# skip binding to instance when obj or objtype has __slots__ attribute
if hasattr(obj, '__slots__') or hasattr(objtype, '__slots__'):
return result

if self._name is not None:
if obj is not None:
object.__setattr__(obj, self._name, result)
else:
object.__setattr__(objtype, self._name, result)
return result

def __set_name__(self, owner: Any, name: str) -> None:
self._name = name

def __repr__(self) -> str:
return f'ValidateCallWrapper({self.raw_function})'

def __eq__(self, other) -> bool:
return (
(self.raw_function == other.raw_function)
and (self._config == other._config)
and (self._validate_return == other._validate_return)
)

def __hash__(self):
return hash(self.raw_function)
11 changes: 10 additions & 1 deletion pydantic/validate_call_decorator.py
@@ -1,6 +1,7 @@
"""Decorator for validating function calls."""
from __future__ import annotations as _annotations

import functools
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload

from ._internal import _validate_call
Expand Down Expand Up @@ -50,7 +51,15 @@ def validate(function: AnyCallableT) -> AnyCallableT:
if isinstance(function, (classmethod, staticmethod)):
name = type(function).__name__
raise TypeError(f'The `@{name}` decorator should be applied after `@validate_call` (put `@{name}` on top)')
return _validate_call.ValidateCallWrapper(function, config, validate_return) # type: ignore
validate_call_wrapper = _validate_call.ValidateCallWrapper(function, config, validate_return)

@functools.wraps(function)
def wrapper_function(*args, **kwargs):
return validate_call_wrapper(*args, **kwargs)

wrapper_function.raw_function = function # type: ignore

return wrapper_function # type: ignore

if __func:
return validate(__func)
Expand Down
10 changes: 4 additions & 6 deletions tests/test_validate_call.py
Expand Up @@ -93,9 +93,7 @@ def foo_bar(a: int, b: int):
assert foo_bar.__name__ == 'foo_bar'
assert foo_bar.__module__ == 'tests.test_validate_call'
assert foo_bar.__qualname__ == 'test_wrap.<locals>.foo_bar'
assert isinstance(foo_bar.__pydantic_core_schema__, dict)
assert callable(foo_bar.raw_function)
assert repr(foo_bar) == f'ValidateCallWrapper({repr(foo_bar.raw_function)})'
assert repr(inspect.signature(foo_bar)) == '<Signature (a: int, b: int)>'


Expand Down Expand Up @@ -369,8 +367,8 @@ def foo(self, a: int, b: int):

# insert_assert(exc_info.value.errors(include_url=False))
assert exc_info.value.errors(include_url=False) == [
{'type': 'missing_argument', 'loc': ('a',), 'msg': 'Missing required argument', 'input': ArgsKwargs(())},
{'type': 'missing_argument', 'loc': ('b',), 'msg': 'Missing required argument', 'input': ArgsKwargs(())},
{'type': 'missing_argument', 'loc': ('a',), 'msg': 'Missing required argument', 'input': ArgsKwargs((x,))},
{'type': 'missing_argument', 'loc': ('b',), 'msg': 'Missing required argument', 'input': ArgsKwargs((x,))},
]


Expand All @@ -392,8 +390,8 @@ def foo(cls, a: int, b: int):

# insert_assert(exc_info.value.errors(include_url=False))
assert exc_info.value.errors(include_url=False) == [
{'type': 'missing_argument', 'loc': ('a',), 'msg': 'Missing required argument', 'input': ArgsKwargs(())},
{'type': 'missing_argument', 'loc': ('b',), 'msg': 'Missing required argument', 'input': ArgsKwargs(())},
{'type': 'missing_argument', 'loc': ('a',), 'msg': 'Missing required argument', 'input': ArgsKwargs((X,))},
{'type': 'missing_argument', 'loc': ('b',), 'msg': 'Missing required argument', 'input': ArgsKwargs((X,))},
]


Expand Down

0 comments on commit 68df4af

Please sign in to comment.