Skip to content

Commit

Permalink
Allow @validate_call to work on async methods (#7046)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Aug 9, 2023
1 parent 35fc879 commit 3b25e22
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
14 changes: 11 additions & 3 deletions pydantic/_internal/_validate_call.py
Expand Up @@ -3,7 +3,7 @@
import inspect
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable
from typing import Any, Awaitable, Callable

import pydantic_core

Expand Down Expand Up @@ -80,7 +80,15 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
core_config = config_wrapper.core_config(self)
schema = _discriminated_union.apply_discriminators(flatten_schema_defs(schema))
simplified_schema = inline_schema_defs(schema)
self.__return_pydantic_validator__ = pydantic_core.SchemaValidator(simplified_schema, core_config)
validator = pydantic_core.SchemaValidator(simplified_schema, core_config)
if inspect.iscoroutinefunction(self.raw_function):

async def return_val_wrapper(aw: Awaitable[Any]) -> None:
return validator.validate_python(await aw)

self.__return_pydantic_validator__ = return_val_wrapper
else:
self.__return_pydantic_validator__ = validator.validate_python
else:
self.__return_pydantic_core_schema__ = None
self.__return_pydantic_validator__ = None
Expand All @@ -90,7 +98,7 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
def __call__(self, *args: Any, **kwargs: Any) -> Any:
res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
if self.__return_pydantic_validator__:
return self.__return_pydantic_validator__.validate_python(res)
return self.__return_pydantic_validator__(res)
return res

def __get__(self, obj: Any, objtype: type[Any] | None = None) -> ValidateCallWrapper:
Expand Down
24 changes: 23 additions & 1 deletion tests/test_validate_call.py
Expand Up @@ -4,7 +4,7 @@
import sys
from datetime import datetime, timezone
from functools import partial
from typing import List, Tuple
from typing import Any, List, Tuple

import pytest
from pydantic_core import ArgsKwargs
Expand Down Expand Up @@ -703,3 +703,25 @@ class A:
@decorator
def method(self, x: int):
pass


def test_async_func() -> None:
@validate_call(validate_return=True)
async def foo(a: Any) -> int:
return a

res = asyncio.run(foo(1))
assert res == 1

with pytest.raises(ValidationError) as exc_info:
asyncio.run(foo('x'))

# insert_assert(exc_info.value.errors(include_url=False))
assert exc_info.value.errors(include_url=False) == [
{
'type': 'int_parsing',
'loc': (),
'msg': 'Input should be a valid integer, unable to parse string as an integer',
'input': 'x',
}
]

0 comments on commit 3b25e22

Please sign in to comment.