Skip to content

Commit

Permalink
Polishing.
Browse files Browse the repository at this point in the history
Reorder methods. Use CharObjectMap for operator character lookup instead of linear array iteration. Extract peek/has next token functionality into dedicated methods.

Migrate assertions to AssertJ. Add benchmark.

[#512][resolves #513]

Signed-off-by: Mark Paluch <mpaluch@vmware.com>
  • Loading branch information
mp911de committed May 25, 2022
1 parent 7f9b349 commit b65db10
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 95 deletions.
@@ -0,0 +1,56 @@
/*
* Copyright 2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.r2dbc.postgresql;

import org.junit.platform.commons.annotation.Testable;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.infra.Blackhole;

import java.util.concurrent.TimeUnit;

/**
* Benchmarks for {@link PostgresqlSqlParser}.
*/
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.SECONDS)
@Testable
public class PostgresqlSqlParserBenchmarks extends BenchmarkSettings {

@Benchmark
public void simpleStatement(Blackhole blackhole) {
blackhole.consume(PostgresqlSqlParser.parse("SELECT * FROM FOO"));
}

@Benchmark
public void parametrizedStatement(Blackhole blackhole) {
blackhole.consume(PostgresqlSqlParser.parse("SELECT * FROM FOO WHERE $2 = $1"));
}

@Benchmark
public void createOrReplaceFunction(Blackhole blackhole) {
blackhole.consume(PostgresqlSqlParser.parse("CREATE OR REPLACE FUNCTION asterisks(n int)\n" +
" RETURNS SETOF text\n" +
" LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE\n" +
"BEGIN ATOMIC\n" +
"SELECT repeat('*', g) FROM generate_series (1, n) g; -- <-- Note this semicolon\n" +
"END;"));
}

}
4 changes: 2 additions & 2 deletions src/main/java/io/r2dbc/postgresql/ParsedSql.java
Expand Up @@ -50,12 +50,12 @@ public int getParameterCount() {
}

public String getSql() {
return sql;
return this.sql;
}

