Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core, alts, cronet: fix ByteBuffer covariant method usages #7349

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 15 additions & 14 deletions alts/src/main/java/io/grpc/alts/internal/AltsFraming.java
Expand Up @@ -17,6 +17,7 @@
package io.grpc.alts.internal;

import com.google.common.base.Preconditions;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.GeneralSecurityException;
Expand Down Expand Up @@ -63,10 +64,10 @@ static ByteBuffer toFrame(ByteBuffer input, int dataSize) throws GeneralSecurity
}
Producer producer = new Producer();
ByteBuffer inputAlias = input.duplicate();
inputAlias.limit(input.position() + dataSize);
((Buffer) inputAlias).limit(input.position() + dataSize);
producer.readBytes(inputAlias);
producer.flush();
input.position(inputAlias.position());
((Buffer) input).position(inputAlias.position());
ByteBuffer output = producer.getRawFrame();
return output;
}
Expand Down Expand Up @@ -166,10 +167,10 @@ void flush() throws GeneralSecurityException {
int frameLength = buffer.position() + getFrameSuffixLength();

// Set the limit and move to the start.
buffer.flip();
((Buffer) buffer).flip();

// Advance the limit to allow a crypto suffix.
buffer.limit(buffer.limit() + getFrameSuffixLength());
((Buffer) buffer).limit(buffer.limit() + getFrameSuffixLength());

// Write the data length and the message type.
int dataLength = frameLength - FRAME_LENGTH_HEADER_SIZE;
Expand All @@ -178,17 +179,17 @@ void flush() throws GeneralSecurityException {
buffer.putInt(MESSAGE_TYPE);

// Move the position back to 0, the frame is ready.
buffer.position(0);
((Buffer) buffer).position(0);
isComplete = true;
}

/** Resets the state, preparing to construct a new frame. Must be called between frames. */
private void reset() {
buffer.clear();
((Buffer) buffer).clear();

// Save some space for framing, we'll fill that in later.
buffer.position(getFramePrefixLength());
buffer.limit(buffer.limit() - getFrameSuffixLength());
((Buffer) buffer).position(getFramePrefixLength());
((Buffer) buffer).limit(buffer.limit() - getFrameSuffixLength());

isComplete = false;
}
Expand Down Expand Up @@ -279,7 +280,7 @@ public boolean readBytes(ByteBuffer input) throws GeneralSecurityException {
// internal buffer is large enough.
if (buffer.position() == FRAME_LENGTH_HEADER_SIZE && input.hasRemaining()) {
ByteBuffer bufferAlias = buffer.duplicate();
bufferAlias.flip();
((Buffer) bufferAlias).flip();
bufferAlias.order(ByteOrder.LITTLE_ENDIAN);
int dataLength = bufferAlias.getInt();
if (dataLength < FRAME_MESSAGE_TYPE_HEADER_SIZE || dataLength > MAX_DATA_LENGTH) {
Expand All @@ -292,15 +293,15 @@ public boolean readBytes(ByteBuffer input) throws GeneralSecurityException {
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(dataLength);
}
buffer.limit(frameLength);
((Buffer) buffer).limit(frameLength);
}

// TODO: Similarly extract and check message type.

// Read the remaining data into the internal buffer.
copy(buffer, input);
if (!buffer.hasRemaining()) {
buffer.flip();
((Buffer) buffer).flip();
isComplete = true;
}
return isComplete;
Expand All @@ -323,7 +324,7 @@ public boolean isComplete() {

/** Resets the state, preparing to parse a new frame. Must be called between frames. */
private void reset() {
buffer.clear();
((Buffer) buffer).clear();
isComplete = false;
}

Expand Down Expand Up @@ -356,9 +357,9 @@ private static void copy(ByteBuffer dst, ByteBuffer src) {
} else {
int count = Math.min(dst.remaining(), src.remaining());
ByteBuffer slice = src.slice();
slice.limit(count);
((Buffer) slice).limit(count);
dst.put(slice);
src.position(src.position() + count);
((Buffer) src).position(src.position() + count);
}
}
}
Expand Down
Expand Up @@ -23,6 +23,7 @@
import io.grpc.Status;
import io.grpc.alts.internal.HandshakerServiceGrpc.HandshakerServiceStub;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.logging.Level;
Expand Down Expand Up @@ -199,7 +200,7 @@ public ByteBuffer startServerHandshake(ByteBuffer inBytes) throws GeneralSecurit
throw new GeneralSecurityException(e);
}
handleResponse(resp);
inBytes.position(inBytes.position() + resp.getBytesConsumed());
((Buffer) inBytes).position(inBytes.position() + resp.getBytesConsumed());
return resp.getOutFrames().asReadOnlyByteBuffer();
}

