Skip to content

Commit

Permalink
JWK .equals and .hashCode (#823)
Browse files Browse the repository at this point in the history
* Adjusted JWK .equals implementations to only account for kty value and material fields (two JWKs are equal if their type and key material are equal, regardless of other public parameters and/or custom name/value pairs).

* Adjusted JWK .hashCode implementation to pre-cache its value based on JwkThumpbrint fields since JWKs are immutable
  • Loading branch information
lhazlewood committed Sep 13, 2023
1 parent f60d560 commit b55f261
Show file tree
Hide file tree
Showing 24 changed files with 471 additions and 37 deletions.
10 changes: 10 additions & 0 deletions impl/src/main/java/io/jsonwebtoken/impl/lang/Bytes.java
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,16 @@ public static byte[] concat(byte[]... arrays) {
return output;
}

/**
* Clears the array by filling it with all zeros. Does nothing with a null or empty argument.
*
* @param bytes the (possibly null or empty) byte array to clear
*/
public static void clear(byte[] bytes) {
if (isEmpty(bytes)) return;
java.util.Arrays.fill(bytes, (byte) 0);
}

public static boolean isEmpty(byte[] bytes) {
return length(bytes) == 0;
}
Expand Down
44 changes: 44 additions & 0 deletions impl/src/main/java/io/jsonwebtoken/impl/lang/Fields.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

import io.jsonwebtoken.lang.Arrays;
import io.jsonwebtoken.lang.Assert;
import io.jsonwebtoken.lang.Objects;
import io.jsonwebtoken.lang.Registry;

import java.math.BigInteger;
import java.net.URI;
import java.security.MessageDigest;
import java.security.cert.X509Certificate;
import java.util.Collection;
import java.util.Date;
Expand Down Expand Up @@ -97,4 +99,46 @@ public static Registry<String, Field<?>> registry(Registry<String, Field<?>> par
newFields.put(id, field); // add new one
return registry(newFields.values());
}

private static byte[] bytes(BigInteger i) {
return i != null ? i.toByteArray() : null;
}

public static boolean bytesEquals(BigInteger a, BigInteger b) {
//noinspection NumberEquality
if (a == b) return true;
if (a == null || b == null) return false;
byte[] aBytes = bytes(a);
byte[] bBytes = bytes(b);
try {
return MessageDigest.isEqual(aBytes, bBytes);
} finally {
Bytes.clear(aBytes);
Bytes.clear(bBytes);
}
}

private static <T> boolean equals(T a, T b, Field<T> field) {
if (a == b) return true;
if (a == null || b == null) return false;
if (field.isSecret()) {
// byte[] and BigInteger are the only types of secret Fields in the JJWT codebase
// (i.e. Field.isSecret() == true). If a Field is ever marked as secret, and it's not one of these two
// data types, we need to know about it. So we use the 'assertSecret' helper above to ensure we do:
if (a instanceof byte[]) {
return b instanceof byte[] && MessageDigest.isEqual((byte[]) a, (byte[]) b);
} else if (a instanceof BigInteger) {
return b instanceof BigInteger && bytesEquals((BigInteger) a, (BigInteger) b);
}
}
// default to a standard null-safe comparison:
return Objects.nullSafeEquals(a, b);
}

public static <T> boolean equals(FieldReadable a, Object o, Field<T> field) {
if (a == o) return true;
if (a == null || !(o instanceof FieldReadable)) return false;
FieldReadable b = (FieldReadable) o;
return equals(a.get(field), b.get(field), field);
}
}
57 changes: 51 additions & 6 deletions impl/src/main/java/io/jsonwebtoken/impl/security/AbstractJwk.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@

import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
Expand All @@ -48,10 +51,11 @@ public abstract class AbstractJwk<K extends Key> implements Jwk<K>, FieldReadabl
.set().setId("key_ops").setName("Key Operations").build();
static final Field<String> KTY = Fields.string("kty", "Key Type");
static final Set<Field<?>> FIELDS = Collections.setOf(ALG, KID, KEY_OPS, KTY);

public static final String IMMUTABLE_MSG = "JWKs are immutable and may not be modified.";

protected final JwkContext<K> context;
private final List<Field<?>> THUMBPRINT_FIELDS;
private final int hashCode;

