Skip to content

Commit

Permalink
SQL AST: fix aggregate functions with GROUP BY (#599)
Browse files Browse the repository at this point in the history
  • Loading branch information
hemberger committed Apr 13, 2023
1 parent b4e5054 commit e01951c
Show file tree
Hide file tree
Showing 8 changed files with 4,428 additions and 1,325 deletions.
16 changes: 15 additions & 1 deletion src/SqlAst/AvgReturnTypeExtension.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@

final class AvgReturnTypeExtension implements QueryFunctionReturnTypeExtension
{
/**
* @var bool
*/
private $hasGroupBy;

public function __construct(bool $hasGroupBy)
{
$this->hasGroupBy = $hasGroupBy;
}

public function isFunctionSupported(FunctionCall $expression): bool
{
return \in_array($expression->getFunction()->getName(), [BuiltInFunction::AVG], true);
Expand All @@ -33,6 +43,7 @@ public function getReturnType(FunctionCall $expression, QueryScope $scope): ?Typ
if ($argType->isNull()->yes()) {
return $argType;
}
$containsNull = TypeCombinator::containsNull($argType);
$argType = TypeCombinator::removeNull($argType);

if ($argType instanceof UnionType) {
Expand All @@ -45,7 +56,10 @@ public function getReturnType(FunctionCall $expression, QueryScope $scope): ?Typ
$newType = $this->convertToAvgType($argType);
}

return TypeCombinator::addNull($newType);
if ($containsNull || ! $this->hasGroupBy) {
$newType = TypeCombinator::addNull($newType);
}
return $newType;
}

private function convertToAvgType(Type $argType): Type
Expand Down
15 changes: 14 additions & 1 deletion src/SqlAst/MinMaxReturnTypeExtension.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@

final class MinMaxReturnTypeExtension implements QueryFunctionReturnTypeExtension
{
/**
* @var bool
*/
private $hasGroupBy;

public function __construct(bool $hasGroupBy)
{
$this->hasGroupBy = $hasGroupBy;
}

public function isFunctionSupported(FunctionCall $expression): bool
{
return \in_array($expression->getFunction()->getName(), [BuiltInFunction::MIN, BuiltInFunction::MAX], true);
Expand All @@ -26,6 +36,9 @@ public function getReturnType(FunctionCall $expression, QueryScope $scope): ?Typ

$argType = $scope->getType($args[0]);

return TypeCombinator::addNull($argType);
if (! $this->hasGroupBy) {
$argType = TypeCombinator::addNull($argType);
}
return $argType;
}
}
4 changes: 3 additions & 1 deletion src/SqlAst/ParserInference.php
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public function narrowResultType(string $queryString, ConstantArrayType $resultT
$selectColumns = null;
$fromTable = null;
$where = null;
$groupBy = null;
$joins = [];
foreach ($commands as [$command]) {
// Parser does not throw exceptions. this allows to parse partially invalid code and not fail on first error
Expand All @@ -64,6 +65,7 @@ public function narrowResultType(string $queryString, ConstantArrayType $resultT
}
$from = $command->getFrom();
$where = $command->getWhere();
$groupBy = $command->getGroupBy();

if (null === $from) {
// no FROM clause, use an empty Table to signify this
Expand Down Expand Up @@ -136,7 +138,7 @@ public function narrowResultType(string $queryString, ConstantArrayType $resultT
throw new ShouldNotHappenException();
}

$queryScope = new QueryScope($fromTable, $joins, $where);
$queryScope = new QueryScope($fromTable, $joins, $where, $groupBy !== null);

// If we're selecting '*', get the selected columns from the table
if (\count($selectColumns) === 1 && $selectColumns[0]->getExpression() instanceof Asterisk) {
Expand Down
8 changes: 4 additions & 4 deletions src/SqlAst/QueryScope.php
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ final class QueryScope
/**
* @param list<Join> $joinedTables
*/
public function __construct(Table $fromTable, array $joinedTables, ?SqlSerializable $whereCondition)
public function __construct(Table $fromTable, array $joinedTables, ?SqlSerializable $whereCondition, bool $hasGroupBy)
{
$this->fromTable = $fromTable;
$this->joinedTables = $joinedTables;
Expand All @@ -73,12 +73,12 @@ public function __construct(Table $fromTable, array $joinedTables, ?SqlSerializa
new InstrReturnTypeExtension(),
new StrCaseReturnTypeExtension(),
new ReplaceReturnTypeExtension(),
new AvgReturnTypeExtension(),
new SumReturnTypeExtension(),
new AvgReturnTypeExtension($hasGroupBy),
new SumReturnTypeExtension($hasGroupBy),
new IsNullReturnTypeExtension(),
new AbsReturnTypeExtension(),
new RoundReturnTypeExtension(),
new MinMaxReturnTypeExtension(),
new MinMaxReturnTypeExtension($hasGroupBy),
];
}

Expand Down
13 changes: 11 additions & 2 deletions src/SqlAst/SumReturnTypeExtension.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@

final class SumReturnTypeExtension implements QueryFunctionReturnTypeExtension
{
/**
* @var bool
*/
private $hasGroupBy;

public function __construct(bool $hasGroupBy)
{
$this->hasGroupBy = $hasGroupBy;
}

public function isFunctionSupported(FunctionCall $expression): bool
{
return \in_array($expression->getFunction()->getName(), [BuiltInFunction::SUM], true);
Expand Down Expand Up @@ -43,10 +53,9 @@ public function getReturnType(FunctionCall $expression, QueryScope $scope): ?Typ
$result = $argType;
}

if ($containsNull) {
if ($containsNull || ! $this->hasGroupBy) {
$result = TypeCombinator::addNull($result);
}

return $result;
}
}

0 comments on commit e01951c

Please sign in to comment.