Skip to content

Commit

Permalink
Cancel cursor fetching if the outer stream gets canceled.
Browse files Browse the repository at this point in the history
[resolves #536]

Signed-off-by: Mark Paluch <mpaluch@vmware.com>
  • Loading branch information
mp911de committed Jun 19, 2023
1 parent 8ac05a8 commit 5688111
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 31 deletions.
46 changes: 23 additions & 23 deletions src/main/java/io/r2dbc/postgresql/ExtendedFlowDelegate.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,21 @@ class ExtendedFlowDelegate {
* Execute the {@code Parse/Bind/Describe/Execute/Sync} portion of the <a href="https://www.postgresql.org/docs/current/static/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY">Extended query</a>
* message flow.
*
* @param resources the {@link ConnectionResources} providing access to the {@link Client}
* @param factory the {@link ExceptionFactory}
* @param query the query to execute
* @param binding the {@link Binding} to bind
* @param values the binding values
* @param fetchSize the fetch size to apply. Use a single {@link Execute} with fetch all if {@code fetchSize} is zero. Otherwise, perform multiple roundtrips with smaller
* {@link Execute} sizes.
* @param resources the {@link ConnectionResources} providing access to the {@link Client}
* @param factory the {@link ExceptionFactory}
* @param query the query to execute
* @param binding the {@link Binding} to bind
* @param values the binding values
* @param fetchSize the fetch size to apply. Use a single {@link Execute} with fetch all if {@code fetchSize} is zero. Otherwise, perform multiple roundtrips with smaller
* {@link Execute} sizes.
* @param isCanceled whether the conversation is canceled
* @return the messages received in response to the exchange
* @throws IllegalArgumentException if {@code bindings}, {@code client}, {@code portalNameSupplier}, or {@code statementName} is {@code null}
*/
public static Flux<BackendMessage> runQuery(ConnectionResources resources, ExceptionFactory factory, String query, Binding binding, List<ByteBuf> values, int fetchSize) {
public static Flux<BackendMessage> runQuery(ConnectionResources resources, ExceptionFactory factory, String query, Binding binding, List<ByteBuf> values, int fetchSize, AtomicBoolean isCanceled) {

StatementCache cache = resources.getStatementCache();
Client client = resources.getClient();

String portal = resources.getPortalNameSupplier().get();

Flux<BackendMessage> exchange;
Expand All @@ -104,14 +104,14 @@ public static Flux<BackendMessage> runQuery(ConnectionResources resources, Excep
if (fetchSize == NO_LIMIT || implicitTransactions) {
exchange = fetchAll(operator, client, portal);
} else {
exchange = fetchCursoredWithSync(operator, client, portal, fetchSize);
exchange = fetchCursoredWithSync(operator, client, portal, fetchSize, isCanceled);
}
} else {

if (fetchSize == NO_LIMIT) {
exchange = fetchAll(operator, client, portal);
} else {
exchange = fetchCursoredWithFlush(operator, client, portal, fetchSize);
exchange = fetchCursoredWithFlush(operator, client, portal, fetchSize, isCanceled);
}
}

Expand Down Expand Up @@ -147,16 +147,16 @@ private static Flux<BackendMessage> fetchAll(ExtendedFlowOperator operator, Clie
/**
* Execute a chunked query and indicate to fetch rows in chunks with the {@link Execute} message.
*
* @param operator the flow operator
* @param client client to use
* @param portal the portal
* @param fetchSize fetch size per roundtrip
* @param operator the flow operator
* @param client client to use
* @param portal the portal
* @param fetchSize fetch size per roundtrip
* @param isCanceled whether the conversation is canceled
* @return the resulting message stream
*/
private static Flux<BackendMessage> fetchCursoredWithSync(ExtendedFlowOperator operator, Client client, String portal, int fetchSize) {
private static Flux<BackendMessage> fetchCursoredWithSync(ExtendedFlowOperator operator, Client client, String portal, int fetchSize, AtomicBoolean isCanceled) {

Sinks.Many<FrontendMessage> requests = Sinks.many().unicast().onBackpressureBuffer(Queues.<FrontendMessage>small().get());
AtomicBoolean isCanceled = new AtomicBoolean(false);
AtomicBoolean done = new AtomicBoolean(false);

MessageFactory factory = () -> operator.getMessages(Arrays.asList(new Execute(portal, fetchSize), Sync.INSTANCE));
Expand Down Expand Up @@ -210,16 +210,16 @@ private static Flux<BackendMessage> fetchCursoredWithSync(ExtendedFlowOperator o
* Execute a contiguous query and indicate to fetch rows in chunks with the {@link Execute} message. Uses {@link Flush}-based synchronization that creates a cursor. Note that flushing keeps the
* cursor open even with implicit transactions and this method may not work with newer pgpool implementations.
*
* @param operator the flow operator
* @param client client to use
* @param portal the portal
* @param fetchSize fetch size per roundtrip
* @param operator the flow operator
* @param client client to use
* @param portal the portal
* @param fetchSize fetch size per roundtrip
* @param isCanceled whether the conversation is canceled
* @return the resulting message stream
*/
private static Flux<BackendMessage> fetchCursoredWithFlush(ExtendedFlowOperator operator, Client client, String portal, int fetchSize) {
private static Flux<BackendMessage> fetchCursoredWithFlush(ExtendedFlowOperator operator, Client client, String portal, int fetchSize, AtomicBoolean isCanceled) {

Sinks.Many<FrontendMessage> requests = Sinks.many().unicast().onBackpressureBuffer(Queues.<FrontendMessage>small().get());
AtomicBoolean isCanceled = new AtomicBoolean(false);

MessageFactory factory = () -> operator.getMessages(Arrays.asList(new Execute(portal, fetchSize), Flush.INSTANCE));

Expand Down
2 changes: 1 addition & 1 deletion src/main/java/io/r2dbc/postgresql/PostgresqlResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public Mono<Long> getRowsUpdated() {
public <T> Flux<T> map(BiFunction<Row, RowMetadata, ? extends T> f) {
Assert.requireNonNull(f, "f must not be null");

return this.messages
return (Flux<T>) this.messages
.handle((message, sink) -> {

try {
Expand Down
22 changes: 17 additions & 5 deletions src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;

Expand Down Expand Up @@ -199,6 +200,9 @@ private int getIdentifierIndex(String identifier) {
private Flux<io.r2dbc.postgresql.api.PostgresqlResult> execute(String sql) {
ExceptionFactory factory = ExceptionFactory.withSql(sql);

CompletableFuture<Void> onCancel = new CompletableFuture<>();
AtomicBoolean canceled = new AtomicBoolean();

if (this.parsedSql.getParameterCount() != 0) {
// Extended query protocol
if (this.bindings.size() == 0) {
Expand All @@ -213,17 +217,22 @@ private Flux<io.r2dbc.postgresql.api.PostgresqlResult> execute(String sql) {
if (this.bindings.size() == 1) {

Binding binding = this.bindings.peekFirst();
Flux<BackendMessage> messages = collectBindingParameters(binding).flatMapMany(values -> ExtendedFlowDelegate.runQuery(this.resources, factory, sql, binding, values, fetchSize));
Flux<BackendMessage> messages = collectBindingParameters(binding).flatMapMany(values -> ExtendedFlowDelegate.runQuery(this.resources, factory, sql, binding, values, fetchSize,
new AtomicBoolean()));
return Flux.just(PostgresqlResult.toResult(this.resources, messages, factory));
}

Iterator<Binding> iterator = this.bindings.iterator();
Sinks.Many<Binding> bindings = Sinks.many().unicast().onBackpressureBuffer();
AtomicBoolean canceled = new AtomicBoolean();

onCancel.whenComplete((unused, throwable) -> {
clearBindings(iterator, canceled);
});

return bindings.asFlux()
.map(it -> {
Flux<BackendMessage> messages =
collectBindingParameters(it).flatMapMany(values -> ExtendedFlowDelegate.runQuery(this.resources, factory, sql, it, values, this.fetchSize)).doOnComplete(() -> tryNextBinding(iterator, bindings, canceled));
collectBindingParameters(it).flatMapMany(values -> ExtendedFlowDelegate.runQuery(this.resources, factory, sql, it, values, this.fetchSize, canceled)).doOnComplete(() -> tryNextBinding(iterator, bindings, canceled));

return PostgresqlResult.toResult(this.resources, messages, factory);
})
Expand All @@ -237,15 +246,18 @@ private Flux<io.r2dbc.postgresql.api.PostgresqlResult> execute(String sql) {
Flux<BackendMessage> exchange;
// Simple Query protocol
if (this.fetchSize != NO_LIMIT) {
exchange = ExtendedFlowDelegate.runQuery(this.resources, factory, sql, Binding.EMPTY, Collections.emptyList(), this.fetchSize);
exchange = ExtendedFlowDelegate.runQuery(this.resources, factory, sql, Binding.EMPTY, Collections.emptyList(), this.fetchSize, canceled);
} else {
exchange = SimpleQueryMessageFlow.exchange(this.resources.getClient(), sql);
}

return exchange.windowUntil(WINDOW_UNTIL)
.doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release) // ensure release of rows within WindowPredicate
.map(messages -> PostgresqlResult.toResult(this.resources, messages, factory))
.as(Operators::discardOnCancel);
.as(source -> Operators.discardOnCancel(source, () -> {
canceled.set(true);
onCancel.complete(null);
}));
}

private static void tryNextBinding(Iterator<Binding> iterator, Sinks.Many<Binding> bindingSink, AtomicBoolean canceled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.Timeout;
import org.springframework.jdbc.core.JdbcOperations;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

Expand All @@ -42,12 +44,18 @@ void setUp() {

super.setUp();

SERVER.getJdbcOperations().execute("DROP TABLE IF EXISTS insert_test;");
SERVER.getJdbcOperations().execute("CREATE TABLE insert_test\n" +
JdbcOperations jdbc = SERVER.getJdbcOperations();
jdbc.execute("DROP TABLE IF EXISTS insert_test;");
jdbc.execute("CREATE TABLE insert_test\n" +
"(\n" +
" id SERIAL PRIMARY KEY,\n" +
" value CHAR(1) NOT NULL\n" +
");");


jdbc.execute("DROP TABLE IF EXISTS lots_of_data;");
jdbc.execute("CREATE TABLE lots_of_data AS \n"
+ " SELECT i FROM generate_series(1,200000) as i;");
}

@AfterAll
Expand Down Expand Up @@ -111,4 +119,34 @@ void cancelRequest() {
.verify(Duration.ofSeconds(5));
}

@Timeout(10)
@RepeatedTest(20)
void shouldCancelParametrizedWithFetchSize() {

this.connection.createStatement("SELECT * FROM lots_of_data WHERE $1 = $1 ORDER BY i")
.fetchSize(10)
.bind(0, 1)
.execute()
.flatMap(r -> r.map((row, meta) -> row.get(0, Integer.class)))
.as(StepVerifier::create)
.expectNext(1)
.expectNextCount(5)
.thenCancel()
.verify(Duration.ofSeconds(5));
}

@Timeout(10)
@RepeatedTest(20)
void shouldCancelSimpleWithFetchSize() {

this.connection.createStatement("SELECT * FROM lots_of_data ORDER BY i")
.fetchSize(10)
.execute()
.flatMap(r -> r.map((row, meta) -> row.get(0, Integer.class)))
.as(StepVerifier::create)
.expectNext(1)
.expectNextCount(5)
.thenCancel()
.verify(Duration.ofSeconds(5));
}
}

0 comments on commit 5688111

Please sign in to comment.