Skip to content

Commit

Permalink
core, alts, cronet: fix ByteBuffer covariant method usages (grpc#7349)
Browse files Browse the repository at this point in the history
Java 9 introduces overridden methods with covariant return types for the following methods in java.nio.ByteBuffer:

- position​(int newPosition)
- limit​(int newLimit)
- flip​()
- clear​()
- mark​()
- reset​()
- rewind​()

In Java 9 they all now return ByteBuffer, whereas the methods they override return Buffer, resulting in exceptions like this when executing on Java 8 and lower:

java.lang.NoSuchMethodError: java.nio.ByteBuffer.limit(I)Ljava/nio/ByteBuffer

This is because the generated byte code includes the static return type of the method, which is not found on Java 8 and lower because the overloaded methods with covariant return types don't exist (the issue appears even with source and target 8 or lower in compilation parameters).
The solution is to cast ByteBuffer instances to Buffer before calling the method.
  • Loading branch information
voidzcy authored and dfawley committed Jan 15, 2021
1 parent c9b08b6 commit bd489e6
Show file tree
Hide file tree
Showing 15 changed files with 84 additions and 74 deletions.
29 changes: 15 additions & 14 deletions alts/src/main/java/io/grpc/alts/internal/AltsFraming.java
Expand Up @@ -17,6 +17,7 @@
package io.grpc.alts.internal;

import com.google.common.base.Preconditions;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.GeneralSecurityException;
Expand Down Expand Up @@ -63,10 +64,10 @@ static ByteBuffer toFrame(ByteBuffer input, int dataSize) throws GeneralSecurity
}
Producer producer = new Producer();
ByteBuffer inputAlias = input.duplicate();
inputAlias.limit(input.position() + dataSize);
((Buffer) inputAlias).limit(input.position() + dataSize);
producer.readBytes(inputAlias);
producer.flush();
input.position(inputAlias.position());
((Buffer) input).position(inputAlias.position());
ByteBuffer output = producer.getRawFrame();
return output;
}
Expand Down Expand Up @@ -166,10 +167,10 @@ void flush() throws GeneralSecurityException {
int frameLength = buffer.position() + getFrameSuffixLength();

// Set the limit and move to the start.
buffer.flip();
((Buffer) buffer).flip();

// Advance the limit to allow a crypto suffix.
buffer.limit(buffer.limit() + getFrameSuffixLength());
((Buffer) buffer).limit(buffer.limit() + getFrameSuffixLength());

// Write the data length and the message type.
int dataLength = frameLength - FRAME_LENGTH_HEADER_SIZE;
Expand All @@ -178,17 +179,17 @@ void flush() throws GeneralSecurityException {
buffer.putInt(MESSAGE_TYPE);

// Move the position back to 0, the frame is ready.
buffer.position(0);
((Buffer) buffer).position(0);
isComplete = true;
}

/** Resets the state, preparing to construct a new frame. Must be called between frames. */
private void reset() {
buffer.clear();
((Buffer) buffer).clear();

// Save some space for framing, we'll fill that in later.
buffer.position(getFramePrefixLength());
buffer.limit(buffer.limit() - getFrameSuffixLength());
((Buffer) buffer).position(getFramePrefixLength());
((Buffer) buffer).limit(buffer.limit() - getFrameSuffixLength());

isComplete = false;
}
Expand Down Expand Up @@ -279,7 +280,7 @@ public boolean readBytes(ByteBuffer input) throws GeneralSecurityException {
// internal buffer is large enough.
if (buffer.position() == FRAME_LENGTH_HEADER_SIZE && input.hasRemaining()) {
ByteBuffer bufferAlias = buffer.duplicate();
bufferAlias.flip();
((Buffer) bufferAlias).flip();
bufferAlias.order(ByteOrder.LITTLE_ENDIAN);
int dataLength = bufferAlias.getInt();
if (dataLength < FRAME_MESSAGE_TYPE_HEADER_SIZE || dataLength > MAX_DATA_LENGTH) {
Expand All @@ -292,15 +293,15 @@ public boolean readBytes(ByteBuffer input) throws GeneralSecurityException {
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(dataLength);
}
buffer.limit(frameLength);
((Buffer) buffer).limit(frameLength);
}

// TODO: Similarly extract and check message type.

// Read the remaining data into the internal buffer.
copy(buffer, input);
if (!buffer.hasRemaining()) {
buffer.flip();
((Buffer) buffer).flip();
isComplete = true;
}
return isComplete;
Expand All @@ -323,7 +324,7 @@ public boolean isComplete() {

/** Resets the state, preparing to parse a new frame. Must be called between frames. */
private void reset() {
buffer.clear();
((Buffer) buffer).clear();
isComplete = false;
}

Expand Down Expand Up @@ -356,9 +357,9 @@ private static void copy(ByteBuffer dst, ByteBuffer src) {
} else {
int count = Math.min(dst.remaining(), src.remaining());
ByteBuffer slice = src.slice();
slice.limit(count);
((Buffer) slice).limit(count);
dst.put(slice);
src.position(src.position() + count);
((Buffer) src).position(src.position() + count);
}
}
}
Expand Down
Expand Up @@ -23,6 +23,7 @@
import io.grpc.Status;
import io.grpc.alts.internal.HandshakerServiceGrpc.HandshakerServiceStub;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.logging.Level;
Expand Down Expand Up @@ -199,7 +200,7 @@ public ByteBuffer startServerHandshake(ByteBuffer inBytes) throws GeneralSecurit
throw new GeneralSecurityException(e);
}
handleResponse(resp);
inBytes.position(inBytes.position() + resp.getBytesConsumed());
((Buffer) inBytes).position(inBytes.position() + resp.getBytesConsumed());
return resp.getOutFrames().asReadOnlyByteBuffer();
}

