Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to clear downstream tis in "List Task Instances" view #34529

Merged
merged 8 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
105 changes: 96 additions & 9 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5657,6 +5657,7 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
"list": "read",
"delete": "delete",
"action_clear": "edit",
"action_clear_downstream": "edit",
"action_muldelete": "delete",
"action_set_running": "edit",
"action_set_failed": "edit",
Expand Down Expand Up @@ -5793,6 +5794,68 @@ def duration_f(self):
"duration": duration_f,
}

def _clear_task_instances(
self, task_instances: list[TaskInstance], session: Session, clear_downstream: bool = False
) -> tuple[int, int]:
"""
Clears task instances, optionally including their downstream dependencies.

:param task_instances: list of TIs to clear
:param clear_downstream: should downstream task instances be cleared as well?

:return: a tuple with:
- count of cleared task instances actually selected by the user
- count of downstream task instances that were additionally cleared
"""
cleared_tis_count = 0
cleared_downstream_tis_count = 0

# Group TIs by dag id in order to call `get_dag` only once per dag
tis_grouped_by_dag_id = itertools.groupby(task_instances, lambda ti: ti.dag_id)

for dag_id, dag_tis in tis_grouped_by_dag_id:
dag = get_airflow_app().dag_bag.get_dag(dag_id)

tis_to_clear = list(dag_tis)
downstream_tis_to_clear = []

if clear_downstream:
tis_to_clear_grouped_by_dag_run = itertools.groupby(tis_to_clear, lambda ti: ti.dag_run)

for dag_run, dag_run_tis in tis_to_clear_grouped_by_dag_run:
# Determine tasks that are downstream of the cleared TIs and fetch associated TIs
# This has to be run for each dag run because the user may clear different TIs across runs
task_ids_to_clear = [ti.task_id for ti in dag_run_tis]

partial_dag = dag.partial_subset(
task_ids_or_regex=task_ids_to_clear, include_downstream=True, include_upstream=False
)

downstream_task_ids_to_clear = [
task_id for task_id in partial_dag.task_dict if task_id not in task_ids_to_clear
]

# dag.clear returns TIs when in dry run mode
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

downstream_tis_to_clear.extend(
dag.clear(
start_date=dag_run.execution_date,
end_date=dag_run.execution_date,
task_ids=downstream_task_ids_to_clear,
include_subdags=False,
include_parentdag=False,
session=session,
dry_run=True,
)
)

# Once all TIs are fetched, perform the actual clearing
models.clear_task_instances(tis=tis_to_clear + downstream_tis_to_clear, session=session, dag=dag)

cleared_tis_count += len(tis_to_clear)
cleared_downstream_tis_count += len(downstream_tis_to_clear)

return cleared_tis_count, cleared_downstream_tis_count

