Skip to content

Commit

Permalink
Add shortcuts for frequently used assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
snicoll committed May 7, 2024
1 parent c8967de commit bcecce7
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatus.Series;
import org.springframework.http.MediaType;
import org.springframework.test.http.HttpHeadersAssert;
import org.springframework.test.http.MediaTypeAssert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.function.SingletonSupplier;
Expand All @@ -48,15 +50,24 @@
public abstract class AbstractHttpServletResponseAssert<R extends HttpServletResponse, SELF extends AbstractHttpServletResponseAssert<R, SELF, ACTUAL>, ACTUAL>
extends AbstractObjectAssert<SELF, ACTUAL> {

private final Supplier<AbstractIntegerAssert<?>> statusAssert;
private final Supplier<MediaTypeAssert> contentTypeAssertSupplier;

private final Supplier<HttpHeadersAssert> headersAssertSupplier;

private final Supplier<AbstractIntegerAssert<?>> statusAssert;


protected AbstractHttpServletResponseAssert(ACTUAL actual, Class<?> selfType) {
super(actual, selfType);
this.statusAssert = SingletonSupplier.of(() -> Assertions.assertThat(getResponse().getStatus()).as("HTTP status code"));
this.contentTypeAssertSupplier = SingletonSupplier.of(() -> new MediaTypeAssert(getResponse().getContentType()));
this.headersAssertSupplier = SingletonSupplier.of(() -> new HttpHeadersAssert(getHttpHeaders(getResponse())));
this.statusAssert = SingletonSupplier.of(() -> Assertions.assertThat(getResponse().getStatus()).as("HTTP status code"));
}

private static HttpHeaders getHttpHeaders(HttpServletResponse response) {
MultiValueMap<String, String> headers = new LinkedMultiValueMap<>();
response.getHeaderNames().forEach(name -> headers.put(name, new ArrayList<>(response.getHeaders(name))));
return new HttpHeaders(headers);
}

/**
Expand All @@ -67,6 +78,14 @@ protected AbstractHttpServletResponseAssert(ACTUAL actual, Class<?> selfType) {
*/
protected abstract R getResponse();

/**
* Return a new {@linkplain MediaTypeAssert assertion} object that uses the
* response's {@linkplain MediaType content type} as the object to test.
*/
public MediaTypeAssert contentType() {
return this.contentTypeAssertSupplier.get();
}

/**
* Return a new {@linkplain HttpHeadersAssert assertion} object that uses
* {@link HttpHeaders} as the object to test. The returned assertion
Expand All @@ -84,6 +103,82 @@ public HttpHeadersAssert headers() {
return this.headersAssertSupplier.get();
}

// Content-type shortcuts

/**
* Verify that the response's {@code Content-Type} is equal to the given value.
* @param contentType the expected content type
*/
public SELF hasContentType(MediaType contentType) {
contentType().isEqualTo(contentType);
return this.myself;
}

/**
* Verify that the response's {@code Content-Type} is equal to the given
* string representation.
* @param contentType the expected content type
*/
public SELF hasContentType(String contentType) {
contentType().isEqualTo(contentType);
return this.myself;
}

/**
* Verify that the response's {@code Content-Type} is
* {@linkplain MediaType#isCompatibleWith(MediaType) compatible} with the
* given value.
* @param contentType the expected compatible content type
*/
public SELF hasContentTypeCompatibleWith(MediaType contentType) {
contentType().isCompatibleWith(contentType);
return this.myself;
}

/**
* Verify that the response's {@code Content-Type} is
* {@linkplain MediaType#isCompatibleWith(MediaType) compatible} with the
* given string representation.
* @param contentType the expected compatible content type
*/
public SELF hasContentTypeCompatibleWith(String contentType) {
contentType().isCompatibleWith(contentType);
return this.myself;
}

// Headers shortcuts

/**
* Verify that the response contains a header with the given {@code name}.
* @param name the name of an expected HTTP header
*/
public SELF containsHeader(String name) {
headers().containsHeader(name);
return this.myself;
}

/**
* Verify that the response does not contain a header with the given {@code name}.
* @param name the name of an HTTP header that should not be present
*/
public SELF doesNotContainHeader(String name) {
headers().doesNotContainHeader(name);
return this.myself;
}

/**
* Verify that the response contains a header with the given {@code name}
* and primary {@code value}.
* @param name the name of an expected HTTP header
* @param value the expected value of the header
*/
public SELF hasHeader(String name, String value) {
headers().hasValue(name, value);
return this.myself;
}

// Status

/**
* Verify that the HTTP status is equal to the specified status code.
* @param status the expected HTTP status code
Expand Down Expand Up @@ -159,10 +254,4 @@ private AbstractIntegerAssert<?> status() {
return this.statusAssert.get();
}

private static HttpHeaders getHttpHeaders(HttpServletResponse response) {
MultiValueMap<String, String> headers = new LinkedMultiValueMap<>();
response.getHeaderNames().forEach(name -> headers.put(name, new ArrayList<>(response.getHeaders(name))));
return new HttpHeaders(headers);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,10 @@
import org.assertj.core.error.BasicErrorMessageFactory;
import org.assertj.core.internal.Failures;

import org.springframework.http.MediaType;
import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.lang.Nullable;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.test.http.MediaTypeAssert;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.ResultHandler;
import org.springframework.test.web.servlet.ResultMatcher;
Expand Down Expand Up @@ -87,14 +85,6 @@ public CookieMapAssert cookies() {
return new CookieMapAssert(getMvcResult().getResponse().getCookies());
}

/**
* Return a new {@linkplain MediaTypeAssert assertion} object that uses the
* response's {@linkplain MediaType content type} as the object to test.
*/
public MediaTypeAssert contentType() {
return new MediaTypeAssert(getMvcResult().getResponse().getContentType());
}

/**
* Return a new {@linkplain HandlerResultAssert assertion} object that uses
* the handler as the object to test. For a method invocation on a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.junit.jupiter.api.Test;

import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletResponse;

import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
Expand All @@ -37,20 +38,76 @@ class AbstractHttpServletResponseAssertTests {
@Nested
class HeadersTests {

@Test
void containsHeader() {
MockHttpServletResponse response = createResponse(Map.of("n1", "v1", "n2", "v2", "n3", "v3"));
assertThat(response).containsHeader("n1");
}

@Test
void doesNotContainHeader() {
MockHttpServletResponse response = createResponse(Map.of("n1", "v1", "n2", "v2", "n3", "v3"));
assertThat(response).doesNotContainHeader("n4");
}

@Test
void hasHeader() {
MockHttpServletResponse response = createResponse(Map.of("n1", "v1", "n2", "v2", "n3", "v3"));
assertThat(response).hasHeader("n1", "v1");
}

@Test
void headersAreMatching() {
MockHttpServletResponse response = createResponse(Map.of("n1", "v1", "n2", "v2", "n3", "v3"));
assertThat(response).headers().containsHeaders("n1", "n2", "n3");
}


private MockHttpServletResponse createResponse(Map<String, String> headers) {
MockHttpServletResponse response = new MockHttpServletResponse();
headers.forEach(response::addHeader);
return response;
}
}

@Nested
class ContentTypeTests {

@Test
void contentType() {
MockHttpServletResponse response = createResponse("text/plain");
assertThat(response).hasContentType(MediaType.TEXT_PLAIN);
}

@Test
void contentTypeAndRepresentation() {
MockHttpServletResponse response = createResponse("text/plain");
assertThat(response).hasContentType("text/plain");
}

@Test
void contentTypeCompatibleWith() {
MockHttpServletResponse response = createResponse("application/json;charset=UTF-8");
assertThat(response).hasContentTypeCompatibleWith(MediaType.APPLICATION_JSON);
}

@Test
void contentTypeCompatibleWithAndStringRepresentation() {
MockHttpServletResponse response = createResponse("text/plain");
assertThat(response).hasContentTypeCompatibleWith("text/*");
}

@Test
void contentTypeCanBeAsserted() {
MockHttpServletResponse response = createResponse("text/plain");
assertThat(response).contentType().isInstanceOf(MediaType.class).isCompatibleWith("text/*").isNotNull();
}

private MockHttpServletResponse createResponse(String contentType) {
MockHttpServletResponse response = new MockHttpServletResponse();
response.setContentType(contentType);
return response;
}
}

@Nested
class StatusTests {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,6 @@ public boolean preHandle(HttpServletRequest request, HttpServletResponse respons
}
}

@Nested
class ContentTypeTests {

@Test
void contentType() {
assertThat(perform(get("/greet"))).contentType().isCompatibleWith("text/plain");
}

}

@Nested
class StatusTests {

Expand All @@ -168,8 +158,8 @@ class HeadersTests {

@Test
void shouldAssertHeader() {
assertThat(perform(get("/greet"))).headers()
.hasValue("Content-Type", "text/plain;charset=ISO-8859-1");
assertThat(perform(get("/greet")))
.hasHeader("Content-Type", "text/plain;charset=ISO-8859-1");
}

@Test
Expand Down

0 comments on commit bcecce7

Please sign in to comment.