Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47172][CORE] Add support for AES-GCM for RPC encryption #46515

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

sweisdb
Copy link
Contributor

@sweisdb sweisdb commented May 9, 2024

What changes were proposed in this pull request?

This change adds AES-GCM as an optional AES cipher mode for RPC encryption. The current default is using AES-CTR without any authentication. That would allow someone on the network to easily modify RPC contents on the wire and impact Spark behavior. See SPARK-47172 for more details.

Why are the changes needed?

The current default is using AES-CTR without any authentication. That would allow someone on the network to easily modify RPC contents on the wire and impact Spark behavior.

Does this PR introduce any user-facing change?

Yes, it adds an additional configuration flag is reflected in the documentation.

How was this patch tested?

Existing unit tests are all ensured to pass. New unit tests are written to explicitly test GCM support and to verify that modifying ciphertext content will cause an exception and fail.

build/sbt "network-common/test:testOnly"
build/sbt "network-common/test:testOnly org.apache.spark.network.crypto.AuthIntegrationSuite"
build/sbt "network-common/test:testOnly org.apache.spark.network.crypto.AuthEngineSuite"

Was this patch authored or co-authored using generative AI tooling?

Nope.

@dongjoon-hyun dongjoon-hyun changed the title [SPARK-47172] Add support for AES-GCM for RPC encryption [SPARK-47172][CORE] Add support for AES-GCM for RPC encryption May 9, 2024
@dongjoon-hyun
Copy link
Member

cc @mridulm

@mridulm
Copy link
Contributor

mridulm commented May 14, 2024

Took a quick pass through it, sorry for the delay.

+CC @JoshRosen as well.

Copy link
Contributor

@mridulm mridulm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a quick pass, not yet looked through tests and doc.
Thanks for working on this @sweisdb !

import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.WritableByteChannel;
import java.security.GeneralSecurityException;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the import order here.

private final ByteBuffer ciphertextBuffer;
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;

EncryptionHandler() throws GeneralSecurityException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
EncryptionHandler() throws GeneralSecurityException {
EncryptionHandler() throws InvalidAlgorithmParameterException {

private static final byte[] DEFAULT_AAD = new byte[0];
private static final int LENGTH_HEADER_BYTES = 8;
@VisibleForTesting
static final int CIPHERTEXT_BUFFER_SIZE = 1024;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: A larger buffer ? 32k or 64k ?
Existing TransportCipher uses 32k for ex.

plaintextMessage instanceof ByteBuf || plaintextMessage instanceof FileRegion,
"Unrecognized message type: %s", plaintextMessage.getClass().getName());
this.plaintextMessage = plaintextMessage;
this.bytesToRead = getReadableBytes();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: move initialization of this.bytesToRead to the end of the constructor, after all other fields have been initialized.

Comment on lines +145 to +155
if (!headerWritten) {
ByteBuffer expectedLength = ByteBuffer
.allocate(LENGTH_HEADER_BYTES)
.putLong(encryptedCount)
.flip();
target.write(expectedLength);
int headerWritten = LENGTH_HEADER_BYTES + target.write(encrypter.getHeader());
transferredThisCall += headerWritten;
this.transferred += headerWritten;
this.headerWritten = true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not gauranteed that the writes to target will result in all the bytes being written out, and we have to handle partial writes here.

In constructor, initialize a headerByteBuffer:

  this.headerByteBuffer = createHeaderByteBuffer();
}


          // The format of the output is:
          // [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
        private ByteBuffer createHeaderByteBuffer() {
            ByteBuffer encrypterHeader = encrypter.getHeader();
            return ByteBuffer
                    .allocate(encrypterHeader.remaining() + LENGTH_HEADER_BYTES)
                    .putLong(encryptedCount)
                    .put(encrypterHeader)
                    .flip();
        }

This then becomes:

Suggested change
if (!headerWritten) {
ByteBuffer expectedLength = ByteBuffer
.allocate(LENGTH_HEADER_BYTES)
.putLong(encryptedCount)
.flip();
target.write(expectedLength);
int headerWritten = LENGTH_HEADER_BYTES + target.write(encrypter.getHeader());
transferredThisCall += headerWritten;
this.transferred += headerWritten;
this.headerWritten = true;
}
if (headerByteBuffer.hasRemaining()) {
int written = target.write(headerByteBuffer);
if (headerByteBuffer.hasRemaining()) return written;
transferredThisCall += written;
this.transferred += written;
}

Also, we can remove headerWritten field.

Comment on lines +190 to +194
while (ciphertextBuffer.hasRemaining()) {
target.write(ciphertextBuffer);
}
transferredThisCall += outputRemaining;
transferred += outputRemaining;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If target cannot write any more, we will go into a busy loop here.
Instead, we should check before entering this loop.

After if (headerByteBuffer.hasRemaining()) { block I proposed above, do this:

            if (ciphertextBuffer.hasRemaining()) {
                int written = target.write(ciphertextBuffer);
                transferredThisCall += written;
                this.transferred += written;
                if (ciphertextBuffer.hasRemaining()) return transferredThisCall;
            }

This code then becomes:

Suggested change
while (ciphertextBuffer.hasRemaining()) {
target.write(ciphertextBuffer);
}
transferredThisCall += outputRemaining;
transferred += outputRemaining;
int written = target.write(ciphertextBuffer);
transferredThisCall += written;
this.transferred += written;
if (ciphertextBuffer.hasRemaining()) return transferredThisCall;

@Override
public long count() {
return encryptedCount;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add touch, retain and release - the first helps with debugging, and the others would be useful as the code evolves: please see TransportCipher on how to add it, should be simple change.

"HmacSha256",
aesKey.getEncoded().length,
CIPHERTEXT_BUFFER_SIZE,
0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a static method to create AesGcmHkdfStreaming in GcmTransportCipher and use it from both encryption and decryption handler.

ByteBuf ciphertextNettyBuf = (ByteBuf) ciphertextMessage;
// The format of the output is:
// [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
try {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this method, we cannot make assumptions about how much data is available to read from the incoming ByteBuf - when reading a segment, or even when we are reading the header.

These cases are currently not handled.

// Read the ciphertext into the local buffer
int readableBytes = Integer.min(
ciphertextNettyBuf.readableBytes(),
ciphertextBuffer.remaining());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Do we want to enforce we are reading only upto expectedLength ? (currently we throw an exception below if we end up reading more ...).
I am assuming it is possible for input to have more in case multiple messages are being encrypted one after another ?

}
}
plaintextBuffer.flip();
ciphertextBuffer.clear();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: If plaintextBuffer.remaining() != plaintextMessage.capacity(), the segment encryption will end up with incorrect sizes, no ?
For ByteBuf case this should not happen, but can for FileRegion reads.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants