Skip to content

Commit

Permalink
Add support customizing redirect URI
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Batischev authored and Max Batischev committed Apr 17, 2024
1 parent 4c44de7 commit cb06ab0
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,7 @@

import reactor.core.publisher.Mono;

import org.springframework.core.convert.converter.Converter;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
Expand All @@ -35,6 +36,7 @@
import org.springframework.security.web.server.authentication.logout.RedirectServerLogoutSuccessHandler;
import org.springframework.security.web.server.authentication.logout.ServerLogoutSuccessHandler;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;

Expand All @@ -57,6 +59,8 @@ public class OidcClientInitiatedServerLogoutSuccessHandler implements ServerLogo

private String postLogoutRedirectUri;

private Converter<RedirectUriParameters, Mono<String>> redirectUriResolver = new DefaultRedirectUriResolver();

/**
* Constructs an {@link OidcClientInitiatedServerLogoutSuccessHandler} with the
* provided parameters
Expand All @@ -72,22 +76,11 @@ public OidcClientInitiatedServerLogoutSuccessHandler(

@Override
public Mono<Void> onLogoutSuccess(WebFilterExchange exchange, Authentication authentication) {
RedirectUriParameters redirectUriParameters = new RedirectUriParameters();
redirectUriParameters.setAuthentication(authentication);
redirectUriParameters.setServerWebExchange(exchange.getExchange());
// @formatter:off
return Mono.just(authentication)
.filter(OAuth2AuthenticationToken.class::isInstance)
.filter((token) -> authentication.getPrincipal() instanceof OidcUser)
.map(OAuth2AuthenticationToken.class::cast)
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId)
.flatMap(this.clientRegistrationRepository::findByRegistrationId)
.flatMap((clientRegistration) -> {
URI endSessionEndpoint = endSessionEndpoint(clientRegistration);
if (endSessionEndpoint == null) {
return Mono.empty();
}
String idToken = idToken(authentication);
String postLogoutRedirectUri = postLogoutRedirectUri(exchange.getExchange().getRequest(), clientRegistration);
return Mono.just(endpointUri(endSessionEndpoint, idToken, postLogoutRedirectUri));
})
return this.redirectUriResolver.convert(redirectUriParameters)
.switchIfEmpty(
this.serverLogoutSuccessHandler.onLogoutSuccess(exchange, authentication).then(Mono.empty())
)
Expand Down Expand Up @@ -189,4 +182,90 @@ public void setLogoutSuccessUrl(URI logoutSuccessUrl) {
this.serverLogoutSuccessHandler.setLogoutSuccessUrl(logoutSuccessUrl);
}

/**
* Set the {@link Converter} that converts {@link RedirectUriParameters} to redirect
* URI
* @param redirectUriResolver {@link Converter}
* @since 6.4
*/
public void setRedirectUriResolver(Converter<RedirectUriParameters, Mono<String>> redirectUriResolver) {
Assert.notNull(redirectUriResolver, "redirectUriResolver cannot be null");
this.redirectUriResolver = redirectUriResolver;
}

/**
* Parameters, required for redirect URI resolving.
*
* @author Max Batischev
* @since 6.4
*/
public static class RedirectUriParameters {

private ServerWebExchange serverWebExchange;

private Authentication authentication;

private ClientRegistration clientRegistration;

public ServerWebExchange getServerWebExchange() {
return this.serverWebExchange;
}

public Authentication getAuthentication() {
return this.authentication;
}

public ClientRegistration getClientRegistration() {
return this.clientRegistration;
}

public void setClientRegistration(ClientRegistration clientRegistration) {
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
this.clientRegistration = clientRegistration;
}

public void setServerWebExchange(ServerWebExchange serverWebExchange) {
Assert.notNull(serverWebExchange, "serverWebExchange cannot be null");
this.serverWebExchange = serverWebExchange;
}

public void setAuthentication(Authentication authentication) {
Assert.notNull(authentication, "authentication cannot be null");
this.authentication = authentication;
}

}

/**
* Default {@link Converter} for redirect uri resolving.
*
* @since 6.4
*/
private final class DefaultRedirectUriResolver implements Converter<RedirectUriParameters, Mono<String>> {

@Override
public Mono<String> convert(RedirectUriParameters redirectUriParameters) {
// @formatter:off
return Mono.just(redirectUriParameters.authentication)
.filter(OAuth2AuthenticationToken.class::isInstance)
.filter((token) -> redirectUriParameters.authentication.getPrincipal() instanceof OidcUser)
.map(OAuth2AuthenticationToken.class::cast)
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId)
.flatMap(
OidcClientInitiatedServerLogoutSuccessHandler.this.clientRegistrationRepository::findByRegistrationId)
.flatMap((clientRegistration) -> {
URI endSessionEndpoint = endSessionEndpoint(clientRegistration);
if (endSessionEndpoint == null) {
return Mono.empty();
}
String idToken = idToken(redirectUriParameters.authentication);
String postLogoutRedirectUri = postLogoutRedirectUri(
redirectUriParameters.serverWebExchange.getRequest(), clientRegistration);
return Mono.just(endpointUri(endSessionEndpoint, idToken, postLogoutRedirectUri));
});
// @formatter:on
}

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@
import java.io.IOException;
import java.net.URI;
import java.util.Collections;
import java.util.Objects;

import jakarta.servlet.ServletException;
import org.junit.jupiter.api.BeforeEach;
Expand Down Expand Up @@ -199,6 +200,25 @@ public void setPostLogoutRedirectUriTemplateWhenGivenNullThenThrowsException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.handler.setPostLogoutRedirectUri((String) null));
}

@Test
public void logoutWhenCustomRedirectUriResolverSetThenRedirects() {
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(TestOidcUsers.create(),
AuthorityUtils.NO_AUTHORITIES, this.registration.getRegistrationId());
WebFilterExchange filterExchange = new WebFilterExchange(this.exchange, this.chain);
given(this.exchange.getRequest())
.willReturn(MockServerHttpRequest.get("/").queryParam("location", "https://test.com").build());
// @formatter:off
this.handler.setRedirectUriResolver((params) -> Mono.just(
Objects.requireNonNull(params.getServerWebExchange()
.getRequest()
.getQueryParams()
.getFirst("location"))));
// @formatter:on
this.handler.onLogoutSuccess(filterExchange, token).block();

assertThat(redirectedUrl(this.exchange)).isEqualTo("https://test.com");
}

private String redirectedUrl(ServerWebExchange exchange) {
return exchange.getResponse().getHeaders().getFirst("Location");
}
Expand Down

0 comments on commit cb06ab0

Please sign in to comment.