Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix misbehavior related to copying of RootModel #6918

Merged
merged 1 commit into from Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 30 additions & 1 deletion pydantic/root_model.py
Expand Up @@ -3,11 +3,12 @@
from __future__ import annotations as _annotations

import typing
from copy import copy, deepcopy

from pydantic_core import PydanticUndefined

from ._internal import _repr
from .main import BaseModel
from .main import BaseModel, _object_setattr

if typing.TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -68,6 +69,34 @@ def model_construct(cls: type[Model], root: RootModelRootType, _fields_set: set[
"""
return super().model_construct(root=root, _fields_set=_fields_set)

def __getstate__(self) -> dict[Any, Any]:
return {
'__dict__': self.__dict__,
'__pydantic_fields_set__': self.__pydantic_fields_set__,
}

def __setstate__(self, state: dict[Any, Any]) -> None:
_object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__'])
_object_setattr(self, '__dict__', state['__dict__'])

def __copy__(self: Model) -> Model:
"""Returns a shallow copy of the model."""
cls = type(self)
m = cls.__new__(cls)
_object_setattr(m, '__dict__', copy(self.__dict__))
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
return m

def __deepcopy__(self: Model, memo: dict[int, Any] | None = None) -> Model:
"""Returns a deep copy of the model."""
cls = type(self)
m = cls.__new__(cls)
_object_setattr(m, '__dict__', deepcopy(self.__dict__, memo=memo))
# This next line doesn't need a deepcopy because __pydantic_fields_set__ is a set[str],
# and attempting a deepcopy would be marginally slower.
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
return m

if typing.TYPE_CHECKING:

def model_dump(
Expand Down
10 changes: 10 additions & 0 deletions tests/test_root_model.py
Expand Up @@ -612,3 +612,13 @@ def test_help(create_module):
)

assert 'class RootModel' in module.help_result_string


def test_copy_preserves_equality():
model = RootModel()

copied = model.__copy__()
assert model == copied

deepcopied = model.__deepcopy__()
assert model == deepcopied