diff --git a/impl/src/main/java/io/jsonwebtoken/impl/DefaultJwtBuilder.java b/impl/src/main/java/io/jsonwebtoken/impl/DefaultJwtBuilder.java index ef41c7aec..afa6a4804 100644 --- a/impl/src/main/java/io/jsonwebtoken/impl/DefaultJwtBuilder.java +++ b/impl/src/main/java/io/jsonwebtoken/impl/DefaultJwtBuilder.java @@ -595,14 +595,8 @@ private String sign(final Payload payload, final Key key, final Provider provide // Next, b64 extension requires the raw (non-encoded) payload to be included directly in the signing input, // so we ensure we have an input stream for that: - if (payload.isClaims() || payload.isCompressed()) { - ByteArrayOutputStream claimsOut = new ByteArrayOutputStream(8192); - writeAndClose("JWS Unencoded Payload", payload, claimsOut); - payloadStream = Streams.of(claimsOut.toByteArray()); - } else { - // No claims and not compressed, so just get the direct InputStream: - payloadStream = Assert.stateNotNull(payload.toInputStream(), "Payload InputStream cannot be null."); - } + payloadStream = toInputStream("JWS Unencoded Payload", payload); + if (!payload.isClaims()) { payloadStream = new CountingInputStream(payloadStream); // we'll need to assert if it's empty later } @@ -693,14 +687,7 @@ private String encrypt(final Payload content, final Key key, final Provider keyP Assert.stateNotNull(keyAlgFunction, "KeyAlgorithm function cannot be null."); assertPayloadEncoding("JWE"); - InputStream plaintext; - if (content.isClaims()) { - ByteArrayOutputStream out = new ByteArrayOutputStream(4096); - writeAndClose("JWE Claims", content, out); - plaintext = Streams.of(out.toByteArray()); - } else { - plaintext = content.toInputStream(); - } + InputStream plaintext = toInputStream("JWE Payload", content); //only expose (mutable) JweHeader functionality to KeyAlgorithm instances, not the full headerBuilder // (which exposes this JwtBuilder and shouldn't be referenced by KeyAlgorithms): @@ -820,4 +807,15 @@ private void encodeAndWrite(String name, byte[] data, OutputStream out) { Streams.writeAndClose(out, data, "Unable to write bytes"); } + private InputStream toInputStream(final String name, Payload payload) { + if (payload.isClaims() || payload.isCompressed()) { + ByteArrayOutputStream claimsOut = new ByteArrayOutputStream(8192); + writeAndClose(name, payload, claimsOut); + return Streams.of(claimsOut.toByteArray()); + } else { + // No claims and not compressed, so just get the direct InputStream: + return Assert.stateNotNull(payload.toInputStream(), "Payload InputStream cannot be null."); + } + } + } diff --git a/impl/src/test/groovy/io/jsonwebtoken/JwtsTest.groovy b/impl/src/test/groovy/io/jsonwebtoken/JwtsTest.groovy index 2f1c432c8..5f6b82229 100644 --- a/impl/src/test/groovy/io/jsonwebtoken/JwtsTest.groovy +++ b/impl/src/test/groovy/io/jsonwebtoken/JwtsTest.groovy @@ -22,6 +22,7 @@ import io.jsonwebtoken.impl.io.Streams import io.jsonwebtoken.impl.lang.Bytes import io.jsonwebtoken.impl.lang.Services import io.jsonwebtoken.impl.security.* +import io.jsonwebtoken.io.CompressionAlgorithm import io.jsonwebtoken.io.Decoders import io.jsonwebtoken.io.Deserializer import io.jsonwebtoken.io.Encoders @@ -1398,6 +1399,97 @@ class JwtsTest { } } + @Test + void testJweCompressionWithArbitraryContentString() { + def codecs = [Jwts.ZIP.DEF, Jwts.ZIP.GZIP] + + for (CompressionAlgorithm zip : codecs) { + + for (AeadAlgorithm enc : Jwts.ENC.get().values()) { + + SecretKey key = enc.key().build() + + String payload = 'hello, world!' + + // encrypt and compress: + String jwe = Jwts.builder() + .content(payload) + .compressWith(zip) + .encryptWith(key, enc) + .compact() + + //decompress and decrypt: + def jwt = Jwts.parser() + .decryptWith(key) + .build() + .parseEncryptedContent(jwe) + assertEquals payload, new String(jwt.getPayload(), StandardCharsets.UTF_8) + } + } + } + + @Test + void testJweCompressionWithArbitraryContentByteArray() { + def codecs = [Jwts.ZIP.DEF, Jwts.ZIP.GZIP] + + for (CompressionAlgorithm zip : codecs) { + + for (AeadAlgorithm enc : Jwts.ENC.get().values()) { + + SecretKey key = enc.key().build() + + byte[] payload = new byte[14]; + Randoms.secureRandom().nextBytes(payload) + + // encrypt and compress: + String jwe = Jwts.builder() + .content(payload) + .compressWith(zip) + .encryptWith(key, enc) + .compact() + + //decompress and decrypt: + def jwt = Jwts.parser() + .decryptWith(key) + .build() + .parseEncryptedContent(jwe) + assertArrayEquals payload, jwt.getPayload() + } + } + } + + @Test + void testJweCompressionWithArbitraryContentInputStream() { + def codecs = [Jwts.ZIP.DEF, Jwts.ZIP.GZIP] + + for (CompressionAlgorithm zip : codecs) { + + for (AeadAlgorithm enc : Jwts.ENC.get().values()) { + + SecretKey key = enc.key().build() + + byte[] payloadBytes = new byte[14]; + Randoms.secureRandom().nextBytes(payloadBytes) + + ByteArrayInputStream payload = new ByteArrayInputStream(payloadBytes) + + // encrypt and compress: + String jwe = Jwts.builder() + .content(payload) + .compressWith(zip) + .encryptWith(key, enc) + .compact() + + //decompress and decrypt: + def jwt = Jwts.parser() + .decryptWith(key) + .build() + .parseEncryptedContent(jwe) + assertArrayEquals payloadBytes, jwt.getPayload() + } + } + } + @Test void testPasswordJwes() {