diff --git a/src/main/java/com/google/cloud/spanner/JdbcDataTypeConverter.java b/src/main/java/com/google/cloud/spanner/JdbcDataTypeConverter.java new file mode 100644 index 00000000..62d52c6c --- /dev/null +++ b/src/main/java/com/google/cloud/spanner/JdbcDataTypeConverter.java @@ -0,0 +1,29 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.InternalApi; + +@InternalApi +public class JdbcDataTypeConverter { + + /** Converts a protobuf type to a Spanner type. */ + @InternalApi + public static Type toSpannerType(com.google.spanner.v1.Type proto) { + return Type.fromProto(proto); + } +} diff --git a/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapper.java b/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapper.java index f577ac3c..b56aed03 100644 --- a/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapper.java +++ b/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapper.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner.jdbc; +import com.google.cloud.spanner.Dialect; import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.Code; import com.google.common.base.Preconditions; @@ -69,7 +70,74 @@ static int extractColumnType(Type type) { } } - /** Extract Spanner type name from {@link java.sql.Types} code. */ + static String getSpannerTypeName(Type type, Dialect dialect) { + // TODO: Use com.google.cloud.spanner.Type#getSpannerTypeName() when available. + Preconditions.checkNotNull(type); + switch (type.getCode()) { + case BOOL: + return dialect == Dialect.POSTGRESQL ? "boolean" : "BOOL"; + case BYTES: + return dialect == Dialect.POSTGRESQL ? "bytea" : "BYTES"; + case DATE: + return dialect == Dialect.POSTGRESQL ? "date" : "DATE"; + case FLOAT64: + return dialect == Dialect.POSTGRESQL ? "double precision" : "FLOAT64"; + case INT64: + return dialect == Dialect.POSTGRESQL ? "bigint" : "INT64"; + case NUMERIC: + return "NUMERIC"; + case PG_NUMERIC: + return "numeric"; + case STRING: + return dialect == Dialect.POSTGRESQL ? "character varying" : "STRING"; + case JSON: + return "JSON"; + case PG_JSONB: + return "jsonb"; + case TIMESTAMP: + return dialect == Dialect.POSTGRESQL ? "timestamp with time zone" : "TIMESTAMP"; + case STRUCT: + return "STRUCT"; + case ARRAY: + switch (type.getArrayElementType().getCode()) { + case BOOL: + return dialect == Dialect.POSTGRESQL ? "boolean[]" : "ARRAY"; + case BYTES: + return dialect == Dialect.POSTGRESQL ? "bytea[]" : "ARRAY"; + case DATE: + return dialect == Dialect.POSTGRESQL ? "date[]" : "ARRAY"; + case FLOAT64: + return dialect == Dialect.POSTGRESQL ? "double precision[]" : "ARRAY"; + case INT64: + return dialect == Dialect.POSTGRESQL ? "bigint[]" : "ARRAY"; + case NUMERIC: + return "ARRAY"; + case PG_NUMERIC: + return "numeric[]"; + case STRING: + return dialect == Dialect.POSTGRESQL ? "character varying[]" : "ARRAY"; + case JSON: + return "ARRAY"; + case PG_JSONB: + return "jsonb[]"; + case TIMESTAMP: + return dialect == Dialect.POSTGRESQL + ? "timestamp with time zone[]" + : "ARRAY"; + case STRUCT: + return "ARRAY"; + } + default: + return null; + } + } + + /** + * Extract Spanner type name from {@link java.sql.Types} code. + * + * @deprecated Use {@link #getSpannerTypeName(Type, Dialect)} instead. + */ + @Deprecated static String getSpannerTypeName(int sqlType) { if (sqlType == Types.BOOLEAN) return Type.bool().getCode().name(); if (sqlType == Types.BINARY) return Type.bytes().getCode().name(); @@ -89,7 +157,12 @@ static String getSpannerTypeName(int sqlType) { return OTHER_NAME; } - /** Get corresponding Java class name from {@link java.sql.Types} code. */ + /** + * Get corresponding Java class name from {@link java.sql.Types} code. + * + * @deprecated Use {@link #getClassName(Type)} instead. + */ + @Deprecated static String getClassName(int sqlType) { if (sqlType == Types.BOOLEAN) return Boolean.class.getName(); if (sqlType == Types.BINARY) return Byte[].class.getName(); diff --git a/src/main/java/com/google/cloud/spanner/jdbc/JdbcDataType.java b/src/main/java/com/google/cloud/spanner/jdbc/JdbcDataType.java index c495bbe1..5dd082ef 100644 --- a/src/main/java/com/google/cloud/spanner/jdbc/JdbcDataType.java +++ b/src/main/java/com/google/cloud/spanner/jdbc/JdbcDataType.java @@ -390,14 +390,18 @@ public Set> getSupportedJavaClasses() { public static JdbcDataType getType(Class clazz) { for (JdbcDataType type : JdbcDataType.values()) { - if (type.getSupportedJavaClasses().contains(clazz)) return type; + if (type.getSupportedJavaClasses().contains(clazz)) { + return type; + } } return null; } public static JdbcDataType getType(Code code) { for (JdbcDataType type : JdbcDataType.values()) { - if (type.getCode() == code) return type; + if (type.getCode() == code) { + return type; + } } return null; } diff --git a/src/main/java/com/google/cloud/spanner/jdbc/JdbcParameterMetaData.java b/src/main/java/com/google/cloud/spanner/jdbc/JdbcParameterMetaData.java index a520e221..82a4b913 100644 --- a/src/main/java/com/google/cloud/spanner/jdbc/JdbcParameterMetaData.java +++ b/src/main/java/com/google/cloud/spanner/jdbc/JdbcParameterMetaData.java @@ -16,7 +16,13 @@ package com.google.cloud.spanner.jdbc; -import com.google.cloud.spanner.connection.AbstractStatementParser.ParametersInfo; +import com.google.cloud.spanner.JdbcDataTypeConverter; +import com.google.cloud.spanner.ResultSet; +import com.google.rpc.Code; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; import java.math.BigDecimal; import java.sql.Date; import java.sql.ParameterMetaData; @@ -29,9 +35,23 @@ class JdbcParameterMetaData extends AbstractJdbcWrapper implements ParameterMetaData { private final JdbcPreparedStatement statement; - JdbcParameterMetaData(JdbcPreparedStatement statement) throws SQLException { + private final StructType parameters; + + JdbcParameterMetaData(JdbcPreparedStatement statement, ResultSet resultSet) { this.statement = statement; - statement.getParameters().fetchMetaData(statement.getConnection()); + this.parameters = resultSet.getMetadata().getUndeclaredParameters(); + } + + private Field getField(int param) throws SQLException { + JdbcPreconditions.checkArgument(param > 0 && param <= parameters.getFieldsCount(), param); + String paramName = "p" + param; + return parameters.getFieldsList().stream() + .filter(field -> field.getName().equals(paramName)) + .findAny() + .orElseThrow( + () -> + JdbcSqlExceptionFactory.of( + "Unknown parameter: " + paramName, Code.INVALID_ARGUMENT)); } @Override @@ -41,8 +61,7 @@ public boolean isClosed() { @Override public int getParameterCount() { - ParametersInfo info = statement.getParametersInfo(); - return info.numberOfParameters; + return parameters.getFieldsCount(); } @Override @@ -53,7 +72,7 @@ public int isNullable(int param) { } @Override - public boolean isSigned(int param) { + public boolean isSigned(int param) throws SQLException { int type = getParameterType(param); return type == Types.DOUBLE || type == Types.FLOAT @@ -77,9 +96,34 @@ public int getScale(int param) { } @Override - public int getParameterType(int param) { + public int getParameterType(int param) throws SQLException { + JdbcPreconditions.checkArgument(param > 0 && param <= parameters.getFieldsCount(), param); + int typeFromValue = getParameterTypeFromValue(param); + if (typeFromValue != Types.OTHER) { + return typeFromValue; + } + + Type type = getField(param).getType(); + // JDBC only has a generic ARRAY type. + if (type.getCode() == TypeCode.ARRAY) { + return Types.ARRAY; + } + JdbcDataType jdbcDataType = + JdbcDataType.getType(JdbcDataTypeConverter.toSpannerType(type).getCode()); + return jdbcDataType == null ? Types.OTHER : jdbcDataType.getSqlType(); + } + + /** + * This method returns the parameter type based on the parameter value that has been set. This was + * previously the only way to get the parameter types of a statement. Cloud Spanner can now return + * the types and names of parameters in a SQL string, which is what this method should return. + */ + // TODO: Remove this method for the next major version bump. + private int getParameterTypeFromValue(int param) { Integer type = statement.getParameters().getType(param); - if (type != null) return type; + if (type != null) { + return type; + } Object value = statement.getParameters().getParameter(param); if (value == null) { @@ -116,16 +160,49 @@ public int getParameterType(int param) { } @Override - public String getParameterTypeName(int param) { - return getSpannerTypeName(getParameterType(param)); + public String getParameterTypeName(int param) throws SQLException { + JdbcPreconditions.checkArgument(param > 0 && param <= parameters.getFieldsCount(), param); + String typeNameFromValue = getParameterTypeNameFromValue(param); + if (typeNameFromValue != null) { + return typeNameFromValue; + } + + com.google.cloud.spanner.Type type = + JdbcDataTypeConverter.toSpannerType(getField(param).getType()); + return getSpannerTypeName(type, statement.getConnection().getDialect()); + } + + private String getParameterTypeNameFromValue(int param) { + int type = getParameterTypeFromValue(param); + if (type != Types.OTHER) { + return getSpannerTypeName(type); + } + return null; } @Override - public String getParameterClassName(int param) { + public String getParameterClassName(int param) throws SQLException { + JdbcPreconditions.checkArgument(param > 0 && param <= parameters.getFieldsCount(), param); + String classNameFromValue = getParameterClassNameFromValue(param); + if (classNameFromValue != null) { + return classNameFromValue; + } + + com.google.cloud.spanner.Type type = + JdbcDataTypeConverter.toSpannerType(getField(param).getType()); + return getClassName(type); + } + + // TODO: Remove this method for the next major version bump. + private String getParameterClassNameFromValue(int param) { Object value = statement.getParameters().getParameter(param); - if (value != null) return value.getClass().getName(); + if (value != null) { + return value.getClass().getName(); + } Integer type = statement.getParameters().getType(param); - if (type != null) return getClassName(type); + if (type != null) { + return getClassName(type); + } return null; } @@ -136,22 +213,26 @@ public int getParameterMode(int param) { @Override public String toString() { - StringBuilder res = new StringBuilder(); - res.append("CloudSpannerPreparedStatementParameterMetaData, parameter count: ") - .append(getParameterCount()); - for (int param = 1; param <= getParameterCount(); param++) { - res.append("\nParameter ") - .append(param) - .append(":\n\t Class name: ") - .append(getParameterClassName(param)); - res.append(",\n\t Parameter type name: ").append(getParameterTypeName(param)); - res.append(",\n\t Parameter type: ").append(getParameterType(param)); - res.append(",\n\t Parameter precision: ").append(getPrecision(param)); - res.append(",\n\t Parameter scale: ").append(getScale(param)); - res.append(",\n\t Parameter signed: ").append(isSigned(param)); - res.append(",\n\t Parameter nullable: ").append(isNullable(param)); - res.append(",\n\t Parameter mode: ").append(getParameterMode(param)); + try { + StringBuilder res = new StringBuilder(); + res.append("CloudSpannerPreparedStatementParameterMetaData, parameter count: ") + .append(getParameterCount()); + for (int param = 1; param <= getParameterCount(); param++) { + res.append("\nParameter ") + .append(param) + .append(":\n\t Class name: ") + .append(getParameterClassName(param)); + res.append(",\n\t Parameter type name: ").append(getParameterTypeName(param)); + res.append(",\n\t Parameter type: ").append(getParameterType(param)); + res.append(",\n\t Parameter precision: ").append(getPrecision(param)); + res.append(",\n\t Parameter scale: ").append(getScale(param)); + res.append(",\n\t Parameter signed: ").append(isSigned(param)); + res.append(",\n\t Parameter nullable: ").append(isNullable(param)); + res.append(",\n\t Parameter mode: ").append(getParameterMode(param)); + } + return res.toString(); + } catch (SQLException exception) { + return "Failed to get parameter metadata: " + exception; } - return res.toString(); } } diff --git a/src/main/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatement.java b/src/main/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatement.java index 518807dd..9ebbc98f 100644 --- a/src/main/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatement.java +++ b/src/main/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatement.java @@ -40,6 +40,7 @@ class JdbcPreparedStatement extends AbstractJdbcPreparedStatement private static final char POS_PARAM_CHAR = '?'; private final String sql; private final ParametersInfo parameters; + private JdbcParameterMetaData cachedParameterMetadata; private final ImmutableList generatedKeysColumns; JdbcPreparedStatement( @@ -118,7 +119,34 @@ public void addBatch() throws SQLException { @Override public JdbcParameterMetaData getParameterMetaData() throws SQLException { checkClosed(); - return new JdbcParameterMetaData(this); + if (cachedParameterMetadata == null) { + if (getConnection().getParser().isUpdateStatement(sql) + && !getConnection().getParser().checkReturningClause(sql)) { + cachedParameterMetadata = getParameterMetadataForUpdate(); + } else { + cachedParameterMetadata = getParameterMetadataForQuery(); + } + } + return cachedParameterMetadata; + } + + private JdbcParameterMetaData getParameterMetadataForUpdate() { + try (com.google.cloud.spanner.ResultSet resultSet = + getConnection() + .getSpannerConnection() + .analyzeUpdateStatement( + Statement.of(parameters.sqlWithNamedParameters), QueryAnalyzeMode.PLAN)) { + return new JdbcParameterMetaData(this, resultSet); + } + } + + private JdbcParameterMetaData getParameterMetadataForQuery() { + try (com.google.cloud.spanner.ResultSet resultSet = + getConnection() + .getSpannerConnection() + .analyzeQuery(Statement.of(parameters.sqlWithNamedParameters), QueryAnalyzeMode.PLAN)) { + return new JdbcParameterMetaData(this, resultSet); + } } @Override diff --git a/src/test/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapperTest.java b/src/test/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapperTest.java index 372bbb09..f8473f63 100644 --- a/src/test/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapperTest.java +++ b/src/test/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapperTest.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner.jdbc; +import static com.google.cloud.spanner.jdbc.AbstractJdbcWrapper.getSpannerTypeName; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -23,6 +24,8 @@ import static org.junit.Assert.fail; import com.google.cloud.Timestamp; +import com.google.cloud.spanner.Dialect; +import com.google.cloud.spanner.Type; import com.google.rpc.Code; import java.math.BigDecimal; import java.math.BigInteger; @@ -426,4 +429,68 @@ public void testParseTimestampWithCalendar() throws SQLException { assertThat(((JdbcSqlException) e).getCode()).isEqualTo(Code.INVALID_ARGUMENT); } } + + @Test + public void testGoogleSQLTypeNames() { + assertEquals("INT64", getSpannerTypeName(Type.int64(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("BOOL", getSpannerTypeName(Type.bool(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("FLOAT64", getSpannerTypeName(Type.float64(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("STRING", getSpannerTypeName(Type.string(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("BYTES", getSpannerTypeName(Type.bytes(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("DATE", getSpannerTypeName(Type.date(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("TIMESTAMP", getSpannerTypeName(Type.timestamp(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("JSON", getSpannerTypeName(Type.json(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("NUMERIC", getSpannerTypeName(Type.numeric(), Dialect.GOOGLE_STANDARD_SQL)); + + assertEquals( + "ARRAY", getSpannerTypeName(Type.array(Type.int64()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", getSpannerTypeName(Type.array(Type.bool()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", + getSpannerTypeName(Type.array(Type.float64()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", + getSpannerTypeName(Type.array(Type.string()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", getSpannerTypeName(Type.array(Type.bytes()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", getSpannerTypeName(Type.array(Type.date()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", + getSpannerTypeName(Type.array(Type.timestamp()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", getSpannerTypeName(Type.array(Type.json()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", + getSpannerTypeName(Type.array(Type.numeric()), Dialect.GOOGLE_STANDARD_SQL)); + } + + @Test + public void testPostgreSQLTypeNames() { + assertEquals("bigint", getSpannerTypeName(Type.int64(), Dialect.POSTGRESQL)); + assertEquals("boolean", getSpannerTypeName(Type.bool(), Dialect.POSTGRESQL)); + assertEquals("double precision", getSpannerTypeName(Type.float64(), Dialect.POSTGRESQL)); + assertEquals("character varying", getSpannerTypeName(Type.string(), Dialect.POSTGRESQL)); + assertEquals("bytea", getSpannerTypeName(Type.bytes(), Dialect.POSTGRESQL)); + assertEquals("date", getSpannerTypeName(Type.date(), Dialect.POSTGRESQL)); + assertEquals( + "timestamp with time zone", getSpannerTypeName(Type.timestamp(), Dialect.POSTGRESQL)); + assertEquals("jsonb", getSpannerTypeName(Type.pgJsonb(), Dialect.POSTGRESQL)); + assertEquals("numeric", getSpannerTypeName(Type.pgNumeric(), Dialect.POSTGRESQL)); + + assertEquals("bigint[]", getSpannerTypeName(Type.array(Type.int64()), Dialect.POSTGRESQL)); + assertEquals("boolean[]", getSpannerTypeName(Type.array(Type.bool()), Dialect.POSTGRESQL)); + assertEquals( + "double precision[]", getSpannerTypeName(Type.array(Type.float64()), Dialect.POSTGRESQL)); + assertEquals( + "character varying[]", getSpannerTypeName(Type.array(Type.string()), Dialect.POSTGRESQL)); + assertEquals("bytea[]", getSpannerTypeName(Type.array(Type.bytes()), Dialect.POSTGRESQL)); + assertEquals("date[]", getSpannerTypeName(Type.array(Type.date()), Dialect.POSTGRESQL)); + assertEquals( + "timestamp with time zone[]", + getSpannerTypeName(Type.array(Type.timestamp()), Dialect.POSTGRESQL)); + assertEquals("jsonb[]", getSpannerTypeName(Type.array(Type.pgJsonb()), Dialect.POSTGRESQL)); + assertEquals("numeric[]", getSpannerTypeName(Type.array(Type.pgNumeric()), Dialect.POSTGRESQL)); + } } diff --git a/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementTest.java b/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementTest.java index 310d1546..c5748d1c 100644 --- a/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementTest.java +++ b/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementTest.java @@ -18,9 +18,9 @@ import static com.google.cloud.spanner.jdbc.JdbcConnection.NO_GENERATED_KEY_COLUMNS; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; @@ -39,6 +39,10 @@ import com.google.cloud.spanner.Value; import com.google.cloud.spanner.connection.AbstractStatementParser; import com.google.cloud.spanner.connection.Connection; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.TypeCode; import java.io.ByteArrayInputStream; import java.io.StringReader; import java.math.BigDecimal; @@ -55,6 +59,8 @@ import java.util.Collections; import java.util.TimeZone; import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -158,7 +164,8 @@ public void testParameters() throws SQLException, MalformedURLException { final int numberOfParams = 53; String sql = generateSqlWithParameters(numberOfParams); - JdbcConnection connection = createMockConnection(); + Connection spannerConnection = createMockConnectionWithAnalyzeResults(numberOfParams); + JdbcConnection connection = createMockConnection(spannerConnection); try (JdbcPreparedStatement ps = new JdbcPreparedStatement(connection, sql, NO_GENERATED_KEY_COLUMNS)) { ps.setArray(1, connection.createArrayOf("INT64", new Long[] {1L, 2L, 3L})); @@ -252,10 +259,14 @@ public void testParameters() throws SQLException, MalformedURLException { assertEquals(String.class.getName(), pmd.getParameterClassName(35)); assertEquals(String.class.getName(), pmd.getParameterClassName(36)); assertEquals(String.class.getName(), pmd.getParameterClassName(37)); - assertNull(pmd.getParameterClassName(38)); - assertNull(pmd.getParameterClassName(39)); + + // These parameter values are not set, so the driver returns the type that was returned by + // Cloud Spanner. + assertEquals(String.class.getName(), pmd.getParameterClassName(38)); + assertEquals(String.class.getName(), pmd.getParameterClassName(39)); + assertEquals(Short.class.getName(), pmd.getParameterClassName(40)); - assertNull(pmd.getParameterClassName(41)); + assertEquals(String.class.getName(), pmd.getParameterClassName(41)); assertEquals(String.class.getName(), pmd.getParameterClassName(42)); assertEquals(Time.class.getName(), pmd.getParameterClassName(43)); assertEquals(Time.class.getName(), pmd.getParameterClassName(44)); @@ -279,8 +290,11 @@ public void testParameters() throws SQLException, MalformedURLException { public void testSetNullValues() throws SQLException { final int numberOfParameters = 31; String sql = generateSqlWithParameters(numberOfParameters); + + JdbcConnection connection = + createMockConnection(createMockConnectionWithAnalyzeResults(numberOfParameters)); try (JdbcPreparedStatement ps = - new JdbcPreparedStatement(createMockConnection(), sql, NO_GENERATED_KEY_COLUMNS)) { + new JdbcPreparedStatement(connection, sql, NO_GENERATED_KEY_COLUMNS)) { int index = 0; ps.setNull(++index, Types.BLOB); ps.setNull(++index, Types.NVARCHAR); @@ -396,4 +410,34 @@ public void testInvalidSql() { assertEquals( ErrorCode.INVALID_ARGUMENT.getGrpcStatusCode().value(), jdbcSqlException.getErrorCode()); } + + private Connection createMockConnectionWithAnalyzeResults(int numParams) { + Connection spannerConnection = mock(Connection.class); + ResultSet resultSet = mock(ResultSet.class); + when(spannerConnection.analyzeUpdateStatement(any(Statement.class), eq(QueryAnalyzeMode.PLAN))) + .thenReturn(resultSet); + when(spannerConnection.analyzeQuery(any(Statement.class), eq(QueryAnalyzeMode.PLAN))) + .thenReturn(resultSet); + ResultSetMetadata metadata = + ResultSetMetadata.newBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addAllFields( + IntStream.range(0, numParams) + .mapToObj( + i -> + Field.newBuilder() + .setName("p" + (i + 1)) + .setType( + com.google.spanner.v1.Type.newBuilder() + .setCode(TypeCode.STRING) + .build()) + .build()) + .collect(Collectors.toList())) + .build()) + .build(); + when(resultSet.getMetadata()).thenReturn(metadata); + + return spannerConnection; + } } diff --git a/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementWithMockedServerTest.java b/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementWithMockedServerTest.java index d3607d84..a3072e31 100644 --- a/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementWithMockedServerTest.java +++ b/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementWithMockedServerTest.java @@ -28,6 +28,13 @@ import com.google.cloud.spanner.Value; import com.google.cloud.spanner.connection.SpannerPool; import com.google.cloud.spanner.jdbc.JdbcSqlExceptionFactory.JdbcSqlBatchUpdateException; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; import io.grpc.Server; import io.grpc.Status; import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; @@ -198,11 +205,180 @@ public void testExecuteBatch_withException() throws SQLException { @Test public void testInsertUntypedNullValues() throws SQLException { + String sql = + "insert into all_nullable_types (ColInt64, ColFloat64, ColBool, ColString, ColBytes, ColDate, ColTimestamp, ColNumeric, ColJson, ColInt64Array, ColFloat64Array, ColBoolArray, ColStringArray, ColBytesArray, ColDateArray, ColTimestampArray, ColNumericArray, ColJsonArray) " + + "values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11, @p12, @p13, @p14, @p15, @p16, @p17, @p18)"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p2") + .setType( + Type.newBuilder().setCode(TypeCode.FLOAT64).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p3") + .setType(Type.newBuilder().setCode(TypeCode.BOOL).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p4") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p5") + .setType(Type.newBuilder().setCode(TypeCode.BYTES).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p6") + .setType(Type.newBuilder().setCode(TypeCode.DATE).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p7") + .setType( + Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p8") + .setType( + Type.newBuilder().setCode(TypeCode.NUMERIC).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p9") + .setType(Type.newBuilder().setCode(TypeCode.JSON).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p10") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.INT64) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p11") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.FLOAT64) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p12") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.BOOL) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p13") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.STRING) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p14") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.BYTES) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p15") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.DATE) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p16") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.TIMESTAMP) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p17") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.NUMERIC) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p18") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.JSON) + .build()) + .build()) + .build()) + .build()) + .build()) + .setStats(ResultSetStats.newBuilder().build()) + .build())); mockSpanner.putStatementResult( StatementResult.update( - Statement.newBuilder( - "insert into all_nullable_types (ColInt64, ColFloat64, ColBool, ColString, ColBytes, ColDate, ColTimestamp, ColNumeric, ColJson, ColInt64Array, ColFloat64Array, ColBoolArray, ColStringArray, ColBytesArray, ColDateArray, ColTimestampArray, ColNumericArray, ColJsonArray) " - + "values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11, @p12, @p13, @p14, @p15, @p16, @p17, @p18)") + Statement.newBuilder(sql) .bind("p1") .to((Value) null) .bind("p2") diff --git a/src/test/java/com/google/cloud/spanner/jdbc/PreparedStatementParameterMetadataTest.java b/src/test/java/com/google/cloud/spanner/jdbc/PreparedStatementParameterMetadataTest.java new file mode 100644 index 00000000..8b7130ed --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/jdbc/PreparedStatementParameterMetadataTest.java @@ -0,0 +1,361 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.jdbc; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.spanner.Dialect; +import com.google.cloud.spanner.MockSpannerServiceImpl; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.cloud.spanner.connection.SpannerPool; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeAnnotationCode; +import com.google.spanner.v1.TypeCode; +import java.sql.Connection; +import java.sql.ParameterMetaData; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.sql.Types; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class PreparedStatementParameterMetadataTest extends AbstractMockServerTest { + + @After + public void reset() { + // This ensures that each test gets a fresh Spanner instance. This is necessary to get a new + // dialect result for each connection. + SpannerPool.closeSpannerPool(); + } + + @Test + public void testAllTypesParameterMetadata_GoogleSql() throws SQLException { + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.detectDialectResult(Dialect.GOOGLE_STANDARD_SQL)); + + String baseSql = + "insert into all_types (col_bool, col_bytes, col_date, col_float64, col_int64, " + + "col_json, col_numeric, col_string, col_timestamp, col_bool_array, col_bytes_array, " + + "col_date_array, col_float64_array, col_int64_array, col_json_array, col_numeric_array, " + + "col_string_array, col_timestamp_array) values (%s)"; + String jdbcSql = + String.format( + baseSql, + IntStream.range(0, 18).mapToObj(ignored -> "?").collect(Collectors.joining(", "))); + String googleSql = + String.format( + baseSql, + IntStream.range(1, 19) + .mapToObj(index -> "@p" + index) + .collect(Collectors.joining(", "))); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(googleSql), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setUndeclaredParameters( + createAllTypesParameters(Dialect.GOOGLE_STANDARD_SQL)) + .build()) + .setStats(ResultSetStats.newBuilder().build()) + .build())); + + try (Connection connection = createJdbcConnection()) { + try (PreparedStatement statement = connection.prepareStatement(jdbcSql)) { + ParameterMetaData metadata = statement.getParameterMetaData(); + assertEquals(18, metadata.getParameterCount()); + int index = 0; + assertEquals(Types.BOOLEAN, metadata.getParameterType(++index)); + assertEquals("BOOL", metadata.getParameterTypeName(index)); + assertEquals(Types.BINARY, metadata.getParameterType(++index)); + assertEquals("BYTES", metadata.getParameterTypeName(index)); + assertEquals(Types.DATE, metadata.getParameterType(++index)); + assertEquals("DATE", metadata.getParameterTypeName(index)); + assertEquals(Types.DOUBLE, metadata.getParameterType(++index)); + assertEquals("FLOAT64", metadata.getParameterTypeName(index)); + assertEquals(Types.BIGINT, metadata.getParameterType(++index)); + assertEquals("INT64", metadata.getParameterTypeName(index)); + assertEquals(JsonType.VENDOR_TYPE_NUMBER, metadata.getParameterType(++index)); + assertEquals("JSON", metadata.getParameterTypeName(index)); + assertEquals(Types.NUMERIC, metadata.getParameterType(++index)); + assertEquals("NUMERIC", metadata.getParameterTypeName(index)); + assertEquals(Types.NVARCHAR, metadata.getParameterType(++index)); + assertEquals("STRING", metadata.getParameterTypeName(index)); + assertEquals(Types.TIMESTAMP, metadata.getParameterType(++index)); + assertEquals("TIMESTAMP", metadata.getParameterTypeName(index)); + + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + } + } + } + + @Test + public void testAllTypesParameterMetadata_PostgreSQL() throws SQLException { + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.detectDialectResult(Dialect.POSTGRESQL)); + + String baseSql = + "insert into all_types (col_bool, col_bytes, col_date, col_float64, col_int64, " + + "col_json, col_numeric, col_string, col_timestamp, col_bool_array, col_bytes_array, " + + "col_date_array, col_float64_array, col_int64_array, col_json_array, col_numeric_array, " + + "col_string_array, col_timestamp_array) values (%s)"; + String jdbcSql = + String.format( + baseSql, + IntStream.range(0, 18).mapToObj(ignored -> "?").collect(Collectors.joining(", "))); + String googleSql = + String.format( + baseSql, + IntStream.range(1, 19) + .mapToObj(index -> "$" + index) + .collect(Collectors.joining(", "))); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(googleSql), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setUndeclaredParameters(createAllTypesParameters(Dialect.POSTGRESQL)) + .build()) + .setStats(ResultSetStats.newBuilder().build()) + .build())); + + try (Connection connection = createJdbcConnection()) { + try (PreparedStatement statement = connection.prepareStatement(jdbcSql)) { + ParameterMetaData metadata = statement.getParameterMetaData(); + assertEquals(18, metadata.getParameterCount()); + int index = 0; + assertEquals(Types.BOOLEAN, metadata.getParameterType(++index)); + assertEquals("boolean", metadata.getParameterTypeName(index)); + assertEquals(Types.BINARY, metadata.getParameterType(++index)); + assertEquals("bytea", metadata.getParameterTypeName(index)); + assertEquals(Types.DATE, metadata.getParameterType(++index)); + assertEquals("date", metadata.getParameterTypeName(index)); + assertEquals(Types.DOUBLE, metadata.getParameterType(++index)); + assertEquals("double precision", metadata.getParameterTypeName(index)); + assertEquals(Types.BIGINT, metadata.getParameterType(++index)); + assertEquals("bigint", metadata.getParameterTypeName(index)); + assertEquals(PgJsonbType.VENDOR_TYPE_NUMBER, metadata.getParameterType(++index)); + assertEquals("jsonb", metadata.getParameterTypeName(index)); + assertEquals(Types.NUMERIC, metadata.getParameterType(++index)); + assertEquals("numeric", metadata.getParameterTypeName(index)); + assertEquals(Types.NVARCHAR, metadata.getParameterType(++index)); + assertEquals("character varying", metadata.getParameterTypeName(index)); + assertEquals(Types.TIMESTAMP, metadata.getParameterType(++index)); + assertEquals("timestamp with time zone", metadata.getParameterTypeName(index)); + + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("boolean[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("bytea[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("date[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("double precision[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("bigint[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("jsonb[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("numeric[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("character varying[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("timestamp with time zone[]", metadata.getParameterTypeName(index)); + } + } + } + + static StructType createAllTypesParameters(Dialect dialect) { + return StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.BOOL).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p2") + .setType(Type.newBuilder().setCode(TypeCode.BYTES).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p3") + .setType(Type.newBuilder().setCode(TypeCode.DATE).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p4") + .setType(Type.newBuilder().setCode(TypeCode.FLOAT64).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p5") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p6") + .setType( + Type.newBuilder() + .setCode(TypeCode.JSON) + .setTypeAnnotation( + dialect == Dialect.POSTGRESQL + ? TypeAnnotationCode.PG_JSONB + : TypeAnnotationCode.TYPE_ANNOTATION_CODE_UNSPECIFIED) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p7") + .setType( + Type.newBuilder() + .setCode(TypeCode.NUMERIC) + .setTypeAnnotation( + dialect == Dialect.POSTGRESQL + ? TypeAnnotationCode.PG_NUMERIC + : TypeAnnotationCode.TYPE_ANNOTATION_CODE_UNSPECIFIED) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p8") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p9") + .setType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p10") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.BOOL).build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p11") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.BYTES).build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p12") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.DATE).build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p13") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.FLOAT64).build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p14") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.INT64).build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p15") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.JSON) + .setTypeAnnotation( + dialect == Dialect.POSTGRESQL + ? TypeAnnotationCode.PG_JSONB + : TypeAnnotationCode.TYPE_ANNOTATION_CODE_UNSPECIFIED) + .build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p16") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.NUMERIC) + .setTypeAnnotation( + dialect == Dialect.POSTGRESQL + ? TypeAnnotationCode.PG_NUMERIC + : TypeAnnotationCode.TYPE_ANNOTATION_CODE_UNSPECIFIED) + .build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p17") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.STRING).build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p18") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build())) + .build()) + .build(); + } +} diff --git a/src/test/java/com/google/cloud/spanner/jdbc/it/ITJdbcPreparedStatementTest.java b/src/test/java/com/google/cloud/spanner/jdbc/it/ITJdbcPreparedStatementTest.java index 2864559c..8f014937 100644 --- a/src/test/java/com/google/cloud/spanner/jdbc/it/ITJdbcPreparedStatementTest.java +++ b/src/test/java/com/google/cloud/spanner/jdbc/it/ITJdbcPreparedStatementTest.java @@ -21,7 +21,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.junit.Assume.assumeFalse; @@ -35,6 +34,7 @@ import com.google.cloud.spanner.jdbc.JsonType; import com.google.cloud.spanner.testing.EmulatorSpannerHelper; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; import com.google.common.io.BaseEncoding; import com.google.common.io.CharStreams; import java.io.IOException; @@ -394,7 +394,26 @@ public void test01_InsertTestData() throws SQLException { try (PreparedStatement ps = connection.prepareStatement( "INSERT INTO Singers (SingerId, FirstName, LastName, SingerInfo, BirthDate) values (?,?,?,?,?)")) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 5); + assertParameterMetaData( + ps.getParameterMetaData(), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + Types.BIGINT, Types.NVARCHAR, Types.NVARCHAR, Types.BINARY, Types.NVARCHAR) + : ImmutableList.of( + Types.BIGINT, Types.NVARCHAR, Types.NVARCHAR, Types.BINARY, Types.DATE), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + "bigint", + "character varying", + "character varying", + "bytea", + "character varying") + : ImmutableList.of("INT64", "STRING", "STRING", "BYTES", "DATE"), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + Long.class, String.class, String.class, byte[].class, String.class) + : ImmutableList.of( + Long.class, String.class, String.class, byte[].class, Date.class)); for (Singer singer : createSingers()) { singer.setPreparedStatement(ps, getDialect()); assertInsertSingerParameterMetadata(ps.getParameterMetaData()); @@ -410,7 +429,13 @@ public void test01_InsertTestData() throws SQLException { try (PreparedStatement ps = connection.prepareStatement( "INSERT INTO Albums (SingerId, AlbumId, AlbumTitle, MarketingBudget) VALUES (?,?,?,?)")) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 4); + assertParameterMetaData( + ps.getParameterMetaData(), + ImmutableList.of(Types.BIGINT, Types.BIGINT, Types.NVARCHAR, Types.BIGINT), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of("bigint", "bigint", "character varying", "bigint") + : ImmutableList.of("INT64", "INT64", "STRING", "INT64"), + ImmutableList.of(Long.class, Long.class, String.class, Long.class)); for (Album album : createAlbums()) { ps.setLong(1, album.singerId); ps.setLong(2, album.albumId); @@ -425,7 +450,26 @@ public void test01_InsertTestData() throws SQLException { try (PreparedStatement ps = connection.prepareStatement( "INSERT INTO Songs (SingerId, AlbumId, TrackId, SongName, Duration, SongGenre) VALUES (?,?,?,?,?,?);")) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 6); + assertParameterMetaData( + ps.getParameterMetaData(), + ImmutableList.of( + Types.BIGINT, + Types.BIGINT, + Types.BIGINT, + Types.NVARCHAR, + Types.BIGINT, + Types.NVARCHAR), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + "bigint", + "bigint", + "bigint", + "character varying", + "bigint", + "character varying") + : ImmutableList.of("INT64", "INT64", "INT64", "STRING", "INT64", "STRING"), + ImmutableList.of( + Long.class, Long.class, Long.class, String.class, Long.class, String.class)); for (Song song : createSongs()) { ps.setByte(1, (byte) song.singerId); ps.setInt(2, (int) song.albumId); @@ -441,8 +485,36 @@ public void test01_InsertTestData() throws SQLException { } try (PreparedStatement ps = connection.prepareStatement(getConcertsInsertQuery(dialect.dialect))) { - assertDefaultParameterMetaData( - ps.getParameterMetaData(), getConcertExpectedParamCount(dialect.dialect)); + assertParameterMetaData( + ps.getParameterMetaData(), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + Types.BIGINT, Types.BIGINT, Types.NVARCHAR, Types.NVARCHAR, Types.NVARCHAR) + : ImmutableList.of( + Types.BIGINT, + Types.BIGINT, + Types.DATE, + Types.TIMESTAMP, + Types.TIMESTAMP, + Types.ARRAY), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + "bigint", + "bigint", + "character varying", + "character varying", + "character varying") + : ImmutableList.of( + "INT64", "INT64", "DATE", "TIMESTAMP", "TIMESTAMP", "ARRAY"), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of(Long.class, Long.class, String.class, String.class, String.class) + : ImmutableList.of( + Long.class, + Long.class, + Date.class, + Timestamp.class, + Timestamp.class, + Long[].class)); for (Concert concert : createConcerts()) { concert.setPreparedStatement(connection, ps, getDialect()); assertInsertConcertParameterMetadata(ps.getParameterMetaData()); @@ -564,7 +636,24 @@ public void test03_Dates() throws SQLException { try (PreparedStatement ps = connection.prepareStatement( "INSERT INTO Concerts (VenueId, SingerId, ConcertDate, BeginTime, EndTime, TicketPrices) VALUES (?,?,?,?,?,?);")) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 6); + assertParameterMetaData( + ps.getParameterMetaData(), + ImmutableList.of( + Types.BIGINT, + Types.BIGINT, + Types.DATE, + Types.TIMESTAMP, + Types.TIMESTAMP, + Types.ARRAY), + ImmutableList.of( + "INT64", "INT64", "DATE", "TIMESTAMP", "TIMESTAMP", "ARRAY"), + ImmutableList.of( + Long.class, + Long.class, + Date.class, + Timestamp.class, + Timestamp.class, + Long[].class)); ps.setLong(1, 100); ps.setLong(2, 19); ps.setDate(3, testDate); @@ -660,7 +749,24 @@ public void test04_Timestamps() throws SQLException { try (PreparedStatement ps = connection.prepareStatement( "INSERT INTO Concerts (VenueId, SingerId, ConcertDate, BeginTime, EndTime, TicketPrices) VALUES (?,?,?,?,?,?);")) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 6); + assertParameterMetaData( + ps.getParameterMetaData(), + ImmutableList.of( + Types.BIGINT, + Types.BIGINT, + Types.DATE, + Types.TIMESTAMP, + Types.TIMESTAMP, + Types.ARRAY), + ImmutableList.of( + "INT64", "INT64", "DATE", "TIMESTAMP", "TIMESTAMP", "ARRAY"), + ImmutableList.of( + Long.class, + Long.class, + Date.class, + Timestamp.class, + Timestamp.class, + Long[].class)); ps.setLong(1, 100); ps.setLong(2, 19); ps.setDate(3, new Date(System.currentTimeMillis())); @@ -868,7 +974,33 @@ public void test08_InsertAllColumnTypes() throws SQLException { + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, PENDING_COMMIT_TIMESTAMP(), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; try (Connection con = createConnection(env, database)) { try (PreparedStatement ps = con.prepareStatement(sql)) { + ParameterMetaData metadata = ps.getParameterMetaData(); + assertEquals(22, metadata.getParameterCount()); int index = 0; + assertEquals(Types.BIGINT, metadata.getParameterType(++index)); + assertEquals(Types.DOUBLE, metadata.getParameterType(++index)); + assertEquals(Types.BOOLEAN, metadata.getParameterType(++index)); + assertEquals(Types.NVARCHAR, metadata.getParameterType(++index)); + assertEquals(Types.NVARCHAR, metadata.getParameterType(++index)); + assertEquals(Types.BINARY, metadata.getParameterType(++index)); + assertEquals(Types.BINARY, metadata.getParameterType(++index)); + assertEquals(Types.DATE, metadata.getParameterType(++index)); + assertEquals(Types.TIMESTAMP, metadata.getParameterType(++index)); + assertEquals(Types.NUMERIC, metadata.getParameterType(++index)); + assertEquals(JsonType.VENDOR_TYPE_NUMBER, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + + index = 0; ps.setLong(++index, 1L); ps.setDouble(++index, 2D); ps.setBoolean(++index, true); @@ -1182,18 +1314,28 @@ public void test11_InsertDataUsingSpannerValue() throws SQLException { } } - private void assertDefaultParameterMetaData(ParameterMetaData pmd, int expectedParamCount) + private void assertParameterMetaData( + ParameterMetaData pmd, + ImmutableList sqlTypes, + ImmutableList typeNames, + ImmutableList> classNames) throws SQLException { - assertEquals(expectedParamCount, pmd.getParameterCount()); - for (int param = 1; param <= expectedParamCount; param++) { - assertEquals(Types.OTHER, pmd.getParameterType(param)); - assertEquals("OTHER", pmd.getParameterTypeName(param)); + assertEquals(sqlTypes.size(), typeNames.size()); + assertEquals(sqlTypes.size(), classNames.size()); + + ImmutableList signedTypes = + ImmutableList.of(Types.BIGINT, Types.NUMERIC, Types.DOUBLE); + assertEquals(sqlTypes.size(), pmd.getParameterCount()); + for (int param = 1; param <= sqlTypes.size(); param++) { + String msg = "Param " + param; + assertEquals(msg, sqlTypes.get(param - 1).intValue(), pmd.getParameterType(param)); + assertEquals(msg, typeNames.get(param - 1), pmd.getParameterTypeName(param)); assertEquals(0, pmd.getPrecision(param)); assertEquals(0, pmd.getScale(param)); - assertNull(pmd.getParameterClassName(param)); + assertEquals(msg, classNames.get(param - 1).getName(), pmd.getParameterClassName(param)); assertEquals(ParameterMetaData.parameterModeIn, pmd.getParameterMode(param)); assertEquals(ParameterMetaData.parameterNullableUnknown, pmd.isNullable(param)); - assertFalse(pmd.isSigned(param)); + assertEquals(msg, signedTypes.contains(sqlTypes.get(param - 1)), pmd.isSigned(param)); } } @@ -1214,7 +1356,26 @@ public void test12_InsertReturningTestData() throws SQLException { deleteStatements.executeBatch(); try (PreparedStatement ps = connection.prepareStatement(getSingersInsertReturningQuery(dialect.dialect))) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 5); + assertParameterMetaData( + ps.getParameterMetaData(), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + Types.BIGINT, Types.NVARCHAR, Types.NVARCHAR, Types.BINARY, Types.NVARCHAR) + : ImmutableList.of( + Types.BIGINT, Types.NVARCHAR, Types.NVARCHAR, Types.BINARY, Types.DATE), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + "bigint", + "character varying", + "character varying", + "bytea", + "character varying") + : ImmutableList.of("INT64", "STRING", "STRING", "BYTES", "DATE"), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + Long.class, String.class, String.class, byte[].class, String.class) + : ImmutableList.of( + Long.class, String.class, String.class, byte[].class, Date.class)); for (Singer singer : createSingers()) { singer.setPreparedStatement(ps, getDialect()); assertInsertSingerParameterMetadata(ps.getParameterMetaData()); @@ -1229,7 +1390,13 @@ public void test12_InsertReturningTestData() throws SQLException { } try (PreparedStatement ps = connection.prepareStatement(getAlbumsInsertReturningQuery(dialect.dialect))) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 4); + assertParameterMetaData( + ps.getParameterMetaData(), + ImmutableList.of(Types.BIGINT, Types.BIGINT, Types.NVARCHAR, Types.BIGINT), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of("bigint", "bigint", "character varying", "bigint") + : ImmutableList.of("INT64", "INT64", "STRING", "INT64"), + ImmutableList.of(Long.class, Long.class, String.class, Long.class)); for (Album album : createAlbums()) { ps.setLong(1, album.singerId); ps.setLong(2, album.albumId); @@ -1249,7 +1416,26 @@ public void test12_InsertReturningTestData() throws SQLException { } try (PreparedStatement ps = connection.prepareStatement(getSongsInsertReturningQuery(dialect.dialect))) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 6); + assertParameterMetaData( + ps.getParameterMetaData(), + ImmutableList.of( + Types.BIGINT, + Types.BIGINT, + Types.BIGINT, + Types.NVARCHAR, + Types.BIGINT, + Types.NVARCHAR), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + "bigint", + "bigint", + "bigint", + "character varying", + "bigint", + "character varying") + : ImmutableList.of("INT64", "INT64", "INT64", "STRING", "INT64", "STRING"), + ImmutableList.of( + Long.class, Long.class, Long.class, String.class, Long.class, String.class)); for (Song song : createSongs()) { ps.setByte(1, (byte) song.singerId); ps.setInt(2, (int) song.albumId); @@ -1277,8 +1463,36 @@ public void test12_InsertReturningTestData() throws SQLException { } try (PreparedStatement ps = connection.prepareStatement(getConcertsInsertReturningQuery(dialect.dialect))) { - assertDefaultParameterMetaData( - ps.getParameterMetaData(), getConcertExpectedParamCount(dialect.dialect)); + assertParameterMetaData( + ps.getParameterMetaData(), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + Types.BIGINT, Types.BIGINT, Types.NVARCHAR, Types.NVARCHAR, Types.NVARCHAR) + : ImmutableList.of( + Types.BIGINT, + Types.BIGINT, + Types.DATE, + Types.TIMESTAMP, + Types.TIMESTAMP, + Types.ARRAY), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + "bigint", + "bigint", + "character varying", + "character varying", + "character varying") + : ImmutableList.of( + "INT64", "INT64", "DATE", "TIMESTAMP", "TIMESTAMP", "ARRAY"), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of(Long.class, Long.class, String.class, String.class, String.class) + : ImmutableList.of( + Long.class, + Long.class, + Date.class, + Timestamp.class, + Timestamp.class, + Long[].class)); for (Concert concert : createConcerts()) { concert.setPreparedStatement(connection, ps, getDialect()); assertInsertConcertParameterMetadata(ps.getParameterMetaData());