diff --git a/src/main/java/io/r2dbc/postgresql/ParsedSql.java b/src/main/java/io/r2dbc/postgresql/ParsedSql.java index e1ee6804..3480d6b9 100644 --- a/src/main/java/io/r2dbc/postgresql/ParsedSql.java +++ b/src/main/java/io/r2dbc/postgresql/ParsedSql.java @@ -24,20 +24,20 @@ class ParsedSql { private final String sql; - private final List statements; + private final List statements; private final int statementCount; private final int parameterCount; - public ParsedSql(String sql, List statements) { + public ParsedSql(String sql, List statements) { this.sql = sql; this.statements = statements; this.statementCount = statements.size(); this.parameterCount = getParameterCount(statements); } - List getStatements() { + List getStatements() { return this.statements; } @@ -53,16 +53,16 @@ public String getSql() { return sql; } - private static int getParameterCount(List statements) { + private static int getParameterCount(List statements) { int sum = 0; - for (TokenizedStatement statement : statements){ + for (Statement statement : statements){ sum += statement.getParameterCount(); } return sum; } public boolean hasDefaultTokenValue(String... tokenValues) { - for (TokenizedStatement statement : this.statements) { + for (Statement statement : this.statements) { for (Token token : statement.getTokens()) { if (token.getType() == TokenType.DEFAULT) { for (String value : tokenValues) { @@ -129,24 +129,17 @@ public String toString() { } - static class TokenizedStatement { - - private final String sql; + static class Statement { private final List tokens; private final int parameterCount; - public TokenizedStatement(String sql, List tokens) { + public Statement(List tokens) { this.tokens = tokens; - this.sql = sql; this.parameterCount = readParameterCount(tokens); } - public String getSql() { - return this.sql; - } - public List getTokens() { return this.tokens; } @@ -164,19 +157,14 @@ public boolean equals(Object o) { return false; } - TokenizedStatement that = (TokenizedStatement) o; + Statement that = (Statement) o; - if (!this.sql.equals(that.sql)) { - return false; - } return this.tokens.equals(that.tokens); } @Override public int hashCode() { - int result = this.sql.hashCode(); - result = 31 * result + this.tokens.hashCode(); - return result; + return this.tokens.hashCode(); } @Override diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java b/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java index 0a68118a..41ed9291 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java @@ -40,7 +40,7 @@ final class PostgresqlBatch implements io.r2dbc.postgresql.api.PostgresqlBatch { public PostgresqlBatch add(String sql) { Assert.requireNonNull(sql, "sql must not be null"); - if (!(PostgresqlSqlParser.tokenize(sql).getParameterCount() == 0)) { + if (!(PostgresqlSqlParser.parse(sql).getParameterCount() == 0)) { throw new IllegalArgumentException(String.format("Statement '%s' is not supported. This is often due to the presence of parameters.", sql)); } diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java b/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java index 9932ac17..b9884a37 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java @@ -38,11 +38,8 @@ class PostgresqlSqlParser { Arrays.sort(SPECIAL_AND_OPERATOR_CHARS); } - public static ParsedSql tokenize(String sql) { + private static List tokenize(String sql) { List tokens = new ArrayList<>(); - List statements = new ArrayList<>(); - - int statementStartIndex = 0; int i = 0; while (i < sql.length()) { char c = sql.charAt(i); @@ -87,21 +84,48 @@ public static ParsedSql tokenize(String sql) { } i += token.getValue().length(); + tokens.add(token); + } + return tokens; + } - if (token.getType() == ParsedSql.TokenType.STATEMENT_END) { + public static ParsedSql parse(String sql) { + List tokens = tokenize(sql); + List statements = new ArrayList<>(); + List functionBodyList = new ArrayList<>(); - tokens.add(token); - statements.add(new ParsedSql.TokenizedStatement(sql.substring(statementStartIndex, i), tokens)); + List currentStatementTokens = new ArrayList<>(); + for (int i = 0; i < tokens.size(); i++) { + ParsedSql.Token current = tokens.get(i); + currentStatementTokens.add(current); - tokens = new ArrayList<>(); - statementStartIndex = i + 1; - } else { - tokens.add(token); + if (current.getType() == ParsedSql.TokenType.DEFAULT) { + String currentValue = current.getValue(); + + if (currentValue.equalsIgnoreCase("BEGIN")) { + if (i + 1 < tokens.size() && tokens.get(i + 1).getValue().equalsIgnoreCase("ATOMIC")) { + functionBodyList.add(true); + } else { + functionBodyList.add(false); + } + } else if (currentValue.equalsIgnoreCase("END") && !functionBodyList.isEmpty()) { + functionBodyList.remove(functionBodyList.size() - 1); + } + } else if (current.getType().equals(ParsedSql.TokenType.STATEMENT_END)) { + boolean inFunctionBody = false; + + for (boolean b : functionBodyList) { + inFunctionBody |= b; + } + if (!inFunctionBody) { + statements.add(new ParsedSql.Statement(currentStatementTokens)); + currentStatementTokens = new ArrayList<>(); + } } } - // If tokens is not empty, implicit statement end - if (!tokens.isEmpty()) { - statements.add(new ParsedSql.TokenizedStatement(sql.substring(statementStartIndex), tokens)); + + if (!currentStatementTokens.isEmpty()) { + statements.add(new ParsedSql.Statement(currentStatementTokens)); } return new ParsedSql(sql, statements); @@ -209,12 +233,13 @@ private static ParsedSql.Token getQuotedIdentifierToken(String sql, int beginInd } } - private static boolean isAsciiLetter(char c){ + private static boolean isAsciiLetter(char c) { char lower = Character.toLowerCase(c); return lower >= 'a' && lower <= 'z'; } - private static boolean isAsciiDigit(char c){ + private static boolean isAsciiDigit(char c) { return c >= '0' && c <= '9'; } + } diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java b/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java index 62b49e5e..d612472a 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java @@ -73,7 +73,7 @@ final class PostgresqlStatement implements io.r2dbc.postgresql.api.PostgresqlSta PostgresqlStatement(ConnectionResources resources, String sql) { this.resources = Assert.requireNonNull(resources, "resources must not be null"); - this.parsedSql = PostgresqlSqlParser.tokenize(Assert.requireNonNull(sql, "sql must not be null")); + this.parsedSql = PostgresqlSqlParser.parse(Assert.requireNonNull(sql, "sql must not be null")); this.connectionContext = resources.getClient().getContext(); this.bindings = new ArrayDeque<>(this.parsedSql.getParameterCount()); diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java index 08ba9428..ad1008b9 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java @@ -23,6 +23,7 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertIterableEquals; import static org.junit.jupiter.api.Assertions.assertThrows; class PostgresqlSqlParserTest { @@ -116,48 +117,49 @@ class SingleTokenExceptionTests { @Test void unclosedSingleQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("'test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("'test")); } @Test void unclosedDollarQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$$test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$$test")); } @Test void unclosedTaggedDollarQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$abc$test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$abc$test")); } @Test void unclosedQuotedIdentifierThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("\"test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("\"test")); } @Test void unclosedBlockCommentThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("/*test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("/*test")); } @Test void unclosedNestedBlockCommentThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("/*/*test*/")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("/*/*test*/")); } @Test void invalidParameterCharacterThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$1test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$1test")); } @Test void invalidTaggedDollarQuoteThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$a b$test$a b$")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$a b$test$a b$")); } @Test void unclosedTaggedDollarQuoteThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$abc")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$abc")); } + } @Nested @@ -242,13 +244,33 @@ void simpleSelectStatementIsTokenized() { ); } + @Test + void simpleSelectStatementWithFunctionBodyIsTokenized() { + assertSingleStatementEquals("CREATE FUNCTION test() BEGIN ATOMIC SELECT 1; SELECT 2; END", + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "CREATE"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FUNCTION"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "test"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "("), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, ")"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "BEGIN"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "ATOMIC"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "2"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "END") + ); + } + } void assertSingleStatementEquals(String sql, ParsedSql.Token... tokens) { - ParsedSql parsedSql = PostgresqlSqlParser.tokenize(sql); + ParsedSql parsedSql = PostgresqlSqlParser.parse(sql); assertEquals(1, parsedSql.getStatements().size(), "Parse returned zero or more than 2 statements"); - ParsedSql.TokenizedStatement statement = parsedSql.getStatements().get(0); - assertEquals(new ParsedSql.TokenizedStatement(sql, Arrays.asList(tokens)), statement); + ParsedSql.Statement statement = parsedSql.getStatements().get(0); + assertIterableEquals(Arrays.asList(tokens), statement.getTokens()); } } @@ -258,30 +280,30 @@ class MultipleStatementTests { @Test void simpleMultipleStatementIsTokenized() { - ParsedSql parsedSql = PostgresqlSqlParser.tokenize("DELETE * FROM X; SELECT 1;"); - List statements = parsedSql.getStatements(); + ParsedSql parsedSql = PostgresqlSqlParser.parse("DELETE * FROM X; SELECT 1;"); + List statements = parsedSql.getStatements(); assertEquals(2, statements.size()); - ParsedSql.TokenizedStatement statementA = statements.get(0); - ParsedSql.TokenizedStatement statementB = statements.get(1); - - assertEquals(new ParsedSql.TokenizedStatement("DELETE * FROM X;", - Arrays.asList( - new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "DELETE"), - new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), - new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FROM"), - new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "X"), - new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";") - )), - statementA + ParsedSql.Statement statementA = statements.get(0); + ParsedSql.Statement statementB = statements.get(1); + + assertIterableEquals( + Arrays.asList( + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "DELETE"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FROM"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "X"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";") + ), + statementA.getTokens() ); - assertEquals(new ParsedSql.TokenizedStatement("SELECT 1;", - Arrays.asList( - new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), - new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"), - new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";") - )), - statementB + assertIterableEquals( + Arrays.asList( + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";") + ), + statementB.getTokens() ); }