Skip to content

Commit

Permalink
alts: Explicit buffer management to avoid too many ShortBufferException
Browse files Browse the repository at this point in the history
To avoid having too many ShortBufferException thrown in ALTS code path on Java 8, we came up with this workaround creating new managed buffer, filling it, and passing it to underlying Conscrypt not to hit the code path throwing the exception. This might look to introduce another inefficiency but it's more like making it explicit because Conscrypt will do for non-managed buffer which gRPC uses.

Fix: #6761
  • Loading branch information
veblush committed May 20, 2020
1 parent d667a67 commit c7e8990
Showing 1 changed file with 55 additions and 44 deletions.
99 changes: 55 additions & 44 deletions alts/src/main/java/io/grpc/alts/internal/AltsChannelCrypter.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
package io.grpc.alts.internal;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;

import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.List;
Expand Down Expand Up @@ -56,61 +56,72 @@ static int getCounterLength() {

@Override
public void encrypt(ByteBuf outBuf, List<ByteBuf> plainBufs) throws GeneralSecurityException {
checkArgument(outBuf.nioBufferCount() == 1);
// Copy plaintext buffers into outBuf for in-place encryption on single direct buffer.
ByteBuf plainBuf = outBuf.slice(outBuf.writerIndex(), outBuf.writableBytes());
plainBuf.writerIndex(0);
for (ByteBuf inBuf : plainBufs) {
plainBuf.writeBytes(inBuf);
byte[] tempArr = new byte[outBuf.writableBytes()];

// Copy plaintext into tempArr.
{
ByteBuf tempBuf = Unpooled.wrappedBuffer(tempArr, 0, tempArr.length - TAG_LENGTH);
tempBuf.resetWriterIndex();
for (ByteBuf plainBuf : plainBufs) {
tempBuf.writeBytes(plainBuf);
}
}

verify(outBuf.writableBytes() == plainBuf.readableBytes() + TAG_LENGTH);
ByteBuffer out = outBuf.internalNioBuffer(outBuf.writerIndex(), outBuf.writableBytes());
ByteBuffer plain = out.duplicate();
plain.limit(out.limit() - TAG_LENGTH);

byte[] counter = incrementOutCounter();
int outPosition = out.position();
aeadCrypter.encrypt(out, plain, counter);
int bytesWritten = out.position() - outPosition;
outBuf.writerIndex(outBuf.writerIndex() + bytesWritten);
verify(!outBuf.isWritable());
// Encrypt into tempArr.
{
ByteBuffer out = ByteBuffer.wrap(tempArr);
ByteBuffer plain = ByteBuffer.wrap(tempArr, 0, tempArr.length - TAG_LENGTH);

byte[] counter = incrementOutCounter();
aeadCrypter.encrypt(out, plain, counter);
}
outBuf.writeBytes(tempArr);
}

@Override
public void decrypt(ByteBuf out, ByteBuf tag, List<ByteBuf> ciphertextBufs)
public void decrypt(ByteBuf outBuf, ByteBuf tagBuf, List<ByteBuf> ciphertextBufs)
throws GeneralSecurityException {
// There is enough space for the ciphertext including the tag in outBuf.
byte[] tempArr = new byte[outBuf.writableBytes()];

// Copy ciphertext and tag into tempArr.
{
ByteBuf tempBuf = Unpooled.wrappedBuffer(tempArr);
tempBuf.resetWriterIndex();
for (ByteBuf ciphertextBuf : ciphertextBufs) {
tempBuf.writeBytes(ciphertextBuf);
}
tempBuf.writeBytes(tagBuf);
}

ByteBuf cipherTextAndTag = out.slice(out.writerIndex(), out.writableBytes());
cipherTextAndTag.writerIndex(0);
decryptInternal(outBuf, tempArr);
}

for (ByteBuf inBuf : ciphertextBufs) {
cipherTextAndTag.writeBytes(inBuf);
@Override
public void decrypt(
ByteBuf outBuf, ByteBuf ciphertextAndTagDirect) throws GeneralSecurityException {
byte[] tempArr = new byte[ciphertextAndTagDirect.readableBytes()];

// Copy ciphertext and tag into tempArr.
{
ByteBuf tempBuf = Unpooled.wrappedBuffer(tempArr);
tempBuf.resetWriterIndex();
tempBuf.writeBytes(ciphertextAndTagDirect);
}
cipherTextAndTag.writeBytes(tag);

decrypt(out, cipherTextAndTag);
decryptInternal(outBuf, tempArr);
}

@Override
public void decrypt(ByteBuf out, ByteBuf ciphertextAndTag) throws GeneralSecurityException {
int bytesRead = ciphertextAndTag.readableBytes();
checkArgument(bytesRead == out.writableBytes());

checkArgument(out.nioBufferCount() == 1);
ByteBuffer outBuffer = out.internalNioBuffer(out.writerIndex(), out.writableBytes());

checkArgument(ciphertextAndTag.nioBufferCount() == 1);
ByteBuffer ciphertextAndTagBuffer =
ciphertextAndTag.nioBuffer(ciphertextAndTag.readerIndex(), bytesRead);

byte[] counter = incrementInCounter();
int outPosition = outBuffer.position();
aeadCrypter.decrypt(outBuffer, ciphertextAndTagBuffer, counter);
int bytesWritten = outBuffer.position() - outPosition;
out.writerIndex(out.writerIndex() + bytesWritten);
ciphertextAndTag.readerIndex(out.readerIndex() + bytesRead);
verify(out.writableBytes() == TAG_LENGTH);
private void decryptInternal(ByteBuf outBuf, byte[] tempArr) throws GeneralSecurityException {
// Perform in-place decryption on tempArr.
{
ByteBuffer ciphertextAndTag = ByteBuffer.wrap(tempArr);
ByteBuffer out = ByteBuffer.wrap(tempArr);
byte[] counter = incrementInCounter();
aeadCrypter.decrypt(out, ciphertextAndTag, counter);
}

outBuf.writeBytes(tempArr, 0, tempArr.length - TAG_LENGTH);
}

@Override
Expand Down

0 comments on commit c7e8990

Please sign in to comment.