Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX revert MemorizedFunc.call API change #1576

Merged
merged 5 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ Latest changes
In development
--------------

- Fix a backward incompatible change in ``MemorizedFunc.call`` which need to
tomMoral marked this conversation as resolved.
Show resolved Hide resolved
return the metadata. Also make sure that ``NotMemorizedFunc.call`` return
an empty dict for metadata.
tomMoral marked this conversation as resolved.
Show resolved Hide resolved
https://github.com/joblib/joblib/pull/1576

Release 1.4.0 -- 2024/04/08
---------------------------
Expand Down
8 changes: 8 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from joblib.parallel import mp
from joblib.backports import LooseVersion
from joblib import Memory
try:
import lz4
except ImportError:
Expand Down Expand Up @@ -84,3 +85,10 @@ def pytest_unconfigure(config):
# Note that we also use a shorter timeout for the per-test callback
# configured via the pytest-timeout extension.
faulthandler.dump_traceback_later(60, exit=True)


@pytest.fixture(scope='function')
def memory(tmp_path):
tomMoral marked this conversation as resolved.
Show resolved Hide resolved
mem = Memory(location=tmp_path, verbose=0)
yield mem
mem.clear()
42 changes: 27 additions & 15 deletions joblib/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def clear(self, warn=True):
pass

def call(self, *args, **kwargs):
return self.func(*args, **kwargs)
return self.func(*args, **kwargs), {}

def check_call_in_cache(self, *args, **kwargs):
return False
Expand Down Expand Up @@ -476,8 +476,9 @@ def _cached_call(self, args, kwargs, shelving):

Returns
-------
Output of the wrapped function if shelving is false, or a
MemorizedResult reference to the value if shelving is true.
output: Output of the wrapped function if shelving is false, or a
MemorizedResult reference to the value if shelving is true.
metadata: dict containing the metadata associated with the call.
"""
args_id = self._get_args_id(*args, **kwargs)
call_id = (self.func_id, args_id)
Expand Down Expand Up @@ -506,15 +507,15 @@ def _cached_call(self, args, kwargs, shelving):
# the cache.
if self._is_in_cache_and_valid(call_id):
if shelving:
return self._get_memorized_result(call_id)
return self._get_memorized_result(call_id), {}

try:
start_time = time.time()
output = self._load_item(call_id)
if self._verbose > 4:
self._print_duration(time.time() - start_time,
context='cache loaded ')
return output
return output, {}
except Exception:
# XXX: Should use an exception logger
_, signature = format_signature(self.func, *args, **kwargs)
Expand All @@ -527,6 +528,7 @@ def _cached_call(self, args, kwargs, shelving):
f"in location {location}"
)

# Returns the output but not the metadata
return self._call(call_id, args, kwargs, shelving)

@property
Expand Down Expand Up @@ -567,10 +569,12 @@ def call_and_shelve(self, *args, **kwargs):
class "NotMemorizedResult" is used when there is no cache
activated (e.g. location=None in Memory).
"""
return self._cached_call(args, kwargs, shelving=True)
# Return the wrapped output, without the metadata
return self._cached_call(args, kwargs, shelving=True)[0]

def __call__(self, *args, **kwargs):
return self._cached_call(args, kwargs, shelving=False)
# Return the output, without the metadata
return self._cached_call(args, kwargs, shelving=False)[0]

