Skip to content

Commit

Permalink
Return the result of an immediate refreshAfterWrite (#688)
Browse files Browse the repository at this point in the history
In Guava if the CacheLoader returns a completed future on a reload then
the new value is returned to the caller and overwritten in the cache.
Otherwise the last read value is returned and the future overwrites
when it is done. This behavior is now supported instead.

A quirk in Guava is when using Cache.get(key, callable). In Guava this
is wrapped as the cache loader and is used if a refresh is triggered.
That causes the supplied CacheLoader's reload to not be called, which
could be more surprising. However, their asMap().computeIfAbsent does
not trigger a refresh on read so the value is not reloaded. As this
behavior is confusing and accidental, Caffeine will always use the
attached CacheLoader and its reload function for refreshing.
  • Loading branch information
ben-manes committed Mar 27, 2022
1 parent 5bc7bcb commit 6522c7f
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 34 deletions.
Expand Up @@ -1223,8 +1223,9 @@ void demoteFromMainProtected() {
* @param node the entry in the page replacement policy
* @param now the current time, in nanoseconds
* @param recordHit if the hit count should be incremented
* @return the refreshed value if immediately loaded, else null
*/
void afterRead(Node<K, V> node, long now, boolean recordHit) {
@Nullable V afterRead(Node<K, V> node, long now, boolean recordHit) {
if (recordHit) {
statsCounter().recordHits(1);
}
Expand All @@ -1233,7 +1234,7 @@ void afterRead(Node<K, V> node, long now, boolean recordHit) {
if (shouldDrainBuffers(delayable)) {
scheduleDrainBuffers();
}
refreshIfNeeded(node, now);
return refreshIfNeeded(node, now);
}

/** Returns if the cache should bypass the read buffer. */
Expand All @@ -1246,11 +1247,12 @@ boolean skipReadBuffer() {
*
* @param node the entry in the cache to refresh
* @param now the current time, in nanoseconds
* @return the refreshed value if immediately loaded, else null
*/
@SuppressWarnings("FutureReturnValueIgnored")
void refreshIfNeeded(Node<K, V> node, long now) {
@Nullable V refreshIfNeeded(Node<K, V> node, long now) {
if (!refreshAfterWrite()) {
return;
return null;
}

K key;
Expand Down Expand Up @@ -1301,6 +1303,8 @@ void refreshIfNeeded(Node<K, V> node, long now) {
}

if (refreshFuture[0] != null) {
@SuppressWarnings("unchecked")
V[] refreshedValue = (V[]) new Object[1];
refreshFuture[0].whenComplete((newValue, error) -> {
long loadTime = statsTicker().read() - startTime[0];
if (error != null) {
Expand All @@ -1316,7 +1320,7 @@ void refreshIfNeeded(Node<K, V> node, long now) {
V value = (isAsync && (newValue != null)) ? (V) refreshFuture[0] : newValue;

boolean[] discard = new boolean[1];
compute(key, (k, currentValue) -> {
refreshedValue[0] = compute(key, (k, currentValue) -> {
if (currentValue == null) {
// If the entry is absent then discard the refresh and maybe notifying the listener
discard[0] = (value != null);
Expand Down Expand Up @@ -1349,8 +1353,10 @@ void refreshIfNeeded(Node<K, V> node, long now) {

refreshes.remove(keyReference, refreshFuture[0]);
});
return refreshedValue[0];
}
}
return null;
}

/**
Expand Down Expand Up @@ -2074,8 +2080,8 @@ public boolean containsValue(Object value) {
setAccessTime(node, now);
tryExpireAfterRead(node, castedKey, value, expiry(), now);
}
afterRead(node, now, recordStats);
return value;
V refreshed = afterRead(node, now, recordStats);
return (refreshed == null) ? value : refreshed;
}

@Override
Expand Down Expand Up @@ -2117,15 +2123,18 @@ public Map<K, V> getAllPresent(Iterable<? extends K> keys) {
if ((node == null) || ((value = node.getValue()) == null) || hasExpired(node, now)) {
iter.remove();
} else {
entry.setValue(value);

if (!isComputingAsync(node)) {
@SuppressWarnings("unchecked")
K castedKey = (K) entry.getKey();
tryExpireAfterRead(node, castedKey, value, expiry(), now);
setAccessTime(node, now);
}
afterRead(node, now, /* recordHit */ false);
V refreshed = afterRead(node, now, /* recordHit */ false);
if (refreshed == null) {
entry.setValue(value);
} else {
entry.setValue(refreshed);
}
}
}
statsCounter().recordHits(result.size());
Expand Down Expand Up @@ -2489,9 +2498,8 @@ public void replaceAll(BiFunction<? super K, ? super V, ? extends V> function) {
tryExpireAfterRead(node, key, value, expiry(), now);
setAccessTime(node, now);
}

afterRead(node, now, /* recordHit */ recordStats);
return value;
var refreshed = afterRead(node, now, /* recordHit */ recordStats);
return (refreshed == null) ? value : refreshed;
}
}
if (recordStats) {
Expand Down
Expand Up @@ -29,6 +29,7 @@
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth8.assertThat;
import static java.util.Map.entry;
import static java.util.function.Function.identity;
import static org.hamcrest.Matchers.is;
import static uk.org.lidalia.slf4jext.ConventionalLevelHierarchy.INFO_LEVELS;
import static uk.org.lidalia.slf4jext.Level.WARN;
Expand Down Expand Up @@ -63,6 +64,7 @@
import com.github.benmanes.caffeine.testing.ConcurrentTestHarness;
import com.github.benmanes.caffeine.testing.Int;
import com.github.valfirst.slf4jtest.TestLoggerFactory;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;

/**
Expand Down Expand Up @@ -340,18 +342,36 @@ public void refreshIfNeeded_error_log(CacheContext context) {
/* --------------- getIfPresent --------------- */

@Test(dataProvider = "caches")
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE, loader = Loader.NEGATIVE,
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE, loader = Loader.IDENTITY,
population = { Population.SINGLETON, Population.PARTIAL, Population.FULL })
public void getIfPresent(LoadingCache<Int, Int> cache, CacheContext context) {
public void getIfPresent_immediate(LoadingCache<Int, Int> cache, CacheContext context) {
context.ticker().advance(30, TimeUnit.SECONDS);
assertThat(cache.getIfPresent(context.middleKey())).isEqualTo(context.middleKey().negate());
context.ticker().advance(45, TimeUnit.SECONDS);
assertThat(cache.getIfPresent(context.middleKey())).isEqualTo(context.middleKey().negate());
assertThat(cache.getIfPresent(context.middleKey())).isEqualTo(context.middleKey());

assertThat(cache).hasSize(context.initialSize());
assertThat(context).removalNotifications().withCause(REPLACED).hasSize(1).exclusively();
}

@Test(dataProvider = "caches")
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE, loader = Loader.ASYNC_INCOMPLETE,
population = { Population.SINGLETON, Population.PARTIAL, Population.FULL })
public void getIfPresent_delayed(LoadingCache<Int, Int> cache, CacheContext context) {
context.ticker().advance(30, TimeUnit.SECONDS);
assertThat(cache.getIfPresent(context.middleKey())).isEqualTo(context.middleKey().negate());
context.ticker().advance(45, TimeUnit.SECONDS);
assertThat(cache.getIfPresent(context.middleKey())).isEqualTo(context.middleKey().negate());

assertThat(cache).hasSize(context.initialSize());
assertThat(context).removalNotifications().isEmpty();

if (context.isCaffeine()) {
cache.policy().refreshes().get(context.middleKey()).complete(context.middleKey());
assertThat(context).removalNotifications().withCause(REPLACED).hasSize(1).exclusively();
}
}

@Test(dataProvider = "caches")
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE, loader = Loader.NEGATIVE,
population = { Population.SINGLETON, Population.PARTIAL, Population.FULL })
Expand All @@ -368,35 +388,71 @@ public void getIfPresent_async(AsyncLoadingCache<Int, Int> cache, CacheContext c
/* --------------- getAllPresent --------------- */

@Test(dataProvider = "caches")
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE,
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE, loader = Loader.IDENTITY,
population = { Population.PARTIAL, Population.FULL })
public void getAllPresent(LoadingCache<Int, Int> cache, CacheContext context) {
public void getAllPresent_immediate(LoadingCache<Int, Int> cache, CacheContext context) {
int count = context.firstMiddleLastKeys().size();
context.ticker().advance(30, TimeUnit.SECONDS);
cache.getAllPresent(context.firstMiddleLastKeys());
context.ticker().advance(45, TimeUnit.SECONDS);
assertThat(cache.getAllPresent(context.firstMiddleLastKeys())).hasSize(count);
assertThat(cache.getAllPresent(context.firstMiddleLastKeys())).containsExactly(
context.firstKey(), context.firstKey(), context.middleKey(), context.middleKey(),
context.lastKey(), context.lastKey());

assertThat(cache).hasSize(context.initialSize());
assertThat(context).removalNotifications().withCause(REPLACED).hasSize(count).exclusively();
}

@Test(dataProvider = "caches")
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE, loader = Loader.ASYNC_INCOMPLETE,
population = { Population.SINGLETON, Population.PARTIAL, Population.FULL })
public void getAllPresent_delayed(LoadingCache<Int, Int> cache, CacheContext context) {
context.ticker().advance(30, TimeUnit.SECONDS);
var expected = cache.getAllPresent(context.firstMiddleLastKeys());
context.ticker().advance(45, TimeUnit.SECONDS);
assertThat(cache.getAllPresent(context.firstMiddleLastKeys()))
.containsExactlyEntriesIn(expected);

if (context.isCaffeine()) {
for (var key : context.firstMiddleLastKeys()) {
cache.policy().refreshes().get(key).complete(key);
}
assertThat(context).removalNotifications().withCause(REPLACED)
.hasSize(expected.size()).exclusively();
}
}

/* --------------- getFunc --------------- */

@Test(dataProvider = "caches")
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE,
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE, loader = Loader.IDENTITY,
population = { Population.PARTIAL, Population.FULL })
public void getFunc(LoadingCache<Int, Int> cache, CacheContext context) {
Function<Int, Int> mappingFunction = context.original()::get;
public void getFunc_immediate(LoadingCache<Int, Int> cache, CacheContext context) {
context.ticker().advance(30, TimeUnit.SECONDS);
cache.get(context.firstKey(), mappingFunction);
cache.get(context.firstKey(), identity());
context.ticker().advance(45, TimeUnit.SECONDS);
cache.get(context.lastKey(), mappingFunction); // refreshed
assertThat(cache.get(context.lastKey(), identity())).isEqualTo(context.lastKey());

assertThat(cache).hasSize(context.initialSize());
assertThat(context).removalNotifications().withCause(REPLACED).hasSize(1).exclusively();
}

@Test(dataProvider = "caches")
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE, loader = Loader.ASYNC_INCOMPLETE,
population = { Population.PARTIAL, Population.FULL })
public void getFunc_delayed(LoadingCache<Int, Int> cache, CacheContext context) {
Function<Int, Int> mappingFunction = context.original()::get;
context.ticker().advance(30, TimeUnit.SECONDS);
cache.get(context.firstKey(), mappingFunction);
context.ticker().advance(45, TimeUnit.SECONDS);
assertThat(cache.get(context.lastKey(), mappingFunction)).isEqualTo(context.lastKey().negate());

if (context.isCaffeine()) {
cache.policy().refreshes().get(context.lastKey()).complete(context.lastKey());
assertThat(context).removalNotifications().withCause(REPLACED).hasSize(1).exclusively();
}
}

@Test(dataProvider = "caches")
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE,
population = { Population.PARTIAL, Population.FULL })
Expand All @@ -414,18 +470,33 @@ public void getFunc_async(AsyncLoadingCache<Int, Int> cache, CacheContext contex
/* --------------- get --------------- */

@Test(dataProvider = "caches")
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE,
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE, loader = Loader.IDENTITY,
population = { Population.PARTIAL, Population.FULL })
public void get(LoadingCache<Int, Int> cache, CacheContext context) {
public void get_immediate(LoadingCache<Int, Int> cache, CacheContext context) {
context.ticker().advance(30, TimeUnit.SECONDS);
cache.get(context.firstKey());
cache.get(context.absentKey());
context.ticker().advance(45, TimeUnit.SECONDS);
assertThat(cache.get(context.firstKey())).isEqualTo(context.firstKey());

assertThat(cache).containsEntry(context.firstKey(), context.firstKey().negate());
assertThat(cache).hasSize(context.initialSize());
assertThat(context).removalNotifications().withCause(REPLACED).hasSize(1).exclusively();
}

@Test(dataProvider = "caches")
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE, loader = Loader.ASYNC_INCOMPLETE,
population = { Population.PARTIAL, Population.FULL })
public void get_delayed(LoadingCache<Int, Int> cache, CacheContext context) {
context.ticker().advance(30, TimeUnit.SECONDS);
cache.get(context.firstKey());
context.ticker().advance(45, TimeUnit.SECONDS);
assertThat(cache.get(context.firstKey())).isEqualTo(context.firstKey().negate());

if (context.isCaffeine()) {
cache.policy().refreshes().get(context.firstKey()).complete(context.firstKey());
assertThat(context).removalNotifications().withCause(REPLACED).hasSize(1).exclusively();
}
}

@Test(dataProvider = "caches")
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE, loader = Loader.IDENTITY,
population = { Population.PARTIAL, Population.FULL })
Expand Down Expand Up @@ -511,23 +582,44 @@ public void get_null(AsyncLoadingCache<Int, Int> cache, CacheContext context) {
@Test(dataProvider = "caches")
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE, loader = Loader.IDENTITY,
population = { Population.PARTIAL, Population.FULL })
public void getAll(LoadingCache<Int, Int> cache, CacheContext context) {
public void getAll_immediate(LoadingCache<Int, Int> cache, CacheContext context) {
var keys = List.of(context.firstKey(), context.absentKey());
context.ticker().advance(30, TimeUnit.SECONDS);
assertThat(cache.getAll(keys)).containsExactly(
context.firstKey(), context.firstKey().negate(),
context.absentKey(), context.absentKey());

// Trigger a refresh, may return old values
context.ticker().advance(45, TimeUnit.SECONDS);
cache.getAll(keys);

// Ensure new values are present
// Trigger a refresh, ensure new values are present
assertThat(cache.getAll(keys)).containsExactly(
context.firstKey(), context.firstKey(), context.absentKey(), context.absentKey());
assertThat(context).removalNotifications().withCause(REPLACED).hasSize(1).exclusively();
}

@Test(dataProvider = "caches")
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE, loader = Loader.ASYNC_INCOMPLETE,
population = { Population.PARTIAL, Population.FULL })
public void getAll_delayed(LoadingCache<Int, Int> cache, CacheContext context) {
var keys = context.firstMiddleLastKeys();
var expected = ImmutableMap.of(
context.firstKey(), context.firstKey().negate(),
context.middleKey(), context.middleKey().negate(),
context.lastKey(), context.lastKey().negate());
context.ticker().advance(30, TimeUnit.SECONDS);
assertThat(cache.getAll(keys)).containsExactlyEntriesIn(expected);

// Trigger a refresh, returns old values
context.ticker().advance(45, TimeUnit.SECONDS);
assertThat(cache.getAll(keys)).containsExactlyEntriesIn(expected);

if (context.isCaffeine()) {
for (var key : keys) {
cache.policy().refreshes().get(key).complete(key);
}
assertThat(context).removalNotifications().withCause(REPLACED)
.hasSize(keys.size()).exclusively();
}
}

@Test(dataProvider = "caches")
@CacheSpec(refreshAfterWrite = Expire.ONE_MINUTE, loader = Loader.IDENTITY,
population = { Population.PARTIAL, Population.FULL })
Expand Down
Expand Up @@ -63,6 +63,8 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.ExecutionError;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import com.google.common.util.concurrent.UncheckedExecutionException;

/**
Expand Down Expand Up @@ -627,6 +629,20 @@ public V load(K key) throws Exception {
throw e;
}
}

@Override
@SuppressWarnings("FutureReturnValueIgnored")
public ListenableFuture<V> reload(K key, V oldValue) throws Exception {
var future = SettableFuture.<V>create();
delegate.asyncReload(key, oldValue, Runnable::run).whenComplete((r, e) -> {
if (e == null) {
future.set(r);
} else {
future.setException(e);
}
});
return future;
}
}

static class BulkLoader<K, V> extends SingleLoader<K, V> {
Expand Down

0 comments on commit 6522c7f

Please sign in to comment.