Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

New engine: add support for sync and async generator tasks #13138

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e5e7a94
Ensure async utils copies context
jlowin Apr 24, 2024
44c7c92
add run_task_sync
jlowin Apr 24, 2024
59c6114
add run_flow_sync
jlowin Apr 24, 2024
f47099b
apply new engine paths when calling flows/tasks
jlowin Apr 24, 2024
a53a7e9
tests
jlowin Apr 24, 2024
cc1920f
tests
jlowin Apr 25, 2024
e3209fd
add timeout tests
jlowin Apr 25, 2024
985c68c
Fix timeout tests
jlowin Apr 25, 2024
f8b8fa4
Add engine flag
jlowin Apr 25, 2024
16c5d0a
Fix subflow handling
jlowin Apr 25, 2024
e2cd525
Add contextvar test
jlowin Apr 25, 2024
ea79c99
Clean up unused method
jlowin Apr 25, 2024
dfdc2f6
Remove unused attribute
jlowin Apr 26, 2024
ea499b0
Improve parent tracking
jlowin Apr 26, 2024
362e3ed
Update test_new_flow_engine.py
jlowin Apr 26, 2024
f555af4
Improve run_sync
jlowin Apr 26, 2024
048bc6e
Merge branch 'main' into task-run-parent-tracking
jlowin Apr 26, 2024
5303104
Use cancel scopes for timeouts, remove int restriction
jlowin Apr 26, 2024
31095ba
Merge branch 'main' into sync-engine
jlowin Apr 26, 2024
4d25253
Wrap cancelscopes to produce timeouterror
jlowin Apr 26, 2024
af43932
Merge branch 'main' into sync-engine
jlowin Apr 26, 2024
4f86c3c
Merge branch 'main' into sync-engine
jlowin Apr 26, 2024
f6e52cf
Fix typing for 3.8
jlowin Apr 26, 2024
4d4300a
Track a single parent, not a list
jlowin Apr 26, 2024
8ed5f3b
Create tasks using task engine logic
jlowin Apr 26, 2024
58a7191
Yet more 3.8 typing
jlowin Apr 26, 2024
6dcdb76
Restore list of parents and add tests for subflow tracking
jlowin Apr 27, 2024
a54ef54
Merge branch 'task-run-parent-tracking' into generators
jlowin Apr 27, 2024
0f78a0c
Improve parent tracking
jlowin Apr 27, 2024
8a5b016
Merge branch 'task-run-parent-tracking' into generators
jlowin Apr 27, 2024
1d0d1d0
Add support for sync and async task generators
jlowin Apr 27, 2024
0a58633
Clean up handler names for clarity
jlowin Apr 27, 2024
ba2a11d
Merge upstream
jlowin Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/prefect/client/schemas/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,9 @@ class TaskRun(ObjectBaseModel):
task_inputs: Dict[str, List[Union[TaskRunResult, Parameter, Constant]]] = Field(
default_factory=dict,
description=(
"Tracks the source of inputs to a task run. Used for internal bookkeeping."
"Tracks the source of inputs to a task run. Used for internal bookkeeping. "
"Note the special __parents__ key, used to indicate a parent/child "
"relationship that may or may not include an input or wait_for semantic."
),
)
state_type: Optional[StateType] = Field(
Expand Down
9 changes: 2 additions & 7 deletions src/prefect/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,19 +1229,14 @@ def __call__(
return track_viz_task(self.isasync, self.name, parameters)

if PREFECT_EXPERIMENTAL_ENABLE_NEW_ENGINE.value():
from prefect.new_flow_engine import run_flow, run_flow_sync
from prefect.new_flow_engine import run_flow

run_kwargs = dict(
return run_flow(
flow=self,
parameters=parameters,
wait_for=wait_for,
return_type=return_type,
)
if self.isasync:
# this returns an awaitable coroutine
return run_flow(**run_kwargs)
else:
return run_flow_sync(**run_kwargs)

return enter_flow_run_engine_from_flow_call(
self,
Expand Down
54 changes: 29 additions & 25 deletions src/prefect/new_flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from prefect.context import FlowRunContext
from prefect.futures import PrefectFuture, resolve_futures_to_states
from prefect.logging.loggers import flow_run_logger
from prefect.new_task_engine import TaskRunEngine
from prefect.results import ResultFactory
from prefect.states import (
Pending,
Expand All @@ -38,9 +39,7 @@
)
from prefect.utilities.asyncutils import A, Async, run_sync
from prefect.utilities.engine import (
_dynamic_key_for_task_run,
_resolve_custom_flow_run_name,
collect_task_run_inputs,
propose_state,
)

Expand Down Expand Up @@ -167,9 +166,7 @@ async def load_subflow_run(
if flow_runs:
return flow_runs[-1]

async def create_subflow_task_run(
self, client: PrefectClient, context: FlowRunContext
) -> TaskRun:
async def create_subflow_task_run(self, client: PrefectClient) -> TaskRun:
"""
Adds a task to a parent flow run that represents the execution of a subflow run.

Expand All @@ -179,19 +176,8 @@ async def create_subflow_task_run(
dummy_task = Task(
name=self.flow.name, fn=self.flow.fn, version=self.flow.version
)
task_inputs = {
k: await collect_task_run_inputs(v) for k, v in self.parameters.items()
}
parent_task_run = await client.create_task_run(
task=dummy_task,
flow_run_id=(
context.flow_run.id if getattr(context, "flow_run", None) else None
),
dynamic_key=_dynamic_key_for_task_run(context, dummy_task),
task_inputs=task_inputs,
state=Pending(),
)
return parent_task_run
task_engine = TaskRunEngine(task=dummy_task, parameters=self.parameters)
return await task_engine.create_task_run(client)

async def create_flow_run(self, client: PrefectClient) -> FlowRun:
flow_run_ctx = FlowRunContext.get()
Expand All @@ -201,9 +187,7 @@ async def create_flow_run(self, client: PrefectClient) -> FlowRun:
# this is a subflow run
if flow_run_ctx:
# get the parent task run
parent_task_run = await self.create_subflow_task_run(
client=client, context=flow_run_ctx
)
parent_task_run = await self.create_subflow_task_run(client=client)

# check if there is already a flow run for this subflow
if subflow_run := await self.load_subflow_run(
Expand Down Expand Up @@ -349,13 +333,13 @@ def is_pending(self) -> bool:
return getattr(self, "flow_run").state.is_pending()


async def run_flow(
async def run_async_flow(
flow: Task[P, Coroutine[Any, Any, R]],
flow_run: Optional[FlowRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
return_type: Literal["state", "result"] = "result",
) -> Union[R, None]:
) -> Union[R, State, None]:
"""
Runs a flow against the API.

Expand Down Expand Up @@ -384,13 +368,13 @@ async def run_flow(
return await run.result()


def run_flow_sync(
def run_sync_flow(
flow: Task[P, Coroutine[Any, Any, R]],
flow_run: Optional[FlowRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
return_type: Literal["state", "result"] = "result",
) -> Union[R, None]:
) -> Union[R, State, None]:
engine = FlowRunEngine[P, R](flow, parameters, flow_run)
# This is a context manager that keeps track of the state of the flow run.
with engine.start_sync() as run:
Expand All @@ -411,3 +395,23 @@ def run_flow_sync(
if return_type == "state":
return run.state
return run_sync(run.result())


def run_flow(
flow: Task[P, Coroutine[Any, Any, R]],
flow_run: Optional[FlowRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
return_type: Literal["state", "result"] = "result",
) -> Union[R, State, None]:
kwargs = dict(
flow=flow,
flow_run=flow_run,
parameters=parameters,
wait_for=wait_for,
return_type=return_type,
)
if inspect.iscoroutinefunction(flow.fn):
return run_async_flow(**kwargs)
else:
return run_sync_flow(**kwargs)
187 changes: 180 additions & 7 deletions src/prefect/new_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
from dataclasses import dataclass, field
from typing import (
Any,
AsyncGenerator,
Callable,
Coroutine,
Dict,
Generator,
Generic,
Iterable,
List,
Literal,
Optional,
TypeVar,
Expand Down Expand Up @@ -51,6 +54,7 @@
_get_hook_name,
_resolve_custom_task_run_name,
collect_task_run_inputs,
link_state_to_result,
propose_state,
)

Expand Down Expand Up @@ -250,22 +254,72 @@ async def handle_crash(self, exc: BaseException) -> None:
self.logger.debug("Crash details:", exc_info=exc)
await self.set_state(state, force=True)

def infer_parent_task_runs(
self, flow_run_ctx: FlowRunContext
) -> List[TaskRunResult]:
"""
Infer parent task runs.

1. Check if this task is running inside an existing task run context in
the same flow run. If so, this is a nested task and the existing task
run is a parent.
2. Check if any of the inputs to this task correspond to states that are
still running. If so, consider those task runs as parents. Note this
under normal circumstances this probably only applies to values that
were yielded from generator tasks.
"""
parents = []

# check if this task has a parent task run based on running in another
# task run's existing context. A task run is only considered a parent if
# it is in the same flow run (because otherwise presumably the child is
# in a subflow, so the subflow serves as the parent) or if there is no
# flow run
task_run_ctx = TaskRunContext.get()
if task_run_ctx:
# there is no flow run
if not flow_run_ctx:
parents.append(TaskRunResult(id=task_run_ctx.task_run.id))
# there is a flow run and the task run is in the same flow run
elif (
flow_run_ctx
and task_run_ctx.task_run.flow_run_id == flow_run_ctx.flow_run.id
):
parents.append(TaskRunResult(id=task_run_ctx.task_run.id))

# parent dependency tracking: for every provided parameter value, try to
# load the corresponding task run state. If the task run state is still
# running, we consider it a parent task run. Note this is only done if
# there is an active flow run context because dependencies are only
# tracked within the same flow run.
if flow_run_ctx:
for v in self.parameters.values():
if isinstance(v, State):
upstream_state = v
else:
upstream_state = flow_run_ctx.task_run_results.get(id(v))
if upstream_state and upstream_state.is_running():
parents.append(
TaskRunResult(id=upstream_state.state_details.task_run_id)
)

return parents

async def create_task_run(self, client: PrefectClient) -> TaskRun:
flow_run_ctx = FlowRunContext.get()
try:
task_run_name = _resolve_custom_task_run_name(self.task, self.parameters)
except TypeError:
task_run_name = None

# prep input tracking
# upstream dependency tracking: for every provided parameter value, try
# to load the corresponding task run state
task_inputs = {
k: await collect_task_run_inputs(v) for k, v in self.parameters.items()
}

# anticipate nested runs
task_run_ctx = TaskRunContext.get()
if task_run_ctx:
task_inputs["wait_for"] = [TaskRunResult(id=task_run_ctx.task_run.id)]
if task_parents := self.infer_parent_task_runs(flow_run_ctx):
task_inputs["__parents__"] = task_parents

# TODO: implement wait_for
# if wait_for:
Expand Down Expand Up @@ -393,7 +447,7 @@ def is_pending(self) -> bool:
return getattr(self, "task_run").state.is_pending()


async def run_task(
async def run_async_task(
task: Task[P, Coroutine[Any, Any, R]],
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -430,7 +484,7 @@ async def run_task(
return await run.result()


def run_task_sync(
def run_sync_task(
task: Task[P, Coroutine[Any, Any, R]],
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -460,3 +514,122 @@ def run_task_sync(
if return_type == "state":
return run.state
return run_sync(run.result())


async def run_async_generator_task(
task: Task[P, Coroutine[Any, Any, R]],
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
return_type: Literal["state", "result"] = "result",
) -> AsyncGenerator[Union[R, State, None], None]:
if return_type != "result":
raise ValueError("Return type can not be specified for async generators.")
engine = TaskRunEngine[P, R](task=task, parameters=parameters, task_run=task_run)
# This is a context manager that keeps track of the run of the task run.
async with engine.start() as run:
await run.begin_run()

while run.is_running():
async with run.enter_run_context():
try:
# This is where the task is actually run.
async with timeout(run.task.timeout_seconds):
try:
gen = task.fn(**(parameters or {}))
while True:
# can't use anext < 3.10
gen_result = await gen.__anext__()
# link the current state to the result for dependency tracking
#
# TODO: this could grow the task_run_result
# dictionary in an unbounded way, so finding a
# way to periodically clean it up (using
# weakrefs or similar) would be good
link_state_to_result(run.state, gen_result)
yield gen_result

except (StopAsyncIteration, GeneratorExit):
await run.handle_success(None)

except Exception as exc:
await run.handle_exception(exc)

# call this to raise exceptions after retries are exhausted
await run.result()


def run_sync_generator_task(
task: Task[P, Coroutine[Any, Any, R]],
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
return_type: Literal["state", "result"] = "result",
) -> Generator[Union[R, State, None], None, None]:
engine = TaskRunEngine[P, R](task=task, parameters=parameters, task_run=task_run)
# This is a context manager that keeps track of the run of the task run.
with engine.start_sync() as run:
run_sync(run.begin_run())

while run.is_running():
with run.enter_run_context_sync():
try:
# This is where the task is actually run.
with timeout_sync(run.task.timeout_seconds):
try:
gen = task.fn(**(parameters or {}))
# yield in a while loop instead of `yield from` to
# handle StopIteration
while True:
gen_result = next(gen)
# link the current state to the result for dependency tracking
#
# TODO: this could grow the task_run_result
# dictionary in an unbounded way, so finding a
# way to periodically clean it up (using
# weakrefs or similar) would be good
link_state_to_result(run.state, gen_result)
yield gen_result

except (StopIteration, GeneratorExit) as exc:
if isinstance(exc, StopIteration):
result = exc.value
else:
result = None
run_sync(run.handle_success(result))
if return_type == "result":
return result

except Exception as exc:
run_sync(run.handle_exception(exc))

if return_type == "state":
return run.state
return run_sync(run.result())


def run_task(
task: Task[P, Coroutine[Any, Any, R]],
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
return_type: Literal["state", "result"] = "result",
) -> Union[R, State, None]:
"""
Runs a task by choosing the appropriate handler based on the task's function type.
"""
kwargs = dict(
task=task,
task_run=task_run,
parameters=parameters,
wait_for=wait_for,
return_type=return_type,
)
if inspect.isasyncgenfunction(task.fn):
return run_async_generator_task(**kwargs)
elif inspect.isgeneratorfunction(task.fn):
return run_sync_generator_task(**kwargs)
elif inspect.iscoroutinefunction(task.fn):
return run_async_task(**kwargs)
else:
return run_sync_task(**kwargs)