Skip to content

Commit

Permalink
Batch SSE events writes when possible
Browse files Browse the repository at this point in the history
Prior to this commit, the `SseEventBuilder` would be used to create SSE
events and write them to the connection using the `ResponseBodyEmitter`.
This would send each data item one by one, effectively writing and
flushing to the network for each. Since multiple data lines are prepared
by the `SseEventBuilder`, a typical write of an SSE event performs
multiple flushes operations.

This commit adds a method on `ResponseBodyEmitter` to perform batch
writes (given a `Set<DataWithMediaType>`) and only flush once all
elements of the set have been written.
This also applies in case of early writes, where now all buffered
elements are written then flushed altogether.

Fixes gh-30912
  • Loading branch information
bclozel committed Aug 4, 2023
1 parent 18966d0 commit e83793b
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 33 deletions.
Expand Up @@ -128,9 +128,7 @@ synchronized void initialize(Handler handler) throws IOException {
this.handler = handler;

try {
for (DataWithMediaType sendAttempt : this.earlySendAttempts) {
sendInternal(sendAttempt.getData(), sendAttempt.getMediaType());
}
sendInternal(this.earlySendAttempts);
}
finally {
this.earlySendAttempts.clear();
Expand Down Expand Up @@ -194,11 +192,7 @@ public void send(Object object) throws IOException {
*/
public synchronized void send(Object object, @Nullable MediaType mediaType) throws IOException {
Assert.state(!this.complete, () -> "ResponseBodyEmitter has already completed" +
(this.failure != null ? " with error: " + this.failure : ""));
sendInternal(object, mediaType);
}

private void sendInternal(Object object, @Nullable MediaType mediaType) throws IOException {
(this.failure != null ? " with error: " + this.failure : ""));
if (this.handler != null) {
try {
this.handler.send(object, mediaType);
Expand All @@ -217,6 +211,43 @@ private void sendInternal(Object object, @Nullable MediaType mediaType) throws I
}
}

/**
* Write a set of data and MediaType pairs in a batch.
* <p>Compared to {@link #send(Object, MediaType)}, this batches the write operations
* and flushes to the network at the end.
* @param items the object and media type pairs to write
* @throws IOException raised when an I/O error occurs
* @throws java.lang.IllegalStateException wraps any other errors
* @since 6.0.12
*/
public synchronized void send(Set<DataWithMediaType> items) throws IOException {
Assert.state(!this.complete, () -> "ResponseBodyEmitter has already completed" +
(this.failure != null ? " with error: " + this.failure : ""));
sendInternal(items);
}

private void sendInternal(Set<DataWithMediaType> items) throws IOException {
if (items.isEmpty()) {
return;
}
if (this.handler != null) {
try {
this.handler.send(items);
}
catch (IOException ex) {
this.sendFailed = true;
throw ex;
}
catch (Throwable ex) {
this.sendFailed = true;
throw new IllegalStateException("Failed to send " + items, ex);
}
}
else {
this.earlySendAttempts.addAll(items);
}
}

/**
* Complete request processing by performing a dispatch into the servlet
* container, where Spring MVC is invoked once more, and completes the
Expand Down Expand Up @@ -302,8 +333,17 @@ public String toString() {
*/
interface Handler {

/**
* Immediately write and flush the given data to the network.
*/
void send(Object data, @Nullable MediaType mediaType) throws IOException;

/**
* Immediately write all data items then flush to the network.
* @since 6.0.12
*/
void send(Set<DataWithMediaType> items) throws IOException;

void complete();

void completeWithError(Throwable failure);
Expand Down
Expand Up @@ -20,6 +20,7 @@
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;

import jakarta.servlet.ServletRequest;
Expand Down Expand Up @@ -202,14 +203,22 @@ public HttpMessageConvertingHandler(ServerHttpResponse outputMessage, DeferredRe
@Override
public void send(Object data, @Nullable MediaType mediaType) throws IOException {
sendInternal(data, mediaType);
this.outputMessage.flush();
}

@Override
public void send(Set<ResponseBodyEmitter.DataWithMediaType> items) throws IOException {
for (ResponseBodyEmitter.DataWithMediaType item : items) {
sendInternal(item.getData(), item.getMediaType());
}
this.outputMessage.flush();
}

@SuppressWarnings("unchecked")
private <T> void sendInternal(T data, @Nullable MediaType mediaType) throws IOException {
for (HttpMessageConverter<?> converter : ResponseBodyEmitterReturnValueHandler.this.sseMessageConverters) {
if (converter.canWrite(data.getClass(), mediaType)) {
((HttpMessageConverter<T>) converter).write(data, mediaType, this.outputMessage);
this.outputMessage.flush();
return;
}
}
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -123,9 +123,7 @@ public void send(Object object, @Nullable MediaType mediaType) throws IOExceptio
public void send(SseEventBuilder builder) throws IOException {
Set<DataWithMediaType> dataToSend = builder.build();
synchronized (this) {
for (DataWithMediaType entry : dataToSend) {
super.send(entry.getData(), entry.getMediaType());
}
super.send(dataToSend);
}
}

Expand Down
Expand Up @@ -365,6 +365,11 @@ public void send(Object data, MediaType mediaType) throws IOException {
this.values.add(data);
}

@Override
public void send(Set<ResponseBodyEmitter.DataWithMediaType> items) throws IOException {
items.forEach(item -> this.values.add(item.getData()));
}

@Override
public void complete() {
}
Expand Down
Expand Up @@ -30,9 +30,9 @@
import static org.assertj.core.api.Assertions.assertThatIOException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anySet;
import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;

Expand All @@ -52,56 +52,54 @@ public class ResponseBodyEmitterTests {


@Test
public void sendBeforeHandlerInitialized() throws Exception {
void sendBeforeHandlerInitialized() throws Exception {
this.emitter.send("foo", MediaType.TEXT_PLAIN);
this.emitter.send("bar", MediaType.TEXT_PLAIN);
this.emitter.complete();
verifyNoMoreInteractions(this.handler);

this.emitter.initialize(this.handler);
verify(this.handler).send("foo", MediaType.TEXT_PLAIN);
verify(this.handler).send("bar", MediaType.TEXT_PLAIN);
verify(this.handler).send(anySet());
verify(this.handler).complete();
verifyNoMoreInteractions(this.handler);
}

@Test
public void sendDuplicateBeforeHandlerInitialized() throws Exception {
void sendDuplicateBeforeHandlerInitialized() throws Exception {
this.emitter.send("foo", MediaType.TEXT_PLAIN);
this.emitter.send("foo", MediaType.TEXT_PLAIN);
this.emitter.complete();
verifyNoMoreInteractions(this.handler);

this.emitter.initialize(this.handler);
verify(this.handler, times(2)).send("foo", MediaType.TEXT_PLAIN);
verify(this.handler).send(anySet());
verify(this.handler).complete();
verifyNoMoreInteractions(this.handler);
}

@Test
public void sendBeforeHandlerInitializedWithError() throws Exception {
void sendBeforeHandlerInitializedWithError() throws Exception {
IllegalStateException ex = new IllegalStateException();
this.emitter.send("foo", MediaType.TEXT_PLAIN);
this.emitter.send("bar", MediaType.TEXT_PLAIN);
this.emitter.completeWithError(ex);
verifyNoMoreInteractions(this.handler);

this.emitter.initialize(this.handler);
verify(this.handler).send("foo", MediaType.TEXT_PLAIN);
verify(this.handler).send("bar", MediaType.TEXT_PLAIN);
verify(this.handler).send(anySet());
verify(this.handler).completeWithError(ex);
verifyNoMoreInteractions(this.handler);
}

@Test
public void sendFailsAfterComplete() throws Exception {
void sendFailsAfterComplete() throws Exception {
this.emitter.complete();
assertThatIllegalStateException().isThrownBy(() ->
this.emitter.send("foo"));
}

@Test
public void sendAfterHandlerInitialized() throws Exception {
void sendAfterHandlerInitialized() throws Exception {
this.emitter.initialize(this.handler);
verify(this.handler).onTimeout(any());
verify(this.handler).onError(any());
Expand All @@ -119,7 +117,7 @@ public void sendAfterHandlerInitialized() throws Exception {
}

@Test
public void sendAfterHandlerInitializedWithError() throws Exception {
void sendAfterHandlerInitializedWithError() throws Exception {
this.emitter.initialize(this.handler);
verify(this.handler).onTimeout(any());
verify(this.handler).onError(any());
Expand All @@ -138,7 +136,7 @@ public void sendAfterHandlerInitializedWithError() throws Exception {
}

@Test
public void sendWithError() throws Exception {
void sendWithError() throws Exception {
this.emitter.initialize(this.handler);
verify(this.handler).onTimeout(any());
verify(this.handler).onError(any());
Expand All @@ -154,7 +152,7 @@ public void sendWithError() throws Exception {
}

@Test
public void onTimeoutBeforeHandlerInitialized() throws Exception {
void onTimeoutBeforeHandlerInitialized() throws Exception {
Runnable runnable = mock();
this.emitter.onTimeout(runnable);
this.emitter.initialize(this.handler);
Expand All @@ -169,7 +167,7 @@ public void onTimeoutBeforeHandlerInitialized() throws Exception {
}

@Test
public void onTimeoutAfterHandlerInitialized() throws Exception {
void onTimeoutAfterHandlerInitialized() throws Exception {
this.emitter.initialize(this.handler);

ArgumentCaptor<Runnable> captor = ArgumentCaptor.forClass(Runnable.class);
Expand All @@ -185,7 +183,7 @@ public void onTimeoutAfterHandlerInitialized() throws Exception {
}

@Test
public void onCompletionBeforeHandlerInitialized() throws Exception {
void onCompletionBeforeHandlerInitialized() throws Exception {
Runnable runnable = mock();
this.emitter.onCompletion(runnable);
this.emitter.initialize(this.handler);
Expand All @@ -200,7 +198,7 @@ public void onCompletionBeforeHandlerInitialized() throws Exception {
}

@Test
public void onCompletionAfterHandlerInitialized() throws Exception {
void onCompletionAfterHandlerInitialized() throws Exception {
this.emitter.initialize(this.handler);

ArgumentCaptor<Runnable> captor = ArgumentCaptor.forClass(Runnable.class);
Expand Down
Expand Up @@ -20,12 +20,14 @@
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import org.springframework.http.MediaType;
import org.springframework.lang.Nullable;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event;
Expand Down Expand Up @@ -60,6 +62,7 @@ public void send() throws Exception {
this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8);
this.handler.assertObject(1, "foo");
this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8);
this.handler.assertWriteCount(1);
}

@Test
Expand All @@ -69,12 +72,14 @@ public void sendWithMediaType() throws Exception {
this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8);
this.handler.assertObject(1, "foo", MediaType.TEXT_PLAIN);
this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8);
this.handler.assertWriteCount(1);
}

@Test
public void sendEventEmpty() throws Exception {
this.emitter.send(event());
this.handler.assertSentObjectCount(0);
this.handler.assertWriteCount(0);
}

@Test
Expand All @@ -84,6 +89,7 @@ public void sendEventWithDataLine() throws Exception {
this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8);
this.handler.assertObject(1, "foo");
this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8);
this.handler.assertWriteCount(1);
}

@Test
Expand All @@ -95,6 +101,7 @@ public void sendEventWithTwoDataLines() throws Exception {
this.handler.assertObject(2, "\ndata:", TEXT_PLAIN_UTF8);
this.handler.assertObject(3, "bar");
this.handler.assertObject(4, "\n\n", TEXT_PLAIN_UTF8);
this.handler.assertWriteCount(1);
}

@Test
Expand All @@ -104,6 +111,7 @@ public void sendEventFull() throws Exception {
this.handler.assertObject(0, ":blah\nevent:test\nretry:5000\nid:1\ndata:", TEXT_PLAIN_UTF8);
this.handler.assertObject(1, "foo");
this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8);
this.handler.assertWriteCount(1);
}

@Test
Expand All @@ -115,14 +123,17 @@ public void sendEventFullWithTwoDataLinesInTheMiddle() throws Exception {
this.handler.assertObject(2, "\ndata:", TEXT_PLAIN_UTF8);
this.handler.assertObject(3, "bar");
this.handler.assertObject(4, "\nevent:test\nretry:5000\nid:1\n\n", TEXT_PLAIN_UTF8);
this.handler.assertWriteCount(1);
}


private static class TestHandler implements ResponseBodyEmitter.Handler {

private List<Object> objects = new ArrayList<>();
private final List<Object> objects = new ArrayList<>();

private List<MediaType> mediaTypes = new ArrayList<>();
private final List<MediaType> mediaTypes = new ArrayList<>();

private int writeCount;


public void assertSentObjectCount(int size) {
Expand All @@ -139,10 +150,24 @@ public void assertObject(int index, Object object, MediaType mediaType) {
assertThat(this.mediaTypes.get(index)).isEqualTo(mediaType);
}

public void assertWriteCount(int writeCount) {
assertThat(this.writeCount).isEqualTo(writeCount);
}

@Override
public void send(Object data, MediaType mediaType) throws IOException {
public void send(Object data, @Nullable MediaType mediaType) throws IOException {
this.objects.add(data);
this.mediaTypes.add(mediaType);
this.writeCount++;
}

@Override
public void send(Set<ResponseBodyEmitter.DataWithMediaType> items) throws IOException {
for (ResponseBodyEmitter.DataWithMediaType item : items) {
this.objects.add(item.getData());
this.mediaTypes.add(item.getMediaType());
}
this.writeCount++;
}

@Override
Expand Down

0 comments on commit e83793b

Please sign in to comment.