Skip to content

Commit

Permalink
Merge pull request from GHSA-24rp-q3w6-vc56
Browse files Browse the repository at this point in the history
* test: Add failing test for simple query mode parameter injection

Adds a failing test to demonstrate how direct parameter injection in simple query
mode allows for modifying the executed SQL. The issue arises when a bind placeholder
is prefixed with a negation. The direct replacement of a negative value causes the
resulting token to be considered a line comment.

For example the SQL:

    SELECT -?, ?

With parameter values of -1 and any text with a newline in the second parameter
allows arbitrary command execution, e.g. with values -1 and "\nWHERE false" causes
the query to return no rows. More complicated examples can be created by adding
statement terminators.

* fix: Escape literal parameter values in simple query mode

Escape all literal parameter values and wrap them in parentheses to prevent SQL injection
when using specially crafted parameters and SQL in simple query mode. Previously the raw
value of the parameter, e.g. 123, was injected into the ? placeholder. With this change
all parameters are injected as '...value...' literals that are cast to the desired type
by the server and wrapped in parentheses.

So the SQL SELECT -? with a parameter of -123 would become: SELECT -('-123'::int4)

* fix: Add parentheses around NULL parameter values in simple query mode

* fix: remove repeated quoteAndCast calls, and ensure numerics are quoted as well

* test: Add parameter injection tests for additional numerical types

* reformat file

---------

Co-authored-by: Sehrope Sarkuni <sehrope@jackdb.com>
Co-authored-by: Dave Cramer <davecramer@gmail.com>
  • Loading branch information
3 people committed Feb 20, 2024
1 parent 93b0fcb commit 06abfb7
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 62 deletions.
Expand Up @@ -209,7 +209,7 @@ public void setNull(@Positive int index, int oid) throws SQLException {
* {}
* </pre>
**/
private static String quoteAndCast(String text, String type, boolean standardConformingStrings) {
private static String quoteAndCast(String text, @Nullable String type, boolean standardConformingStrings) {
StringBuilder sb = new StringBuilder((text.length() + 10) / 10 * 11); // Add 10% for escaping.
sb.append("('");
try {
Expand Down Expand Up @@ -240,80 +240,103 @@ public String toString(@Positive int index, boolean standardConformingStrings) {
return "?";
} else if (paramValue == NULL_OBJECT) {
return "(NULL)";
} else if ((flags[index] & BINARY) == BINARY) {
}
String textValue;
String type;
if ((flags[index] & BINARY) == BINARY) {
// handle some of the numeric types

switch (paramTypes[index]) {
case Oid.INT2:
short s = ByteConverter.int2((byte[]) paramValue, 0);
return quoteAndCast(Short.toString(s), "int2", standardConformingStrings);
textValue = Short.toString(s);
type = "int2";
break;

case Oid.INT4:
int i = ByteConverter.int4((byte[]) paramValue, 0);
return quoteAndCast(Integer.toString(i), "int4", standardConformingStrings);
textValue = Integer.toString(i);
type = "int4";
break;

case Oid.INT8:
long l = ByteConverter.int8((byte[]) paramValue, 0);
return quoteAndCast(Long.toString(l), "int8", standardConformingStrings);
textValue = Long.toString(l);
type = "int8";
break;

case Oid.FLOAT4:
float f = ByteConverter.float4((byte[]) paramValue, 0);
if (Float.isNaN(f)) {
return "('NaN'::real)";
}
return quoteAndCast(Float.toString(f), "float", standardConformingStrings);
textValue = Float.toString(f);
type = "real";
break;

case Oid.FLOAT8:
double d = ByteConverter.float8((byte[]) paramValue, 0);
if (Double.isNaN(d)) {
return "('NaN'::double precision)";
}
return quoteAndCast(Double.toString(d), "double precision", standardConformingStrings);
textValue = Double.toString(d);
type = "double precision";
break;

case Oid.NUMERIC:
Number n = ByteConverter.numeric((byte[]) paramValue);
if (n instanceof Double) {
assert ((Double) n).isNaN();
return "('NaN'::numeric)";
}
return n.toString();
textValue = n.toString();
type = "numeric";
break;

case Oid.UUID:
String uuid =
textValue =
new UUIDArrayAssistant().buildElement((byte[]) paramValue, 0, 16).toString();
return quoteAndCast(uuid, "uuid", standardConformingStrings);
type = "uuid";
break;

case Oid.POINT:
PGpoint pgPoint = new PGpoint();
pgPoint.setByteValue((byte[]) paramValue, 0);
return quoteAndCast(pgPoint.toString(), "point", standardConformingStrings);
textValue = pgPoint.toString();
type = "point";
break;

case Oid.BOX:
PGbox pgBox = new PGbox();
pgBox.setByteValue((byte[]) paramValue, 0);
return quoteAndCast(pgBox.toString(), "box", standardConformingStrings);
textValue = pgBox.toString();
type = "box";
break;

default:
return "?";
}
return "?";
} else {
String param = paramValue.toString();
textValue = paramValue.toString();
int paramType = paramTypes[index];
if (paramType == Oid.TIMESTAMP) {
return quoteAndCast(param, "timestamp", standardConformingStrings);
type = "timestamp";
} else if (paramType == Oid.TIMESTAMPTZ) {
return quoteAndCast(param, "timestamp with time zone", standardConformingStrings);
type = "timestamp with time zone";
} else if (paramType == Oid.TIME) {
return quoteAndCast(param, "time", standardConformingStrings);
type = "time";
} else if (paramType == Oid.TIMETZ) {
return quoteAndCast(param, "time with time zone", standardConformingStrings);
type = "time with time zone";
} else if (paramType == Oid.DATE) {
return quoteAndCast(param, "date", standardConformingStrings);
type = "date";
} else if (paramType == Oid.INTERVAL) {
return quoteAndCast(param, "interval", standardConformingStrings);
type = "interval";
} else if (paramType == Oid.NUMERIC) {
return quoteAndCast(param, "numeric", standardConformingStrings);
type = "numeric";
} else {
type = null;
}
return quoteAndCast(param, null, standardConformingStrings);
}
return quoteAndCast(textValue, type, standardConformingStrings);
}

@Override
Expand Down
155 changes: 116 additions & 39 deletions pgjdbc/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java
Expand Up @@ -12,56 +12,133 @@

import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;

public class ParameterInjectionTest {
@Test
public void negateParameter() throws Exception {
try (Connection conn = TestUtil.openDB()) {
PreparedStatement stmt = conn.prepareStatement("SELECT -?");
private interface ParameterBinder {
void bind(PreparedStatement stmt) throws SQLException;
}

stmt.setInt(1, 1);
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next());
assertEquals(1, rs.getMetaData().getColumnCount(), "number of result columns must match");
int value = rs.getInt(1);
assertEquals(-1, value);
}
private void testParamInjection(ParameterBinder bindPositiveOne, ParameterBinder bindNegativeOne)
throws SQLException {
try (Connection conn = TestUtil.openDB()) {
{
PreparedStatement stmt = conn.prepareStatement("SELECT -?");
bindPositiveOne.bind(stmt);
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next());
assertEquals(1, rs.getMetaData().getColumnCount(),
"number of result columns must match");
int value = rs.getInt(1);
assertEquals(-1, value);
}
bindNegativeOne.bind(stmt);
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next());
assertEquals(1, rs.getMetaData().getColumnCount(),
"number of result columns must match");
int value = rs.getInt(1);
assertEquals(1, value);
}
}
{
PreparedStatement stmt = conn.prepareStatement("SELECT -?, ?");
bindPositiveOne.bind(stmt);
stmt.setString(2, "\nWHERE false --");
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next(), "ResultSet should contain a row");
assertEquals(2, rs.getMetaData().getColumnCount(),
"rs.getMetaData().getColumnCount(");
int value = rs.getInt(1);
assertEquals(-1, value);
}

stmt.setInt(1, -1);
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next());
assertEquals(1, rs.getMetaData().getColumnCount(), "number of result columns must match");
int value = rs.getInt(1);
assertEquals(1, value);
}
bindNegativeOne.bind(stmt);
stmt.setString(2, "\nWHERE false --");
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next(), "ResultSet should contain a row");
assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount(");
int value = rs.getInt(1);
assertEquals(1, value);
}

}
}
}

