diff --git a/pgjdbc/src/main/java/org/postgresql/core/v3/SimpleParameterList.java b/pgjdbc/src/main/java/org/postgresql/core/v3/SimpleParameterList.java index 04fc782b5f..9741a05edf 100644 --- a/pgjdbc/src/main/java/org/postgresql/core/v3/SimpleParameterList.java +++ b/pgjdbc/src/main/java/org/postgresql/core/v3/SimpleParameterList.java @@ -209,7 +209,7 @@ public void setNull(@Positive int index, int oid) throws SQLException { * {} * **/ - private static String quoteAndCast(String text, String type, boolean standardConformingStrings) { + private static String quoteAndCast(String text, @Nullable String type, boolean standardConformingStrings) { StringBuilder sb = new StringBuilder((text.length() + 10) / 10 * 11); // Add 10% for escaping. sb.append("('"); try { @@ -240,35 +240,47 @@ public String toString(@Positive int index, boolean standardConformingStrings) { return "?"; } else if (paramValue == NULL_OBJECT) { return "(NULL)"; - } else if ((flags[index] & BINARY) == BINARY) { + } + String textValue; + String type; + if ((flags[index] & BINARY) == BINARY) { // handle some of the numeric types - switch (paramTypes[index]) { case Oid.INT2: short s = ByteConverter.int2((byte[]) paramValue, 0); - return quoteAndCast(Short.toString(s), "int2", standardConformingStrings); + textValue = Short.toString(s); + type = "int2"; + break; case Oid.INT4: int i = ByteConverter.int4((byte[]) paramValue, 0); - return quoteAndCast(Integer.toString(i), "int4", standardConformingStrings); + textValue = Integer.toString(i); + type = "int4"; + break; case Oid.INT8: long l = ByteConverter.int8((byte[]) paramValue, 0); - return quoteAndCast(Long.toString(l), "int8", standardConformingStrings); + textValue = Long.toString(l); + type = "int8"; + break; case Oid.FLOAT4: float f = ByteConverter.float4((byte[]) paramValue, 0); if (Float.isNaN(f)) { return "('NaN'::real)"; } - return quoteAndCast(Float.toString(f), "float", standardConformingStrings); + textValue = Float.toString(f); + type = "real"; + break; case Oid.FLOAT8: double d = ByteConverter.float8((byte[]) paramValue, 0); if (Double.isNaN(d)) { return "('NaN'::double precision)"; } - return quoteAndCast(Double.toString(d), "double precision", standardConformingStrings); + textValue = Double.toString(d); + type = "double precision"; + break; case Oid.NUMERIC: Number n = ByteConverter.numeric((byte[]) paramValue); @@ -276,44 +288,55 @@ public String toString(@Positive int index, boolean standardConformingStrings) { assert ((Double) n).isNaN(); return "('NaN'::numeric)"; } - return n.toString(); + textValue = n.toString(); + type = "numeric"; + break; case Oid.UUID: - String uuid = + textValue = new UUIDArrayAssistant().buildElement((byte[]) paramValue, 0, 16).toString(); - return quoteAndCast(uuid, "uuid", standardConformingStrings); + type = "uuid"; + break; case Oid.POINT: PGpoint pgPoint = new PGpoint(); pgPoint.setByteValue((byte[]) paramValue, 0); - return quoteAndCast(pgPoint.toString(), "point", standardConformingStrings); + textValue = pgPoint.toString(); + type = "point"; + break; case Oid.BOX: PGbox pgBox = new PGbox(); pgBox.setByteValue((byte[]) paramValue, 0); - return quoteAndCast(pgBox.toString(), "box", standardConformingStrings); + textValue = pgBox.toString(); + type = "box"; + break; + + default: + return "?"; } - return "?"; } else { - String param = paramValue.toString(); + textValue = paramValue.toString(); int paramType = paramTypes[index]; if (paramType == Oid.TIMESTAMP) { - return quoteAndCast(param, "timestamp", standardConformingStrings); + type = "timestamp"; } else if (paramType == Oid.TIMESTAMPTZ) { - return quoteAndCast(param, "timestamp with time zone", standardConformingStrings); + type = "timestamp with time zone"; } else if (paramType == Oid.TIME) { - return quoteAndCast(param, "time", standardConformingStrings); + type = "time"; } else if (paramType == Oid.TIMETZ) { - return quoteAndCast(param, "time with time zone", standardConformingStrings); + type = "time with time zone"; } else if (paramType == Oid.DATE) { - return quoteAndCast(param, "date", standardConformingStrings); + type = "date"; } else if (paramType == Oid.INTERVAL) { - return quoteAndCast(param, "interval", standardConformingStrings); + type = "interval"; } else if (paramType == Oid.NUMERIC) { - return quoteAndCast(param, "numeric", standardConformingStrings); + type = "numeric"; + } else { + type = null; } - return quoteAndCast(param, null, standardConformingStrings); } + return quoteAndCast(textValue, type, standardConformingStrings); } @Override diff --git a/pgjdbc/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java b/pgjdbc/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java index ec4ab2b69d..10c0af3843 100644 --- a/pgjdbc/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java +++ b/pgjdbc/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java @@ -12,56 +12,133 @@ import org.junit.jupiter.api.Test; +import java.math.BigDecimal; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; +import java.sql.SQLException; public class ParameterInjectionTest { - @Test - public void negateParameter() throws Exception { - try (Connection conn = TestUtil.openDB()) { - PreparedStatement stmt = conn.prepareStatement("SELECT -?"); + private interface ParameterBinder { + void bind(PreparedStatement stmt) throws SQLException; + } - stmt.setInt(1, 1); - try (ResultSet rs = stmt.executeQuery()) { - assertTrue(rs.next()); - assertEquals(1, rs.getMetaData().getColumnCount(), "number of result columns must match"); - int value = rs.getInt(1); - assertEquals(-1, value); - } + private void testParamInjection(ParameterBinder bindPositiveOne, ParameterBinder bindNegativeOne) + throws SQLException { + try (Connection conn = TestUtil.openDB()) { + { + PreparedStatement stmt = conn.prepareStatement("SELECT -?"); + bindPositiveOne.bind(stmt); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getMetaData().getColumnCount(), + "number of result columns must match"); + int value = rs.getInt(1); + assertEquals(-1, value); + } + bindNegativeOne.bind(stmt); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getMetaData().getColumnCount(), + "number of result columns must match"); + int value = rs.getInt(1); + assertEquals(1, value); + } + } + { + PreparedStatement stmt = conn.prepareStatement("SELECT -?, ?"); + bindPositiveOne.bind(stmt); + stmt.setString(2, "\nWHERE false --"); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "ResultSet should contain a row"); + assertEquals(2, rs.getMetaData().getColumnCount(), + "rs.getMetaData().getColumnCount("); + int value = rs.getInt(1); + assertEquals(-1, value); + } - stmt.setInt(1, -1); - try (ResultSet rs = stmt.executeQuery()) { - assertTrue(rs.next()); - assertEquals(1, rs.getMetaData().getColumnCount(), "number of result columns must match"); - int value = rs.getInt(1); - assertEquals(1, value); - } + bindNegativeOne.bind(stmt); + stmt.setString(2, "\nWHERE false --"); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "ResultSet should contain a row"); + assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount("); + int value = rs.getInt(1); + assertEquals(1, value); } + + } } + } - @Test - public void negateParameterWithContinuation() throws Exception { - try (Connection conn = TestUtil.openDB()) { - PreparedStatement stmt = conn.prepareStatement("SELECT -?, ?"); + @Test + public void handleInt2() throws SQLException { + testParamInjection( + stmt -> { + stmt.setShort(1, (short) 1); + }, + stmt -> { + stmt.setShort(1, (short) -1); + } + ); + } - stmt.setInt(1, 1); - stmt.setString(2, "\nWHERE false --"); - try (ResultSet rs = stmt.executeQuery()) { - assertTrue(rs.next(), "ResultSet should contain a row"); - assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount("); - int value = rs.getInt(1); - assertEquals(-1, value); - } + @Test + public void handleInt4() throws SQLException { + testParamInjection( + stmt -> { + stmt.setInt(1, 1); + }, + stmt -> { + stmt.setInt(1, -1); + } + ); + } - stmt.setInt(1, -1); - stmt.setString(2, "\nWHERE false --"); - try (ResultSet rs = stmt.executeQuery()) { - assertTrue(rs.next(), "ResultSet should contain a row"); - assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount("); - int value = rs.getInt(1); - assertEquals(1, value); - } + @Test + public void handleBigInt() throws SQLException { + testParamInjection( + stmt -> { + stmt.setLong(1, (long) 1); + }, + stmt -> { + stmt.setLong(1, (long) -1); } - } + ); + } + + @Test + public void handleNumeric() throws SQLException { + testParamInjection( + stmt -> { + stmt.setBigDecimal(1, new BigDecimal("1")); + }, + stmt -> { + stmt.setBigDecimal(1, new BigDecimal("-1")); + } + ); + } + + @Test + public void handleFloat() throws SQLException { + testParamInjection( + stmt -> { + stmt.setFloat(1, 1); + }, + stmt -> { + stmt.setFloat(1, -1); + } + ); + } + + @Test + public void handleDouble() throws SQLException { + testParamInjection( + stmt -> { + stmt.setDouble(1, 1); + }, + stmt -> { + stmt.setDouble(1, -1); + } + ); + } }