Skip to content

Commit

Permalink
Extend PostgresqlSqlLexer to handle PG14 SQL-standard function body…
Browse files Browse the repository at this point in the history
… syntax

- Lexing/parsing is now done in two steps: first only tokenize, then parse into statements
- Added support for function bodies ("BEGIN ATOMIC")
- Added a test case for newly supported grammar

[resolves #512][#513]
  • Loading branch information
toverdijk authored and mp911de committed May 25, 2022
1 parent c74bd6e commit 7f9b349
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 73 deletions.
32 changes: 10 additions & 22 deletions src/main/java/io/r2dbc/postgresql/ParsedSql.java
Expand Up @@ -24,20 +24,20 @@ class ParsedSql {

private final String sql;

private final List<TokenizedStatement> statements;
private final List<Statement> statements;

private final int statementCount;

private final int parameterCount;

public ParsedSql(String sql, List<TokenizedStatement> statements) {
public ParsedSql(String sql, List<Statement> statements) {
this.sql = sql;
this.statements = statements;
this.statementCount = statements.size();
this.parameterCount = getParameterCount(statements);
}

List<TokenizedStatement> getStatements() {
List<Statement> getStatements() {
return this.statements;
}

Expand All @@ -53,16 +53,16 @@ public String getSql() {
return sql;
}

private static int getParameterCount(List<TokenizedStatement> statements) {
private static int getParameterCount(List<Statement> statements) {
int sum = 0;
for (TokenizedStatement statement : statements){
for (Statement statement : statements){
sum += statement.getParameterCount();
}
return sum;
}

public boolean hasDefaultTokenValue(String... tokenValues) {
for (TokenizedStatement statement : this.statements) {
for (Statement statement : this.statements) {
for (Token token : statement.getTokens()) {
if (token.getType() == TokenType.DEFAULT) {
for (String value : tokenValues) {
Expand Down Expand Up @@ -129,24 +129,17 @@ public String toString() {

}

static class TokenizedStatement {

private final String sql;
static class Statement {

private final List<Token> tokens;

private final int parameterCount;

public TokenizedStatement(String sql, List<Token> tokens) {
public Statement(List<Token> tokens) {
this.tokens = tokens;
this.sql = sql;
this.parameterCount = readParameterCount(tokens);
}

public String getSql() {
return this.sql;
}

public List<Token> getTokens() {
return this.tokens;
}
Expand All @@ -164,19 +157,14 @@ public boolean equals(Object o) {
return false;
}

TokenizedStatement that = (TokenizedStatement) o;
Statement that = (Statement) o;

if (!this.sql.equals(that.sql)) {
return false;
}
return this.tokens.equals(that.tokens);
}

@Override
public int hashCode() {
int result = this.sql.hashCode();
result = 31 * result + this.tokens.hashCode();
return result;
return this.tokens.hashCode();
}

@Override
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java
Expand Up @@ -40,7 +40,7 @@ final class PostgresqlBatch implements io.r2dbc.postgresql.api.PostgresqlBatch {
public PostgresqlBatch add(String sql) {
Assert.requireNonNull(sql, "sql must not be null");

if (!(PostgresqlSqlParser.tokenize(sql).getParameterCount() == 0)) {
if (!(PostgresqlSqlParser.parse(sql).getParameterCount() == 0)) {
throw new IllegalArgumentException(String.format("Statement '%s' is not supported. This is often due to the presence of parameters.", sql));
}

Expand Down
57 changes: 41 additions & 16 deletions src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java
Expand Up @@ -38,11 +38,8 @@ class PostgresqlSqlParser {
Arrays.sort(SPECIAL_AND_OPERATOR_CHARS);
}

public static ParsedSql tokenize(String sql) {
private static List<ParsedSql.Token> tokenize(String sql) {
List<ParsedSql.Token> tokens = new ArrayList<>();
List<ParsedSql.TokenizedStatement> statements = new ArrayList<>();

int statementStartIndex = 0;
int i = 0;
while (i < sql.length()) {
char c = sql.charAt(i);
Expand Down Expand Up @@ -87,21 +84,48 @@ public static ParsedSql tokenize(String sql) {
}

i += token.getValue().length();
tokens.add(token);
}
return tokens;
}

if (token.getType() == ParsedSql.TokenType.STATEMENT_END) {
public static ParsedSql parse(String sql) {
List<ParsedSql.Token> tokens = tokenize(sql);
List<ParsedSql.Statement> statements = new ArrayList<>();
List<Boolean> functionBodyList = new ArrayList<>();

tokens.add(token);
statements.add(new ParsedSql.TokenizedStatement(sql.substring(statementStartIndex, i), tokens));
List<ParsedSql.Token> currentStatementTokens = new ArrayList<>();
for (int i = 0; i < tokens.size(); i++) {
ParsedSql.Token current = tokens.get(i);
currentStatementTokens.add(current);

tokens = new ArrayList<>();
statementStartIndex = i + 1;
} else {
tokens.add(token);
if (current.getType() == ParsedSql.TokenType.DEFAULT) {
String currentValue = current.getValue();

if (currentValue.equalsIgnoreCase("BEGIN")) {
if (i + 1 < tokens.size() && tokens.get(i + 1).getValue().equalsIgnoreCase("ATOMIC")) {
functionBodyList.add(true);
} else {
functionBodyList.add(false);
}
} else if (currentValue.equalsIgnoreCase("END") && !functionBodyList.isEmpty()) {
functionBodyList.remove(functionBodyList.size() - 1);
}
} else if (current.getType().equals(ParsedSql.TokenType.STATEMENT_END)) {
boolean inFunctionBody = false;

for (boolean b : functionBodyList) {
inFunctionBody |= b;
}
if (!inFunctionBody) {
statements.add(new ParsedSql.Statement(currentStatementTokens));
currentStatementTokens = new ArrayList<>();
}
}
}
// If tokens is not empty, implicit statement end
if (!tokens.isEmpty()) {
statements.add(new ParsedSql.TokenizedStatement(sql.substring(statementStartIndex), tokens));

if (!currentStatementTokens.isEmpty()) {
statements.add(new ParsedSql.Statement(currentStatementTokens));
}

return new ParsedSql(sql, statements);
Expand Down Expand Up @@ -209,12 +233,13 @@ private static ParsedSql.Token getQuotedIdentifierToken(String sql, int beginInd
}
}

private static boolean isAsciiLetter(char c){
private static boolean isAsciiLetter(char c) {
char lower = Character.toLowerCase(c);
return lower >= 'a' && lower <= 'z';
}

private static boolean isAsciiDigit(char c){
private static boolean isAsciiDigit(char c) {
return c >= '0' && c <= '9';
}

}
2 changes: 1 addition & 1 deletion src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java
Expand Up @@ -73,7 +73,7 @@ final class PostgresqlStatement implements io.r2dbc.postgresql.api.PostgresqlSta

PostgresqlStatement(ConnectionResources resources, String sql) {
this.resources = Assert.requireNonNull(resources, "resources must not be null");
this.parsedSql = PostgresqlSqlParser.tokenize(Assert.requireNonNull(sql, "sql must not be null"));
this.parsedSql = PostgresqlSqlParser.parse(Assert.requireNonNull(sql, "sql must not be null"));
this.connectionContext = resources.getClient().getContext();
this.bindings = new ArrayDeque<>(this.parsedSql.getParameterCount());

Expand Down
88 changes: 55 additions & 33 deletions src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java
Expand Up @@ -23,6 +23,7 @@
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

class PostgresqlSqlParserTest {
Expand Down Expand Up @@ -116,48 +117,49 @@ class SingleTokenExceptionTests {

@Test
void unclosedSingleQuotedStringThrowsIllegalArgumentException() {
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("'test"));
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("'test"));
}

@Test
void unclosedDollarQuotedStringThrowsIllegalArgumentException() {
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$$test"));
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$$test"));
}

@Test
void unclosedTaggedDollarQuotedStringThrowsIllegalArgumentException() {
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$abc$test"));
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$abc$test"));
}

@Test
void unclosedQuotedIdentifierThrowsIllegalArgumentException() {
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("\"test"));
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("\"test"));
}

@Test
void unclosedBlockCommentThrowsIllegalArgumentException() {
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("/*test"));
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("/*test"));
}

@Test
void unclosedNestedBlockCommentThrowsIllegalArgumentException() {
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("/*/*test*/"));
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("/*/*test*/"));
}

@Test
void invalidParameterCharacterThrowsIllegalArgumentException() {
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$1test"));
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$1test"));
}

@Test
void invalidTaggedDollarQuoteThrowsIllegalArgumentException() {
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$a b$test$a b$"));
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$a b$test$a b$"));
}

@Test
void unclosedTaggedDollarQuoteThrowsIllegalArgumentException() {
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$abc"));
assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$abc"));
}

}

@Nested
Expand Down Expand Up @@ -242,13 +244,33 @@ void simpleSelectStatementIsTokenized() {
);
}

@Test
void simpleSelectStatementWithFunctionBodyIsTokenized() {
assertSingleStatementEquals("CREATE FUNCTION test() BEGIN ATOMIC SELECT 1; SELECT 2; END",
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "CREATE"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FUNCTION"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "test"),
new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "("),
new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, ")"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "BEGIN"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "ATOMIC"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"),
new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "2"),
new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "END")
);
}

}

void assertSingleStatementEquals(String sql, ParsedSql.Token... tokens) {
ParsedSql parsedSql = PostgresqlSqlParser.tokenize(sql);
ParsedSql parsedSql = PostgresqlSqlParser.parse(sql);
assertEquals(1, parsedSql.getStatements().size(), "Parse returned zero or more than 2 statements");
ParsedSql.TokenizedStatement statement = parsedSql.getStatements().get(0);
assertEquals(new ParsedSql.TokenizedStatement(sql, Arrays.asList(tokens)), statement);
ParsedSql.Statement statement = parsedSql.getStatements().get(0);
assertIterableEquals(Arrays.asList(tokens), statement.getTokens());
}

}
Expand All @@ -258,30 +280,30 @@ class MultipleStatementTests {

@Test
void simpleMultipleStatementIsTokenized() {
ParsedSql parsedSql = PostgresqlSqlParser.tokenize("DELETE * FROM X; SELECT 1;");
List<ParsedSql.TokenizedStatement> statements = parsedSql.getStatements();
ParsedSql parsedSql = PostgresqlSqlParser.parse("DELETE * FROM X; SELECT 1;");
List<ParsedSql.Statement> statements = parsedSql.getStatements();
assertEquals(2, statements.size());
ParsedSql.TokenizedStatement statementA = statements.get(0);
ParsedSql.TokenizedStatement statementB = statements.get(1);

assertEquals(new ParsedSql.TokenizedStatement("DELETE * FROM X;",
Arrays.asList(
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "DELETE"),
new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FROM"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "X"),
new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";")
)),
statementA
ParsedSql.Statement statementA = statements.get(0);
ParsedSql.Statement statementB = statements.get(1);

assertIterableEquals(
Arrays.asList(
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "DELETE"),
new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FROM"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "X"),
new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";")
),
statementA.getTokens()
);

assertEquals(new ParsedSql.TokenizedStatement("SELECT 1;",
Arrays.asList(
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"),
new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";")
)),
statementB
assertIterableEquals(
Arrays.asList(
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"),
new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"),
new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";")
),
statementB.getTokens()
);

}
Expand Down

0 comments on commit 7f9b349

Please sign in to comment.