diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPool.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPool.java index fc6152959d..e22d38cce0 100644 --- a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPool.java +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/ConnectionWorkerPool.java @@ -392,6 +392,21 @@ public void close(StreamWriter streamWriter) { } } + /** Fetch the wait seconds from corresponding worker. */ + public long getInflightWaitSeconds(StreamWriter streamWriter) { + lock.lock(); + try { + ConnectionWorker connectionWorker = streamWriterToConnection.get(streamWriter); + if (connectionWorker == null) { + return 0; + } else { + return connectionWorker.getInflightWaitSeconds(); + } + } finally { + lock.unlock(); + } + } + /** Enable Test related logic. */ public static void enableTestingLogic() { enableTesting = true; diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java index e4dc85e5ca..92631af228 100644 --- a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java @@ -141,10 +141,9 @@ public void close(StreamWriter streamWriter) { } } - long getInflightWaitSeconds() { + long getInflightWaitSeconds(StreamWriter streamWriter) { if (getKind() == Kind.CONNECTION_WORKER_POOL) { - throw new IllegalStateException( - "getInflightWaitSeconds is not supported in multiplexing mode."); + return connectionWorkerPool().getInflightWaitSeconds(streamWriter); } return connectionWorker().getInflightWaitSeconds(); } @@ -363,7 +362,7 @@ public ApiFuture append(ProtoRows rows, long offset) { * stream case. */ public long getInflightWaitSeconds() { - return singleConnectionOrConnectionPool.getInflightWaitSeconds(); + return singleConnectionOrConnectionPool.getInflightWaitSeconds(this); } /** @return a unique Id for the writer. */ diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/StreamWriterTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/StreamWriterTest.java index bd9409ea52..3f029ac811 100644 --- a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/StreamWriterTest.java +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/StreamWriterTest.java @@ -29,6 +29,7 @@ import com.google.api.gax.rpc.StatusCode.Code; import com.google.api.gax.rpc.UnknownException; import com.google.cloud.bigquery.storage.test.Test.FooType; +import com.google.cloud.bigquery.storage.v1.ConnectionWorkerPool.Settings; import com.google.cloud.bigquery.storage.v1.StorageError.StorageErrorCode; import com.google.cloud.bigquery.storage.v1.StreamWriter.SingleConnectionOrConnectionPool.Kind; import com.google.common.base.Strings; @@ -60,7 +61,8 @@ @RunWith(JUnit4.class) public class StreamWriterTest { private static final Logger log = Logger.getLogger(StreamWriterTest.class.getName()); - private static final String TEST_STREAM = "projects/p/datasets/d/tables/t/streams/s"; + private static final String TEST_STREAM_1 = "projects/p/datasets/d/tables/t/streams/s"; + private static final String TEST_STREAM_2 = "projects/p/datasets/d/tables/t/streams/s"; private static final String TEST_TRACE_ID = "DATAFLOW:job_id"; private FakeScheduledExecutorService fakeExecutor; private FakeBigQueryWrite testBigQueryWrite; @@ -94,7 +96,7 @@ public void tearDown() throws Exception { } private StreamWriter getMultiplexingTestStreamWriter() throws IOException { - return StreamWriter.newBuilder(TEST_STREAM, client) + return StreamWriter.newBuilder(TEST_STREAM_1, client) .setWriterSchema(createProtoSchema()) .setTraceId(TEST_TRACE_ID) .setLocation("US") @@ -103,7 +105,7 @@ private StreamWriter getMultiplexingTestStreamWriter() throws IOException { } private StreamWriter getTestStreamWriter() throws IOException { - return StreamWriter.newBuilder(TEST_STREAM, client) + return StreamWriter.newBuilder(TEST_STREAM_1, client) .setWriterSchema(createProtoSchema()) .setTraceId(TEST_TRACE_ID) .build(); @@ -197,7 +199,7 @@ private void verifyAppendRequests(long appendCount) { if (i == 0) { // First request received by server should have schema and stream name. assertTrue(serverRequest.getProtoRows().hasWriterSchema()); - assertEquals(serverRequest.getWriteStream(), TEST_STREAM); + assertEquals(serverRequest.getWriteStream(), TEST_STREAM_1); assertEquals(serverRequest.getTraceId(), TEST_TRACE_ID); } else { // Following request should not have schema and stream name. @@ -210,7 +212,7 @@ private void verifyAppendRequests(long appendCount) { public void testBuildBigQueryWriteClientInWriter() throws Exception { StreamWriter writer = - StreamWriter.newBuilder(TEST_STREAM) + StreamWriter.newBuilder(TEST_STREAM_1) .setCredentialsProvider(NoCredentialsProvider.create()) .setChannelProvider(serviceHelper.createChannelProvider()) .setWriterSchema(createProtoSchema()) @@ -253,7 +255,7 @@ public void testNoSchema() throws Exception { new ThrowingRunnable() { @Override public void run() throws Throwable { - StreamWriter.newBuilder(TEST_STREAM, client).build(); + StreamWriter.newBuilder(TEST_STREAM_1, client).build(); } }); assertEquals(ex.getStatus().getCode(), Status.INVALID_ARGUMENT.getCode()); @@ -267,7 +269,7 @@ public void testInvalidTraceId() throws Exception { new ThrowingRunnable() { @Override public void run() throws Throwable { - StreamWriter.newBuilder(TEST_STREAM).setTraceId("abc"); + StreamWriter.newBuilder(TEST_STREAM_1).setTraceId("abc"); } }); assertThrows( @@ -275,7 +277,7 @@ public void run() throws Throwable { new ThrowingRunnable() { @Override public void run() throws Throwable { - StreamWriter.newBuilder(TEST_STREAM).setTraceId("abc:"); + StreamWriter.newBuilder(TEST_STREAM_1).setTraceId("abc:"); } }); assertThrows( @@ -283,7 +285,7 @@ public void run() throws Throwable { new ThrowingRunnable() { @Override public void run() throws Throwable { - StreamWriter.newBuilder(TEST_STREAM).setTraceId(":abc"); + StreamWriter.newBuilder(TEST_STREAM_1).setTraceId(":abc"); } }); } @@ -487,7 +489,7 @@ public void serverCloseWhileRequestsInflight() throws Exception { @Test public void testZeroMaxInflightRequests() throws Exception { StreamWriter writer = - StreamWriter.newBuilder(TEST_STREAM, client) + StreamWriter.newBuilder(TEST_STREAM_1, client) .setWriterSchema(createProtoSchema()) .setMaxInflightRequests(0) .build(); @@ -499,7 +501,7 @@ public void testZeroMaxInflightRequests() throws Exception { @Test public void testZeroMaxInflightBytes() throws Exception { StreamWriter writer = - StreamWriter.newBuilder(TEST_STREAM, client) + StreamWriter.newBuilder(TEST_STREAM_1, client) .setWriterSchema(createProtoSchema()) .setMaxInflightBytes(0) .build(); @@ -511,7 +513,7 @@ public void testZeroMaxInflightBytes() throws Exception { @Test public void testOneMaxInflightRequests() throws Exception { StreamWriter writer = - StreamWriter.newBuilder(TEST_STREAM, client) + StreamWriter.newBuilder(TEST_STREAM_1, client) .setWriterSchema(createProtoSchema()) .setMaxInflightRequests(1) .build(); @@ -525,10 +527,45 @@ public void testOneMaxInflightRequests() throws Exception { writer.close(); } + @Test + public void testOneMaxInflightRequests_MultiplexingCase() throws Exception { + ConnectionWorkerPool.setOptions(Settings.builder().setMaxConnectionsPerRegion(2).build()); + StreamWriter writer1 = + StreamWriter.newBuilder(TEST_STREAM_1, client) + .setWriterSchema(createProtoSchema()) + .setLocation("US") + .setEnableConnectionPool(true) + .setMaxInflightRequests(1) + .build(); + StreamWriter writer2 = + StreamWriter.newBuilder(TEST_STREAM_2, client) + .setWriterSchema(createProtoSchema()) + .setMaxInflightRequests(1) + .setEnableConnectionPool(true) + .setMaxInflightRequests(1) + .setLocation("US") + .build(); + + // Server will sleep 1 second before every response. + testBigQueryWrite.setResponseSleep(Duration.ofSeconds(1)); + testBigQueryWrite.addResponse(createAppendResponse(0)); + testBigQueryWrite.addResponse(createAppendResponse(1)); + + ApiFuture appendFuture1 = sendTestMessage(writer1, new String[] {"A"}); + ApiFuture appendFuture2 = sendTestMessage(writer2, new String[] {"A"}); + + assertTrue(writer1.getInflightWaitSeconds() >= 1); + assertTrue(writer2.getInflightWaitSeconds() >= 1); + assertEquals(0, appendFuture1.get().getAppendResult().getOffset().getValue()); + assertEquals(1, appendFuture2.get().getAppendResult().getOffset().getValue()); + writer1.close(); + writer2.close(); + } + @Test public void testAppendsWithTinyMaxInflightBytes() throws Exception { StreamWriter writer = - StreamWriter.newBuilder(TEST_STREAM, client) + StreamWriter.newBuilder(TEST_STREAM_1, client) .setWriterSchema(createProtoSchema()) .setMaxInflightBytes(1) .build(); @@ -560,7 +597,7 @@ public void testAppendsWithTinyMaxInflightBytes() throws Exception { @Test public void testAppendsWithTinyMaxInflightBytesThrow() throws Exception { StreamWriter writer = - StreamWriter.newBuilder(TEST_STREAM, client) + StreamWriter.newBuilder(TEST_STREAM_1, client) .setWriterSchema(createProtoSchema()) .setMaxInflightBytes(1) .setLimitExceededBehavior(FlowController.LimitExceededBehavior.ThrowException) @@ -595,7 +632,7 @@ public void testLimitBehaviorIgnoreNotAccepted() throws Exception { @Override public void run() throws Throwable { StreamWriter writer = - StreamWriter.newBuilder(TEST_STREAM, client) + StreamWriter.newBuilder(TEST_STREAM_1, client) .setWriterSchema(createProtoSchema()) .setMaxInflightBytes(1) .setLimitExceededBehavior(FlowController.LimitExceededBehavior.Ignore) @@ -745,7 +782,7 @@ public void testExtractDatasetName() throws Exception { @Test(timeout = 10000) public void testCloseDisconnectedStream() throws Exception { StreamWriter writer = - StreamWriter.newBuilder(TEST_STREAM) + StreamWriter.newBuilder(TEST_STREAM_1) .setCredentialsProvider(NoCredentialsProvider.create()) .setChannelProvider(serviceHelper.createChannelProvider()) .setWriterSchema(createProtoSchema())