Skip to content

Commit

Permalink
Allow multiple JWS algorithms to be configured
Browse files Browse the repository at this point in the history
Closes gh-31321
  • Loading branch information
wilkinsona committed Jun 16, 2022
1 parent 5e1cd28 commit a1cc5bf
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 9 deletions.
Expand Up @@ -20,9 +20,11 @@
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.DeprecatedConfigurationProperty;
import org.springframework.boot.context.properties.source.InvalidConfigurationPropertyValueException;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -59,9 +61,9 @@ public static class Jwt {
private String jwkSetUri;

/**
* JSON Web Algorithm used for verifying the digital signatures.
* JSON Web Algorithms used for verifying the digital signatures.
*/
private String jwsAlgorithm = "RS256";
private List<String> jwsAlgorithms = Arrays.asList("RS256");

/**
* URI that can either be an OpenID Connect discovery endpoint or an OAuth 2.0
Expand All @@ -87,12 +89,23 @@ public void setJwkSetUri(String jwkSetUri) {
this.jwkSetUri = jwkSetUri;
}

@Deprecated
@DeprecatedConfigurationProperty(replacement = "spring.security.oauth2.resourceserver.jwt.jws-algorithms")
public String getJwsAlgorithm() {
return this.jwsAlgorithm;
return this.jwsAlgorithms.isEmpty() ? null : this.jwsAlgorithms.get(0);
}

@Deprecated
public void setJwsAlgorithm(String jwsAlgorithm) {
this.jwsAlgorithm = jwsAlgorithm;
this.jwsAlgorithms = new ArrayList<>(Arrays.asList(jwsAlgorithm));
}

public List<String> getJwsAlgorithms() {
return this.jwsAlgorithms;
}

public void setJwsAlgorithms(List<String> jwsAlgortithms) {
this.jwsAlgorithms = jwsAlgortithms;
}

public String getIssuerUri() {
Expand Down
Expand Up @@ -23,6 +23,7 @@
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;

import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
Expand Down Expand Up @@ -78,15 +79,20 @@ static class JwtConfiguration {
@ConditionalOnProperty(name = "spring.security.oauth2.resourceserver.jwt.jwk-set-uri")
ReactiveJwtDecoder jwtDecoder() {
NimbusReactiveJwtDecoder nimbusReactiveJwtDecoder = NimbusReactiveJwtDecoder
.withJwkSetUri(this.properties.getJwkSetUri())
.jwsAlgorithm(SignatureAlgorithm.from(this.properties.getJwsAlgorithm())).build();
.withJwkSetUri(this.properties.getJwkSetUri()).jwsAlgorithms(this::jwsAlgorithms).build();
String issuerUri = this.properties.getIssuerUri();
Supplier<OAuth2TokenValidator<Jwt>> defaultValidator = (issuerUri != null)
? () -> JwtValidators.createDefaultWithIssuer(issuerUri) : JwtValidators::createDefault;
nimbusReactiveJwtDecoder.setJwtValidator(getValidators(defaultValidator));
return nimbusReactiveJwtDecoder;
}

private void jwsAlgorithms(Set<SignatureAlgorithm> signatureAlgorithms) {
for (String algorithm : this.properties.getJwsAlgorithms()) {
signatureAlgorithms.add(SignatureAlgorithm.from(algorithm));
}
}

private OAuth2TokenValidator<Jwt> getValidators(Supplier<OAuth2TokenValidator<Jwt>> defaultValidator) {
OAuth2TokenValidator<Jwt> defaultValidators = defaultValidator.get();
List<String> audiences = this.properties.getAudiences();
Expand All @@ -106,7 +112,7 @@ NimbusReactiveJwtDecoder jwtDecoderByPublicKeyValue() throws Exception {
RSAPublicKey publicKey = (RSAPublicKey) KeyFactory.getInstance("RSA")
.generatePublic(new X509EncodedKeySpec(getKeySpec(this.properties.readPublicKey())));
NimbusReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withPublicKey(publicKey)
.signatureAlgorithm(SignatureAlgorithm.from(this.properties.getJwsAlgorithm())).build();
.signatureAlgorithm(SignatureAlgorithm.from(exactlyOneAlgorithm())).build();
jwtDecoder.setJwtValidator(getValidators(JwtValidators::createDefault));
return jwtDecoder;
}
Expand All @@ -116,6 +122,17 @@ private byte[] getKeySpec(String keyValue) {
return Base64.getMimeDecoder().decode(keyValue);
}

private String exactlyOneAlgorithm() {
List<String> algorithms = this.properties.getJwsAlgorithms();
int count = (algorithms != null) ? algorithms.size() : 0;
if (count != 1) {
throw new IllegalStateException(
"Creating a JWT decoder using a public key requires exactly one JWS algorithm but " + count
+ " were configured");
}
return algorithms.get(0);
}

@Bean
@Conditional(IssuerUriCondition.class)
SupplierReactiveJwtDecoder jwtDecoderByIssuerUri() {
Expand Down
Expand Up @@ -23,6 +23,7 @@
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;

import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
Expand Down Expand Up @@ -78,14 +79,20 @@ static class JwtDecoderConfiguration {
@ConditionalOnProperty(name = "spring.security.oauth2.resourceserver.jwt.jwk-set-uri")
JwtDecoder jwtDecoderByJwkKeySetUri() {
NimbusJwtDecoder nimbusJwtDecoder = NimbusJwtDecoder.withJwkSetUri(this.properties.getJwkSetUri())
.jwsAlgorithm(SignatureAlgorithm.from(this.properties.getJwsAlgorithm())).build();
.jwsAlgorithms(this::jwsAlgorithms).build();
String issuerUri = this.properties.getIssuerUri();
Supplier<OAuth2TokenValidator<Jwt>> defaultValidator = (issuerUri != null)
? () -> JwtValidators.createDefaultWithIssuer(issuerUri) : JwtValidators::createDefault;
nimbusJwtDecoder.setJwtValidator(getValidators(defaultValidator));
return nimbusJwtDecoder;
}

private void jwsAlgorithms(Set<SignatureAlgorithm> signatureAlgorithms) {
for (String algorithm : this.properties.getJwsAlgorithms()) {
signatureAlgorithms.add(SignatureAlgorithm.from(algorithm));
}
}

private OAuth2TokenValidator<Jwt> getValidators(Supplier<OAuth2TokenValidator<Jwt>> defaultValidator) {
OAuth2TokenValidator<Jwt> defaultValidators = defaultValidator.get();
List<String> audiences = this.properties.getAudiences();
Expand All @@ -105,7 +112,7 @@ JwtDecoder jwtDecoderByPublicKeyValue() throws Exception {
RSAPublicKey publicKey = (RSAPublicKey) KeyFactory.getInstance("RSA")
.generatePublic(new X509EncodedKeySpec(getKeySpec(this.properties.readPublicKey())));
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(publicKey)
.signatureAlgorithm(SignatureAlgorithm.from(this.properties.getJwsAlgorithm())).build();
.signatureAlgorithm(SignatureAlgorithm.from(exactlyOneAlgorithm())).build();
jwtDecoder.setJwtValidator(getValidators(JwtValidators::createDefault));
return jwtDecoder;
}
Expand All @@ -115,6 +122,17 @@ private byte[] getKeySpec(String keyValue) {
return Base64.getMimeDecoder().decode(keyValue);
}

private String exactlyOneAlgorithm() {
List<String> algorithms = this.properties.getJwsAlgorithms();
int count = (algorithms != null) ? algorithms.size() : 0;
if (count != 1) {
throw new IllegalStateException(
"Creating a JWT decoder using a public key requires exactly one JWS algorithm but " + count
+ " were configured");
}
return algorithms.get(0);
}

@Bean
@Conditional(IssuerUriCondition.class)
SupplierJwtDecoder jwtDecoderByIssuerUri() {
Expand Down
Expand Up @@ -2058,6 +2058,11 @@
"name": "spring.security.filter.order",
"defaultValue": -100
},
{
"name": "spring.security.oauth2.resourceserver.jwt.jws-algorithm",
"description": "JSON Web Algorithm used for verifying the digital signatures.",
"defaultValue": "RS256"
},
{
"name": "spring.session.hazelcast.flush-mode",
"defaultValue": "on-save"
Expand Down
Expand Up @@ -32,6 +32,7 @@
import com.nimbusds.jose.JWSAlgorithm;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
Expand Down Expand Up @@ -114,6 +115,7 @@ void autoConfigurationShouldConfigureResourceServer() {

@SuppressWarnings("unchecked")
@Test
@Deprecated
void autoConfigurationUsingJwkSetUriShouldConfigureResourceServerUsingJwsAlgorithm() {
this.contextRunner
.withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
Expand All @@ -126,6 +128,33 @@ void autoConfigurationUsingJwkSetUriShouldConfigureResourceServerUsingJwsAlgorit
}

@Test
void autoConfigurationUsingJwkSetUriShouldConfigureResourceServerUsingSingleJwsAlgorithm() {
this.contextRunner
.withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
"spring.security.oauth2.resourceserver.jwt.jws-algorithms=RS512")
.run((context) -> {
NimbusReactiveJwtDecoder nimbusReactiveJwtDecoder = context.getBean(NimbusReactiveJwtDecoder.class);
assertThat(nimbusReactiveJwtDecoder).extracting("jwtProcessor.arg$2.arg$1.jwsAlgs")
.asInstanceOf(InstanceOfAssertFactories.collection(JWSAlgorithm.class))
.containsExactlyInAnyOrder(JWSAlgorithm.RS512);
});
}

@Test
void autoConfigurationUsingJwkSetUriShouldConfigureResourceServerUsingMultipleJwsAlgorithms() {
this.contextRunner
.withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
"spring.security.oauth2.resourceserver.jwt.jws-algorithms=RS256, RS384, RS512")
.run((context) -> {
NimbusReactiveJwtDecoder nimbusReactiveJwtDecoder = context.getBean(NimbusReactiveJwtDecoder.class);
assertThat(nimbusReactiveJwtDecoder).extracting("jwtProcessor.arg$2.arg$1.jwsAlgs")
.asInstanceOf(InstanceOfAssertFactories.collection(JWSAlgorithm.class))
.containsExactlyInAnyOrder(JWSAlgorithm.RS256, JWSAlgorithm.RS384, JWSAlgorithm.RS512);
});
}

@Test
@Deprecated
void autoConfigurationUsingPublicKeyValueShouldConfigureResourceServerUsingJwsAlgorithm() {
this.contextRunner.withPropertyValues(
"spring.security.oauth2.resourceserver.jwt.public-key-location=classpath:public-key-location",
Expand All @@ -136,6 +165,29 @@ void autoConfigurationUsingPublicKeyValueShouldConfigureResourceServerUsingJwsAl
});
}

@Test
void autoConfigurationUsingPublicKeyValueShouldConfigureResourceServerUsingSingleJwsAlgorithm() {
this.contextRunner.withPropertyValues(
"spring.security.oauth2.resourceserver.jwt.public-key-location=classpath:public-key-location",
"spring.security.oauth2.resourceserver.jwt.jws-algorithms=RS384").run((context) -> {
NimbusReactiveJwtDecoder nimbusReactiveJwtDecoder = context.getBean(NimbusReactiveJwtDecoder.class);
assertThat(nimbusReactiveJwtDecoder).extracting("jwtProcessor.arg$1.jwsKeySelector.expectedJWSAlg")
.isEqualTo(JWSAlgorithm.RS384);
});
}

@Test
void autoConfigurationUsingPublicKeyValueWithMultipleJwsAlgorithmsShouldFail() {
this.contextRunner.withPropertyValues(
"spring.security.oauth2.resourceserver.jwt.public-key-location=classpath:public-key-location",
"spring.security.oauth2.resourceserver.jwt.jws-algorithms=RSA256,RS384").run((context) -> {
assertThat(context).hasFailed();
assertThat(context.getStartupFailure()).hasRootCauseMessage(
"Creating a JWT decoder using a public key requires exactly one JWS algorithm but 2 were "
+ "configured");
});
}

@Test
@SuppressWarnings("unchecked")
void autoConfigurationShouldConfigureResourceServerUsingOidcIssuerUri() throws IOException {
Expand Down
Expand Up @@ -33,6 +33,7 @@
import com.nimbusds.jose.JWSAlgorithm;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

Expand All @@ -55,6 +56,7 @@
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtIssuerValidator;
import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.security.oauth2.jwt.SupplierJwtDecoder;
import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken;
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationProvider;
Expand Down Expand Up @@ -120,6 +122,7 @@ void autoConfigurationShouldMatchDefaultJwsAlgorithm() {
}

@Test
@Deprecated
void autoConfigurationShouldConfigureResourceServerWithJwsAlgorithm() {
this.contextRunner
.withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
Expand All @@ -134,6 +137,73 @@ void autoConfigurationShouldConfigureResourceServerWithJwsAlgorithm() {
});
}

@Test
void autoConfigurationShouldConfigureResourceServerWithSingleJwsAlgorithm() {
this.contextRunner
.withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
"spring.security.oauth2.resourceserver.jwt.jws-algorithms=RS384")
.run((context) -> {
JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class);
Object processor = ReflectionTestUtils.getField(jwtDecoder, "jwtProcessor");
Object keySelector = ReflectionTestUtils.getField(processor, "jwsKeySelector");
assertThat(keySelector).extracting("jwsAlgs")
.asInstanceOf(InstanceOfAssertFactories.collection(JWSAlgorithm.class))
.containsExactlyInAnyOrder(JWSAlgorithm.RS384);
assertThat(getBearerTokenFilter(context)).isNotNull();
});
}

@Test
void autoConfigurationShouldConfigureResourceServerWithMultipleJwsAlgorithms() {
this.contextRunner
.withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
"spring.security.oauth2.resourceserver.jwt.jws-algorithms=RS256, RS384, RS512")
.run((context) -> {
JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class);
Object processor = ReflectionTestUtils.getField(jwtDecoder, "jwtProcessor");
Object keySelector = ReflectionTestUtils.getField(processor, "jwsKeySelector");
assertThat(keySelector).extracting("jwsAlgs")
.asInstanceOf(InstanceOfAssertFactories.collection(JWSAlgorithm.class))
.containsExactlyInAnyOrder(JWSAlgorithm.RS256, JWSAlgorithm.RS384, JWSAlgorithm.RS512);
assertThat(getBearerTokenFilter(context)).isNotNull();
});
}

@Test
@Deprecated
void autoConfigurationUsingPublicKeyValueShouldConfigureResourceServerUsingJwsAlgorithm() {
this.contextRunner.withPropertyValues(
"spring.security.oauth2.resourceserver.jwt.public-key-location=classpath:public-key-location",
"spring.security.oauth2.resourceserver.jwt.jws-algorithm=RS384").run((context) -> {
NimbusJwtDecoder nimbusJwtDecoder = context.getBean(NimbusJwtDecoder.class);
assertThat(nimbusJwtDecoder).extracting("jwtProcessor.jwsKeySelector.expectedJWSAlg")
.isEqualTo(JWSAlgorithm.RS384);
});
}

@Test
void autoConfigurationUsingPublicKeyValueShouldConfigureResourceServerUsingSingleJwsAlgorithm() {
this.contextRunner.withPropertyValues(
"spring.security.oauth2.resourceserver.jwt.public-key-location=classpath:public-key-location",
"spring.security.oauth2.resourceserver.jwt.jws-algorithms=RS384").run((context) -> {
NimbusJwtDecoder nimbusJwtDecoder = context.getBean(NimbusJwtDecoder.class);
assertThat(nimbusJwtDecoder).extracting("jwtProcessor.jwsKeySelector.expectedJWSAlg")
.isEqualTo(JWSAlgorithm.RS384);
});
}

@Test
void autoConfigurationUsingPublicKeyValueWithMultipleJwsAlgorithmsShouldFail() {
this.contextRunner.withPropertyValues(
"spring.security.oauth2.resourceserver.jwt.public-key-location=classpath:public-key-location",
"spring.security.oauth2.resourceserver.jwt.jws-algorithms=RSA256,RS384").run((context) -> {
assertThat(context).hasFailed();
assertThat(context.getStartupFailure()).hasRootCauseMessage(
"Creating a JWT decoder using a public key requires exactly one JWS algorithm but 2 were "
+ "configured");
});
}

@Test
@SuppressWarnings("unchecked")
void autoConfigurationShouldConfigureResourceServerUsingOidcIssuerUri() throws Exception {
Expand Down

0 comments on commit a1cc5bf

Please sign in to comment.