diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java index aa96ae66dd..6a65b30f99 100644 --- a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java @@ -33,7 +33,10 @@ import java.util.Objects; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; /** * A BigQuery Stream Writer that can be used to write data into BigQuery Table. @@ -43,6 +46,12 @@ public class StreamWriter implements AutoCloseable { private static final Logger log = Logger.getLogger(StreamWriter.class.getName()); + private static String datasetsMatching = "projects/[^/]+/datasets/[^/]+/"; + private static Pattern streamPattern = Pattern.compile(datasetsMatching); + + // Cache of location info for a given dataset. + private static Map projectAndDatasetToLocation = new ConcurrentHashMap<>(); + /* * The identifier of stream to write to. */ @@ -167,12 +176,11 @@ public static SingleConnectionOrConnectionPool ofConnectionPool( } private StreamWriter(Builder builder) throws IOException { - BigQueryWriteClient client; this.streamName = builder.streamName; this.writerSchema = builder.writerSchema; - this.location = builder.location; boolean ownsBigQueryWriteClient = builder.client == null; if (!builder.enableConnectionPool) { + this.location = builder.location; this.singleConnectionOrConnectionPool = SingleConnectionOrConnectionPool.ofSingleConnection( new ConnectionWorker( @@ -185,9 +193,38 @@ private StreamWriter(Builder builder) throws IOException { getBigQueryWriteClient(builder), ownsBigQueryWriteClient)); } else { - if (builder.location == null || builder.location.isEmpty()) { - throw new IllegalArgumentException("Location must be specified for multiplexing client!"); + BigQueryWriteClient client = getBigQueryWriteClient(builder); + String location = builder.location; + if (location == null || location.isEmpty()) { + // Location is not passed in, try to fetch from RPC + String datasetAndProjectName = extractDatasetAndProjectName(builder.streamName); + location = + projectAndDatasetToLocation.computeIfAbsent( + datasetAndProjectName, + (key) -> { + GetWriteStreamRequest writeStreamRequest = + GetWriteStreamRequest.newBuilder() + .setName(this.getStreamName()) + .setView(WriteStreamView.BASIC) + .build(); + + WriteStream writeStream = client.getWriteStream(writeStreamRequest); + TableSchema writeStreamTableSchema = writeStream.getTableSchema(); + String fetchedLocation = writeStream.getLocation(); + log.info( + String.format( + "Fethed location %s for stream name %s", fetchedLocation, streamName)); + return fetchedLocation; + }); + if (location.isEmpty()) { + throw new IllegalStateException( + String.format( + "The location is empty for both user passed in value and looked up value for " + + "stream: %s", + streamName)); + } } + this.location = location; // Assume the connection in the same pool share the same client and trace id. // The first StreamWriter for a new stub will create the pool for the other // streams in the same region, meaning the per StreamWriter settings are no @@ -195,21 +232,40 @@ private StreamWriter(Builder builder) throws IOException { this.singleConnectionOrConnectionPool = SingleConnectionOrConnectionPool.ofConnectionPool( connectionPoolMap.computeIfAbsent( - ConnectionPoolKey.create(builder.location), + ConnectionPoolKey.create(location), (key) -> { - try { - return new ConnectionWorkerPool( - builder.maxInflightRequest, - builder.maxInflightBytes, - builder.limitExceededBehavior, - builder.traceId, - getBigQueryWriteClient(builder), - ownsBigQueryWriteClient); - } catch (IOException e) { - throw new RuntimeException(e); - } + return new ConnectionWorkerPool( + builder.maxInflightRequest, + builder.maxInflightBytes, + builder.limitExceededBehavior, + builder.traceId, + client, + ownsBigQueryWriteClient); })); validateFetchedConnectonPool(builder); + // Shut down the passed in client. Internally we will create another client inside connection + // pool for every new connection worker. + if (client != singleConnectionOrConnectionPool.connectionWorkerPool().bigQueryWriteClient() + && ownsBigQueryWriteClient) { + client.shutdown(); + try { + client.awaitTermination(150, TimeUnit.SECONDS); + } catch (InterruptedException unused) { + // Ignore interruption as this client is not used. + } + client.close(); + } + } + } + + @VisibleForTesting + static String extractDatasetAndProjectName(String streamName) { + Matcher streamMatcher = streamPattern.matcher(streamName); + if (streamMatcher.find()) { + return streamMatcher.group(); + } else { + throw new IllegalStateException( + String.format("The passed in stream name does not match standard format %s", streamName)); } } diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/JsonStreamWriterTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/JsonStreamWriterTest.java index 468df368c0..71b2bee1d6 100644 --- a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/JsonStreamWriterTest.java +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/JsonStreamWriterTest.java @@ -391,7 +391,7 @@ public void testAppendOutOfRangeException() throws Exception { } @Test - public void testCreateDefaultStream() throws Exception { + public void testCreateDefaultStream_withNoSchemaPassedIn() throws Exception { TableSchema tableSchema = TableSchema.newBuilder().addFields(0, TEST_INT).addFields(1, TEST_STRING).build(); testBigQueryWrite.addResponse( @@ -411,6 +411,28 @@ public void testCreateDefaultStream() throws Exception { } } + @Test + public void testCreateDefaultStream_withNoClientPassedIn() throws Exception { + TableSchema tableSchema = + TableSchema.newBuilder().addFields(0, TEST_INT).addFields(1, TEST_STRING).build(); + testBigQueryWrite.addResponse( + WriteStream.newBuilder() + .setName(TEST_STREAM) + .setLocation("aa") + .setTableSchema(tableSchema) + .build()); + try (JsonStreamWriter writer = + JsonStreamWriter.newBuilder(TEST_TABLE, tableSchema) + .setChannelProvider(channelProvider) + .setCredentialsProvider(NoCredentialsProvider.create()) + .setExecutorProvider(InstantiatingExecutorProvider.newBuilder().build()) + .setEnableConnectionPool(true) + .build()) { + assertEquals("projects/p/datasets/d/tables/t/_default", writer.getStreamName()); + assertEquals("aa", writer.getLocation()); + } + } + @Test public void testCreateDefaultStreamWrongLocation() throws Exception { TableSchema tableSchema = diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/StreamWriterTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/StreamWriterTest.java index 2cf8a60b29..bd9409ea52 100644 --- a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/StreamWriterTest.java +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/StreamWriterTest.java @@ -725,22 +725,20 @@ public void testInitialization_operationKind() throws Exception { } @Test - public void createStreamWithDifferentWhetherOwnsClient() throws Exception { - StreamWriter streamWriter1 = getMultiplexingTestStreamWriter(); + public void testExtractDatasetName() throws Exception { + Assert.assertEquals( + StreamWriter.extractDatasetAndProjectName( + "projects/project1/datasets/dataset2/tables/something"), + "projects/project1/datasets/dataset2/"); - assertThrows( - IllegalArgumentException.class, - new ThrowingRunnable() { - @Override - public void run() throws Throwable { - StreamWriter.newBuilder(TEST_STREAM) - .setWriterSchema(createProtoSchema()) - .setTraceId(TEST_TRACE_ID) - .setLocation("US") - .setEnableConnectionPool(true) - .build(); - } - }); + IllegalStateException ex = + assertThrows( + IllegalStateException.class, + () -> { + StreamWriter.extractDatasetAndProjectName( + "wrong/projects/project1/wrong/datasets/dataset2/tables/something"); + }); + Assert.assertTrue(ex.getMessage().contains("The passed in stream name does not match")); } // Timeout to ensure close() doesn't wait for done callback timeout.