private static int getParameterCount(List<Statement> statements) {
int sum = 0;
for (Statement statement : statements){
for (Statement statement : statements) {
sum += statement.getParameterCount();
}
return sum;
Expand Down
129 changes: 78 additions & 51 deletions src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java
Expand Up @@ -16,8 +16,11 @@

package io.r2dbc.postgresql;

import io.netty.util.collection.CharObjectHashMap;
import io.netty.util.collection.CharObjectMap;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

import static java.lang.Character.isWhitespace;
Expand All @@ -29,13 +32,79 @@
*/
class PostgresqlSqlParser {

private static final char[] SPECIAL_AND_OPERATOR_CHARS = {
'+', '-', '*', '/', '<', '>', '=', '~', '!', '@', '#', '%', '^', '&', '|', '`', '?',
'(', ')', '[', ']', ',', ';', ':', '*', '.', '\'', '"'
};
private static final CharObjectMap<Object> SPECIAL_AND_OPERATOR_CHARS = new CharObjectHashMap<>();

static {
Arrays.sort(SPECIAL_AND_OPERATOR_CHARS);
char[] specialCharsAndOperators = {'+', '-', '*', '/', '<', '>', '=', '~', '!', '@', '#', '%', '^', '&', '|', '`', '?',
'(', ')', '[', ']', ',', ';', ':', '*', '.', '\'', '"'};

for (char c : specialCharsAndOperators) {
SPECIAL_AND_OPERATOR_CHARS.put(c, new Object());
}
}

public static ParsedSql parse(String sql) {
List<ParsedSql.Token> tokens = tokenize(sql);
List<ParsedSql.Statement> statements = new ArrayList<>();
LinkedList<Boolean> functionBodyList = null;

List<ParsedSql.Token> currentStatementTokens = new ArrayList<>(tokens.size());

for (int i = 0; i < tokens.size(); i++) {
ParsedSql.Token current = tokens.get(i);
currentStatementTokens.add(current);

if (current.getType() == ParsedSql.TokenType.DEFAULT) {
String currentValue = current.getValue();

if (currentValue.equalsIgnoreCase("BEGIN")) {
if (functionBodyList == null) {
functionBodyList = new LinkedList<>();
}
if (hasNextToken(tokens, i) && peekNext(tokens, i).getValue().equalsIgnoreCase("ATOMIC")) {
functionBodyList.add(true);
} else {
functionBodyList.add(false);
}
} else if (currentValue.equalsIgnoreCase("END") && functionBodyList != null && !functionBodyList.isEmpty()) {
functionBodyList.removeLast();
}
} else if (current.getType().equals(ParsedSql.TokenType.STATEMENT_END)) {
boolean inFunctionBody = false;

if (functionBodyList != null) {
for (boolean b : functionBodyList) {
inFunctionBody |= b;
}
}
if (!inFunctionBody) {
statements.add(new ParsedSql.Statement(currentStatementTokens));
currentStatementTokens = new ArrayList<>();
}
}
}

if (!currentStatementTokens.isEmpty()) {
statements.add(new ParsedSql.Statement(currentStatementTokens));
}

return new ParsedSql(sql, statements);
}

private static ParsedSql.Token peekNext(List<ParsedSql.Token> tokens, int index) {
return tokens.get(index + 1);
}

private static boolean hasNextToken(List<ParsedSql.Token> tokens, int index) {
return tokens.size() > index + 1;
}

private static char peekNext(CharSequence sequence, int index) {
return sequence.charAt(index + 1);
}

private static boolean hasNextToken(CharSequence sequence, int index) {
return sequence.length() > index + 1;
}

private static List<ParsedSql.Token> tokenize(String sql) {
Expand All @@ -57,12 +126,12 @@ private static List<ParsedSql.Token> tokenize(String sql) {
token = getQuotedIdentifierToken(sql, i);
break;
case '-': // Possible start of double-dash comment
if ((i + 1) < sql.length() && sql.charAt(i + 1) == '-') {
if (hasNextToken(sql, i) && peekNext(sql, i) == '-') {
token = getCommentToLineEndToken(sql, i);
}
break;
case '/': // Possible start of c-style comment
if ((i + 1) < sql.length() && sql.charAt(i + 1) == '*') {
if (hasNextToken(sql, i) && peekNext(sql, i) == '*') {
token = getBlockCommentToken(sql, i);
}
break;
Expand All @@ -89,48 +158,6 @@ private static List<ParsedSql.Token> tokenize(String sql) {
return tokens;
}

public static ParsedSql parse(String sql) {
List<ParsedSql.Token> tokens = tokenize(sql);
List<ParsedSql.Statement> statements = new ArrayList<>();
List<Boolean> functionBodyList = new ArrayList<>();

List<ParsedSql.Token> currentStatementTokens = new ArrayList<>();
for (int i = 0; i < tokens.size(); i++) {
ParsedSql.Token current = tokens.get(i);
currentStatementTokens.add(current);

if (current.getType() == ParsedSql.TokenType.DEFAULT) {
String currentValue = current.getValue();

if (currentValue.equalsIgnoreCase("BEGIN")) {
if (i + 1 < tokens.size() && tokens.get(i + 1).getValue().equalsIgnoreCase("ATOMIC")) {
functionBodyList.add(true);
} else {
functionBodyList.add(false);
}
} else if (currentValue.equalsIgnoreCase("END") && !functionBodyList.isEmpty()) {
functionBodyList.remove(functionBodyList.size() - 1);
}
} else if (current.getType().equals(ParsedSql.TokenType.STATEMENT_END)) {
boolean inFunctionBody = false;

for (boolean b : functionBodyList) {
inFunctionBody |= b;
}
if (!inFunctionBody) {
statements.add(new ParsedSql.Statement(currentStatementTokens));
currentStatementTokens = new ArrayList<>();
}
}
}

if (!currentStatementTokens.isEmpty()) {
statements.add(new ParsedSql.Statement(currentStatementTokens));
}

return new ParsedSql(sql, statements);
}

private static ParsedSql.Token getDefaultToken(String sql, int beginIndex) {
for (int i = beginIndex + 1; i < sql.length(); i++) {
char c = sql.charAt(i);
Expand All @@ -142,7 +169,7 @@ private static ParsedSql.Token getDefaultToken(String sql, int beginIndex) {
}

private static boolean isSpecialOrOperatorChar(char c) {
return Arrays.binarySearch(SPECIAL_AND_OPERATOR_CHARS, c) >= 0;
return SPECIAL_AND_OPERATOR_CHARS.containsKey(c);
}

private static ParsedSql.Token getBlockCommentToken(String sql, int beginIndex) {
Expand Down

0 comments on commit b65db10

Please sign in to comment.