@action(
"clear",
lazy_gettext("Clear"),
Expand All @@ -5806,21 +5869,45 @@ def duration_f(self):
@provide_session
@action_logging
def action_clear(self, task_instances, session: Session = NEW_SESSION):
"""Clears the action."""
"""Clears an arbitrary number of task instances."""
try:
dag_to_tis = collections.defaultdict(list)

for ti in task_instances:
dag = get_airflow_app().dag_bag.get_dag(ti.dag_id)
dag_to_tis[dag].append(ti)
count, _ = self._clear_task_instances(
task_instances=task_instances, session=session, clear_downstream=False
)
session.commit()
flash(f"{count} task instance{'s have' if count > 1 else ' has'} been cleared")
except Exception as e:
flash(f'Failed to clear task instances: "{e}"', "error")

for dag, task_instances_list in dag_to_tis.items():
models.clear_task_instances(task_instances_list, session, dag=dag)
self.update_redirect()
return redirect(self.get_redirect())

@action(
"clear_downstream",
lazy_gettext("Clear (including downstream tasks)"),
lazy_gettext(
"Are you sure you want to clear the state of the selected task"
" instance(s) and all their downstream dependencies, and set their dagruns to the QUEUED state?"
),
single=False,
)
@action_has_dag_edit_access
@provide_session
@action_logging
def action_clear_downstream(self, task_instances, session: Session = NEW_SESSION):
"""Clears an arbitrary number of task instances, including downstream dependencies."""
try:
selected_ti_count, downstream_ti_count = self._clear_task_instances(
task_instances=task_instances, session=session, clear_downstream=True
)
session.commit()
flash(f"{len(task_instances)} task instances have been cleared")
flash(
f"Cleared {selected_ti_count} selected task instance{'s' if selected_ti_count > 1 else ''} "
f"and {downstream_ti_count} downstream dependencies"
)
except Exception as e:
flash(f'Failed to clear task instances: "{e}"', "error")

self.update_redirect()
return redirect(self.get_redirect())

Expand Down
61 changes: 60 additions & 1 deletion tests/www/views/test_views_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import urllib.parse
from getpass import getuser

import pendulum
import pytest
import time_machine

Expand All @@ -32,12 +33,13 @@
from airflow.models import DAG, DagBag, DagModel, TaskFail, TaskInstance, TaskReschedule, XCom
from airflow.models.dagcode import DagCode
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.providers.celery.executors.celery_executor import CeleryExecutor
from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import ExternalLoggingMixin
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.state import DagRunState, State
from airflow.utils.types import DagRunType
from airflow.www.views import TaskInstanceModelView
from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user
Expand Down Expand Up @@ -857,6 +859,63 @@ def test_task_instance_clear(session, request, client_fixture, should_succeed):
assert state == (State.NONE if should_succeed else initial_state)


def test_task_instance_clear_downstream(session, admin_client, dag_maker):
"""Ensures clearing a task instance clears its downstream dependencies exclusively"""
with dag_maker(
dag_id="test_dag_id",
serialized=True,
session=session,
start_date=pendulum.DateTime(2023, 1, 1, 0, 0, 0, tzinfo=pendulum.UTC),
):
EmptyOperator(task_id="task_1") >> EmptyOperator(task_id="task_2")
EmptyOperator(task_id="task_3")

run1 = dag_maker.create_dagrun(
run_id="run_1",
state=DagRunState.SUCCESS,
run_type=DagRunType.SCHEDULED,
execution_date=dag_maker.dag.start_date,
start_date=dag_maker.dag.start_date,
session=session,
)

run2 = dag_maker.create_dagrun(
run_id="run_2",
state=DagRunState.SUCCESS,
run_type=DagRunType.SCHEDULED,
execution_date=dag_maker.dag.start_date.add(days=1),
start_date=dag_maker.dag.start_date.add(days=1),
session=session,
)

for run in (run1, run2):
for ti in run.task_instances:
ti.state = State.SUCCESS

# Clear task_1 from dag run 1
run1_ti1 = run1.get_task_instance(task_id="task_1")
rowid = _get_appbuilder_pk_string(TaskInstanceModelView, run1_ti1)
resp = admin_client.post(
"/taskinstance/action_post",
data={"action": "clear_downstream", "rowid": rowid},
follow_redirects=True,
)
assert resp.status_code == 200

# Assert that task_1 and task_2 of dag run 1 are cleared, but task_3 is left untouched
run1_ti1.refresh_from_db(session=session)
run1_ti2 = run1.get_task_instance(task_id="task_2")
run1_ti3 = run1.get_task_instance(task_id="task_3")

assert run1_ti1.state == State.NONE
assert run1_ti2.state == State.NONE
assert run1_ti3.state == State.SUCCESS

# Assert that task_1 of dag run 2 is left untouched
run2_ti1 = run2.get_task_instance(task_id="task_1")
assert run2_ti1.state == State.SUCCESS


def test_task_instance_clear_failure(admin_client):
rowid = '["12345"]' # F.A.B. crashes if the rowid is *too* invalid.
resp = admin_client.post(
Expand Down