Expand Down Expand Up @@ -227,7 +228,7 @@ public ByteBuffer next(ByteBuffer inBytes) throws GeneralSecurityException {
throw new GeneralSecurityException(e);
}
handleResponse(resp);
inBytes.position(inBytes.position() + resp.getBytesConsumed());
((Buffer) inBytes).position(inBytes.position() + resp.getBytesConsumed());
return resp.getOutFrames().asReadOnlyByteBuffer();
}

Expand Down
Expand Up @@ -22,6 +22,7 @@
import com.google.common.base.Preconditions;
import io.grpc.alts.internal.HandshakerServiceGrpc.HandshakerServiceStub;
import io.netty.buffer.ByteBufAllocator;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
Expand Down Expand Up @@ -151,10 +152,10 @@ public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityExcepti
ByteBuffer outputFrameAlias = outputFrame;
if (outputFrame.remaining() > bytes.remaining()) {
outputFrameAlias = outputFrame.duplicate();
outputFrameAlias.limit(outputFrameAlias.position() + bytes.remaining());
((Buffer) outputFrameAlias).limit(outputFrameAlias.position() + bytes.remaining());
}
bytes.put(outputFrameAlias);
outputFrame.position(outputFrameAlias.position());
((Buffer) outputFrame).position(outputFrameAlias.position());
}

