Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow plugins to catch non ValidationError errors #7806

Merged
merged 1 commit into from Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions pydantic/plugin/__init__.py
Expand Up @@ -76,6 +76,14 @@ def on_error(self, error: ValidationError) -> None:
"""
return

def on_exception(self, exception: Exception) -> None:
"""Callback to be notified of validation exceptions.

Args:
exception: The exception raised during validation.
"""
return


class ValidatePythonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
"""Event handler for `SchemaValidator.validate_python`."""
Expand Down
5 changes: 5 additions & 0 deletions pydantic/plugin/_schema_validator.py
Expand Up @@ -75,6 +75,7 @@ def build_wrapper(func: Callable[P, R], event_handlers: list[BaseValidateHandler
on_enters = tuple(h.on_enter for h in event_handlers if filter_handlers(h, 'on_enter'))
on_successes = tuple(h.on_success for h in event_handlers if filter_handlers(h, 'on_success'))
on_errors = tuple(h.on_error for h in event_handlers if filter_handlers(h, 'on_error'))
on_exceptions = tuple(h.on_exception for h in event_handlers if filter_handlers(h, 'on_exception'))

@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
Expand All @@ -87,6 +88,10 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
for on_error_handler in on_errors:
on_error_handler(error)
raise
except Exception as exception:
for on_exception_handler in on_exceptions:
on_exception_handler(exception)
raise
else:
for on_success_handler in on_successes:
on_success_handler(result)
Expand Down
70 changes: 62 additions & 8 deletions tests/test_plugins.py
Expand Up @@ -5,7 +5,7 @@

from pydantic_core import ValidationError

from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from pydantic.plugin import (
PydanticPluginProtocol,
ValidateJsonHandlerProtocol,
Expand All @@ -18,8 +18,10 @@
@contextlib.contextmanager
def install_plugin(plugin: PydanticPluginProtocol) -> Generator[None, None, None]:
_plugins[plugin.__class__.__qualname__] = plugin
yield
_plugins.clear()
try:
yield
finally:
_plugins.clear()


def test_on_validate_json_on_success() -> None:
Expand Down Expand Up @@ -58,7 +60,7 @@ class Model(BaseModel, plugin_settings={'observe': 'all'}):

def test_on_validate_json_on_error() -> None:
class CustomOnValidateJson:
def enter(
def on_enter(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of these were getting called 😢. In my opinion, this is a strong argument for using a single method with a context manager instead of multiple optional callbacks.

self,
input: str | bytes | bytearray,
*,
Expand Down Expand Up @@ -101,7 +103,7 @@ class Model(BaseModel, plugin_settings={'observe': 'all'}):

def test_on_validate_python_on_success() -> None:
class CustomOnValidatePython(ValidatePythonHandlerProtocol):
def enter(
def on_enter(
self,
input: Any,
*,
Expand Down Expand Up @@ -136,7 +138,7 @@ class Model(BaseModel, plugin_settings={'observe': 'all'}):

def test_on_validate_python_on_error() -> None:
class CustomOnValidatePython(ValidatePythonHandlerProtocol):
def enter(
def on_enter(
self,
input: Any,
*,
Expand All @@ -149,8 +151,6 @@ def enter(
assert strict is None
assert context is None
assert self_instance is None
assert self.config == {'title': 'Model'}
assert self.plugin_settings == {'observe': 'all'}

def on_error(self, error: ValidationError) -> None:
assert error.title == 'Model'
Expand Down Expand Up @@ -180,6 +180,60 @@ class Model(BaseModel, plugin_settings={'observe': 'all'}):
Model.model_validate_json('{"a": 1}') == {'a': 1}


def test_stateful_plugin() -> None:
stack: list[Any] = []

class CustomOnValidatePython(ValidatePythonHandlerProtocol):
def on_enter(
self,
input: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
self_instance: Any | None = None,
) -> None:
stack.append(input)

def on_success(self, result: Any) -> None:
stack.pop()

def on_error(self, error: Exception) -> None:
stack.pop()

def on_exception(self, exception: Exception) -> None:
stack.pop()

class Plugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, config, plugin_settings):
return CustomOnValidatePython(), None, None

plugin = Plugin()

class MyException(Exception):
pass

with install_plugin(plugin):

class Model(BaseModel, plugin_settings={'observe': 'all'}):
a: int

@field_validator('a')
def validate_a(cls, v: int) -> int:
if v < 0:
raise MyException
return v

with contextlib.suppress(ValidationError):
Model.model_validate({'a': 'potato'})
assert not stack
with contextlib.suppress(MyException):
Model.model_validate({'a': -1})
assert not stack
assert Model.model_validate({'a': 1}).a == 1
assert not stack


def test_all_handlers():
log = []

Expand Down