Expand Down Expand Up @@ -227,7 +228,7 @@ public ByteBuffer next(ByteBuffer inBytes) throws GeneralSecurityException {
throw new GeneralSecurityException(e);
}
handleResponse(resp);
inBytes.position(inBytes.position() + resp.getBytesConsumed());
((Buffer) inBytes).position(inBytes.position() + resp.getBytesConsumed());
return resp.getOutFrames().asReadOnlyByteBuffer();
}

Expand Down
Expand Up @@ -22,6 +22,7 @@
import com.google.common.base.Preconditions;
import io.grpc.alts.internal.HandshakerServiceGrpc.HandshakerServiceStub;
import io.netty.buffer.ByteBufAllocator;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
Expand Down Expand Up @@ -151,10 +152,10 @@ public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityExcepti
ByteBuffer outputFrameAlias = outputFrame;
if (outputFrame.remaining() > bytes.remaining()) {
outputFrameAlias = outputFrame.duplicate();
outputFrameAlias.limit(outputFrameAlias.position() + bytes.remaining());
((Buffer) outputFrameAlias).limit(outputFrameAlias.position() + bytes.remaining());
}
bytes.put(outputFrameAlias);
outputFrame.position(outputFrameAlias.position());
((Buffer) outputFrame).position(outputFrameAlias.position());
}