@Test
public void negateParameterWithContinuation() throws Exception {
try (Connection conn = TestUtil.openDB()) {
PreparedStatement stmt = conn.prepareStatement("SELECT -?, ?");
@Test
public void handleInt2() throws SQLException {
testParamInjection(
stmt -> {
stmt.setShort(1, (short) 1);
},
stmt -> {
stmt.setShort(1, (short) -1);
}
);
}

stmt.setInt(1, 1);
stmt.setString(2, "\nWHERE false --");
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next(), "ResultSet should contain a row");
assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount(");
int value = rs.getInt(1);
assertEquals(-1, value);
}
@Test
public void handleInt4() throws SQLException {
testParamInjection(
stmt -> {
stmt.setInt(1, 1);
},
stmt -> {
stmt.setInt(1, -1);
}
);
}

stmt.setInt(1, -1);
stmt.setString(2, "\nWHERE false --");
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next(), "ResultSet should contain a row");
assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount(");
int value = rs.getInt(1);
assertEquals(1, value);
}
@Test
public void handleBigInt() throws SQLException {
testParamInjection(
stmt -> {
stmt.setLong(1, (long) 1);
},
stmt -> {
stmt.setLong(1, (long) -1);
}
}
);
}

@Test
public void handleNumeric() throws SQLException {
testParamInjection(
stmt -> {
stmt.setBigDecimal(1, new BigDecimal("1"));
},
stmt -> {
stmt.setBigDecimal(1, new BigDecimal("-1"));
}
);
}

@Test
public void handleFloat() throws SQLException {
testParamInjection(
stmt -> {
stmt.setFloat(1, 1);
},
stmt -> {
stmt.setFloat(1, -1);
}
);
}

@Test
public void handleDouble() throws SQLException {
testParamInjection(
stmt -> {
stmt.setDouble(1, 1);
},
stmt -> {
stmt.setDouble(1, -1);
}
);
}
}

0 comments on commit 06abfb7

Please sign in to comment.