Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contextual type inference for high order function arg #7417

Merged
merged 13 commits into from Jan 20, 2022
Expand Up @@ -19,6 +19,7 @@
use Psalm\Internal\DataFlow\TaintSink;
use Psalm\Internal\MethodIdentifier;
use Psalm\Internal\Stubs\Generator\StubsGenerator;
use Psalm\Internal\Type\Comparator\CallableTypeComparator;
use Psalm\Internal\Type\Comparator\UnionTypeComparator;
use Psalm\Internal\Type\TemplateInferredTypeReplacer;
use Psalm\Internal\Type\TemplateResult;
Expand Down Expand Up @@ -196,7 +197,21 @@ public static function analyze(
$toggled_class_exists = true;
}

if (($arg->value instanceof PhpParser\Node\Expr\Closure
$high_order_template_result = null;

if (($arg->value instanceof PhpParser\Node\Expr\FuncCall
|| $arg->value instanceof PhpParser\Node\Expr\MethodCall
|| $arg->value instanceof PhpParser\Node\Expr\StaticCall)
&& $param
&& $function_storage = self::getHighOrderFuncStorage($context, $statements_analyzer, $arg->value)
) {
$high_order_template_result = self::handleHighOrderFuncCallArg(
$statements_analyzer,
$template_result ?? new TemplateResult([], []),
$function_storage,
$param
);
} elseif (($arg->value instanceof PhpParser\Node\Expr\Closure
|| $arg->value instanceof PhpParser\Node\Expr\ArrowFunction)
&& $param
&& !$arg->value->getDocComment()
Expand All @@ -217,7 +232,15 @@ public static function analyze(

$context->inside_call = true;

if (ExpressionAnalyzer::analyze($statements_analyzer, $arg->value, $context) === false) {
if (ExpressionAnalyzer::analyze(
$statements_analyzer,
$arg->value,
$context,
false,
null,
false,
$high_order_template_result
) === false) {
$context->inside_call = $was_inside_call;

return false;
Expand Down Expand Up @@ -315,6 +338,172 @@ private static function handleArrayMapFilterArrayArg(
}
}

private static function getHighOrderFuncStorage(
Context $context,
StatementsAnalyzer $statements_analyzer,
PhpParser\Node\Expr\CallLike $function_like_call
): ?FunctionLikeStorage {
$codebase = $statements_analyzer->getCodebase();

try {
if ($function_like_call instanceof PhpParser\Node\Expr\FuncCall) {
$function_id = strtolower((string) $function_like_call->name->getAttribute('resolvedName'));

if (empty($function_id)) {
return null;
}

return $codebase->functions->getStorage($statements_analyzer, $function_id);
}

if ($function_like_call instanceof PhpParser\Node\Expr\MethodCall &&
$function_like_call->var instanceof PhpParser\Node\Expr\Variable &&
$function_like_call->name instanceof PhpParser\Node\Identifier &&
is_string($function_like_call->var->name) &&
isset($context->vars_in_scope['$' . $function_like_call->var->name])
) {
$lhs_type = $context->vars_in_scope['$' . $function_like_call->var->name]->getSingleAtomic();

if (!$lhs_type instanceof Type\Atomic\TNamedObject) {
return null;
}

$method_id = new MethodIdentifier(
$lhs_type->value,
strtolower((string)$function_like_call->name)
);

return $codebase->methods->getStorage($method_id);
}

if ($function_like_call instanceof PhpParser\Node\Expr\StaticCall &&
$function_like_call->name instanceof PhpParser\Node\Identifier
) {
$method_id = new MethodIdentifier(
(string)$function_like_call->class->getAttribute('resolvedName'),
strtolower($function_like_call->name->name)
);

return $codebase->methods->getStorage($method_id);
}
} catch (UnexpectedValueException $e) {
return null;
}

return null;
}

/**
* Compiles TemplateResult for high-order functions ($func_call)
* by previous template args ($inferred_template_result).
*
* It's need for proper template replacement:
*
* ```
* * template T
* * return Closure(T): T
* function id(): Closure { ... }
*
* * template A
* * template B
* *
* * param list<A> $_items
* * param callable(A): B $_ab
* * return list<B>
* function map(array $items, callable $ab): array { ... }
*
* // list<int>
* $numbers = [1, 2, 3];
*
* $result = map($numbers, id());
* // $result is list<int> because template T of id() was inferred by previous arg.
* ```
*/
private static function handleHighOrderFuncCallArg(
StatementsAnalyzer $statements_analyzer,
TemplateResult $inferred_template_result,
FunctionLikeStorage $storage,
FunctionLikeParameter $actual_func_param
): ?TemplateResult {
$codebase = $statements_analyzer->getCodebase();

$input_hof_atomic = $storage->return_type && $storage->return_type->isSingle()
? $storage->return_type->getSingleAtomic()
: null;

// Try upcast invokable to callable type.
if ($input_hof_atomic instanceof Type\Atomic\TNamedObject &&
$input_hof_atomic->value !== 'Closure' &&
$codebase->classExists($input_hof_atomic->value)
) {
$callable_from_invokable = CallableTypeComparator::getCallableFromAtomic(
$codebase,
$input_hof_atomic
);

if ($callable_from_invokable) {
$invoke_id = new MethodIdentifier($input_hof_atomic->value, '__invoke');
$declaring_invoke_id = $codebase->methods->getDeclaringMethodId($invoke_id);

$storage = $codebase->methods->getStorage($declaring_invoke_id ?? $invoke_id);
$input_hof_atomic = $callable_from_invokable;
}
}

if (!$input_hof_atomic instanceof TClosure && !$input_hof_atomic instanceof TCallable) {
return null;
}

$container_hof_atomic = $actual_func_param->type && $actual_func_param->type->isSingle()
? $actual_func_param->type->getSingleAtomic()
: null;

if (!$container_hof_atomic instanceof TClosure && !$container_hof_atomic instanceof TCallable) {
return null;
}

$replaced_container_hof_atomic = new Union([clone $container_hof_atomic]);

// Replaces all input args in container function.
//
// For example:
// The map function expects callable(A):B as second param
// We know that previous arg type is list<int> where the int is the A template.
// Then we can replace callable(A): B to callable(int):B using $inferred_template_result.
TemplateInferredTypeReplacer::replace(
$replaced_container_hof_atomic,
$inferred_template_result,
$codebase
);

/** @var TClosure|TCallable $container_hof_atomic */
$container_hof_atomic = $replaced_container_hof_atomic->getSingleAtomic();
$high_order_template_result = new TemplateResult($storage->template_types ?: [], []);

// We can replace each templated param for the input function.
// Example:
// map($numbers, id());
// We know that map expects callable(int):B because the $numbers is list<int>.
// We know that id() returns callable(T):T.
// Then we can replace templated params sequentially using the expected callable(int):B.
foreach ($input_hof_atomic->params ?? [] as $offset => $actual_func_param) {
if ($actual_func_param->type &&
$actual_func_param->type->getTemplateTypes() &&
isset($container_hof_atomic->params[$offset])
) {
TemplateStandinTypeReplacer::replace(
clone $actual_func_param->type,
$high_order_template_result,
$codebase,
null,
$container_hof_atomic->params[$offset]->type
);
}
}

return $high_order_template_result;
}

/**
* @param array<int, PhpParser\Node\Arg> $args
*/
Expand Down
Expand Up @@ -82,7 +82,8 @@ class FunctionCallAnalyzer extends CallAnalyzer
public static function analyze(
StatementsAnalyzer $statements_analyzer,
PhpParser\Node\Expr\FuncCall $stmt,
Context $context
Context $context,
?TemplateResult $template_result = null
): bool {
$function_name = $stmt->name;

Expand Down Expand Up @@ -166,10 +167,12 @@ public static function analyze(
}

if (!$is_first_class_callable) {
$template_result = null;

if (isset($function_call_info->function_storage->template_types)) {
$template_result = new TemplateResult($function_call_info->function_storage->template_types ?: [], []);
if (!$template_result) {
$template_result = new TemplateResult([], []);
}

$template_result->template_types += $function_call_info->function_storage->template_types ?: [];
}

ArgumentsAnalyzer::analyze(
Expand Down Expand Up @@ -205,6 +208,10 @@ public static function analyze(
}
}

$already_inferred_lower_bounds = $template_result
? $template_result->lower_bounds
: [];

$template_result = new TemplateResult([], []);

// do this here to allow closure param checks
Expand All @@ -229,6 +236,8 @@ public static function analyze(
$function_call_info->function_id
);

$template_result->lower_bounds += $already_inferred_lower_bounds;

if ($function_name instanceof PhpParser\Node\Name && $function_call_info->function_id) {
$stmt_type = FunctionCallReturnTypeFetcher::fetch(
$statements_analyzer,
Expand Down
Expand Up @@ -74,7 +74,8 @@ public static function analyze(
?Atomic $static_type,
bool $is_intersection,
?string $lhs_var_id,
AtomicMethodCallAnalysisResult $result
AtomicMethodCallAnalysisResult $result,
?TemplateResult $inferred_template_result = null
): void {
if ($lhs_type_part instanceof TTemplateParam
&& !$lhs_type_part->as->isMixed()
Expand Down Expand Up @@ -438,7 +439,8 @@ public static function analyze(
$static_type,
$lhs_var_id,
$method_id,
$result
$result,
$inferred_template_result
);

$statements_analyzer->node_data = $old_node_data;
Expand Down
Expand Up @@ -65,7 +65,8 @@ public static function analyze(
?Atomic $static_type,
?string $lhs_var_id,
MethodIdentifier $method_id,
AtomicMethodCallAnalysisResult $result
AtomicMethodCallAnalysisResult $result,
?TemplateResult $inferred_template_result = null
): Union {
$config = $codebase->config;

Expand Down Expand Up @@ -217,6 +218,10 @@ public static function analyze(
$template_result = new TemplateResult([], $class_template_params ?: []);
$template_result->lower_bounds += $method_template_params;

if ($inferred_template_result) {
$template_result->lower_bounds += $inferred_template_result->lower_bounds;
}

if ($codebase->store_node_types
&& !$context->collect_initializations
&& !$context->collect_mutations
Expand Down
Expand Up @@ -11,6 +11,7 @@
use Psalm\Internal\Analyzer\Statements\Expression\ExpressionIdentifier;
use Psalm\Internal\Analyzer\Statements\ExpressionAnalyzer;
use Psalm\Internal\Analyzer\StatementsAnalyzer;
use Psalm\Internal\Type\TemplateResult;
use Psalm\Issue\InvalidMethodCall;
use Psalm\Issue\InvalidScope;
use Psalm\Issue\NullReference;
Expand Down Expand Up @@ -43,7 +44,8 @@ public static function analyze(
StatementsAnalyzer $statements_analyzer,
PhpParser\Node\Expr\MethodCall $stmt,
Context $context,
bool $real_method_call = true
bool $real_method_call = true,
?TemplateResult $template_result = null
): bool {
$was_inside_call = $context->inside_call;

Expand Down Expand Up @@ -194,7 +196,8 @@ public static function analyze(
: null,
false,
$lhs_var_id,
$result
$result,
$template_result
);
if (isset($context->vars_in_scope[$lhs_var_id])
&& ($possible_new_class_type = $context->vars_in_scope[$lhs_var_id]) instanceof Union
Expand Down
Expand Up @@ -41,7 +41,8 @@ class StaticCallAnalyzer extends CallAnalyzer
public static function analyze(
StatementsAnalyzer $statements_analyzer,
PhpParser\Node\Expr\StaticCall $stmt,
Context $context
Context $context,
?TemplateResult $template_result = null
): bool {
$method_id = null;

Expand Down Expand Up @@ -219,7 +220,8 @@ public static function analyze(
$lhs_type->ignore_nullable_issues,
$moved_call,
$has_mock,
$has_existing_method
$has_existing_method,
$template_result
);
}

Expand Down