Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap weakref.ref instead of subclassing to fix cloudpickle serialization #7780

Merged
merged 16 commits into from Oct 10, 2023
35 changes: 33 additions & 2 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -587,8 +587,39 @@ def generate_model_signature(
return Signature(parameters=list(merged_params.values()), return_annotation=None)


class _PydanticWeakRef(weakref.ReferenceType):
pass
class _PydanticWeakRef:
"""Wrapper for `weakref.ref` that enables `pickle` serialization.

Cloudpickle fails to serialize `weakref.ref` objects due to an arcane error related
to abstract base classes (`abc.ABC`). This class works around the issue by wrapping
`weakref.ref` instead of subclassing it.

See https://github.com/pydantic/pydantic/issues/6763 for context.

Semantics:
- If not pickled, behaves the same as a `weakref.ref`.
- If pickled along with the referenced object, the same `weakref.ref` behavior
will be maintained between them after unpickling.
- If pickled without the referenced object, after unpickling the underlying
reference will be cleared (`__call__` will always return `None`).
"""

def __init__(self, obj: Any):
if obj is None:
# The object will be `None` upon deserialization if the serialized weakref
# had lost its underlying object.
self._wr = None
else:
self._wr = weakref.ref(obj)

def __call__(self) -> Any:
if self._wr is None:
return None
else:
return self._wr()

def __reduce__(self) -> tuple[Callable, tuple[weakref.ReferenceType | None]]:
return _PydanticWeakRef, (self(),)


def build_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
Expand Down
52 changes: 52 additions & 0 deletions tests/test_pickle_pydantic_weakref.py
@@ -0,0 +1,52 @@
import gc
import pickle

from pydantic._internal._model_construction import _PydanticWeakRef


class IntWrapper:
def __init__(self, v: int):
self._v = v

def get(self) -> int:
return self._v

def __eq__(self, other: 'IntWrapper') -> bool:
return self.get() == other.get()


def test_pickle_pydantic_weakref():
obj1 = IntWrapper(1)
ref1 = _PydanticWeakRef(obj1)
assert ref1() is obj1

obj2 = IntWrapper(2)
ref2 = _PydanticWeakRef(obj2)
assert ref2() is obj2

ref3 = _PydanticWeakRef(IntWrapper(3))
gc.collect() # PyPy does not use reference counting and always relies on GC.
assert ref3() is None

d = {
# Hold a hard reference to the underlying object for ref1 that will also
# be pickled.
'hard_ref': obj1,
# ref1's underlying object has a hard reference in the pickled object so it
# should maintain the reference after deserialization.
'has_hard_ref': ref1,
# ref2's underlying object has no hard reference in the pickled object so it
# should be `None` after deserialization.
'has_no_hard_ref': ref2,
# ref3's underlying object had already gone out of scope before pickling so it
# should be `None` after deserialization.
'ref_out_of_scope': ref3,
}

loaded = pickle.loads(pickle.dumps(d))
gc.collect() # PyPy does not use reference counting and always relies on GC.

assert loaded['hard_ref'] == IntWrapper(1)
assert loaded['has_hard_ref']() is loaded['hard_ref']
assert loaded['has_no_hard_ref']() is None
assert loaded['ref_out_of_scope']() is None