From 06abfb78a627277a580d4df825f210e96a4e14ee Mon Sep 17 00:00:00 2001 From: Vladimir Sitnikov Date: Tue, 20 Feb 2024 18:01:14 +0300 Subject: [PATCH] Merge pull request from GHSA-24rp-q3w6-vc56 * test: Add failing test for simple query mode parameter injection Adds a failing test to demonstrate how direct parameter injection in simple query mode allows for modifying the executed SQL. The issue arises when a bind placeholder is prefixed with a negation. The direct replacement of a negative value causes the resulting token to be considered a line comment. For example the SQL: SELECT -?, ? With parameter values of -1 and any text with a newline in the second parameter allows arbitrary command execution, e.g. with values -1 and "\nWHERE false" causes the query to return no rows. More complicated examples can be created by adding statement terminators. * fix: Escape literal parameter values in simple query mode Escape all literal parameter values and wrap them in parentheses to prevent SQL injection when using specially crafted parameters and SQL in simple query mode. Previously the raw value of the parameter, e.g. 123, was injected into the ? placeholder. With this change all parameters are injected as '...value...' literals that are cast to the desired type by the server and wrapped in parentheses. So the SQL SELECT -? with a parameter of -123 would become: SELECT -('-123'::int4) * fix: Add parentheses around NULL parameter values in simple query mode * fix: remove repeated quoteAndCast calls, and ensure numerics are quoted as well * test: Add parameter injection tests for additional numerical types * reformat file --------- Co-authored-by: Sehrope Sarkuni Co-authored-by: Dave Cramer --- .../core/v3/SimpleParameterList.java | 69 +++++--- .../jdbc/ParameterInjectionTest.java | 155 +++++++++++++----- 2 files changed, 162 insertions(+), 62 deletions(-) diff --git a/pgjdbc/src/main/java/org/postgresql/core/v3/SimpleParameterList.java b/pgjdbc/src/main/java/org/postgresql/core/v3/SimpleParameterList.java index 04fc782b5f..9741a05edf 100644 --- a/pgjdbc/src/main/java/org/postgresql/core/v3/SimpleParameterList.java +++ b/pgjdbc/src/main/java/org/postgresql/core/v3/SimpleParameterList.java @@ -209,7 +209,7 @@ public void setNull(@Positive int index, int oid) throws SQLException { * {} * **/ - private static String quoteAndCast(String text, String type, boolean standardConformingStrings) { + private static String quoteAndCast(String text, @Nullable String type, boolean standardConformingStrings) { StringBuilder sb = new StringBuilder((text.length() + 10) / 10 * 11); // Add 10% for escaping. sb.append("('"); try { @@ -240,35 +240,47 @@ public String toString(@Positive int index, boolean standardConformingStrings) { return "?"; } else if (paramValue == NULL_OBJECT) { return "(NULL)"; - } else if ((flags[index] & BINARY) == BINARY) { + } + String textValue; + String type; + if ((flags[index] & BINARY) == BINARY) { // handle some of the numeric types - switch (paramTypes[index]) { case Oid.INT2: short s = ByteConverter.int2((byte[]) paramValue, 0); - return quoteAndCast(Short.toString(s), "int2", standardConformingStrings); + textValue = Short.toString(s); + type = "int2"; + break; case Oid.INT4: int i = ByteConverter.int4((byte[]) paramValue, 0); - return quoteAndCast(Integer.toString(i), "int4", standardConformingStrings); + textValue = Integer.toString(i); + type = "int4"; + break; case Oid.INT8: long l = ByteConverter.int8((byte[]) paramValue, 0); - return quoteAndCast(Long.toString(l), "int8", standardConformingStrings); + textValue = Long.toString(l); + type = "int8"; + break; case Oid.FLOAT4: float f = ByteConverter.float4((byte[]) paramValue, 0); if (Float.isNaN(f)) { return "('NaN'::real)"; } - return quoteAndCast(Float.toString(f), "float", standardConformingStrings); + textValue = Float.toString(f); + type = "real"; + break; case Oid.FLOAT8: double d = ByteConverter.float8((byte[]) paramValue, 0); if (Double.isNaN(d)) { return "('NaN'::double precision)"; } - return quoteAndCast(Double.toString(d), "double precision", standardConformingStrings); + textValue = Double.toString(d); + type = "double precision"; + break; case Oid.NUMERIC: Number n = ByteConverter.numeric((byte[]) paramValue); @@ -276,44 +288,55 @@ public String toString(@Positive int index, boolean standardConformingStrings) { assert ((Double) n).isNaN(); return "('NaN'::numeric)"; } - return n.toString(); + textValue = n.toString(); + type = "numeric"; + break; case Oid.UUID: - String uuid = + textValue = new UUIDArrayAssistant().buildElement((byte[]) paramValue, 0, 16).toString(); - return quoteAndCast(uuid, "uuid", standardConformingStrings); + type = "uuid"; + break; case Oid.POINT: PGpoint pgPoint = new PGpoint(); pgPoint.setByteValue((byte[]) paramValue, 0); - return quoteAndCast(pgPoint.toString(), "point", standardConformingStrings); + textValue = pgPoint.toString(); + type = "point"; + break; case Oid.BOX: PGbox pgBox = new PGbox(); pgBox.setByteValue((byte[]) paramValue, 0); - return quoteAndCast(pgBox.toString(), "box", standardConformingStrings); + textValue = pgBox.toString(); + type = "box"; + break; + + default: + return "?"; } - return "?"; } else { - String param = paramValue.toString(); + textValue = paramValue.toString(); int paramType = paramTypes[index]; if (paramType == Oid.TIMESTAMP) { - return quoteAndCast(param, "timestamp", standardConformingStrings); + type = "timestamp"; } else if (paramType == Oid.TIMESTAMPTZ) { - return quoteAndCast(param, "timestamp with time zone", standardConformingStrings); + type = "timestamp with time zone"; } else if (paramType == Oid.TIME) { - return quoteAndCast(param, "time", standardConformingStrings); + type = "time"; } else if (paramType == Oid.TIMETZ) { - return quoteAndCast(param, "time with time zone", standardConformingStrings); + type = "time with time zone"; } else if (paramType == Oid.DATE) { - return quoteAndCast(param, "date", standardConformingStrings); + type = "date"; } else if (paramType == Oid.INTERVAL) { - return quoteAndCast(param, "interval", standardConformingStrings); + type = "interval"; } else if (paramType == Oid.NUMERIC) { - return quoteAndCast(param, "numeric", standardConformingStrings); + type = "numeric"; + } else { + type = null; } - return quoteAndCast(param, null, standardConformingStrings); } + return quoteAndCast(textValue, type, standardConformingStrings); } @Override diff --git a/pgjdbc/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java b/pgjdbc/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java index ec4ab2b69d..10c0af3843 100644 --- a/pgjdbc/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java +++ b/pgjdbc/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java @@ -12,56 +12,133 @@ import org.junit.jupiter.api.Test; +import java.math.BigDecimal; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; +import java.sql.SQLException; public class ParameterInjectionTest { - @Test - public void negateParameter() throws Exception { - try (Connection conn = TestUtil.openDB()) { - PreparedStatement stmt = conn.prepareStatement("SELECT -?"); + private interface ParameterBinder { + void bind(PreparedStatement stmt) throws SQLException; + } - stmt.setInt(1, 1); - try (ResultSet rs = stmt.executeQuery()) { - assertTrue(rs.next()); - assertEquals(1, rs.getMetaData().getColumnCount(), "number of result columns must match"); - int value = rs.getInt(1); - assertEquals(-1, value); - } + private void testParamInjection(ParameterBinder bindPositiveOne, ParameterBinder bindNegativeOne) + throws SQLException { + try (Connection conn = TestUtil.openDB()) { + { + PreparedStatement stmt = conn.prepareStatement("SELECT -?"); + bindPositiveOne.bind(stmt); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getMetaData().getColumnCount(), + "number of result columns must match"); + int value = rs.getInt(1); + assertEquals(-1, value); + } + bindNegativeOne.bind(stmt); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getMetaData().getColumnCount(), + "number of result columns must match"); + int value = rs.getInt(1); + assertEquals(1, value); + } + } + { + PreparedStatement stmt = conn.prepareStatement("SELECT -?, ?"); + bindPositiveOne.bind(stmt); + stmt.setString(2, "\nWHERE false --"); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "ResultSet should contain a row"); + assertEquals(2, rs.getMetaData().getColumnCount(), + "rs.getMetaData().getColumnCount("); + int value = rs.getInt(1); + assertEquals(-1, value); + } - stmt.setInt(1, -1); - try (ResultSet rs = stmt.executeQuery()) { - assertTrue(rs.next()); - assertEquals(1, rs.getMetaData().getColumnCount(), "number of result columns must match"); - int value = rs.getInt(1); - assertEquals(1, value); - } + bindNegativeOne.bind(stmt); + stmt.setString(2, "\nWHERE false --"); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "ResultSet should contain a row"); + assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount("); + int value = rs.getInt(1); + assertEquals(1, value); } + + } } + } - @Test - public void negateParameterWithContinuation() throws Exception { - try (Connection conn = TestUtil.openDB()) { - PreparedStatement stmt = conn.prepareStatement("SELECT -?, ?"); + @Test + public void handleInt2() throws SQLException { + testParamInjection( + stmt -> { + stmt.setShort(1, (short) 1); + }, + stmt -> { + stmt.setShort(1, (short) -1); + } + ); + } - stmt.setInt(1, 1); - stmt.setString(2, "\nWHERE false --"); - try (ResultSet rs = stmt.executeQuery()) { - assertTrue(rs.next(), "ResultSet should contain a row"); - assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount("); - int value = rs.getInt(1); - assertEquals(-1, value); - } + @Test + public void handleInt4() throws SQLException { + testParamInjection( + stmt -> { + stmt.setInt(1, 1); + }, + stmt -> { + stmt.setInt(1, -1); + } + ); + } - stmt.setInt(1, -1); - stmt.setString(2, "\nWHERE false --"); - try (ResultSet rs = stmt.executeQuery()) { - assertTrue(rs.next(), "ResultSet should contain a row"); - assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount("); - int value = rs.getInt(1); - assertEquals(1, value); - } + @Test + public void handleBigInt() throws SQLException { + testParamInjection( + stmt -> { + stmt.setLong(1, (long) 1); + }, + stmt -> { + stmt.setLong(1, (long) -1); } - } + ); + } + + @Test + public void handleNumeric() throws SQLException { + testParamInjection( + stmt -> { + stmt.setBigDecimal(1, new BigDecimal("1")); + }, + stmt -> { + stmt.setBigDecimal(1, new BigDecimal("-1")); + } + ); + } + + @Test + public void handleFloat() throws SQLException { + testParamInjection( + stmt -> { + stmt.setFloat(1, 1); + }, + stmt -> { + stmt.setFloat(1, -1); + } + ); + } + + @Test + public void handleDouble() throws SQLException { + testParamInjection( + stmt -> { + stmt.setDouble(1, 1); + }, + stmt -> { + stmt.setDouble(1, -1); + } + ); + } }