From 721ff4552104efba47c19ef511282071c3b334c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 22 Dec 2023 11:34:06 +0100 Subject: [PATCH] feat: support PreparedStatement#getParameterMetaData() (#1218) * feat: support PreparedStatement#getParameterMetaData() Add actual support for `PreparedStatement#getParameterMetaData()`. The first time this method is called for a PreparedStatement, the connection will now send the query to Cloud Spanner in analyze mode and without any parameter values. This will instruct Cloud Spanner to return the names and types of any query parameters in the statement. Fixes #35 * fix: restore previous behavior * fix: PostgreSQL string type name should be 'character varying' * fix: update type name to 'character varying' in integration test --- .../cloud/spanner/JdbcDataTypeConverter.java | 29 ++ .../spanner/jdbc/AbstractJdbcWrapper.java | 77 +++- .../cloud/spanner/jdbc/JdbcDataType.java | 8 +- .../spanner/jdbc/JdbcParameterMetaData.java | 139 +++++-- .../spanner/jdbc/JdbcPreparedStatement.java | 30 +- .../spanner/jdbc/AbstractJdbcWrapperTest.java | 67 ++++ .../jdbc/JdbcPreparedStatementTest.java | 56 ++- ...PreparedStatementWithMockedServerTest.java | 182 ++++++++- ...reparedStatementParameterMetadataTest.java | 361 ++++++++++++++++++ .../jdbc/it/ITJdbcPreparedStatementTest.java | 254 +++++++++++- 10 files changed, 1140 insertions(+), 63 deletions(-) create mode 100644 src/main/java/com/google/cloud/spanner/JdbcDataTypeConverter.java create mode 100644 src/test/java/com/google/cloud/spanner/jdbc/PreparedStatementParameterMetadataTest.java diff --git a/src/main/java/com/google/cloud/spanner/JdbcDataTypeConverter.java b/src/main/java/com/google/cloud/spanner/JdbcDataTypeConverter.java new file mode 100644 index 00000000..62d52c6c --- /dev/null +++ b/src/main/java/com/google/cloud/spanner/JdbcDataTypeConverter.java @@ -0,0 +1,29 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.InternalApi; + +@InternalApi +public class JdbcDataTypeConverter { + + /** Converts a protobuf type to a Spanner type. */ + @InternalApi + public static Type toSpannerType(com.google.spanner.v1.Type proto) { + return Type.fromProto(proto); + } +} diff --git a/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapper.java b/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapper.java index f577ac3c..b56aed03 100644 --- a/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapper.java +++ b/src/main/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapper.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner.jdbc; +import com.google.cloud.spanner.Dialect; import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.Code; import com.google.common.base.Preconditions; @@ -69,7 +70,74 @@ static int extractColumnType(Type type) { } } - /** Extract Spanner type name from {@link java.sql.Types} code. */ + static String getSpannerTypeName(Type type, Dialect dialect) { + // TODO: Use com.google.cloud.spanner.Type#getSpannerTypeName() when available. + Preconditions.checkNotNull(type); + switch (type.getCode()) { + case BOOL: + return dialect == Dialect.POSTGRESQL ? "boolean" : "BOOL"; + case BYTES: + return dialect == Dialect.POSTGRESQL ? "bytea" : "BYTES"; + case DATE: + return dialect == Dialect.POSTGRESQL ? "date" : "DATE"; + case FLOAT64: + return dialect == Dialect.POSTGRESQL ? "double precision" : "FLOAT64"; + case INT64: + return dialect == Dialect.POSTGRESQL ? "bigint" : "INT64"; + case NUMERIC: + return "NUMERIC"; + case PG_NUMERIC: + return "numeric"; + case STRING: + return dialect == Dialect.POSTGRESQL ? "character varying" : "STRING"; + case JSON: + return "JSON"; + case PG_JSONB: + return "jsonb"; + case TIMESTAMP: + return dialect == Dialect.POSTGRESQL ? "timestamp with time zone" : "TIMESTAMP"; + case STRUCT: + return "STRUCT"; + case ARRAY: + switch (type.getArrayElementType().getCode()) { + case BOOL: + return dialect == Dialect.POSTGRESQL ? "boolean[]" : "ARRAY"; + case BYTES: + return dialect == Dialect.POSTGRESQL ? "bytea[]" : "ARRAY"; + case DATE: + return dialect == Dialect.POSTGRESQL ? "date[]" : "ARRAY"; + case FLOAT64: + return dialect == Dialect.POSTGRESQL ? "double precision[]" : "ARRAY"; + case INT64: + return dialect == Dialect.POSTGRESQL ? "bigint[]" : "ARRAY"; + case NUMERIC: + return "ARRAY"; + case PG_NUMERIC: + return "numeric[]"; + case STRING: + return dialect == Dialect.POSTGRESQL ? "character varying[]" : "ARRAY"; + case JSON: + return "ARRAY"; + case PG_JSONB: + return "jsonb[]"; + case TIMESTAMP: + return dialect == Dialect.POSTGRESQL + ? "timestamp with time zone[]" + : "ARRAY"; + case STRUCT: + return "ARRAY"; + } + default: + return null; + } + } + + /** + * Extract Spanner type name from {@link java.sql.Types} code. + * + * @deprecated Use {@link #getSpannerTypeName(Type, Dialect)} instead. + */ + @Deprecated static String getSpannerTypeName(int sqlType) { if (sqlType == Types.BOOLEAN) return Type.bool().getCode().name(); if (sqlType == Types.BINARY) return Type.bytes().getCode().name(); @@ -89,7 +157,12 @@ static String getSpannerTypeName(int sqlType) { return OTHER_NAME; } - /** Get corresponding Java class name from {@link java.sql.Types} code. */ + /** + * Get corresponding Java class name from {@link java.sql.Types} code. + * + * @deprecated Use {@link #getClassName(Type)} instead. + */ + @Deprecated static String getClassName(int sqlType) { if (sqlType == Types.BOOLEAN) return Boolean.class.getName(); if (sqlType == Types.BINARY) return Byte[].class.getName(); diff --git a/src/main/java/com/google/cloud/spanner/jdbc/JdbcDataType.java b/src/main/java/com/google/cloud/spanner/jdbc/JdbcDataType.java index c495bbe1..5dd082ef 100644 --- a/src/main/java/com/google/cloud/spanner/jdbc/JdbcDataType.java +++ b/src/main/java/com/google/cloud/spanner/jdbc/JdbcDataType.java @@ -390,14 +390,18 @@ public Set> getSupportedJavaClasses() { public static JdbcDataType getType(Class clazz) { for (JdbcDataType type : JdbcDataType.values()) { - if (type.getSupportedJavaClasses().contains(clazz)) return type; + if (type.getSupportedJavaClasses().contains(clazz)) { + return type; + } } return null; } public static JdbcDataType getType(Code code) { for (JdbcDataType type : JdbcDataType.values()) { - if (type.getCode() == code) return type; + if (type.getCode() == code) { + return type; + } } return null; } diff --git a/src/main/java/com/google/cloud/spanner/jdbc/JdbcParameterMetaData.java b/src/main/java/com/google/cloud/spanner/jdbc/JdbcParameterMetaData.java index a520e221..82a4b913 100644 --- a/src/main/java/com/google/cloud/spanner/jdbc/JdbcParameterMetaData.java +++ b/src/main/java/com/google/cloud/spanner/jdbc/JdbcParameterMetaData.java @@ -16,7 +16,13 @@ package com.google.cloud.spanner.jdbc; -import com.google.cloud.spanner.connection.AbstractStatementParser.ParametersInfo; +import com.google.cloud.spanner.JdbcDataTypeConverter; +import com.google.cloud.spanner.ResultSet; +import com.google.rpc.Code; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; import java.math.BigDecimal; import java.sql.Date; import java.sql.ParameterMetaData; @@ -29,9 +35,23 @@ class JdbcParameterMetaData extends AbstractJdbcWrapper implements ParameterMetaData { private final JdbcPreparedStatement statement; - JdbcParameterMetaData(JdbcPreparedStatement statement) throws SQLException { + private final StructType parameters; + + JdbcParameterMetaData(JdbcPreparedStatement statement, ResultSet resultSet) { this.statement = statement; - statement.getParameters().fetchMetaData(statement.getConnection()); + this.parameters = resultSet.getMetadata().getUndeclaredParameters(); + } + + private Field getField(int param) throws SQLException { + JdbcPreconditions.checkArgument(param > 0 && param <= parameters.getFieldsCount(), param); + String paramName = "p" + param; + return parameters.getFieldsList().stream() + .filter(field -> field.getName().equals(paramName)) + .findAny() + .orElseThrow( + () -> + JdbcSqlExceptionFactory.of( + "Unknown parameter: " + paramName, Code.INVALID_ARGUMENT)); } @Override @@ -41,8 +61,7 @@ public boolean isClosed() { @Override public int getParameterCount() { - ParametersInfo info = statement.getParametersInfo(); - return info.numberOfParameters; + return parameters.getFieldsCount(); } @Override @@ -53,7 +72,7 @@ public int isNullable(int param) { } @Override - public boolean isSigned(int param) { + public boolean isSigned(int param) throws SQLException { int type = getParameterType(param); return type == Types.DOUBLE || type == Types.FLOAT @@ -77,9 +96,34 @@ public int getScale(int param) { } @Override - public int getParameterType(int param) { + public int getParameterType(int param) throws SQLException { + JdbcPreconditions.checkArgument(param > 0 && param <= parameters.getFieldsCount(), param); + int typeFromValue = getParameterTypeFromValue(param); + if (typeFromValue != Types.OTHER) { + return typeFromValue; + } + + Type type = getField(param).getType(); + // JDBC only has a generic ARRAY type. + if (type.getCode() == TypeCode.ARRAY) { + return Types.ARRAY; + } + JdbcDataType jdbcDataType = + JdbcDataType.getType(JdbcDataTypeConverter.toSpannerType(type).getCode()); + return jdbcDataType == null ? Types.OTHER : jdbcDataType.getSqlType(); + } + + /** + * This method returns the parameter type based on the parameter value that has been set. This was + * previously the only way to get the parameter types of a statement. Cloud Spanner can now return + * the types and names of parameters in a SQL string, which is what this method should return. + */ + // TODO: Remove this method for the next major version bump. + private int getParameterTypeFromValue(int param) { Integer type = statement.getParameters().getType(param); - if (type != null) return type; + if (type != null) { + return type; + } Object value = statement.getParameters().getParameter(param); if (value == null) { @@ -116,16 +160,49 @@ public int getParameterType(int param) { } @Override - public String getParameterTypeName(int param) { - return getSpannerTypeName(getParameterType(param)); + public String getParameterTypeName(int param) throws SQLException { + JdbcPreconditions.checkArgument(param > 0 && param <= parameters.getFieldsCount(), param); + String typeNameFromValue = getParameterTypeNameFromValue(param); + if (typeNameFromValue != null) { + return typeNameFromValue; + } + + com.google.cloud.spanner.Type type = + JdbcDataTypeConverter.toSpannerType(getField(param).getType()); + return getSpannerTypeName(type, statement.getConnection().getDialect()); + } + + private String getParameterTypeNameFromValue(int param) { + int type = getParameterTypeFromValue(param); + if (type != Types.OTHER) { + return getSpannerTypeName(type); + } + return null; } @Override - public String getParameterClassName(int param) { + public String getParameterClassName(int param) throws SQLException { + JdbcPreconditions.checkArgument(param > 0 && param <= parameters.getFieldsCount(), param); + String classNameFromValue = getParameterClassNameFromValue(param); + if (classNameFromValue != null) { + return classNameFromValue; + } + + com.google.cloud.spanner.Type type = + JdbcDataTypeConverter.toSpannerType(getField(param).getType()); + return getClassName(type); + } + + // TODO: Remove this method for the next major version bump. + private String getParameterClassNameFromValue(int param) { Object value = statement.getParameters().getParameter(param); - if (value != null) return value.getClass().getName(); + if (value != null) { + return value.getClass().getName(); + } Integer type = statement.getParameters().getType(param); - if (type != null) return getClassName(type); + if (type != null) { + return getClassName(type); + } return null; } @@ -136,22 +213,26 @@ public int getParameterMode(int param) { @Override public String toString() { - StringBuilder res = new StringBuilder(); - res.append("CloudSpannerPreparedStatementParameterMetaData, parameter count: ") - .append(getParameterCount()); - for (int param = 1; param <= getParameterCount(); param++) { - res.append("\nParameter ") - .append(param) - .append(":\n\t Class name: ") - .append(getParameterClassName(param)); - res.append(",\n\t Parameter type name: ").append(getParameterTypeName(param)); - res.append(",\n\t Parameter type: ").append(getParameterType(param)); - res.append(",\n\t Parameter precision: ").append(getPrecision(param)); - res.append(",\n\t Parameter scale: ").append(getScale(param)); - res.append(",\n\t Parameter signed: ").append(isSigned(param)); - res.append(",\n\t Parameter nullable: ").append(isNullable(param)); - res.append(",\n\t Parameter mode: ").append(getParameterMode(param)); + try { + StringBuilder res = new StringBuilder(); + res.append("CloudSpannerPreparedStatementParameterMetaData, parameter count: ") + .append(getParameterCount()); + for (int param = 1; param <= getParameterCount(); param++) { + res.append("\nParameter ") + .append(param) + .append(":\n\t Class name: ") + .append(getParameterClassName(param)); + res.append(",\n\t Parameter type name: ").append(getParameterTypeName(param)); + res.append(",\n\t Parameter type: ").append(getParameterType(param)); + res.append(",\n\t Parameter precision: ").append(getPrecision(param)); + res.append(",\n\t Parameter scale: ").append(getScale(param)); + res.append(",\n\t Parameter signed: ").append(isSigned(param)); + res.append(",\n\t Parameter nullable: ").append(isNullable(param)); + res.append(",\n\t Parameter mode: ").append(getParameterMode(param)); + } + return res.toString(); + } catch (SQLException exception) { + return "Failed to get parameter metadata: " + exception; } - return res.toString(); } } diff --git a/src/main/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatement.java b/src/main/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatement.java index 518807dd..9ebbc98f 100644 --- a/src/main/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatement.java +++ b/src/main/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatement.java @@ -40,6 +40,7 @@ class JdbcPreparedStatement extends AbstractJdbcPreparedStatement private static final char POS_PARAM_CHAR = '?'; private final String sql; private final ParametersInfo parameters; + private JdbcParameterMetaData cachedParameterMetadata; private final ImmutableList generatedKeysColumns; JdbcPreparedStatement( @@ -118,7 +119,34 @@ public void addBatch() throws SQLException { @Override public JdbcParameterMetaData getParameterMetaData() throws SQLException { checkClosed(); - return new JdbcParameterMetaData(this); + if (cachedParameterMetadata == null) { + if (getConnection().getParser().isUpdateStatement(sql) + && !getConnection().getParser().checkReturningClause(sql)) { + cachedParameterMetadata = getParameterMetadataForUpdate(); + } else { + cachedParameterMetadata = getParameterMetadataForQuery(); + } + } + return cachedParameterMetadata; + } + + private JdbcParameterMetaData getParameterMetadataForUpdate() { + try (com.google.cloud.spanner.ResultSet resultSet = + getConnection() + .getSpannerConnection() + .analyzeUpdateStatement( + Statement.of(parameters.sqlWithNamedParameters), QueryAnalyzeMode.PLAN)) { + return new JdbcParameterMetaData(this, resultSet); + } + } + + private JdbcParameterMetaData getParameterMetadataForQuery() { + try (com.google.cloud.spanner.ResultSet resultSet = + getConnection() + .getSpannerConnection() + .analyzeQuery(Statement.of(parameters.sqlWithNamedParameters), QueryAnalyzeMode.PLAN)) { + return new JdbcParameterMetaData(this, resultSet); + } } @Override diff --git a/src/test/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapperTest.java b/src/test/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapperTest.java index 372bbb09..f8473f63 100644 --- a/src/test/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapperTest.java +++ b/src/test/java/com/google/cloud/spanner/jdbc/AbstractJdbcWrapperTest.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner.jdbc; +import static com.google.cloud.spanner.jdbc.AbstractJdbcWrapper.getSpannerTypeName; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -23,6 +24,8 @@ import static org.junit.Assert.fail; import com.google.cloud.Timestamp; +import com.google.cloud.spanner.Dialect; +import com.google.cloud.spanner.Type; import com.google.rpc.Code; import java.math.BigDecimal; import java.math.BigInteger; @@ -426,4 +429,68 @@ public void testParseTimestampWithCalendar() throws SQLException { assertThat(((JdbcSqlException) e).getCode()).isEqualTo(Code.INVALID_ARGUMENT); } } + + @Test + public void testGoogleSQLTypeNames() { + assertEquals("INT64", getSpannerTypeName(Type.int64(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("BOOL", getSpannerTypeName(Type.bool(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("FLOAT64", getSpannerTypeName(Type.float64(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("STRING", getSpannerTypeName(Type.string(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("BYTES", getSpannerTypeName(Type.bytes(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("DATE", getSpannerTypeName(Type.date(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("TIMESTAMP", getSpannerTypeName(Type.timestamp(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("JSON", getSpannerTypeName(Type.json(), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals("NUMERIC", getSpannerTypeName(Type.numeric(), Dialect.GOOGLE_STANDARD_SQL)); + + assertEquals( + "ARRAY", getSpannerTypeName(Type.array(Type.int64()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", getSpannerTypeName(Type.array(Type.bool()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", + getSpannerTypeName(Type.array(Type.float64()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", + getSpannerTypeName(Type.array(Type.string()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", getSpannerTypeName(Type.array(Type.bytes()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", getSpannerTypeName(Type.array(Type.date()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", + getSpannerTypeName(Type.array(Type.timestamp()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", getSpannerTypeName(Type.array(Type.json()), Dialect.GOOGLE_STANDARD_SQL)); + assertEquals( + "ARRAY", + getSpannerTypeName(Type.array(Type.numeric()), Dialect.GOOGLE_STANDARD_SQL)); + } + + @Test + public void testPostgreSQLTypeNames() { + assertEquals("bigint", getSpannerTypeName(Type.int64(), Dialect.POSTGRESQL)); + assertEquals("boolean", getSpannerTypeName(Type.bool(), Dialect.POSTGRESQL)); + assertEquals("double precision", getSpannerTypeName(Type.float64(), Dialect.POSTGRESQL)); + assertEquals("character varying", getSpannerTypeName(Type.string(), Dialect.POSTGRESQL)); + assertEquals("bytea", getSpannerTypeName(Type.bytes(), Dialect.POSTGRESQL)); + assertEquals("date", getSpannerTypeName(Type.date(), Dialect.POSTGRESQL)); + assertEquals( + "timestamp with time zone", getSpannerTypeName(Type.timestamp(), Dialect.POSTGRESQL)); + assertEquals("jsonb", getSpannerTypeName(Type.pgJsonb(), Dialect.POSTGRESQL)); + assertEquals("numeric", getSpannerTypeName(Type.pgNumeric(), Dialect.POSTGRESQL)); + + assertEquals("bigint[]", getSpannerTypeName(Type.array(Type.int64()), Dialect.POSTGRESQL)); + assertEquals("boolean[]", getSpannerTypeName(Type.array(Type.bool()), Dialect.POSTGRESQL)); + assertEquals( + "double precision[]", getSpannerTypeName(Type.array(Type.float64()), Dialect.POSTGRESQL)); + assertEquals( + "character varying[]", getSpannerTypeName(Type.array(Type.string()), Dialect.POSTGRESQL)); + assertEquals("bytea[]", getSpannerTypeName(Type.array(Type.bytes()), Dialect.POSTGRESQL)); + assertEquals("date[]", getSpannerTypeName(Type.array(Type.date()), Dialect.POSTGRESQL)); + assertEquals( + "timestamp with time zone[]", + getSpannerTypeName(Type.array(Type.timestamp()), Dialect.POSTGRESQL)); + assertEquals("jsonb[]", getSpannerTypeName(Type.array(Type.pgJsonb()), Dialect.POSTGRESQL)); + assertEquals("numeric[]", getSpannerTypeName(Type.array(Type.pgNumeric()), Dialect.POSTGRESQL)); + } } diff --git a/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementTest.java b/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementTest.java index 310d1546..c5748d1c 100644 --- a/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementTest.java +++ b/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementTest.java @@ -18,9 +18,9 @@ import static com.google.cloud.spanner.jdbc.JdbcConnection.NO_GENERATED_KEY_COLUMNS; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; @@ -39,6 +39,10 @@ import com.google.cloud.spanner.Value; import com.google.cloud.spanner.connection.AbstractStatementParser; import com.google.cloud.spanner.connection.Connection; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.TypeCode; import java.io.ByteArrayInputStream; import java.io.StringReader; import java.math.BigDecimal; @@ -55,6 +59,8 @@ import java.util.Collections; import java.util.TimeZone; import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -158,7 +164,8 @@ public void testParameters() throws SQLException, MalformedURLException { final int numberOfParams = 53; String sql = generateSqlWithParameters(numberOfParams); - JdbcConnection connection = createMockConnection(); + Connection spannerConnection = createMockConnectionWithAnalyzeResults(numberOfParams); + JdbcConnection connection = createMockConnection(spannerConnection); try (JdbcPreparedStatement ps = new JdbcPreparedStatement(connection, sql, NO_GENERATED_KEY_COLUMNS)) { ps.setArray(1, connection.createArrayOf("INT64", new Long[] {1L, 2L, 3L})); @@ -252,10 +259,14 @@ public void testParameters() throws SQLException, MalformedURLException { assertEquals(String.class.getName(), pmd.getParameterClassName(35)); assertEquals(String.class.getName(), pmd.getParameterClassName(36)); assertEquals(String.class.getName(), pmd.getParameterClassName(37)); - assertNull(pmd.getParameterClassName(38)); - assertNull(pmd.getParameterClassName(39)); + + // These parameter values are not set, so the driver returns the type that was returned by + // Cloud Spanner. + assertEquals(String.class.getName(), pmd.getParameterClassName(38)); + assertEquals(String.class.getName(), pmd.getParameterClassName(39)); + assertEquals(Short.class.getName(), pmd.getParameterClassName(40)); - assertNull(pmd.getParameterClassName(41)); + assertEquals(String.class.getName(), pmd.getParameterClassName(41)); assertEquals(String.class.getName(), pmd.getParameterClassName(42)); assertEquals(Time.class.getName(), pmd.getParameterClassName(43)); assertEquals(Time.class.getName(), pmd.getParameterClassName(44)); @@ -279,8 +290,11 @@ public void testParameters() throws SQLException, MalformedURLException { public void testSetNullValues() throws SQLException { final int numberOfParameters = 31; String sql = generateSqlWithParameters(numberOfParameters); + + JdbcConnection connection = + createMockConnection(createMockConnectionWithAnalyzeResults(numberOfParameters)); try (JdbcPreparedStatement ps = - new JdbcPreparedStatement(createMockConnection(), sql, NO_GENERATED_KEY_COLUMNS)) { + new JdbcPreparedStatement(connection, sql, NO_GENERATED_KEY_COLUMNS)) { int index = 0; ps.setNull(++index, Types.BLOB); ps.setNull(++index, Types.NVARCHAR); @@ -396,4 +410,34 @@ public void testInvalidSql() { assertEquals( ErrorCode.INVALID_ARGUMENT.getGrpcStatusCode().value(), jdbcSqlException.getErrorCode()); } + + private Connection createMockConnectionWithAnalyzeResults(int numParams) { + Connection spannerConnection = mock(Connection.class); + ResultSet resultSet = mock(ResultSet.class); + when(spannerConnection.analyzeUpdateStatement(any(Statement.class), eq(QueryAnalyzeMode.PLAN))) + .thenReturn(resultSet); + when(spannerConnection.analyzeQuery(any(Statement.class), eq(QueryAnalyzeMode.PLAN))) + .thenReturn(resultSet); + ResultSetMetadata metadata = + ResultSetMetadata.newBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addAllFields( + IntStream.range(0, numParams) + .mapToObj( + i -> + Field.newBuilder() + .setName("p" + (i + 1)) + .setType( + com.google.spanner.v1.Type.newBuilder() + .setCode(TypeCode.STRING) + .build()) + .build()) + .collect(Collectors.toList())) + .build()) + .build(); + when(resultSet.getMetadata()).thenReturn(metadata); + + return spannerConnection; + } } diff --git a/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementWithMockedServerTest.java b/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementWithMockedServerTest.java index d3607d84..a3072e31 100644 --- a/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementWithMockedServerTest.java +++ b/src/test/java/com/google/cloud/spanner/jdbc/JdbcPreparedStatementWithMockedServerTest.java @@ -28,6 +28,13 @@ import com.google.cloud.spanner.Value; import com.google.cloud.spanner.connection.SpannerPool; import com.google.cloud.spanner.jdbc.JdbcSqlExceptionFactory.JdbcSqlBatchUpdateException; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; import io.grpc.Server; import io.grpc.Status; import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; @@ -198,11 +205,180 @@ public void testExecuteBatch_withException() throws SQLException { @Test public void testInsertUntypedNullValues() throws SQLException { + String sql = + "insert into all_nullable_types (ColInt64, ColFloat64, ColBool, ColString, ColBytes, ColDate, ColTimestamp, ColNumeric, ColJson, ColInt64Array, ColFloat64Array, ColBoolArray, ColStringArray, ColBytesArray, ColDateArray, ColTimestampArray, ColNumericArray, ColJsonArray) " + + "values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11, @p12, @p13, @p14, @p15, @p16, @p17, @p18)"; + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(sql), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setUndeclaredParameters( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p2") + .setType( + Type.newBuilder().setCode(TypeCode.FLOAT64).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p3") + .setType(Type.newBuilder().setCode(TypeCode.BOOL).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p4") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p5") + .setType(Type.newBuilder().setCode(TypeCode.BYTES).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p6") + .setType(Type.newBuilder().setCode(TypeCode.DATE).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p7") + .setType( + Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p8") + .setType( + Type.newBuilder().setCode(TypeCode.NUMERIC).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p9") + .setType(Type.newBuilder().setCode(TypeCode.JSON).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p10") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.INT64) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p11") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.FLOAT64) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p12") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.BOOL) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p13") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.STRING) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p14") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.BYTES) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p15") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.DATE) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p16") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.TIMESTAMP) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p17") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.NUMERIC) + .build()) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p18") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.JSON) + .build()) + .build()) + .build()) + .build()) + .build()) + .setStats(ResultSetStats.newBuilder().build()) + .build())); mockSpanner.putStatementResult( StatementResult.update( - Statement.newBuilder( - "insert into all_nullable_types (ColInt64, ColFloat64, ColBool, ColString, ColBytes, ColDate, ColTimestamp, ColNumeric, ColJson, ColInt64Array, ColFloat64Array, ColBoolArray, ColStringArray, ColBytesArray, ColDateArray, ColTimestampArray, ColNumericArray, ColJsonArray) " - + "values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11, @p12, @p13, @p14, @p15, @p16, @p17, @p18)") + Statement.newBuilder(sql) .bind("p1") .to((Value) null) .bind("p2") diff --git a/src/test/java/com/google/cloud/spanner/jdbc/PreparedStatementParameterMetadataTest.java b/src/test/java/com/google/cloud/spanner/jdbc/PreparedStatementParameterMetadataTest.java new file mode 100644 index 00000000..8b7130ed --- /dev/null +++ b/src/test/java/com/google/cloud/spanner/jdbc/PreparedStatementParameterMetadataTest.java @@ -0,0 +1,361 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.jdbc; + +import static org.junit.Assert.assertEquals; + +import com.google.cloud.spanner.Dialect; +import com.google.cloud.spanner.MockSpannerServiceImpl; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.cloud.spanner.connection.SpannerPool; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeAnnotationCode; +import com.google.spanner.v1.TypeCode; +import java.sql.Connection; +import java.sql.ParameterMetaData; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.sql.Types; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class PreparedStatementParameterMetadataTest extends AbstractMockServerTest { + + @After + public void reset() { + // This ensures that each test gets a fresh Spanner instance. This is necessary to get a new + // dialect result for each connection. + SpannerPool.closeSpannerPool(); + } + + @Test + public void testAllTypesParameterMetadata_GoogleSql() throws SQLException { + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.detectDialectResult(Dialect.GOOGLE_STANDARD_SQL)); + + String baseSql = + "insert into all_types (col_bool, col_bytes, col_date, col_float64, col_int64, " + + "col_json, col_numeric, col_string, col_timestamp, col_bool_array, col_bytes_array, " + + "col_date_array, col_float64_array, col_int64_array, col_json_array, col_numeric_array, " + + "col_string_array, col_timestamp_array) values (%s)"; + String jdbcSql = + String.format( + baseSql, + IntStream.range(0, 18).mapToObj(ignored -> "?").collect(Collectors.joining(", "))); + String googleSql = + String.format( + baseSql, + IntStream.range(1, 19) + .mapToObj(index -> "@p" + index) + .collect(Collectors.joining(", "))); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(googleSql), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setUndeclaredParameters( + createAllTypesParameters(Dialect.GOOGLE_STANDARD_SQL)) + .build()) + .setStats(ResultSetStats.newBuilder().build()) + .build())); + + try (Connection connection = createJdbcConnection()) { + try (PreparedStatement statement = connection.prepareStatement(jdbcSql)) { + ParameterMetaData metadata = statement.getParameterMetaData(); + assertEquals(18, metadata.getParameterCount()); + int index = 0; + assertEquals(Types.BOOLEAN, metadata.getParameterType(++index)); + assertEquals("BOOL", metadata.getParameterTypeName(index)); + assertEquals(Types.BINARY, metadata.getParameterType(++index)); + assertEquals("BYTES", metadata.getParameterTypeName(index)); + assertEquals(Types.DATE, metadata.getParameterType(++index)); + assertEquals("DATE", metadata.getParameterTypeName(index)); + assertEquals(Types.DOUBLE, metadata.getParameterType(++index)); + assertEquals("FLOAT64", metadata.getParameterTypeName(index)); + assertEquals(Types.BIGINT, metadata.getParameterType(++index)); + assertEquals("INT64", metadata.getParameterTypeName(index)); + assertEquals(JsonType.VENDOR_TYPE_NUMBER, metadata.getParameterType(++index)); + assertEquals("JSON", metadata.getParameterTypeName(index)); + assertEquals(Types.NUMERIC, metadata.getParameterType(++index)); + assertEquals("NUMERIC", metadata.getParameterTypeName(index)); + assertEquals(Types.NVARCHAR, metadata.getParameterType(++index)); + assertEquals("STRING", metadata.getParameterTypeName(index)); + assertEquals(Types.TIMESTAMP, metadata.getParameterType(++index)); + assertEquals("TIMESTAMP", metadata.getParameterTypeName(index)); + + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("ARRAY", metadata.getParameterTypeName(index)); + } + } + } + + @Test + public void testAllTypesParameterMetadata_PostgreSQL() throws SQLException { + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.detectDialectResult(Dialect.POSTGRESQL)); + + String baseSql = + "insert into all_types (col_bool, col_bytes, col_date, col_float64, col_int64, " + + "col_json, col_numeric, col_string, col_timestamp, col_bool_array, col_bytes_array, " + + "col_date_array, col_float64_array, col_int64_array, col_json_array, col_numeric_array, " + + "col_string_array, col_timestamp_array) values (%s)"; + String jdbcSql = + String.format( + baseSql, + IntStream.range(0, 18).mapToObj(ignored -> "?").collect(Collectors.joining(", "))); + String googleSql = + String.format( + baseSql, + IntStream.range(1, 19) + .mapToObj(index -> "$" + index) + .collect(Collectors.joining(", "))); + mockSpanner.putStatementResult( + StatementResult.query( + Statement.of(googleSql), + ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setUndeclaredParameters(createAllTypesParameters(Dialect.POSTGRESQL)) + .build()) + .setStats(ResultSetStats.newBuilder().build()) + .build())); + + try (Connection connection = createJdbcConnection()) { + try (PreparedStatement statement = connection.prepareStatement(jdbcSql)) { + ParameterMetaData metadata = statement.getParameterMetaData(); + assertEquals(18, metadata.getParameterCount()); + int index = 0; + assertEquals(Types.BOOLEAN, metadata.getParameterType(++index)); + assertEquals("boolean", metadata.getParameterTypeName(index)); + assertEquals(Types.BINARY, metadata.getParameterType(++index)); + assertEquals("bytea", metadata.getParameterTypeName(index)); + assertEquals(Types.DATE, metadata.getParameterType(++index)); + assertEquals("date", metadata.getParameterTypeName(index)); + assertEquals(Types.DOUBLE, metadata.getParameterType(++index)); + assertEquals("double precision", metadata.getParameterTypeName(index)); + assertEquals(Types.BIGINT, metadata.getParameterType(++index)); + assertEquals("bigint", metadata.getParameterTypeName(index)); + assertEquals(PgJsonbType.VENDOR_TYPE_NUMBER, metadata.getParameterType(++index)); + assertEquals("jsonb", metadata.getParameterTypeName(index)); + assertEquals(Types.NUMERIC, metadata.getParameterType(++index)); + assertEquals("numeric", metadata.getParameterTypeName(index)); + assertEquals(Types.NVARCHAR, metadata.getParameterType(++index)); + assertEquals("character varying", metadata.getParameterTypeName(index)); + assertEquals(Types.TIMESTAMP, metadata.getParameterType(++index)); + assertEquals("timestamp with time zone", metadata.getParameterTypeName(index)); + + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("boolean[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("bytea[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("date[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("double precision[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("bigint[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("jsonb[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("numeric[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("character varying[]", metadata.getParameterTypeName(index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals("timestamp with time zone[]", metadata.getParameterTypeName(index)); + } + } + } + + static StructType createAllTypesParameters(Dialect dialect) { + return StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("p1") + .setType(Type.newBuilder().setCode(TypeCode.BOOL).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p2") + .setType(Type.newBuilder().setCode(TypeCode.BYTES).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p3") + .setType(Type.newBuilder().setCode(TypeCode.DATE).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p4") + .setType(Type.newBuilder().setCode(TypeCode.FLOAT64).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p5") + .setType(Type.newBuilder().setCode(TypeCode.INT64).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p6") + .setType( + Type.newBuilder() + .setCode(TypeCode.JSON) + .setTypeAnnotation( + dialect == Dialect.POSTGRESQL + ? TypeAnnotationCode.PG_JSONB + : TypeAnnotationCode.TYPE_ANNOTATION_CODE_UNSPECIFIED) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p7") + .setType( + Type.newBuilder() + .setCode(TypeCode.NUMERIC) + .setTypeAnnotation( + dialect == Dialect.POSTGRESQL + ? TypeAnnotationCode.PG_NUMERIC + : TypeAnnotationCode.TYPE_ANNOTATION_CODE_UNSPECIFIED) + .build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p8") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p9") + .setType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("p10") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.BOOL).build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p11") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.BYTES).build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p12") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.DATE).build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p13") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.FLOAT64).build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p14") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.INT64).build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p15") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.JSON) + .setTypeAnnotation( + dialect == Dialect.POSTGRESQL + ? TypeAnnotationCode.PG_JSONB + : TypeAnnotationCode.TYPE_ANNOTATION_CODE_UNSPECIFIED) + .build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p16") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType( + Type.newBuilder() + .setCode(TypeCode.NUMERIC) + .setTypeAnnotation( + dialect == Dialect.POSTGRESQL + ? TypeAnnotationCode.PG_NUMERIC + : TypeAnnotationCode.TYPE_ANNOTATION_CODE_UNSPECIFIED) + .build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p17") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.STRING).build())) + .build()) + .addFields( + Field.newBuilder() + .setName("p18") + .setType( + Type.newBuilder() + .setCode(TypeCode.ARRAY) + .setArrayElementType(Type.newBuilder().setCode(TypeCode.TIMESTAMP).build())) + .build()) + .build(); + } +} diff --git a/src/test/java/com/google/cloud/spanner/jdbc/it/ITJdbcPreparedStatementTest.java b/src/test/java/com/google/cloud/spanner/jdbc/it/ITJdbcPreparedStatementTest.java index 2864559c..8f014937 100644 --- a/src/test/java/com/google/cloud/spanner/jdbc/it/ITJdbcPreparedStatementTest.java +++ b/src/test/java/com/google/cloud/spanner/jdbc/it/ITJdbcPreparedStatementTest.java @@ -21,7 +21,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.junit.Assume.assumeFalse; @@ -35,6 +34,7 @@ import com.google.cloud.spanner.jdbc.JsonType; import com.google.cloud.spanner.testing.EmulatorSpannerHelper; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; import com.google.common.io.BaseEncoding; import com.google.common.io.CharStreams; import java.io.IOException; @@ -394,7 +394,26 @@ public void test01_InsertTestData() throws SQLException { try (PreparedStatement ps = connection.prepareStatement( "INSERT INTO Singers (SingerId, FirstName, LastName, SingerInfo, BirthDate) values (?,?,?,?,?)")) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 5); + assertParameterMetaData( + ps.getParameterMetaData(), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + Types.BIGINT, Types.NVARCHAR, Types.NVARCHAR, Types.BINARY, Types.NVARCHAR) + : ImmutableList.of( + Types.BIGINT, Types.NVARCHAR, Types.NVARCHAR, Types.BINARY, Types.DATE), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + "bigint", + "character varying", + "character varying", + "bytea", + "character varying") + : ImmutableList.of("INT64", "STRING", "STRING", "BYTES", "DATE"), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + Long.class, String.class, String.class, byte[].class, String.class) + : ImmutableList.of( + Long.class, String.class, String.class, byte[].class, Date.class)); for (Singer singer : createSingers()) { singer.setPreparedStatement(ps, getDialect()); assertInsertSingerParameterMetadata(ps.getParameterMetaData()); @@ -410,7 +429,13 @@ public void test01_InsertTestData() throws SQLException { try (PreparedStatement ps = connection.prepareStatement( "INSERT INTO Albums (SingerId, AlbumId, AlbumTitle, MarketingBudget) VALUES (?,?,?,?)")) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 4); + assertParameterMetaData( + ps.getParameterMetaData(), + ImmutableList.of(Types.BIGINT, Types.BIGINT, Types.NVARCHAR, Types.BIGINT), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of("bigint", "bigint", "character varying", "bigint") + : ImmutableList.of("INT64", "INT64", "STRING", "INT64"), + ImmutableList.of(Long.class, Long.class, String.class, Long.class)); for (Album album : createAlbums()) { ps.setLong(1, album.singerId); ps.setLong(2, album.albumId); @@ -425,7 +450,26 @@ public void test01_InsertTestData() throws SQLException { try (PreparedStatement ps = connection.prepareStatement( "INSERT INTO Songs (SingerId, AlbumId, TrackId, SongName, Duration, SongGenre) VALUES (?,?,?,?,?,?);")) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 6); + assertParameterMetaData( + ps.getParameterMetaData(), + ImmutableList.of( + Types.BIGINT, + Types.BIGINT, + Types.BIGINT, + Types.NVARCHAR, + Types.BIGINT, + Types.NVARCHAR), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + "bigint", + "bigint", + "bigint", + "character varying", + "bigint", + "character varying") + : ImmutableList.of("INT64", "INT64", "INT64", "STRING", "INT64", "STRING"), + ImmutableList.of( + Long.class, Long.class, Long.class, String.class, Long.class, String.class)); for (Song song : createSongs()) { ps.setByte(1, (byte) song.singerId); ps.setInt(2, (int) song.albumId); @@ -441,8 +485,36 @@ public void test01_InsertTestData() throws SQLException { } try (PreparedStatement ps = connection.prepareStatement(getConcertsInsertQuery(dialect.dialect))) { - assertDefaultParameterMetaData( - ps.getParameterMetaData(), getConcertExpectedParamCount(dialect.dialect)); + assertParameterMetaData( + ps.getParameterMetaData(), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + Types.BIGINT, Types.BIGINT, Types.NVARCHAR, Types.NVARCHAR, Types.NVARCHAR) + : ImmutableList.of( + Types.BIGINT, + Types.BIGINT, + Types.DATE, + Types.TIMESTAMP, + Types.TIMESTAMP, + Types.ARRAY), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + "bigint", + "bigint", + "character varying", + "character varying", + "character varying") + : ImmutableList.of( + "INT64", "INT64", "DATE", "TIMESTAMP", "TIMESTAMP", "ARRAY"), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of(Long.class, Long.class, String.class, String.class, String.class) + : ImmutableList.of( + Long.class, + Long.class, + Date.class, + Timestamp.class, + Timestamp.class, + Long[].class)); for (Concert concert : createConcerts()) { concert.setPreparedStatement(connection, ps, getDialect()); assertInsertConcertParameterMetadata(ps.getParameterMetaData()); @@ -564,7 +636,24 @@ public void test03_Dates() throws SQLException { try (PreparedStatement ps = connection.prepareStatement( "INSERT INTO Concerts (VenueId, SingerId, ConcertDate, BeginTime, EndTime, TicketPrices) VALUES (?,?,?,?,?,?);")) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 6); + assertParameterMetaData( + ps.getParameterMetaData(), + ImmutableList.of( + Types.BIGINT, + Types.BIGINT, + Types.DATE, + Types.TIMESTAMP, + Types.TIMESTAMP, + Types.ARRAY), + ImmutableList.of( + "INT64", "INT64", "DATE", "TIMESTAMP", "TIMESTAMP", "ARRAY"), + ImmutableList.of( + Long.class, + Long.class, + Date.class, + Timestamp.class, + Timestamp.class, + Long[].class)); ps.setLong(1, 100); ps.setLong(2, 19); ps.setDate(3, testDate); @@ -660,7 +749,24 @@ public void test04_Timestamps() throws SQLException { try (PreparedStatement ps = connection.prepareStatement( "INSERT INTO Concerts (VenueId, SingerId, ConcertDate, BeginTime, EndTime, TicketPrices) VALUES (?,?,?,?,?,?);")) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 6); + assertParameterMetaData( + ps.getParameterMetaData(), + ImmutableList.of( + Types.BIGINT, + Types.BIGINT, + Types.DATE, + Types.TIMESTAMP, + Types.TIMESTAMP, + Types.ARRAY), + ImmutableList.of( + "INT64", "INT64", "DATE", "TIMESTAMP", "TIMESTAMP", "ARRAY"), + ImmutableList.of( + Long.class, + Long.class, + Date.class, + Timestamp.class, + Timestamp.class, + Long[].class)); ps.setLong(1, 100); ps.setLong(2, 19); ps.setDate(3, new Date(System.currentTimeMillis())); @@ -868,7 +974,33 @@ public void test08_InsertAllColumnTypes() throws SQLException { + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, PENDING_COMMIT_TIMESTAMP(), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; try (Connection con = createConnection(env, database)) { try (PreparedStatement ps = con.prepareStatement(sql)) { + ParameterMetaData metadata = ps.getParameterMetaData(); + assertEquals(22, metadata.getParameterCount()); int index = 0; + assertEquals(Types.BIGINT, metadata.getParameterType(++index)); + assertEquals(Types.DOUBLE, metadata.getParameterType(++index)); + assertEquals(Types.BOOLEAN, metadata.getParameterType(++index)); + assertEquals(Types.NVARCHAR, metadata.getParameterType(++index)); + assertEquals(Types.NVARCHAR, metadata.getParameterType(++index)); + assertEquals(Types.BINARY, metadata.getParameterType(++index)); + assertEquals(Types.BINARY, metadata.getParameterType(++index)); + assertEquals(Types.DATE, metadata.getParameterType(++index)); + assertEquals(Types.TIMESTAMP, metadata.getParameterType(++index)); + assertEquals(Types.NUMERIC, metadata.getParameterType(++index)); + assertEquals(JsonType.VENDOR_TYPE_NUMBER, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + assertEquals(Types.ARRAY, metadata.getParameterType(++index)); + + index = 0; ps.setLong(++index, 1L); ps.setDouble(++index, 2D); ps.setBoolean(++index, true); @@ -1182,18 +1314,28 @@ public void test11_InsertDataUsingSpannerValue() throws SQLException { } } - private void assertDefaultParameterMetaData(ParameterMetaData pmd, int expectedParamCount) + private void assertParameterMetaData( + ParameterMetaData pmd, + ImmutableList sqlTypes, + ImmutableList typeNames, + ImmutableList> classNames) throws SQLException { - assertEquals(expectedParamCount, pmd.getParameterCount()); - for (int param = 1; param <= expectedParamCount; param++) { - assertEquals(Types.OTHER, pmd.getParameterType(param)); - assertEquals("OTHER", pmd.getParameterTypeName(param)); + assertEquals(sqlTypes.size(), typeNames.size()); + assertEquals(sqlTypes.size(), classNames.size()); + + ImmutableList signedTypes = + ImmutableList.of(Types.BIGINT, Types.NUMERIC, Types.DOUBLE); + assertEquals(sqlTypes.size(), pmd.getParameterCount()); + for (int param = 1; param <= sqlTypes.size(); param++) { + String msg = "Param " + param; + assertEquals(msg, sqlTypes.get(param - 1).intValue(), pmd.getParameterType(param)); + assertEquals(msg, typeNames.get(param - 1), pmd.getParameterTypeName(param)); assertEquals(0, pmd.getPrecision(param)); assertEquals(0, pmd.getScale(param)); - assertNull(pmd.getParameterClassName(param)); + assertEquals(msg, classNames.get(param - 1).getName(), pmd.getParameterClassName(param)); assertEquals(ParameterMetaData.parameterModeIn, pmd.getParameterMode(param)); assertEquals(ParameterMetaData.parameterNullableUnknown, pmd.isNullable(param)); - assertFalse(pmd.isSigned(param)); + assertEquals(msg, signedTypes.contains(sqlTypes.get(param - 1)), pmd.isSigned(param)); } } @@ -1214,7 +1356,26 @@ public void test12_InsertReturningTestData() throws SQLException { deleteStatements.executeBatch(); try (PreparedStatement ps = connection.prepareStatement(getSingersInsertReturningQuery(dialect.dialect))) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 5); + assertParameterMetaData( + ps.getParameterMetaData(), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + Types.BIGINT, Types.NVARCHAR, Types.NVARCHAR, Types.BINARY, Types.NVARCHAR) + : ImmutableList.of( + Types.BIGINT, Types.NVARCHAR, Types.NVARCHAR, Types.BINARY, Types.DATE), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + "bigint", + "character varying", + "character varying", + "bytea", + "character varying") + : ImmutableList.of("INT64", "STRING", "STRING", "BYTES", "DATE"), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + Long.class, String.class, String.class, byte[].class, String.class) + : ImmutableList.of( + Long.class, String.class, String.class, byte[].class, Date.class)); for (Singer singer : createSingers()) { singer.setPreparedStatement(ps, getDialect()); assertInsertSingerParameterMetadata(ps.getParameterMetaData()); @@ -1229,7 +1390,13 @@ public void test12_InsertReturningTestData() throws SQLException { } try (PreparedStatement ps = connection.prepareStatement(getAlbumsInsertReturningQuery(dialect.dialect))) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 4); + assertParameterMetaData( + ps.getParameterMetaData(), + ImmutableList.of(Types.BIGINT, Types.BIGINT, Types.NVARCHAR, Types.BIGINT), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of("bigint", "bigint", "character varying", "bigint") + : ImmutableList.of("INT64", "INT64", "STRING", "INT64"), + ImmutableList.of(Long.class, Long.class, String.class, Long.class)); for (Album album : createAlbums()) { ps.setLong(1, album.singerId); ps.setLong(2, album.albumId); @@ -1249,7 +1416,26 @@ public void test12_InsertReturningTestData() throws SQLException { } try (PreparedStatement ps = connection.prepareStatement(getSongsInsertReturningQuery(dialect.dialect))) { - assertDefaultParameterMetaData(ps.getParameterMetaData(), 6); + assertParameterMetaData( + ps.getParameterMetaData(), + ImmutableList.of( + Types.BIGINT, + Types.BIGINT, + Types.BIGINT, + Types.NVARCHAR, + Types.BIGINT, + Types.NVARCHAR), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + "bigint", + "bigint", + "bigint", + "character varying", + "bigint", + "character varying") + : ImmutableList.of("INT64", "INT64", "INT64", "STRING", "INT64", "STRING"), + ImmutableList.of( + Long.class, Long.class, Long.class, String.class, Long.class, String.class)); for (Song song : createSongs()) { ps.setByte(1, (byte) song.singerId); ps.setInt(2, (int) song.albumId); @@ -1277,8 +1463,36 @@ public void test12_InsertReturningTestData() throws SQLException { } try (PreparedStatement ps = connection.prepareStatement(getConcertsInsertReturningQuery(dialect.dialect))) { - assertDefaultParameterMetaData( - ps.getParameterMetaData(), getConcertExpectedParamCount(dialect.dialect)); + assertParameterMetaData( + ps.getParameterMetaData(), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + Types.BIGINT, Types.BIGINT, Types.NVARCHAR, Types.NVARCHAR, Types.NVARCHAR) + : ImmutableList.of( + Types.BIGINT, + Types.BIGINT, + Types.DATE, + Types.TIMESTAMP, + Types.TIMESTAMP, + Types.ARRAY), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of( + "bigint", + "bigint", + "character varying", + "character varying", + "character varying") + : ImmutableList.of( + "INT64", "INT64", "DATE", "TIMESTAMP", "TIMESTAMP", "ARRAY"), + dialect.dialect == Dialect.POSTGRESQL + ? ImmutableList.of(Long.class, Long.class, String.class, String.class, String.class) + : ImmutableList.of( + Long.class, + Long.class, + Date.class, + Timestamp.class, + Timestamp.class, + Long[].class)); for (Concert concert : createConcerts()) { concert.setPreparedStatement(connection, ps, getDialect()); assertInsertConcertParameterMetadata(ps.getParameterMetaData());