Skip to content

Commit

Permalink
Make path of the item to validate available in plugin (#7861)
Browse files Browse the repository at this point in the history
Co-authored-by: Samuel Colvin <s@muelcolvin.com>
  • Loading branch information
hramezani and samuelcolvin committed Oct 31, 2023
1 parent cc21074 commit 60c5db6
Show file tree
Hide file tree
Showing 14 changed files with 302 additions and 42 deletions.
2 changes: 1 addition & 1 deletion docs/concepts/models.md
Expand Up @@ -1057,7 +1057,7 @@ BarModel = create_model(
__base__=FooModel,
)
print(BarModel)
#> <class 'pydantic.main.BarModel'>
#> <class '__main__.BarModel'>
print(BarModel.model_fields.keys())
#> dict_keys(['foo', 'bar', 'apple', 'banana'])
```
Expand Down
5 changes: 5 additions & 0 deletions docs/concepts/plugins.md
Expand Up @@ -62,6 +62,8 @@ from pydantic_core import CoreConfig, CoreSchema, ValidationError
from pydantic.plugin import (
NewSchemaReturns,
PydanticPluginProtocol,
SchemaKind,
SchemaTypePath,
ValidatePythonHandlerProtocol,
)

Expand Down Expand Up @@ -89,6 +91,9 @@ class Plugin(PydanticPluginProtocol):
def new_schema_validator(
self,
schema: CoreSchema,
schema_type: Any,
schema_type_path: SchemaTypePath,
schema_kind: SchemaKind,
config: Union[CoreConfig, None],
plugin_settings: Dict[str, object],
) -> NewSchemaReturns:
Expand Down
2 changes: 1 addition & 1 deletion pydantic/_internal/_dataclasses.py
Expand Up @@ -172,7 +172,7 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -

cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = validator = create_schema_validator(
schema, core_config, config_wrapper.plugin_settings
schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)

Expand Down
16 changes: 15 additions & 1 deletion pydantic/_internal/_model_construction.py
Expand Up @@ -63,6 +63,7 @@ def __new__(
namespace: dict[str, Any],
__pydantic_generic_metadata__: PydanticGenericMetadata | None = None,
__pydantic_reset_parent_namespace__: bool = True,
_create_model_module: str | None = None,
**kwargs: Any,
) -> type:
"""Metaclass for creating Pydantic models.
Expand All @@ -73,6 +74,7 @@ def __new__(
namespace: The attribute dictionary of the class to be created.
__pydantic_generic_metadata__: Metadata for generic models.
__pydantic_reset_parent_namespace__: Reset parent namespace.
_create_model_module: The module of the class to be created, if created by `create_model`.
**kwargs: Catch-all for any other keyword arguments.
Returns:
Expand Down Expand Up @@ -182,6 +184,7 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
config_wrapper,
raise_errors=False,
types_namespace=types_namespace,
create_model_module=_create_model_module,
)
# using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__
# I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is
Expand Down Expand Up @@ -438,6 +441,7 @@ def complete_model_class(
*,
raise_errors: bool = True,
types_namespace: dict[str, Any] | None,
create_model_module: str | None = None,
) -> bool:
"""Finish building a model class.
Expand All @@ -450,6 +454,7 @@ def complete_model_class(
config_wrapper: The config wrapper instance.
raise_errors: Whether to raise errors.
types_namespace: Optional extra namespace to look for types in.
create_model_module: The module of the class to be created, if created by `create_model`.
Returns:
`True` if the model is successfully completed, else `False`.
Expand Down Expand Up @@ -493,7 +498,16 @@ def complete_model_class(

# debug(schema)
cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)

cls.__pydantic_validator__ = create_schema_validator(
schema,
cls,
create_model_module or cls.__module__,
cls.__qualname__,
'create_model' if create_model_module else 'BaseModel',
core_config,
config_wrapper.plugin_settings,
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
cls.__pydantic_complete__ = True

Expand Down
23 changes: 21 additions & 2 deletions pydantic/_internal/_validate_call.py
Expand Up @@ -46,12 +46,14 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
self.__signature__ = inspect.signature(function)
if isinstance(function, partial):
func = function.func
schema_type = func
self.__name__ = f'partial({func.__name__})'
self.__qualname__ = f'partial({func.__qualname__})'
self.__annotations__ = func.__annotations__
self.__module__ = func.__module__
self.__doc__ = func.__doc__
else:
schema_type = function
self.__name__ = function.__name__
self.__qualname__ = function.__qualname__
self.__annotations__ = function.__annotations__
Expand All @@ -64,7 +66,16 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
self.__pydantic_core_schema__ = schema
core_config = config_wrapper.core_config(self)
self.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)

self.__pydantic_validator__ = create_schema_validator(
schema,
schema_type,
self.__module__,
self.__qualname__,
'validate_call',
core_config,
config_wrapper.plugin_settings,
)

if self._validate_return:
return_type = (
Expand All @@ -75,7 +86,15 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
self.__return_pydantic_core_schema__ = schema
validator = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)
validator = create_schema_validator(
schema,
schema_type,
self.__module__,
self.__qualname__,
'validate_call',
core_config,
config_wrapper.plugin_settings,
)
if inspect.iscoroutinefunction(self.raw_function):

async def return_val_wrapper(aw: Awaitable[Any]) -> None:
Expand Down
25 changes: 19 additions & 6 deletions pydantic/main.py
@@ -1,6 +1,7 @@
"""Logic for creating models."""
from __future__ import annotations as _annotations

import sys
import types
import typing
import warnings
Expand Down Expand Up @@ -1345,13 +1346,13 @@ def create_model(
...


def create_model(
def create_model( # noqa: C901
__model_name: str,
*,
__config__: ConfigDict | None = None,
__doc__: str | None = None,
__base__: type[Model] | tuple[type[Model], ...] | None = None,
__module__: str = __name__,
__module__: str | None = None,
__validators__: dict[str, AnyClassMethod] | None = None,
__cls_kwargs__: dict[str, Any] | None = None,
__slots__: tuple[str, ...] | None = None,
Expand All @@ -1365,9 +1366,9 @@ def create_model(
__config__: The configuration of the new model.
__doc__: The docstring of the new model.
__base__: The base class for the new model.
__module__: The name of the module that the model belongs to.
__validators__: A dictionary of methods that validate
fields.
__module__: The name of the module that the model belongs to,
if `None` the value is taken from `sys._getframe(1)`
__validators__: A dictionary of methods that validate fields.
__cls_kwargs__: A dictionary of keyword arguments for class creation.
__slots__: Deprecated. Should not be passed to `create_model`.
**field_definitions: Attributes of the new model. They should be passed in the format:
Expand Down Expand Up @@ -1418,6 +1419,10 @@ def create_model(
annotations[f_name] = f_annotation
fields[f_name] = f_value

if __module__ is None:
f = sys._getframe(1)
__module__ = f.f_globals['__name__']

namespace: dict[str, Any] = {'__annotations__': annotations, '__module__': __module__}
if __doc__:
namespace.update({'__doc__': __doc__})
Expand All @@ -1431,7 +1436,15 @@ def create_model(
if resolved_bases is not __base__:
ns['__orig_bases__'] = __base__
namespace.update(ns)
return meta(__model_name, resolved_bases, namespace, __pydantic_reset_parent_namespace__=False, **kwds)

return meta(
__model_name,
resolved_bases,
namespace,
__pydantic_reset_parent_namespace__=False,
_create_model_module=__module__,
**kwds,
)


__getattr__ = getattr_migration(__name__)
22 changes: 20 additions & 2 deletions pydantic/plugin/__init__.py
Expand Up @@ -4,10 +4,10 @@
"""
from __future__ import annotations

from typing import Any, Callable
from typing import Any, Callable, NamedTuple

from pydantic_core import CoreConfig, CoreSchema, ValidationError
from typing_extensions import Protocol, TypeAlias
from typing_extensions import Literal, Protocol, TypeAlias

__all__ = (
'PydanticPluginProtocol',
Expand All @@ -16,17 +16,32 @@
'ValidateJsonHandlerProtocol',
'ValidateStringsHandlerProtocol',
'NewSchemaReturns',
'SchemaTypePath',
'SchemaKind',
)

NewSchemaReturns: TypeAlias = 'tuple[ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None]'


class SchemaTypePath(NamedTuple):
"""Path defining where `schema_type` was defined, or where `TypeAdapter` was called."""

module: str
name: str


SchemaKind: TypeAlias = Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call']


class PydanticPluginProtocol(Protocol):
"""Protocol defining the interface for Pydantic plugins."""

def new_schema_validator(
self,
schema: CoreSchema,
schema_type: Any,
schema_type_path: SchemaTypePath,
schema_kind: SchemaKind,
config: CoreConfig | None,
plugin_settings: dict[str, object],
) -> tuple[
Expand All @@ -40,6 +55,9 @@ def new_schema_validator(
Args:
schema: The schema to validate against.
schema_type: The original type which the schema was created from, e.g. the model class.
schema_type_path: Path defining where `schema_type` was defined, or where `TypeAdapter` was called.
schema_kind: The kind of schema to validate against.
config: The config to use for validation.
plugin_settings: Any plugin settings.
Expand Down
28 changes: 24 additions & 4 deletions pydantic/plugin/_schema_validator.py
Expand Up @@ -8,7 +8,7 @@
from typing_extensions import Literal, ParamSpec

if TYPE_CHECKING:
from . import BaseValidateHandlerProtocol, PydanticPluginProtocol
from . import BaseValidateHandlerProtocol, PydanticPluginProtocol, SchemaKind, SchemaTypePath


P = ParamSpec('P')
Expand All @@ -18,18 +18,33 @@


def create_schema_validator(
schema: CoreSchema, config: CoreConfig | None = None, plugin_settings: dict[str, Any] | None = None
schema: CoreSchema,
schema_type: Any,
schema_type_module: str,
schema_type_name: str,
schema_kind: SchemaKind,
config: CoreConfig | None = None,
plugin_settings: dict[str, Any] | None = None,
) -> SchemaValidator:
"""Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed.
Returns:
If plugins are installed then return `PluggableSchemaValidator`, otherwise return `SchemaValidator`.
"""
from . import SchemaTypePath
from ._loader import get_plugins

plugins = get_plugins()
if plugins:
return PluggableSchemaValidator(schema, config, plugins, plugin_settings or {}) # type: ignore
return PluggableSchemaValidator(
schema,
schema_type,
SchemaTypePath(schema_type_module, schema_type_name),
schema_kind,
config,
plugins,
plugin_settings or {},
) # type: ignore
else:
return SchemaValidator(schema, config)

Expand All @@ -42,6 +57,9 @@ class PluggableSchemaValidator:
def __init__(
self,
schema: CoreSchema,
schema_type: Any,
schema_type_path: SchemaTypePath,
schema_kind: SchemaKind,
config: CoreConfig | None,
plugins: Iterable[PydanticPluginProtocol],
plugin_settings: dict[str, Any],
Expand All @@ -52,7 +70,9 @@ def __init__(
json_event_handlers: list[BaseValidateHandlerProtocol] = []
strings_event_handlers: list[BaseValidateHandlerProtocol] = []
for plugin in plugins:
p, j, s = plugin.new_schema_validator(schema, config, plugin_settings)
p, j, s = plugin.new_schema_validator(
schema, schema_type, schema_type_path, schema_kind, config, plugin_settings
)
if p is not None:
python_event_handlers.append(p)
if j is not None:
Expand Down
22 changes: 17 additions & 5 deletions pydantic/type_adapter.py
Expand Up @@ -78,7 +78,7 @@ class Item(BaseModel):

import sys
from dataclasses import is_dataclass
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Set, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Set, TypeVar, Union, cast, overload

from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator, Some
from typing_extensions import Literal, is_typeddict
Expand Down Expand Up @@ -205,23 +205,30 @@ def __new__(cls, __type: Any, *, config: ConfigDict | None = ...) -> TypeAdapter
raise NotImplementedError

@overload
def __init__(self, type: type[T], *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
def __init__(
self, type: type[T], *, config: ConfigDict | None = None, _parent_depth: int = 2, module: str | None = None
) -> None:
...

# this overload is for non-type things like Union[int, str]
# Pyright currently handles this "correctly", but MyPy understands this as TypeAdapter[object]
# so an explicit type cast is needed
@overload
def __init__(self, type: T, *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
def __init__(
self, type: T, *, config: ConfigDict | None = None, _parent_depth: int = 2, module: str | None = None
) -> None:
...

def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
def __init__(
self, type: Any, *, config: ConfigDict | None = None, _parent_depth: int = 2, module: str | None = None
) -> None:
"""Initializes the TypeAdapter object.
Args:
type: The type associated with the `TypeAdapter`.
config: Configuration for the `TypeAdapter`, should be a dictionary conforming to [`ConfigDict`][pydantic.config.ConfigDict].
_parent_depth: depth at which to search the parent namespace to construct the local namespace.
module: The module that passes to plugin if provided.
!!! note
You cannot use the `config` argument when instantiating a `TypeAdapter` if the type you're using has its own
Expand Down Expand Up @@ -264,7 +271,12 @@ def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth
try:
validator = _getattr_no_parents(type, '__pydantic_validator__')
except AttributeError:
validator = create_schema_validator(core_schema, core_config, config_wrapper.plugin_settings)
if module is None:
f = sys._getframe(1)
module = cast(str, f.f_globals['__name__'])
validator = create_schema_validator(
core_schema, type, module, str(type), 'TypeAdapter', core_config, config_wrapper.plugin_settings
) # type: ignore

serializer: SchemaSerializer
try:
Expand Down

0 comments on commit 60c5db6

Please sign in to comment.