Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make @validate_call return a function instead of a custom descriptor - fixes binding issue with inheritance and adds self/cls argument to validation errors #8268

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
87 changes: 5 additions & 82 deletions pydantic/_internal/_validate_call.py
@@ -1,7 +1,6 @@
from __future__ import annotations as _annotations

import inspect
from dataclasses import dataclass
from functools import partial
from typing import Any, Awaitable, Callable

Expand All @@ -13,58 +12,34 @@
from ._config import ConfigWrapper


@dataclass
class CallMarker:
function: Callable[..., Any]
validate_return: bool


class ValidateCallWrapper:
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value.

It's partially inspired by `wraps` which in turn uses `partial`, but extended to be a descriptor so
these functions can be applied to instance methods, class methods, static methods, as well as normal functions.
"""
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""

__slots__ = (
'raw_function',
'_config',
'_validate_return',
'__pydantic_core_schema__',
'__pydantic_validator__',
'__signature__',
'__name__',
'__qualname__',
'__annotations__',
'__dict__', # required for __module__
)

def __init__(self, function: Callable[..., Any], config: ConfigDict | None, validate_return: bool):
self.raw_function = function
self._config = config
self._validate_return = validate_return
self.__signature__ = inspect.signature(function)
if isinstance(function, partial):
func = function.func
schema_type = func
self.__name__ = f'partial({func.__name__})'
self.__qualname__ = f'partial({func.__qualname__})'
self.__annotations__ = func.__annotations__
self.__module__ = func.__module__
self.__doc__ = func.__doc__
else:
schema_type = function
self.__name__ = function.__name__
self.__qualname__ = function.__qualname__
self.__annotations__ = function.__annotations__
self.__module__ = function.__module__
self.__doc__ = function.__doc__

namespace = _typing_extra.add_module_globals(function, None)
config_wrapper = ConfigWrapper(config)
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
self.__pydantic_core_schema__ = schema
core_config = config_wrapper.core_config(self)

self.__pydantic_validator__ = create_schema_validator(
Expand All @@ -77,15 +52,11 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
config_wrapper.plugin_settings,
)

if self._validate_return:
return_type = (
self.__signature__.return_annotation
if self.__signature__.return_annotation is not self.__signature__.empty
else Any
)
if validate_return:
signature = inspect.signature(function)
return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
self.__return_pydantic_core_schema__ = schema
validator = create_schema_validator(
schema,
schema_type,
Expand All @@ -95,7 +66,7 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
core_config,
config_wrapper.plugin_settings,
)
if inspect.iscoroutinefunction(self.raw_function):
if inspect.iscoroutinefunction(function):

async def return_val_wrapper(aw: Awaitable[Any]) -> None:
return validator.validate_python(await aw)
Expand All @@ -104,58 +75,10 @@ async def return_val_wrapper(aw: Awaitable[Any]) -> None:
else:
self.__return_pydantic_validator__ = validator.validate_python
else:
self.__return_pydantic_core_schema__ = None
self.__return_pydantic_validator__ = None

self._name: str | None = None # set by __get__, used to set the instance attribute when decorating methods

def __call__(self, *args: Any, **kwargs: Any) -> Any:
res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
if self.__return_pydantic_validator__:
return self.__return_pydantic_validator__(res)
return res

def __get__(self, obj: Any, objtype: type[Any] | None = None) -> ValidateCallWrapper:
"""Bind the raw function and return another ValidateCallWrapper wrapping that."""
if obj is None:
# It's possible this wrapper is dynamically applied to a class attribute not allowing
# name to be populated by __set_name__. In this case, we'll manually acquire the name
# from the function reference.
if self._name is None:
self._name = self.raw_function.__name__
try:
# Handle the case where a method is accessed as a class attribute
return objtype.__getattribute__(objtype, self._name) # type: ignore
except AttributeError:
# This will happen the first time the attribute is accessed
pass

bound_function = self.raw_function.__get__(obj, objtype)
result = self.__class__(bound_function, self._config, self._validate_return)

# skip binding to instance when obj or objtype has __slots__ attribute
if hasattr(obj, '__slots__') or hasattr(objtype, '__slots__'):
return result

if self._name is not None:
if obj is not None:
object.__setattr__(obj, self._name, result)
else:
object.__setattr__(objtype, self._name, result)
return result

def __set_name__(self, owner: Any, name: str) -> None:
self._name = name

def __repr__(self) -> str:
return f'ValidateCallWrapper({self.raw_function})'

def __eq__(self, other) -> bool:
return (
(self.raw_function == other.raw_function)
and (self._config == other._config)
and (self._validate_return == other._validate_return)
)

def __hash__(self):
return hash(self.raw_function)
11 changes: 10 additions & 1 deletion pydantic/validate_call_decorator.py
@@ -1,6 +1,7 @@
"""Decorator for validating function calls."""
from __future__ import annotations as _annotations

import functools
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload

from ._internal import _validate_call
Expand Down Expand Up @@ -50,7 +51,15 @@ def validate(function: AnyCallableT) -> AnyCallableT:
if isinstance(function, (classmethod, staticmethod)):
name = type(function).__name__
raise TypeError(f'The `@{name}` decorator should be applied after `@validate_call` (put `@{name}` on top)')
return _validate_call.ValidateCallWrapper(function, config, validate_return) # type: ignore
validate_call_wrapper = _validate_call.ValidateCallWrapper(function, config, validate_return)

@functools.wraps(function)
def wrapper_function(*args, **kwargs):
return validate_call_wrapper(*args, **kwargs)

wrapper_function.raw_function = function # type: ignore

return wrapper_function # type: ignore

if __func:
return validate(__func)
Expand Down
10 changes: 4 additions & 6 deletions tests/test_validate_call.py
Expand Up @@ -93,9 +93,7 @@ def foo_bar(a: int, b: int):
assert foo_bar.__name__ == 'foo_bar'
assert foo_bar.__module__ == 'tests.test_validate_call'
assert foo_bar.__qualname__ == 'test_wrap.<locals>.foo_bar'
assert isinstance(foo_bar.__pydantic_core_schema__, dict)
assert callable(foo_bar.raw_function)
assert repr(foo_bar) == f'ValidateCallWrapper({repr(foo_bar.raw_function)})'
assert repr(inspect.signature(foo_bar)) == '<Signature (a: int, b: int)>'


Expand Down Expand Up @@ -369,8 +367,8 @@ def foo(self, a: int, b: int):

# insert_assert(exc_info.value.errors(include_url=False))
assert exc_info.value.errors(include_url=False) == [
{'type': 'missing_argument', 'loc': ('a',), 'msg': 'Missing required argument', 'input': ArgsKwargs(())},
{'type': 'missing_argument', 'loc': ('b',), 'msg': 'Missing required argument', 'input': ArgsKwargs(())},
{'type': 'missing_argument', 'loc': ('a',), 'msg': 'Missing required argument', 'input': ArgsKwargs((x,))},
{'type': 'missing_argument', 'loc': ('b',), 'msg': 'Missing required argument', 'input': ArgsKwargs((x,))},
Comment on lines +370 to +371
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the observable change with this approach looks like the error messages will now always show the self or cls argument. There's debate to be had whether this is a good thing; showing this argument is less implicit but now it doesn't match exactly what the user typed as input arguments, which has potential for confusion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO Python already forces users to get used to the idea of the extra self argument with error messages like this:

class C:
  def foo(self, x):
    pass

# TypeError: C.foo() takes 2 positional arguments but 3 were given
C().foo(1, 2)

I think there may be times when showing the value of self (assuming it has a useful repr) might make the error message more useful by providing some extra context about the state of things aside from the actual validation error.

But I'm worried there might be cases where users broadly catch validation errors, serialize them to JSON, and return that in an API. I don't know if that's common but it looks like an intended use case. Then I expect that JSON serialization errors may start appearing, or the validation errors seen by clients (who are far removed from the Python code) will look confusing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, good points. I think overall we tend to be ok with error messages changing (we generally justify this by saying that we work to improve error messages). So given that Python itself is a bit mixed here, I'm tempted to argue let's merge this patch as it's a fantastic simplification, and then if we want to work again to improve this later we revisit.

]


Expand All @@ -392,8 +390,8 @@ def foo(cls, a: int, b: int):

# insert_assert(exc_info.value.errors(include_url=False))
assert exc_info.value.errors(include_url=False) == [
{'type': 'missing_argument', 'loc': ('a',), 'msg': 'Missing required argument', 'input': ArgsKwargs(())},
{'type': 'missing_argument', 'loc': ('b',), 'msg': 'Missing required argument', 'input': ArgsKwargs(())},
{'type': 'missing_argument', 'loc': ('a',), 'msg': 'Missing required argument', 'input': ArgsKwargs((X,))},
{'type': 'missing_argument', 'loc': ('b',), 'msg': 'Missing required argument', 'input': ArgsKwargs((X,))},
]


Expand Down