/**
* @param ctx the backing JwkContext containing the JWK field values.
Expand All @@ -71,6 +75,40 @@ public abstract class AbstractJwk<K extends Key> implements Jwk<K>, FieldReadabl
String kid = thumbprint.toString();
ctx.setId(kid);
}
this.hashCode = computeHashCode();
}

/**
* Compute and return the JWK hashCode. As JWKs are immutable, this value will be cached as a final constant
* upon JWK instantiation. This uses the JWK's thumbprint fields during computation, but differs from JwkThumbprint
* calculation in two ways:
* <ol>
* <li>JwkThumbprints use a MessageDigest calculation, which is unnecessary overhead for a hashcode</li>
* <li>The hashCode calculation uses each field's idiomatic (Java) object value instead of the
* JwkThumbprint-required canonical (String) value.</li>
* </ol>
*
* @return the JWK hashcode
*/
private int computeHashCode() {
List<Object> list = new ArrayList<>(this.THUMBPRINT_FIELDS.size() + 1 /* possible discriminator */);
// So we don't leak information about the private key value, we need a discriminator to ensure that
// public and private key hashCodes are not identical (in case both JWKs need to be in the same hash set).
// So we add a discriminator String to the list of values that are used during hashCode calculation
Key key = Assert.notNull(toKey(), "JWK toKey() value cannot be null.");
if (key instanceof PublicKey) {
list.add("Public");
} else if (key instanceof PrivateKey) {
list.add("Private");
}
for (Field<?> field : this.THUMBPRINT_FIELDS) {
// Unlike thumbprint calculation, we get the idiomatic (Java) value, not canonical (String) value
// (We could have used either actually, but the idiomatic value hashCode calculation is probably
// faster).
Object val = Assert.notNull(get(field), "computeHashCode: Field idiomatic value cannot be null.");
list.add(val);
}
return Objects.nullSafeHashCode(list.toArray());
}

