Skip to content

Commit

Permalink
Make cluster health API cancellable (#96990)
Browse files Browse the repository at this point in the history
This API can be quite heavy in large clusters, and might spam the
`MANAGEMENT` threadpool queue with work for clients that have long-since
given up. This commit adds some basic cancellability checks to reduce
the problem.

Backport of #96551 to 7.17
  • Loading branch information
DaveCTurner committed Jun 22, 2023
1 parent e1995a7 commit eeedb98
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 28 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/96551.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 96551
summary: Make cluster health API cancellable
area: Distributed
type: bug
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.http;

import org.apache.http.client.methods.HttpGet;
import org.elasticsearch.action.admin.cluster.health.ClusterHealthAction;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Cancellable;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateUpdateTask;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Priority;

import java.util.concurrent.CancellationException;
import java.util.concurrent.CyclicBarrier;

import static org.elasticsearch.action.support.ActionTestUtils.wrapAsRestResponseListener;
import static org.elasticsearch.test.TaskAssertions.assertAllCancellableTasksAreCancelled;
import static org.elasticsearch.test.TaskAssertions.awaitTaskWithPrefixOnMaster;

public class ClusterHealthRestCancellationIT extends HttpSmokeTestCase {

public void testClusterHealthRestCancellation() throws Exception {

final CyclicBarrier barrier = new CyclicBarrier(2);

internalCluster().getCurrentMasterNodeInstance(ClusterService.class)
.submitStateUpdateTask("blocking", new ClusterStateUpdateTask() {
@Override
public ClusterState execute(ClusterState currentState) {
safeAwait(barrier);
safeAwait(barrier);
return currentState;
}

@Override
public void onFailure(String source, Exception e) {
throw new AssertionError(e);
}
});

final Request clusterHealthRequest = new Request(HttpGet.METHOD_NAME, "/_cluster/health");
clusterHealthRequest.addParameter("wait_for_events", Priority.LANGUID.toString());

final PlainActionFuture<Response> future = new PlainActionFuture<>();
logger.info("--> sending cluster state request");
final Cancellable cancellable = getRestClient().performRequestAsync(clusterHealthRequest, wrapAsRestResponseListener(future));

safeAwait(barrier);

awaitTaskWithPrefixOnMaster(ClusterHealthAction.NAME);

logger.info("--> cancelling cluster health request");
cancellable.cancel();
expectThrows(CancellationException.class, future::actionGet);

logger.info("--> checking cluster health task cancelled");
assertAllCancellableTasksAreCancelled(ClusterHealthAction.NAME);

safeAwait(barrier);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -264,6 +268,11 @@ public ActionRequestValidationException validate() {
return null;
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, "", parentTaskId, headers);
}

public enum Level {
CLUSTER,
INDICES,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.node.NodeClosedException;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
Expand Down Expand Up @@ -91,28 +92,42 @@ protected void masterOperation(
final ClusterState unusedState,
final ActionListener<ClusterHealthResponse> listener
) {
assert task instanceof CancellableTask;
final CancellableTask cancellableTask = (CancellableTask) task;

final int waitCount = getWaitCount(request);

if (request.waitForEvents() != null) {
waitForEventsAndExecuteHealth(request, listener, waitCount, threadPool.relativeTimeInMillis() + request.timeout().millis());
waitForEventsAndExecuteHealth(
cancellableTask,
request,
listener,
waitCount,
threadPool.relativeTimeInMillis() + request.timeout().millis()
);
} else {
executeHealth(
cancellableTask,
request,
clusterService.state(),
listener,
waitCount,
clusterState -> listener.onResponse(getResponse(request, clusterState, waitCount, TimeoutState.OK))
clusterState -> sendResponse(cancellableTask, request, clusterState, waitCount, TimeoutState.OK, listener)
);
}
}

private void waitForEventsAndExecuteHealth(
final CancellableTask task,
final ClusterHealthRequest request,
final ActionListener<ClusterHealthResponse> listener,
final int waitCount,
final long endTimeRelativeMillis
) {
if (task.notifyIfCancelled(listener)) {
return;
}

assert request.waitForEvents() != null;
if (request.local()) {
clusterService.submitStateUpdateTask(
Expand All @@ -129,11 +144,12 @@ public void clusterStateProcessed(String source, ClusterState oldState, ClusterS
final TimeValue newTimeout = TimeValue.timeValueMillis(timeoutInMillis);
request.timeout(newTimeout);
executeHealth(
task,
request,
clusterService.state(),
listener,
waitCount,
observedState -> waitForEventsAndExecuteHealth(request, listener, waitCount, endTimeRelativeMillis)
observedState -> waitForEventsAndExecuteHealth(task, request, listener, waitCount, endTimeRelativeMillis)
);
}

Expand Down Expand Up @@ -166,11 +182,12 @@ public void clusterStateProcessed(String source, ClusterState oldState, ClusterS
assert newState.stateUUID().equals(appliedState.stateUUID())
: newState.stateUUID() + " vs " + appliedState.stateUUID();
executeHealth(
task,
request,
appliedState,
listener,
waitCount,
observedState -> waitForEventsAndExecuteHealth(request, listener, waitCount, endTimeRelativeMillis)
observedState -> waitForEventsAndExecuteHealth(task, request, listener, waitCount, endTimeRelativeMillis)
);
}

Expand All @@ -187,7 +204,7 @@ public void onNoLongerMaster(String source) {
@Override
public void onFailure(String source, Exception e) {
if (e instanceof ProcessClusterEventTimeoutException) {
listener.onResponse(getResponse(request, clusterService.state(), waitCount, TimeoutState.TIMED_OUT));
sendResponse(task, request, clusterService.state(), waitCount, TimeoutState.TIMED_OUT, listener);
} else {
logger.error(() -> new ParameterizedMessage("unexpected failure during [{}]", source), e);
listener.onFailure(e);
Expand All @@ -199,21 +216,25 @@ public void onFailure(String source, Exception e) {
}

private void executeHealth(
final CancellableTask task,
final ClusterHealthRequest request,
final ClusterState currentState,
final ActionListener<ClusterHealthResponse> listener,
final int waitCount,
final Consumer<ClusterState> onNewClusterStateAfterDelay
) {
if (task.notifyIfCancelled(listener)) {
return;
}

if (request.timeout().millis() == 0) {
listener.onResponse(getResponse(request, currentState, waitCount, TimeoutState.ZERO_TIMEOUT));
sendResponse(task, request, currentState, waitCount, TimeoutState.ZERO_TIMEOUT, listener);
return;
}

final Predicate<ClusterState> validationPredicate = newState -> validateRequest(request, newState, waitCount);
if (validationPredicate.test(currentState)) {
listener.onResponse(getResponse(request, currentState, waitCount, TimeoutState.OK));
sendResponse(task, request, currentState, waitCount, TimeoutState.OK, listener);
} else {
final ClusterStateObserver observer = new ClusterStateObserver(
currentState,
Expand All @@ -235,7 +256,7 @@ public void onClusterServiceClose() {

@Override
public void onTimeout(TimeValue timeout) {
listener.onResponse(getResponse(request, observer.setAndGetObservedState(), waitCount, TimeoutState.TIMED_OUT));
sendResponse(task, request, observer.setAndGetObservedState(), waitCount, TimeoutState.TIMED_OUT, listener);
}
};
observer.waitForNextChange(stateListener, validationPredicate, request.timeout());
Expand Down Expand Up @@ -282,27 +303,32 @@ private enum TimeoutState {
ZERO_TIMEOUT
}

private ClusterHealthResponse getResponse(
private void sendResponse(
final CancellableTask task,
final ClusterHealthRequest request,
ClusterState clusterState,
final ClusterState clusterState,
final int waitFor,
TimeoutState timeoutState
final TimeoutState timeoutState,
final ActionListener<ClusterHealthResponse> listener
) {
ClusterHealthResponse response = clusterHealth(
request,
clusterState,
clusterService.getMasterService().numberOfPendingTasks(),
allocationService.getNumberOfInFlightFetches(),
clusterService.getMasterService().getMaxTaskWaitTime()
);
int readyCounter = prepareResponse(request, response, clusterState, indexNameExpressionResolver);
boolean valid = (readyCounter == waitFor);
assert valid || (timeoutState != TimeoutState.OK);
// If valid && timeoutState == TimeoutState.ZERO_TIMEOUT then we immediately found **and processed** a valid state, so we don't
// consider this a timeout. However if timeoutState == TimeoutState.TIMED_OUT then we didn't process a valid state (perhaps we
// failed on wait_for_events) so this does count as a timeout.
response.setTimedOut(valid == false || timeoutState == TimeoutState.TIMED_OUT);
return response;
ActionListener.completeWith(listener, () -> {
task.ensureNotCancelled();
ClusterHealthResponse response = clusterHealth(
request,
clusterState,
clusterService.getMasterService().numberOfPendingTasks(),
allocationService.getNumberOfInFlightFetches(),
clusterService.getMasterService().getMaxTaskWaitTime()
);
int readyCounter = prepareResponse(request, response, clusterState, indexNameExpressionResolver);
boolean valid = (readyCounter == waitFor);
assert valid || (timeoutState != TimeoutState.OK);
// If valid && timeoutState == TimeoutState.ZERO_TIMEOUT then we immediately found **and processed** a valid state, so we don't
// consider this a timeout. However if timeoutState == TimeoutState.TIMED_OUT then we didn't process a valid state (perhaps we
// failed on wait_for_events) so this does count as a timeout.
response.setTimedOut(valid == false || timeoutState == TimeoutState.TIMED_OUT);
return response;
});
}

static int prepareResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.rest.action.RestStatusToXContentListener;

import java.io.IOException;
Expand Down Expand Up @@ -50,7 +51,9 @@ public boolean allowSystemIndexAccessByDefault() {
@Override
public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException {
final ClusterHealthRequest clusterHealthRequest = fromRequest(request);
return channel -> client.admin().cluster().health(clusterHealthRequest, new RestStatusToXContentListener<>(channel));
return channel -> new RestCancellableNodeClient(client, request.getHttpChannel()).admin()
.cluster()
.health(clusterHealthRequest, new RestStatusToXContentListener<>(channel));
}

public static ClusterHealthRequest fromRequest(final RestRequest request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@

import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.health.ClusterHealthAction;
import org.elasticsearch.action.admin.cluster.health.ClusterHealthRequest;
import org.elasticsearch.action.admin.cluster.health.ClusterHealthResponse;
import org.elasticsearch.action.admin.cluster.health.TransportClusterHealthAction;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.ClusterName;
Expand All @@ -40,6 +42,8 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.indices.TestIndexNameExpressionResolver;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.gateway.TestGatewayAllocator;
import org.elasticsearch.test.transport.CapturingTransport;
Expand Down Expand Up @@ -156,7 +160,12 @@ public void testClusterHealthWaitsForClusterStateApplication() throws Interrupte
new AllocationService(null, new TestGatewayAllocator(), null, null, null)
);
PlainActionFuture<ClusterHealthResponse> listener = new PlainActionFuture<>();
action.execute(new ClusterHealthRequest().waitForGreenStatus(), listener);
ActionTestUtils.execute(
action,
new CancellableTask(1, "direct", ClusterHealthAction.NAME, "", TaskId.EMPTY_TASK_ID, Collections.emptyMap()),
new ClusterHealthRequest().waitForGreenStatus(),
listener
);

assertFalse(listener.isDone());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import java.util.Set;
import java.util.TimeZone;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -1824,4 +1825,15 @@ public String toString() {
return String.format(Locale.ROOT, "%s: %s", level.name(), message);
}
}

public static void safeAwait(CyclicBarrier barrier) {
try {
barrier.await(10, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new AssertionError("unexpected", e);
} catch (Exception e) {
throw new AssertionError("unexpected", e);
}
}
}

0 comments on commit eeedb98

Please sign in to comment.