def __getstate__(self):
# Make sure self.func's source is introspected prior to being pickled -
Expand Down Expand Up @@ -752,11 +756,16 @@ def call(self, *args, **kwargs):
-------
output : object
The output of the function call.
metadata : dict
The metadata associated with the call.
"""
call_id = (self.func_id, self._get_args_id(*args, **kwargs))

# Return the output and the metadata
return self._call(call_id, args, kwargs)

def _call(self, call_id, args, kwargs, shelving=False):
# Return the output and the metadata
self._before_call(args, kwargs)
start_time = time.time()
output = self.func(*args, **kwargs)
Expand All @@ -774,13 +783,13 @@ def _after_call(self, call_id, args, kwargs, shelving, output, start_time):
self._print_duration(duration)
metadata = self._persist_input(duration, call_id, args, kwargs)
if shelving:
return self._get_memorized_result(call_id, metadata)
return self._get_memorized_result(call_id, metadata), metadata

if self.mmap_mode is not None:
# Memmap the output at the first call to be consistent with
# later calls
output = self._load_item(call_id, metadata)
return output
return output, metadata

def _persist_input(self, duration, call_id, args, kwargs,
this_duration_limit=0.5):
Expand Down Expand Up @@ -861,12 +870,14 @@ def __repr__(self):
###############################################################################
class AsyncMemorizedFunc(MemorizedFunc):
async def __call__(self, *args, **kwargs):
out = super().__call__(*args, **kwargs)
return await out if asyncio.iscoroutine(out) else out
out = self._cached_call(args, kwargs, shelving=False)
out = await out if asyncio.iscoroutine(out) else out
return out[0] # Don't return metadata

async def call_and_shelve(self, *args, **kwargs):
out = super().call_and_shelve(*args, **kwargs)
return await out if asyncio.iscoroutine(out) else out
out = self._cached_call(args, kwargs, shelving=True)
out = await out if asyncio.iscoroutine(out) else out
return out[0] # Don't return metadata

async def call(self, *args, **kwargs):
out = super().call(*args, **kwargs)
Expand All @@ -876,8 +887,9 @@ async def _call(self, call_id, args, kwargs, shelving=False):
self._before_call(args, kwargs)
start_time = time.time()
output = await self.func(*args, **kwargs)
return self._after_call(call_id, args, kwargs, shelving,
output, start_time)
return self._after_call(
call_id, args, kwargs, shelving, output, start_time
)


###############################################################################
Expand Down
45 changes: 39 additions & 6 deletions joblib/test/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,12 +1402,6 @@ def f(x):
class TestCacheValidationCallback:
"Tests on parameter `cache_validation_callback`"

@pytest.fixture()
def memory(self, tmp_path):
mem = Memory(location=tmp_path)
yield mem
mem.clear()

def foo(self, x, d, delay=None):
d["run"] = True
if delay is not None:
Expand Down Expand Up @@ -1481,3 +1475,42 @@ def test_memory_expires_after(self, memory):
assert d1["run"]
assert not d2["run"]
assert d3["run"]


class TestMemorizedFunc:
"Tests for the MemorizedFunc and NotMemorizedFunc classes"

@staticmethod
def f(x, counter):
counter[x] = counter.get(x, 0) + 1
return counter[x]

def test_call_method_memorized(self, memory):
"Test calling the function"

f = memory.cache(self.f, ignore=['counter'])

counter = {}
assert f(2, counter) == 1
assert f(2, counter) == 1

x, meta = f.call(2, counter)
assert x == 2, "f has not been called properly"
assert isinstance(meta, dict), (
"Metadata are not returned by MemorizedFunc.call."
)

def test_call_method_not_memorized(self, memory):
"Test calling the function"

f = NotMemorizedFunc(self.f)

counter = {}
assert f(2, counter) == 1
assert f(2, counter) == 2

x, meta = f.call(2, counter)
assert x == 3, "f has not been called properly"
assert isinstance(meta, dict), (
"Metadata are not returned by MemorizedFunc.call."
)
21 changes: 21 additions & 0 deletions joblib/test/test_memory_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,24 @@ async def f(x, y=1):
with raises(KeyError):
result.get()
result.clear() # Do nothing if there is no cache.


@pytest.mark.asyncio
async def test_memorized_func_call_async(memory):

async def ff(x, counter):
await asyncio.sleep(0.1)
counter[x] = counter.get(x, 0) + 1
return counter[x]

gg = memory.cache(ff, ignore=['counter'])

counter = {}
assert await gg(2, counter) == 1
assert await gg(2, counter) == 1

x, meta = await gg.call(2, counter)
assert x == 2, "f has not been called properly"
assert isinstance(meta, dict), (
"Metadata are not returned by MemorizedFunc.call."
)