Skip to content

Commit

Permalink
FIX revert MemorizedFunc.call API change (#1576)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomMoral committed May 2, 2024
1 parent 2be8dcd commit 398d8ee
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 21 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ Latest changes
In development
--------------

- Fix a backward incompatible change in ``MemorizedFunc.call`` which needs to
return the metadata. Also make sure that ``NotMemorizedFunc.call`` return
an empty dict for metadata for consistency.
https://github.com/joblib/joblib/pull/1576

Release 1.4.0 -- 2024/04/08
---------------------------
Expand Down
9 changes: 9 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from joblib.parallel import mp
from joblib.backports import LooseVersion
from joblib import Memory
try:
import lz4
except ImportError:
Expand Down Expand Up @@ -84,3 +85,11 @@ def pytest_unconfigure(config):
# Note that we also use a shorter timeout for the per-test callback
# configured via the pytest-timeout extension.
faulthandler.dump_traceback_later(60, exit=True)


@pytest.fixture(scope='function')
def memory(tmp_path):
"Fixture to get an independent and self-cleaning Memory"
mem = Memory(location=tmp_path, verbose=0)
yield mem
mem.clear()
42 changes: 27 additions & 15 deletions joblib/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def clear(self, warn=True):
pass

def call(self, *args, **kwargs):
return self.func(*args, **kwargs)
return self.func(*args, **kwargs), {}

def check_call_in_cache(self, *args, **kwargs):
return False
Expand Down Expand Up @@ -476,8 +476,9 @@ def _cached_call(self, args, kwargs, shelving):
Returns
-------
Output of the wrapped function if shelving is false, or a
MemorizedResult reference to the value if shelving is true.
output: Output of the wrapped function if shelving is false, or a
MemorizedResult reference to the value if shelving is true.
metadata: dict containing the metadata associated with the call.
"""
args_id = self._get_args_id(*args, **kwargs)
call_id = (self.func_id, args_id)
Expand Down Expand Up @@ -506,15 +507,15 @@ def _cached_call(self, args, kwargs, shelving):
# the cache.
if self._is_in_cache_and_valid(call_id):
if shelving:
return self._get_memorized_result(call_id)
return self._get_memorized_result(call_id), {}

try:
start_time = time.time()
output = self._load_item(call_id)
if self._verbose > 4:
self._print_duration(time.time() - start_time,
context='cache loaded ')
return output
return output, {}
except Exception:
# XXX: Should use an exception logger
_, signature = format_signature(self.func, *args, **kwargs)
Expand All @@ -527,6 +528,7 @@ def _cached_call(self, args, kwargs, shelving):
f"in location {location}"
)

# Returns the output but not the metadata
return self._call(call_id, args, kwargs, shelving)

@property
Expand Down Expand Up @@ -567,10 +569,12 @@ def call_and_shelve(self, *args, **kwargs):
class "NotMemorizedResult" is used when there is no cache
activated (e.g. location=None in Memory).
"""
return self._cached_call(args, kwargs, shelving=True)
# Return the wrapped output, without the metadata
return self._cached_call(args, kwargs, shelving=True)[0]

def __call__(self, *args, **kwargs):
return self._cached_call(args, kwargs, shelving=False)
# Return the output, without the metadata
return self._cached_call(args, kwargs, shelving=False)[0]

def __getstate__(self):
# Make sure self.func's source is introspected prior to being pickled -
Expand Down Expand Up @@ -752,11 +756,16 @@ def call(self, *args, **kwargs):
-------
output : object
The output of the function call.
metadata : dict
The metadata associated with the call.
"""
call_id = (self.func_id, self._get_args_id(*args, **kwargs))

# Return the output and the metadata
return self._call(call_id, args, kwargs)

def _call(self, call_id, args, kwargs, shelving=False):
# Return the output and the metadata
self._before_call(args, kwargs)
start_time = time.time()
output = self.func(*args, **kwargs)
Expand All @@ -774,13 +783,13 @@ def _after_call(self, call_id, args, kwargs, shelving, output, start_time):
self._print_duration(duration)
metadata = self._persist_input(duration, call_id, args, kwargs)
if shelving:
return self._get_memorized_result(call_id, metadata)
return self._get_memorized_result(call_id, metadata), metadata

if self.mmap_mode is not None:
# Memmap the output at the first call to be consistent with
# later calls
output = self._load_item(call_id, metadata)
return output
return output, metadata

def _persist_input(self, duration, call_id, args, kwargs,
this_duration_limit=0.5):
Expand Down Expand Up @@ -861,12 +870,14 @@ def __repr__(self):
###############################################################################
class AsyncMemorizedFunc(MemorizedFunc):
async def __call__(self, *args, **kwargs):
out = super().__call__(*args, **kwargs)
return await out if asyncio.iscoroutine(out) else out
out = self._cached_call(args, kwargs, shelving=False)
out = await out if asyncio.iscoroutine(out) else out
return out[0] # Don't return metadata

async def call_and_shelve(self, *args, **kwargs):
out = super().call_and_shelve(*args, **kwargs)
return await out if asyncio.iscoroutine(out) else out
out = self._cached_call(args, kwargs, shelving=True)
out = await out if asyncio.iscoroutine(out) else out
return out[0] # Don't return metadata

async def call(self, *args, **kwargs):
out = super().call(*args, **kwargs)
Expand All @@ -876,8 +887,9 @@ async def _call(self, call_id, args, kwargs, shelving=False):
self._before_call(args, kwargs)
start_time = time.time()
output = await self.func(*args, **kwargs)
return self._after_call(call_id, args, kwargs, shelving,
output, start_time)
return self._after_call(
call_id, args, kwargs, shelving, output, start_time
)


###############################################################################
Expand Down
45 changes: 39 additions & 6 deletions joblib/test/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,12 +1402,6 @@ def f(x):
class TestCacheValidationCallback:
"Tests on parameter `cache_validation_callback`"

@pytest.fixture()
def memory(self, tmp_path):
mem = Memory(location=tmp_path)
yield mem
mem.clear()

def foo(self, x, d, delay=None):
d["run"] = True
if delay is not None:
Expand Down Expand Up @@ -1481,3 +1475,42 @@ def test_memory_expires_after(self, memory):
assert d1["run"]
assert not d2["run"]
assert d3["run"]


class TestMemorizedFunc:
"Tests for the MemorizedFunc and NotMemorizedFunc classes"

@staticmethod
def f(x, counter):
counter[x] = counter.get(x, 0) + 1
return counter[x]

def test_call_method_memorized(self, memory):
"Test calling the function"

f = memory.cache(self.f, ignore=['counter'])

counter = {}
assert f(2, counter) == 1
assert f(2, counter) == 1

x, meta = f.call(2, counter)
assert x == 2, "f has not been called properly"
assert isinstance(meta, dict), (
"Metadata are not returned by MemorizedFunc.call."
)

def test_call_method_not_memorized(self, memory):
"Test calling the function"

f = NotMemorizedFunc(self.f)

counter = {}
assert f(2, counter) == 1
assert f(2, counter) == 2

x, meta = f.call(2, counter)
assert x == 3, "f has not been called properly"
assert isinstance(meta, dict), (
"Metadata are not returned by MemorizedFunc.call."
)
21 changes: 21 additions & 0 deletions joblib/test/test_memory_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,24 @@ async def f(x, y=1):
with raises(KeyError):
result.get()
result.clear() # Do nothing if there is no cache.


@pytest.mark.asyncio
async def test_memorized_func_call_async(memory):

async def ff(x, counter):
await asyncio.sleep(0.1)
counter[x] = counter.get(x, 0) + 1
return counter[x]

gg = memory.cache(ff, ignore=['counter'])

counter = {}
assert await gg(2, counter) == 1
assert await gg(2, counter) == 1

x, meta = await gg.call(2, counter)
assert x == 2, "f has not been called properly"
assert isinstance(meta, dict), (
"Metadata are not returned by MemorizedFunc.call."
)

0 comments on commit 398d8ee

Please sign in to comment.