/**
Expand Down
11 changes: 6 additions & 5 deletions alts/src/test/java/io/grpc/alts/internal/AltsFramingTest.java
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;

import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.GeneralSecurityException;
Expand All @@ -38,7 +39,7 @@ public void parserFrameLengthNegativeFails() throws GeneralSecurityException {
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(-1); // write invalid length
buffer.put((byte) 0); // write some byte
buffer.flip();
((Buffer) buffer).flip();

try {
parser.readBytes(buffer);
Expand All @@ -56,7 +57,7 @@ public void parserFrameLengthSmallerMessageTypeFails() throws GeneralSecurityExc
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(AltsFraming.getFrameMessageTypeHeaderSize() - 1); // write invalid length
buffer.put((byte) 0); // write some byte
buffer.flip();
((Buffer) buffer).flip();

try {
parser.readBytes(buffer);
Expand All @@ -74,7 +75,7 @@ public void parserFrameLengthTooLargeFails() throws GeneralSecurityException {
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(AltsFraming.getMaxDataLength() + 1); // write invalid length
buffer.put((byte) 0); // write some byte
buffer.flip();
((Buffer) buffer).flip();

try {
parser.readBytes(buffer);
Expand All @@ -97,7 +98,7 @@ public void parserFrameLengthMaxOk() throws GeneralSecurityException {
buffer.putInt(6); // default message type
buffer.put(new byte[dataLength - AltsFraming.getFrameMessageTypeHeaderSize()]); // write data
buffer.put((byte) 0);
buffer.flip();
((Buffer) buffer).flip();

parser.readBytes(buffer);

Expand All @@ -116,7 +117,7 @@ public void parserFrameLengthZeroOk() throws GeneralSecurityException {
buffer.putInt(dataLength); // write invalid length
buffer.putInt(6); // default message type
buffer.put((byte) 0);
buffer.flip();
((Buffer) buffer).flip();

parser.readBytes(buffer);

Expand Down
Expand Up @@ -29,6 +29,7 @@

import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import org.junit.Before;
Expand Down Expand Up @@ -178,7 +179,7 @@ public void startServerHandshakeWithPrefixBuffer() throws Exception {
.thenReturn(MockAltsHandshakerResp.getOkResponse(BYTES_CONSUMED));

ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
inBytes.position(PREFIX_POSITION);
((Buffer) inBytes).position(PREFIX_POSITION);
ByteBuffer outFrame = handshaker.startServerHandshake(inBytes);

assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
Expand Down
Expand Up @@ -26,6 +26,7 @@
import static org.mockito.Mockito.when;

import com.google.protobuf.ByteString;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -112,7 +113,7 @@ public void processBytesFromPeerStartServer() throws Exception {
verify(mockServer, never()).startClientHandshake();
verify(mockServer, never()).next(ArgumentMatchers.<ByteBuffer>any());
// Mock transport buffer all consumed by processBytesFromPeer and there is an output frame.
transportBuffer.position(transportBuffer.limit());
((Buffer) transportBuffer).position(transportBuffer.limit());
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(outputFrame);
when(mockServer.isFinished()).thenReturn(false);

Expand All @@ -127,7 +128,7 @@ public void processBytesFromPeerStartServerEmptyOutput() throws Exception {
verify(mockServer, never()).next(ArgumentMatchers.<ByteBuffer>any());
// Mock transport buffer all consumed by processBytesFromPeer and output frame is empty.
// Expect processBytesFromPeer return False, because more data are needed from the peer.
transportBuffer.position(transportBuffer.limit());
((Buffer) transportBuffer).position(transportBuffer.limit());
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(emptyOutputFrame);
when(mockServer.isFinished()).thenReturn(false);

Expand Down Expand Up @@ -174,7 +175,7 @@ public void processBytesFromPeerClientNext() throws Exception {
when(mockClient.isFinished()).thenReturn(false);

handshakerClient.getBytesToSendToPeer(transportBuffer);
transportBuffer.position(transportBuffer.limit());
((Buffer) transportBuffer).position(transportBuffer.limit());
assertFalse(handshakerClient.processBytesFromPeer(transportBuffer));
}

Expand Down
35 changes: 18 additions & 17 deletions alts/src/test/java/io/grpc/alts/internal/FakeTsiTest.java
Expand Up @@ -27,6 +27,7 @@
import io.netty.util.ReferenceCounted;
import io.netty.util.ResourceLeakDetector;
import io.netty.util.ResourceLeakDetector.Level;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
Expand Down Expand Up @@ -86,11 +87,11 @@ public void handshakeStateOrderTest() {

byte[] transportBufferBytes = new byte[TsiTest.getDefaultTransportBufferSize()];
ByteBuffer transportBuffer = ByteBuffer.wrap(transportBufferBytes);
transportBuffer.limit(0); // Start off with an empty buffer
((Buffer) transportBuffer).limit(0); // Start off with an empty buffer

transportBuffer.clear();
((Buffer) transportBuffer).clear();
clientHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertEquals(
FakeTsiHandshaker.State.CLIENT_INIT.toString().trim(),
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
Expand All @@ -99,14 +100,14 @@ public void handshakeStateOrderTest() {
assertFalse(transportBuffer.hasRemaining());

// client shouldn't offer any more bytes
transportBuffer.clear();
((Buffer) transportBuffer).clear();
clientHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertFalse(transportBuffer.hasRemaining());

transportBuffer.clear();
((Buffer) transportBuffer).clear();
serverHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertEquals(
FakeTsiHandshaker.State.SERVER_INIT.toString().trim(),
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
Expand All @@ -115,14 +116,14 @@ public void handshakeStateOrderTest() {
assertFalse(transportBuffer.hasRemaining());

// server shouldn't offer any more bytes
transportBuffer.clear();
((Buffer) transportBuffer).clear();
serverHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertFalse(transportBuffer.hasRemaining());

transportBuffer.clear();
((Buffer) transportBuffer).clear();
clientHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertEquals(
FakeTsiHandshaker.State.CLIENT_FINISHED.toString().trim(),
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
Expand All @@ -131,14 +132,14 @@ public void handshakeStateOrderTest() {
assertFalse(transportBuffer.hasRemaining());

// client shouldn't offer any more bytes
transportBuffer.clear();
((Buffer) transportBuffer).clear();
clientHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertFalse(transportBuffer.hasRemaining());

transportBuffer.clear();
((Buffer) transportBuffer).clear();
serverHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertEquals(
FakeTsiHandshaker.State.SERVER_FINISHED.toString().trim(),
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
Expand All @@ -147,9 +148,9 @@ public void handshakeStateOrderTest() {
assertFalse(transportBuffer.hasRemaining());

// server shouldn't offer any more bytes
transportBuffer.clear();
((Buffer) transportBuffer).clear();
serverHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertFalse(transportBuffer.hasRemaining());
} catch (GeneralSecurityException e) {
throw new AssertionError(e);
Expand Down
Expand Up @@ -20,6 +20,7 @@

import com.google.protobuf.ByteString;
import io.grpc.Status;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.SecureRandom;
Expand Down Expand Up @@ -62,7 +63,7 @@ static ByteString getOutFrame() {
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(frameSize);
buffer.put(TEST_OUT_FRAME.getBytes(UTF_8));
buffer.flip();
((Buffer) buffer).flip();
return ByteString.copyFrom(buffer);
}

Expand Down
7 changes: 4 additions & 3 deletions alts/src/test/java/io/grpc/alts/internal/TsiTest.java
Expand Up @@ -26,6 +26,7 @@
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
Expand Down Expand Up @@ -83,7 +84,7 @@ static void performHandshake(int transportBufferSize, Handshakers handshakers)

byte[] transportBufferBytes = new byte[transportBufferSize];
ByteBuffer transportBuffer = ByteBuffer.wrap(transportBufferBytes);
transportBuffer.limit(0); // Start off with an empty buffer
((Buffer) transportBuffer).limit(0); // Start off with an empty buffer

while (clientHandshaker.isInProgress() || serverHandshaker.isInProgress()) {
for (TsiHandshaker handshaker : new TsiHandshaker[] {clientHandshaker, serverHandshaker}) {
Expand All @@ -94,9 +95,9 @@ static void performHandshake(int transportBufferSize, Handshakers handshakers)
}
// Put new bytes on the wire, if needed.
if (handshaker.isInProgress()) {
transportBuffer.clear();
((Buffer) transportBuffer).clear();
handshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
}
}
}
Expand Down

0 comments on commit bd489e6

Please sign in to comment.