Skip to content

Commit

Permalink
key byte array cleanup as necessary (#846)
Browse files Browse the repository at this point in the history
  • Loading branch information
lhazlewood committed Oct 3, 2023
1 parent e78f3f5 commit b411b19
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 40 deletions.
92 changes: 59 additions & 33 deletions impl/src/main/java/io/jsonwebtoken/impl/security/ConcatKDF.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package io.jsonwebtoken.impl.security;

import io.jsonwebtoken.impl.lang.Bytes;
import io.jsonwebtoken.impl.lang.CheckedFunction;
import io.jsonwebtoken.lang.Assert;
import io.jsonwebtoken.security.SecurityException;
Expand Down Expand Up @@ -114,43 +115,68 @@ public SecretKey deriveKey(final byte[] Z, final long derivedKeyBitLength, final
long inputBitLength = bitLength(counter) + bitLength(Z) + bitLength(OtherInfo);
Assert.state(inputBitLength <= MAX_HASH_INPUT_BIT_LENGTH, "Hash input is too large.");

byte[] derivedKeyBytes = jca().withMessageDigest(new CheckedFunction<MessageDigest, byte[]>() {
@Override
public byte[] apply(MessageDigest md) throws Exception {

final ByteArrayOutputStream stream = new ByteArrayOutputStream((int) derivedKeyByteLength);

// Section 5.8.1.1, Process step #5. We depart from Java idioms here by starting iteration index at 1
// (instead of 0) and continue to <= reps (instead of < reps) to match the NIST publication algorithm
// notation convention (so variables like Ki and kLast below match the NIST definitions).
for (long i = 1; i <= reps; i++) {

// Section 5.8.1.1, Process step #5.1:
md.update(counter);
md.update(Z);
md.update(OtherInfo);
byte[] Ki = md.digest();

// Section 5.8.1.1, Process step #5.2:
increment(counter);

// Section 5.8.1.1, Process step #6:
if (i == reps && kLastPartial) {
long leftmostBitLength = derivedKeyBitLength % hashBitLength;
int leftmostByteLength = (int) (leftmostBitLength / Byte.SIZE);
byte[] kLast = new byte[leftmostByteLength];
System.arraycopy(Ki, 0, kLast, 0, kLast.length);
Ki = kLast;
final ClearableByteArrayOutputStream stream = new ClearableByteArrayOutputStream((int) derivedKeyByteLength);
byte[] derivedKeyBytes = EMPTY;

try {
derivedKeyBytes = jca().withMessageDigest(new CheckedFunction<MessageDigest, byte[]>() {
@Override
public byte[] apply(MessageDigest md) throws Exception {

// Section 5.8.1.1, Process step #5. We depart from Java idioms here by starting iteration index at 1
// (instead of 0) and continue to <= reps (instead of < reps) to match the NIST publication algorithm
// notation convention (so variables like Ki and kLast below match the NIST definitions).
for (long i = 1; i <= reps; i++) {

// Section 5.8.1.1, Process step #5.1:
md.update(counter);
md.update(Z);
md.update(OtherInfo);
byte[] Ki = md.digest();

// Section 5.8.1.1, Process step #5.2:
increment(counter);

// Section 5.8.1.1, Process step #6:
if (i == reps && kLastPartial) {
long leftmostBitLength = derivedKeyBitLength % hashBitLength;
int leftmostByteLength = (int) (leftmostBitLength / Byte.SIZE);
byte[] kLast = new byte[leftmostByteLength];
System.arraycopy(Ki, 0, kLast, 0, kLast.length);
Ki = kLast;
}

stream.write(Ki);
}

stream.write(Ki);
// Section 5.8.1.1, Process step #7:
return stream.toByteArray();
}
});
return new SecretKeySpec(derivedKeyBytes, AesAlgorithm.KEY_ALG_NAME);
} finally {
// key cleanup
Bytes.clear(derivedKeyBytes); // SecretKeySpec clones this, so we can clear it out safely
Bytes.clear(counter);
stream.reset();
// we don't clear out 'Z', since that is the responsibility of the caller
}
}

// Section 5.8.1.1, Process step #7:
return stream.toByteArray();
}
});
/**
* Calling ByteArrayOutputStream.toByteArray returns a copy of the bytes, so this class allows us to completely
* zero-out the buffer upon reset (whereas BAOS just resets the position marker, leaving the bytes in tact)
*/
private static class ClearableByteArrayOutputStream extends ByteArrayOutputStream {

return new SecretKeySpec(derivedKeyBytes, AesAlgorithm.KEY_ALG_NAME);
public ClearableByteArrayOutputStream(int size) {
super(size);
}

@Override
public synchronized void reset() {
super.reset();
Bytes.clear(buf); // zero out internal buffer
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ private SecretKey deriveKey(KeyRequest<?> request, PublicKey publicKey, PrivateK
byte[] apv = request.getHeader().getAgreementPartyVInfo();
byte[] OtherInfo = createOtherInfo(requiredCekBitLen, AlgorithmID, apu, apv);
byte[] Z = generateZ(request, publicKey, privateKey);
return CONCAT_KDF.deriveKey(Z, requiredCekBitLen, OtherInfo);
try {
return CONCAT_KDF.deriveKey(Z, requiredCekBitLen, OtherInfo);
} finally {
Bytes.clear(Z);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,13 @@ public void encrypt(final AeadRequest req, final AeadResult res) {
int halfCount = compositeKeyBytes.length / 2; // https://tools.ietf.org/html/rfc7518#section-5.2
byte[] macKeyBytes = Arrays.copyOfRange(compositeKeyBytes, 0, halfCount);
byte[] encKeyBytes = Arrays.copyOfRange(compositeKeyBytes, halfCount, compositeKeyBytes.length);
final SecretKey encryptionKey = new SecretKeySpec(encKeyBytes, KEY_ALG_NAME);
final SecretKey encryptionKey;
try {
encryptionKey = new SecretKeySpec(encKeyBytes, KEY_ALG_NAME);
} finally {
Bytes.clear(encKeyBytes);
Bytes.clear(compositeKeyBytes);
}

final InputStream plaintext = Assert.notNull(req.getPayload(),
"Request content (plaintext) InputStream cannot be null.");
Expand All @@ -121,9 +127,13 @@ public Object apply(Cipher cipher) throws Exception {

byte[] aadBytes = aad == null ? Bytes.EMPTY : Streams.bytes(aad, "Unable to read AAD bytes.");

byte[] tag = sign(aadBytes, iv, Streams.of(copy.toByteArray()), macKeyBytes);

res.setTag(tag).setIv(iv);
byte[] tag;
try {
tag = sign(aadBytes, iv, Streams.of(copy.toByteArray()), macKeyBytes);
res.setTag(tag).setIv(iv);
} finally {
Bytes.clear(macKeyBytes);
}
}

private byte[] sign(byte[] aad, byte[] iv, InputStream ciphertext, byte[] macKeyBytes) {
Expand Down Expand Up @@ -162,7 +172,13 @@ public void decrypt(final DecryptAeadRequest req, final OutputStream plaintext)
int halfCount = compositeKeyBytes.length / 2; // https://tools.ietf.org/html/rfc7518#section-5.2
byte[] macKeyBytes = Arrays.copyOfRange(compositeKeyBytes, 0, halfCount);
byte[] encKeyBytes = Arrays.copyOfRange(compositeKeyBytes, halfCount, compositeKeyBytes.length);
final SecretKey decryptionKey = new SecretKeySpec(encKeyBytes, KEY_ALG_NAME);
final SecretKey decryptionKey;
try {
decryptionKey = new SecretKeySpec(encKeyBytes, KEY_ALG_NAME);
} finally {
Bytes.clear(encKeyBytes);
Bytes.clear(compositeKeyBytes);
}

InputStream in = Assert.notNull(req.getPayload(),
"Decryption request content (ciphertext) InputStream cannot be null.");
Expand All @@ -174,7 +190,12 @@ public void decrypt(final DecryptAeadRequest req, final OutputStream plaintext)
// Assert that the aad + iv + ciphertext provided, when signed, equals the tag provided,
// thereby verifying none of it has been tampered with:
byte[] aadBytes = aad == null ? Bytes.EMPTY : Streams.bytes(aad, "Unable to read AAD bytes.");
byte[] digest = sign(aadBytes, iv, in, macKeyBytes);
byte[] digest;
try {
digest = sign(aadBytes, iv, in, macKeyBytes);
} finally {
Bytes.clear(macKeyBytes);
}
if (!MessageDigest.isEqual(digest, tag)) { //constant time comparison to avoid side-channel attacks
String msg = "Ciphertext decryption failed: Authentication tag verification failed.";
throw new SignatureException(msg);
Expand Down

0 comments on commit b411b19

Please sign in to comment.