diff --git a/spring-boot-project/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/endpoint/invoker/cache/CachingOperationInvoker.java b/spring-boot-project/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/endpoint/invoker/cache/CachingOperationInvoker.java index 719a968ce0d6..00f8ca160555 100644 --- a/spring-boot-project/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/endpoint/invoker/cache/CachingOperationInvoker.java +++ b/spring-boot-project/spring-boot-actuator/src/main/java/org/springframework/boot/actuate/endpoint/invoker/cache/CachingOperationInvoker.java @@ -18,9 +18,10 @@ import java.security.Principal; import java.time.Duration; +import java.util.Iterator; import java.util.Map; +import java.util.Map.Entry; import java.util.Objects; -import java.util.concurrent.ConcurrentHashMap; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -30,6 +31,7 @@ import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.ConcurrentReferenceHashMap; import org.springframework.util.ObjectUtils; /** @@ -45,6 +47,8 @@ public class CachingOperationInvoker implements OperationInvoker { private static final boolean IS_REACTOR_PRESENT = ClassUtils.isPresent("reactor.core.publisher.Mono", null); + private static final int CACHE_CLEANUP_THRESHOLD = 40; + private final OperationInvoker invoker; private final long timeToLive; @@ -61,7 +65,7 @@ public class CachingOperationInvoker implements OperationInvoker { Assert.isTrue(timeToLive > 0, "TimeToLive must be strictly positive"); this.invoker = invoker; this.timeToLive = timeToLive; - this.cachedResponses = new ConcurrentHashMap<>(); + this.cachedResponses = new ConcurrentReferenceHashMap<>(); } /** @@ -78,6 +82,9 @@ public Object invoke(InvocationContext context) { return this.invoker.invoke(context); } long accessTime = System.currentTimeMillis(); + if (this.cachedResponses.size() > CACHE_CLEANUP_THRESHOLD) { + cleanExpiredCachedResponses(accessTime); + } ApiVersion contextApiVersion = context.resolveArgument(ApiVersion.class); Principal principal = context.resolveArgument(Principal.class); CacheKey cacheKey = new CacheKey(contextApiVersion, principal); @@ -90,6 +97,20 @@ public Object invoke(InvocationContext context) { return cached.getResponse(); } + private void cleanExpiredCachedResponses(long accessTime) { + try { + Iterator> iterator = this.cachedResponses.entrySet().iterator(); + while (iterator.hasNext()) { + Entry entry = iterator.next(); + if (entry.getValue().isStale(accessTime, this.timeToLive)) { + iterator.remove(); + } + } + } + catch (Exception ex) { + } + } + private boolean hasInput(InvocationContext context) { Map arguments = context.getArguments(); if (!ObjectUtils.isEmpty(arguments)) { diff --git a/spring-boot-project/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/endpoint/invoker/cache/CachingOperationInvokerTests.java b/spring-boot-project/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/endpoint/invoker/cache/CachingOperationInvokerTests.java index d4800cbac88c..6a8edada256a 100644 --- a/spring-boot-project/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/endpoint/invoker/cache/CachingOperationInvokerTests.java +++ b/spring-boot-project/spring-boot-actuator/src/test/java/org/springframework/boot/actuate/endpoint/invoker/cache/CachingOperationInvokerTests.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -101,6 +102,31 @@ void cacheInTtlWithFluxResponse() { assertThat(response).isSameAs(cachedResponse); } + @Test // gh-28313 + void cacheWhenEachPrincipalIsUniqueDoesNotConsumeTooMuchMemory() throws Exception { + MonoOperationInvoker target = new MonoOperationInvoker(); + CachingOperationInvoker invoker = new CachingOperationInvoker(target, 50L); + int count = 1000; + for (int i = 0; i < count; i++) { + invokeWithUniquePrincipal(invoker); + } + long expired = System.currentTimeMillis() + 50; + while (System.currentTimeMillis() < expired) { + Thread.sleep(10); + } + invokeWithUniquePrincipal(invoker); + assertThat(invoker).extracting("cachedResponses").asInstanceOf(InstanceOfAssertFactories.MAP) + .hasSizeLessThan(count); + } + + private void invokeWithUniquePrincipal(CachingOperationInvoker invoker) { + SecurityContext securityContext = mock(SecurityContext.class); + Principal principal = mock(Principal.class); + given(securityContext.getPrincipal()).willReturn(principal); + InvocationContext context = new InvocationContext(securityContext, Collections.emptyMap()); + ((Mono) invoker.invoke(context)).block(); + } + private void assertCacheIsUsed(Map parameters) { assertCacheIsUsed(parameters, null); }