diff --git a/src/crypto/crypto_hkdf.cc b/src/crypto/crypto_hkdf.cc index bff8cc60c4c13b..081a930e69f3c4 100644 --- a/src/crypto/crypto_hkdf.cc +++ b/src/crypto/crypto_hkdf.cc @@ -112,47 +112,33 @@ bool HKDFTraits::DeriveBits( // TODO(panva): Once support for OpenSSL 1.1.1 is dropped the whole // of HKDFTraits::DeriveBits can be refactored to use // EVP_KDF which does handle zero length key. - if (params.key->GetSymmetricKeySize() != 0) { - if (!EVP_PKEY_CTX_hkdf_mode(ctx.get(), - EVP_PKEY_HKDEF_MODE_EXTRACT_AND_EXPAND) || - !EVP_PKEY_CTX_set1_hkdf_salt( - ctx.get(), params.salt.data(), params.salt.size()) || - !EVP_PKEY_CTX_set1_hkdf_key(ctx.get(), - reinterpret_cast( - params.key->GetSymmetricKey()), - params.key->GetSymmetricKeySize())) { - return false; - } + + std::string_view salt; + if (params.salt.size() != 0) { + salt = {params.salt.data(), params.salt.size()}; } else { - // Workaround for EVP_PKEY_derive HKDF not handling zero length keys. - unsigned char temp_key[EVP_MAX_MD_SIZE]; - unsigned int len = sizeof(temp_key); - if (params.salt.size() != 0) { - if (HMAC(params.digest, - params.salt.data(), - params.salt.size(), - nullptr, - 0, - temp_key, - &len) == nullptr) { - return false; - } - } else { - char salt[EVP_MAX_MD_SIZE] = {0}; - if (HMAC(params.digest, - salt, - EVP_MD_size(params.digest), - nullptr, - 0, - temp_key, - &len) == nullptr) { - return false; - } - } - if (!EVP_PKEY_CTX_hkdf_mode(ctx.get(), EVP_PKEY_HKDEF_MODE_EXPAND_ONLY) || - !EVP_PKEY_CTX_set1_hkdf_key(ctx.get(), temp_key, len)) { - return false; - } + static const char default_salt[EVP_MAX_MD_SIZE] = {0}; + salt = {default_salt, static_cast(EVP_MD_size(params.digest))}; + } + + // We do not use EVP_PKEY_HKDEF_MODE_EXTRACT_AND_EXPAND and instead implement + // the extraction step ourselves because EVP_PKEY_derive does not handle + // zero-length keys, which are required for Web Crypto. + unsigned char pseudorandom_key[EVP_MAX_MD_SIZE]; + unsigned int prk_len = sizeof(pseudorandom_key); + if (HMAC( + params.digest, + salt.data(), + salt.size(), + reinterpret_cast(params.key->GetSymmetricKey()), + params.key->GetSymmetricKeySize(), + pseudorandom_key, + &prk_len) == nullptr) { + return false; + } + if (!EVP_PKEY_CTX_hkdf_mode(ctx.get(), EVP_PKEY_HKDEF_MODE_EXPAND_ONLY) || + !EVP_PKEY_CTX_set1_hkdf_key(ctx.get(), pseudorandom_key, prk_len)) { + return false; } size_t length = params.length;