/**
Expand Down
11 changes: 6 additions & 5 deletions alts/src/test/java/io/grpc/alts/internal/AltsFramingTest.java
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;

import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.GeneralSecurityException;
Expand All @@ -38,7 +39,7 @@ public void parserFrameLengthNegativeFails() throws GeneralSecurityException {
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(-1); // write invalid length
buffer.put((byte) 0); // write some byte
buffer.flip();
((Buffer) buffer).flip();

try {
parser.readBytes(buffer);
Expand All @@ -56,7 +57,7 @@ public void parserFrameLengthSmallerMessageTypeFails() throws GeneralSecurityExc
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(AltsFraming.getFrameMessageTypeHeaderSize() - 1); // write invalid length
buffer.put((byte) 0); // write some byte
buffer.flip();
((Buffer) buffer).flip();

try {
parser.readBytes(buffer);
Expand All @@ -74,7 +75,7 @@ public void parserFrameLengthTooLargeFails() throws GeneralSecurityException {
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(AltsFraming.getMaxDataLength() + 1); // write invalid length
buffer.put((byte) 0); // write some byte
buffer.flip();
((Buffer) buffer).flip();

try {
parser.readBytes(buffer);
Expand All @@ -97,7 +98,7 @@ public void parserFrameLengthMaxOk() throws GeneralSecurityException {
buffer.putInt(6); // default message type
buffer.put(new byte[dataLength - AltsFraming.getFrameMessageTypeHeaderSize()]); // write data
buffer.put((byte) 0);
buffer.flip();
((Buffer) buffer).flip();

parser.readBytes(buffer);

Expand All @@ -116,7 +117,7 @@ public void parserFrameLengthZeroOk() throws GeneralSecurityException {
buffer.putInt(dataLength); // write invalid length
buffer.putInt(6); // default message type
buffer.put((byte) 0);
buffer.flip();
((Buffer) buffer).flip();

parser.readBytes(buffer);

Expand Down
Expand Up @@ -29,6 +29,7 @@

import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import org.junit.Before;
Expand Down Expand Up @@ -178,7 +179,7 @@ public void startServerHandshakeWithPrefixBuffer() throws Exception {
.thenReturn(MockAltsHandshakerResp.getOkResponse(BYTES_CONSUMED));

ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
inBytes.position(PREFIX_POSITION);
((Buffer) inBytes).position(PREFIX_POSITION);
ByteBuffer outFrame = handshaker.startServerHandshake(inBytes);

assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
Expand Down
Expand Up @@ -26,6 +26,7 @@
import static org.mockito.Mockito.when;

import com.google.protobuf.ByteString;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -112,7 +113,7 @@ public void processBytesFromPeerStartServer() throws Exception {
verify(mockServer, never()).startClientHandshake();
verify(mockServer, never()).next(ArgumentMatchers.<ByteBuffer>any());
// Mock transport buffer all consumed by processBytesFromPeer and there is an output frame.
transportBuffer.position(transportBuffer.limit());
((Buffer) transportBuffer).position(transportBuffer.limit());
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(outputFrame);
when(mockServer.isFinished()).thenReturn(false);

Expand All @@ -127,7 +128,7 @@ public void processBytesFromPeerStartServerEmptyOutput() throws Exception {
verify(mockServer, never()).next(ArgumentMatchers.<ByteBuffer>any());
// Mock transport buffer all consumed by processBytesFromPeer and output frame is empty.
// Expect processBytesFromPeer return False, because more data are needed from the peer.
transportBuffer.position(transportBuffer.limit());
((Buffer) transportBuffer).position(transportBuffer.limit());
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(emptyOutputFrame);
when(mockServer.isFinished()).thenReturn(false);

Expand Down Expand Up @@ -174,7 +175,7 @@ public void processBytesFromPeerClientNext() throws Exception {
when(mockClient.isFinished()).thenReturn(false);

handshakerClient.getBytesToSendToPeer(transportBuffer);
transportBuffer.position(transportBuffer.limit());
((Buffer) transportBuffer).position(transportBuffer.limit());
assertFalse(handshakerClient.processBytesFromPeer(transportBuffer));
}

Expand Down
35 changes: 18 additions & 17 deletions alts/src/test/java/io/grpc/alts/internal/FakeTsiTest.java
Expand Up @@ -27,6 +27,7 @@
import io.netty.util.ReferenceCounted;
import io.netty.util.ResourceLeakDetector;
import io.netty.util.ResourceLeakDetector.Level;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
Expand Down Expand Up @@ -86,11 +87,11 @@ public void handshakeStateOrderTest() {

byte[] transportBufferBytes = new byte[TsiTest.getDefaultTransportBufferSize()];
ByteBuffer transportBuffer = ByteBuffer.wrap(transportBufferBytes);
transportBuffer.limit(0); // Start off with an empty buffer
((Buffer) transportBuffer).limit(0); // Start off with an empty buffer

transportBuffer.clear();
((Buffer) transportBuffer).clear();
clientHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertEquals(
FakeTsiHandshaker.State.CLIENT_INIT.toString().trim(),
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
Expand All @@ -99,14 +100,14 @@ public void handshakeStateOrderTest() {
assertFalse(transportBuffer.hasRemaining());

// client shouldn't offer any more bytes
transportBuffer.clear();
((Buffer) transportBuffer).clear();
clientHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertFalse(transportBuffer.hasRemaining());

transportBuffer.clear();
((Buffer) transportBuffer).clear();
serverHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertEquals(
FakeTsiHandshaker.State.SERVER_INIT.toString().trim(),
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
Expand All @@ -115,14 +116,14 @@ public void handshakeStateOrderTest() {
assertFalse(transportBuffer.hasRemaining());

// server shouldn't offer any more bytes
transportBuffer.clear();
((Buffer) transportBuffer).clear();
serverHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertFalse(transportBuffer.hasRemaining());

transportBuffer.clear();
((Buffer) transportBuffer).clear();
clientHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertEquals(
FakeTsiHandshaker.State.CLIENT_FINISHED.toString().trim(),
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
Expand All @@ -131,14 +132,14 @@ public void handshakeStateOrderTest() {
assertFalse(transportBuffer.hasRemaining());

// client shouldn't offer any more bytes
transportBuffer.clear();
((Buffer) transportBuffer).clear();
clientHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertFalse(transportBuffer.hasRemaining());

transportBuffer.clear();
((Buffer) transportBuffer).clear();
serverHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertEquals(
FakeTsiHandshaker.State.SERVER_FINISHED.toString().trim(),
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
Expand All @@ -147,9 +148,9 @@ public void handshakeStateOrderTest() {
assertFalse(transportBuffer.hasRemaining());

// server shouldn't offer any more bytes
transportBuffer.clear();
((Buffer) transportBuffer).clear();
serverHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
assertFalse(transportBuffer.hasRemaining());
} catch (GeneralSecurityException e) {
throw new AssertionError(e);
Expand Down
Expand Up @@ -20,6 +20,7 @@

import com.google.protobuf.ByteString;
import io.grpc.Status;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.SecureRandom;
Expand Down Expand Up @@ -62,7 +63,7 @@ static ByteString getOutFrame() {
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(frameSize);
buffer.put(TEST_OUT_FRAME.getBytes(UTF_8));
buffer.flip();
((Buffer) buffer).flip();
return ByteString.copyFrom(buffer);
}

Expand Down
7 changes: 4 additions & 3 deletions alts/src/test/java/io/grpc/alts/internal/TsiTest.java
Expand Up @@ -26,6 +26,7 @@
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
Expand Down Expand Up @@ -83,7 +84,7 @@ static void performHandshake(int transportBufferSize, Handshakers handshakers)

byte[] transportBufferBytes = new byte[transportBufferSize];
ByteBuffer transportBuffer = ByteBuffer.wrap(transportBufferBytes);
transportBuffer.limit(0); // Start off with an empty buffer
((Buffer) transportBuffer).limit(0); // Start off with an empty buffer

while (clientHandshaker.isInProgress() || serverHandshaker.isInProgress()) {
for (TsiHandshaker handshaker : new TsiHandshaker[] {clientHandshaker, serverHandshaker}) {
Expand All @@ -94,9 +95,9 @@ static void performHandshake(int transportBufferSize, Handshakers handshakers)
}
// Put new bytes on the wire, if needed.
if (handshaker.isInProgress()) {
transportBuffer.clear();
((Buffer) transportBuffer).clear();
handshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
((Buffer) transportBuffer).flip();
}
}
}
Expand Down