Skip to content

Commit

Permalink
Wrap weakref.ref instead of subclassing to fix cloudpickle serial…
Browse files Browse the repository at this point in the history
…ization (#7780)

Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
edoakes committed Oct 10, 2023
1 parent 4d52abb commit 5c21e7c
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 2 deletions.
35 changes: 33 additions & 2 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -587,8 +587,39 @@ def generate_model_signature(
return Signature(parameters=list(merged_params.values()), return_annotation=None)


class _PydanticWeakRef(weakref.ReferenceType):
pass
class _PydanticWeakRef:
"""Wrapper for `weakref.ref` that enables `pickle` serialization.
Cloudpickle fails to serialize `weakref.ref` objects due to an arcane error related
to abstract base classes (`abc.ABC`). This class works around the issue by wrapping
`weakref.ref` instead of subclassing it.
See https://github.com/pydantic/pydantic/issues/6763 for context.
Semantics:
- If not pickled, behaves the same as a `weakref.ref`.
- If pickled along with the referenced object, the same `weakref.ref` behavior
will be maintained between them after unpickling.
- If pickled without the referenced object, after unpickling the underlying
reference will be cleared (`__call__` will always return `None`).
"""

def __init__(self, obj: Any):
if obj is None:
# The object will be `None` upon deserialization if the serialized weakref
# had lost its underlying object.
self._wr = None
else:
self._wr = weakref.ref(obj)

def __call__(self) -> Any:
if self._wr is None:
return None
else:
return self._wr()

def __reduce__(self) -> tuple[Callable, tuple[weakref.ReferenceType | None]]:
return _PydanticWeakRef, (self(),)


def build_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
Expand Down
52 changes: 52 additions & 0 deletions tests/test_pickle_pydantic_weakref.py
@@ -0,0 +1,52 @@
import gc
import pickle

from pydantic._internal._model_construction import _PydanticWeakRef


class IntWrapper:
def __init__(self, v: int):
self._v = v

def get(self) -> int:
return self._v

def __eq__(self, other: 'IntWrapper') -> bool:
return self.get() == other.get()


def test_pickle_pydantic_weakref():
obj1 = IntWrapper(1)
ref1 = _PydanticWeakRef(obj1)
assert ref1() is obj1

obj2 = IntWrapper(2)
ref2 = _PydanticWeakRef(obj2)
assert ref2() is obj2

ref3 = _PydanticWeakRef(IntWrapper(3))
gc.collect() # PyPy does not use reference counting and always relies on GC.
assert ref3() is None

d = {
# Hold a hard reference to the underlying object for ref1 that will also
# be pickled.
'hard_ref': obj1,
# ref1's underlying object has a hard reference in the pickled object so it
# should maintain the reference after deserialization.
'has_hard_ref': ref1,
# ref2's underlying object has no hard reference in the pickled object so it
# should be `None` after deserialization.
'has_no_hard_ref': ref2,
# ref3's underlying object had already gone out of scope before pickling so it
# should be `None` after deserialization.
'ref_out_of_scope': ref3,
}

loaded = pickle.loads(pickle.dumps(d))
gc.collect() # PyPy does not use reference counting and always relies on GC.

assert loaded['hard_ref'] == IntWrapper(1)
assert loaded['has_hard_ref']() is loaded['hard_ref']
assert loaded['has_no_hard_ref']() is None
assert loaded['ref_out_of_scope']() is None

0 comments on commit 5c21e7c

Please sign in to comment.