private String getRequiredThumbprintValue(Field<?> field) {
Expand Down Expand Up @@ -230,13 +268,20 @@ public String toString() {
}

@Override
public int hashCode() {
return this.context.hashCode();
public final int hashCode() {
return this.hashCode;
}

@SuppressWarnings("EqualsWhichDoesntCheckParameterClass")
@Override
public boolean equals(Object obj) {
return this.context.equals(obj);
public final boolean equals(Object obj) {
if (obj == this) return true;
if (obj instanceof Jwk<?>) {
Jwk<?> other = (Jwk<?>) obj;
// this.getType() guaranteed non-null in constructor:
return getType().equals(other.getType()) && equals(other);
}
return false;
}

protected abstract boolean equals(Jwk<?> jwk);
}
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,12 @@ static class DefaultSecretJwkBuilder extends AbstractJwkBuilder<SecretKey, Secre
implements SecretJwkBuilder {
public DefaultSecretJwkBuilder(JwkContext<SecretKey> ctx) {
super(ctx);
// assign a standard algorithm if possible:
Key key = Assert.notNull(ctx.getKey(), "SecretKey cannot be null.");
DefaultMacAlgorithm mac = DefaultMacAlgorithm.findByKey(key);
if (mac != null) {
algorithm(mac.getId());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import io.jsonwebtoken.impl.lang.Field;
import io.jsonwebtoken.lang.Assert;
import io.jsonwebtoken.security.Jwk;
import io.jsonwebtoken.security.KeyPair;
import io.jsonwebtoken.security.PrivateJwk;
import io.jsonwebtoken.security.PublicJwk;
Expand Down Expand Up @@ -47,4 +48,11 @@ public M toPublicJwk() {
public KeyPair<L, K> toKeyPair() {
return this.keyPair;
}

@Override
protected final boolean equals(Jwk<?> jwk) {
return jwk instanceof PrivateJwk && equals((PrivateJwk<?, ?, ?>) jwk);
}

protected abstract boolean equals(PrivateJwk<?, ?, ?> jwk);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.jsonwebtoken.impl.security;

import io.jsonwebtoken.impl.lang.Field;
import io.jsonwebtoken.security.Jwk;
import io.jsonwebtoken.security.PublicJwk;

import java.security.PublicKey;
Expand All @@ -25,4 +26,11 @@ abstract class AbstractPublicJwk<K extends PublicKey> extends AbstractAsymmetric
AbstractPublicJwk(JwkContext<K> ctx, List<Field<?>> thumbprintFields) {
super(ctx, thumbprintFields);
}

@Override
protected final boolean equals(Jwk<?> jwk) {
return jwk instanceof PublicJwk && equals((PublicJwk<?>) jwk);
}

protected abstract boolean equals(PublicJwk<?> jwk);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.security.EcPrivateJwk;
import io.jsonwebtoken.security.EcPublicJwk;
import io.jsonwebtoken.security.PrivateJwk;

import java.math.BigInteger;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.util.Set;

import static io.jsonwebtoken.impl.security.DefaultEcPublicJwk.equalsPublic;

class DefaultEcPrivateJwk extends AbstractPrivateJwk<ECPrivateKey, ECPublicKey, EcPublicJwk> implements EcPrivateJwk {

static final Field<BigInteger> D = Fields.secretBigInt("d", "ECC Private Key");
Expand All @@ -38,4 +41,9 @@ class DefaultEcPrivateJwk extends AbstractPrivateJwk<ECPrivateKey, ECPublicKey,
DefaultEcPublicJwk.THUMBPRINT_FIELDS,
pubJwk);
}

@Override
protected boolean equals(PrivateJwk<?, ?, ?> jwk) {
return jwk instanceof EcPrivateJwk && equalsPublic(this, jwk) && Fields.equals(this, jwk, D);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
package io.jsonwebtoken.impl.security;

import io.jsonwebtoken.impl.lang.Field;
import io.jsonwebtoken.impl.lang.FieldReadable;
import io.jsonwebtoken.impl.lang.Fields;
import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.security.EcPublicJwk;
import io.jsonwebtoken.security.PublicJwk;

import java.math.BigInteger;
import java.security.interfaces.ECPublicKey;
Expand All @@ -39,4 +41,15 @@ class DefaultEcPublicJwk extends AbstractPublicJwk<ECPublicKey> implements EcPub
DefaultEcPublicJwk(JwkContext<ECPublicKey> ctx) {
super(ctx, THUMBPRINT_FIELDS);
}

static boolean equalsPublic(FieldReadable self, Object candidate) {
return Fields.equals(self, candidate, CRV) &&
Fields.equals(self, candidate, X) &&
Fields.equals(self, candidate, Y);
}

@Override
protected boolean equals(PublicJwk<?> jwk) {
return jwk instanceof EcPublicJwk && equalsPublic(this, jwk);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ final class DefaultMacAlgorithm extends AbstractSecureDigestAlgorithm<SecretKey,
static final DefaultMacAlgorithm HS384 = new DefaultMacAlgorithm(384);
static final DefaultMacAlgorithm HS512 = new DefaultMacAlgorithm(512);

private static final Map<String, MacAlgorithm> JCA_NAME_MAP;
private static final Map<String, DefaultMacAlgorithm> JCA_NAME_MAP;

static {
JCA_NAME_MAP = new LinkedHashMap<>(6);
Expand Down Expand Up @@ -96,15 +96,15 @@ private static boolean isJwaStandardJcaName(String jcaName) {
return JCA_NAME_MAP.containsKey(key);
}

static MacAlgorithm findByKey(Key key) {
static DefaultMacAlgorithm findByKey(Key key) {

String alg = KeysBridge.findAlgorithm(key);
if (!Strings.hasText(alg)) {
return null;
}

String upper = alg.toUpperCase(Locale.ENGLISH);
MacAlgorithm mac = JCA_NAME_MAP.get(upper);
DefaultMacAlgorithm mac = JCA_NAME_MAP.get(upper);
if (mac == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@
import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.security.OctetPrivateJwk;
import io.jsonwebtoken.security.OctetPublicJwk;
import io.jsonwebtoken.security.PrivateJwk;

import java.security.PrivateKey;
import java.security.PublicKey;
import java.util.Set;

public class DefaultOctetPrivateJwk<T extends PrivateKey, P extends PublicKey> extends AbstractPrivateJwk<T, P, OctetPublicJwk<P>> implements OctetPrivateJwk<T, P> {
import static io.jsonwebtoken.impl.security.DefaultOctetPublicJwk.equalsPublic;

public class DefaultOctetPrivateJwk<T extends PrivateKey, P extends PublicKey>
extends AbstractPrivateJwk<T, P, OctetPublicJwk<P>> implements OctetPrivateJwk<T, P> {

static final Field<byte[]> D = Fields.bytes("d", "The private key").setSecret(true).build();

Expand All @@ -37,4 +41,9 @@ public class DefaultOctetPrivateJwk<T extends PrivateKey, P extends PublicKey> e
// https://www.rfc-editor.org/rfc/rfc7638#section-3.2.1
DefaultOctetPublicJwk.THUMBPRINT_FIELDS, pubJwk);
}

@Override
protected boolean equals(PrivateJwk<?, ?, ?> jwk) {
return jwk instanceof OctetPrivateJwk && equalsPublic(this, jwk) && Fields.equals(this, jwk, D);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
package io.jsonwebtoken.impl.security;

import io.jsonwebtoken.impl.lang.Field;
import io.jsonwebtoken.impl.lang.FieldReadable;
import io.jsonwebtoken.impl.lang.Fields;
import io.jsonwebtoken.lang.Collections;
import io.jsonwebtoken.security.OctetPublicJwk;
import io.jsonwebtoken.security.PublicJwk;

import java.security.PublicKey;
import java.util.List;
Expand All @@ -37,4 +39,13 @@ public class DefaultOctetPublicJwk<T extends PublicKey> extends AbstractPublicJw
DefaultOctetPublicJwk(JwkContext<T> ctx) {
super(ctx, THUMBPRINT_FIELDS);
}

static boolean equalsPublic(FieldReadable self, Object candidate) {
return Fields.equals(self, candidate, CRV) && Fields.equals(self, candidate, X);
}

@Override
protected boolean equals(PublicJwk<?> jwk) {
return jwk instanceof OctetPublicJwk && equalsPublic(this, jwk);
}
}

0 comments on commit b55f261

Please sign in to comment.