Skip to content

Commit

Permalink
Add ability to clear downstream tis in "List Task Instances" view (#3…
Browse files Browse the repository at this point in the history
…4529)

* Add "clear including downstream" action in task instance view

* Extract logic into helper + support dynamic tasks

* Add unit test

* Restore quick path for ti clear without downstream

* Fix wording

* Call clear_task_instances once per dag + split cleared ti count

* Handle plural

* Update airflow/www/views.py

---------

Co-authored-by: Jean-Eudes Peloye <jean-eudes.peloye@adevinta.com>
Co-authored-by: Hussein Awala <hussein@awala.fr>
  • Loading branch information
3 people committed Sep 22, 2023
1 parent 541c9ad commit 5b0ce3d
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 10 deletions.
105 changes: 96 additions & 9 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5657,6 +5657,7 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView):
"list": "read",
"delete": "delete",
"action_clear": "edit",
"action_clear_downstream": "edit",
"action_muldelete": "delete",
"action_set_running": "edit",
"action_set_failed": "edit",
Expand Down Expand Up @@ -5793,6 +5794,68 @@ def duration_f(self):
"duration": duration_f,
}

def _clear_task_instances(
self, task_instances: list[TaskInstance], session: Session, clear_downstream: bool = False
) -> tuple[int, int]:
"""
Clears task instances, optionally including their downstream dependencies.
:param task_instances: list of TIs to clear
:param clear_downstream: should downstream task instances be cleared as well?
:return: a tuple with:
- count of cleared task instances actually selected by the user
- count of downstream task instances that were additionally cleared
"""
cleared_tis_count = 0
cleared_downstream_tis_count = 0

# Group TIs by dag id in order to call `get_dag` only once per dag
tis_grouped_by_dag_id = itertools.groupby(task_instances, lambda ti: ti.dag_id)

for dag_id, dag_tis in tis_grouped_by_dag_id:
dag = get_airflow_app().dag_bag.get_dag(dag_id)

tis_to_clear = list(dag_tis)
downstream_tis_to_clear = []

if clear_downstream:
tis_to_clear_grouped_by_dag_run = itertools.groupby(tis_to_clear, lambda ti: ti.dag_run)

for dag_run, dag_run_tis in tis_to_clear_grouped_by_dag_run:
# Determine tasks that are downstream of the cleared TIs and fetch associated TIs
# This has to be run for each dag run because the user may clear different TIs across runs
task_ids_to_clear = [ti.task_id for ti in dag_run_tis]

partial_dag = dag.partial_subset(
task_ids_or_regex=task_ids_to_clear, include_downstream=True, include_upstream=False
)

downstream_task_ids_to_clear = [
task_id for task_id in partial_dag.task_dict if task_id not in task_ids_to_clear
]

# dag.clear returns TIs when in dry run mode
downstream_tis_to_clear.extend(
dag.clear(
start_date=dag_run.execution_date,
end_date=dag_run.execution_date,
task_ids=downstream_task_ids_to_clear,
include_subdags=False,
include_parentdag=False,
session=session,
dry_run=True,
)
)

# Once all TIs are fetched, perform the actual clearing
models.clear_task_instances(tis=tis_to_clear + downstream_tis_to_clear, session=session, dag=dag)

cleared_tis_count += len(tis_to_clear)
cleared_downstream_tis_count += len(downstream_tis_to_clear)

return cleared_tis_count, cleared_downstream_tis_count

@action(
"clear",
lazy_gettext("Clear"),
Expand All @@ -5806,21 +5869,45 @@ def duration_f(self):
@provide_session
@action_logging
def action_clear(self, task_instances, session: Session = NEW_SESSION):
"""Clears the action."""
"""Clears an arbitrary number of task instances."""
try:
dag_to_tis = collections.defaultdict(list)

for ti in task_instances:
dag = get_airflow_app().dag_bag.get_dag(ti.dag_id)
dag_to_tis[dag].append(ti)
count, _ = self._clear_task_instances(
task_instances=task_instances, session=session, clear_downstream=False
)
session.commit()
flash(f"{count} task instance{'s have' if count > 1 else ' has'} been cleared")
except Exception as e:
flash(f'Failed to clear task instances: "{e}"', "error")

for dag, task_instances_list in dag_to_tis.items():
models.clear_task_instances(task_instances_list, session, dag=dag)
self.update_redirect()
return redirect(self.get_redirect())

@action(
"clear_downstream",
lazy_gettext("Clear (including downstream tasks)"),
lazy_gettext(
"Are you sure you want to clear the state of the selected task"
" instance(s) and all their downstream dependencies, and set their dagruns to the QUEUED state?"
),
single=False,
)
@action_has_dag_edit_access
@provide_session
@action_logging
def action_clear_downstream(self, task_instances, session: Session = NEW_SESSION):
"""Clears an arbitrary number of task instances, including downstream dependencies."""
try:
selected_ti_count, downstream_ti_count = self._clear_task_instances(
task_instances=task_instances, session=session, clear_downstream=True
)
session.commit()
flash(f"{len(task_instances)} task instances have been cleared")
flash(
f"Cleared {selected_ti_count} selected task instance{'s' if selected_ti_count > 1 else ''} "
f"and {downstream_ti_count} downstream dependencies"
)
except Exception as e:
flash(f'Failed to clear task instances: "{e}"', "error")

self.update_redirect()
return redirect(self.get_redirect())

Expand Down
61 changes: 60 additions & 1 deletion tests/www/views/test_views_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import urllib.parse
from getpass import getuser

import pendulum
import pytest
import time_machine

Expand All @@ -32,12 +33,13 @@
from airflow.models import DAG, DagBag, DagModel, TaskFail, TaskInstance, TaskReschedule, XCom
from airflow.models.dagcode import DagCode
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.providers.celery.executors.celery_executor import CeleryExecutor
from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import ExternalLoggingMixin
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.state import DagRunState, State
from airflow.utils.types import DagRunType
from airflow.www.views import TaskInstanceModelView
from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user
Expand Down Expand Up @@ -857,6 +859,63 @@ def test_task_instance_clear(session, request, client_fixture, should_succeed):
assert state == (State.NONE if should_succeed else initial_state)


def test_task_instance_clear_downstream(session, admin_client, dag_maker):
"""Ensures clearing a task instance clears its downstream dependencies exclusively"""
with dag_maker(
dag_id="test_dag_id",
serialized=True,
session=session,
start_date=pendulum.DateTime(2023, 1, 1, 0, 0, 0, tzinfo=pendulum.UTC),
):
EmptyOperator(task_id="task_1") >> EmptyOperator(task_id="task_2")
EmptyOperator(task_id="task_3")

run1 = dag_maker.create_dagrun(
run_id="run_1",
state=DagRunState.SUCCESS,
run_type=DagRunType.SCHEDULED,
execution_date=dag_maker.dag.start_date,
start_date=dag_maker.dag.start_date,
session=session,
)

run2 = dag_maker.create_dagrun(
run_id="run_2",
state=DagRunState.SUCCESS,
run_type=DagRunType.SCHEDULED,
execution_date=dag_maker.dag.start_date.add(days=1),
start_date=dag_maker.dag.start_date.add(days=1),
session=session,
)

for run in (run1, run2):
for ti in run.task_instances:
ti.state = State.SUCCESS

# Clear task_1 from dag run 1
run1_ti1 = run1.get_task_instance(task_id="task_1")
rowid = _get_appbuilder_pk_string(TaskInstanceModelView, run1_ti1)
resp = admin_client.post(
"/taskinstance/action_post",
data={"action": "clear_downstream", "rowid": rowid},
follow_redirects=True,
)
assert resp.status_code == 200

# Assert that task_1 and task_2 of dag run 1 are cleared, but task_3 is left untouched
run1_ti1.refresh_from_db(session=session)
run1_ti2 = run1.get_task_instance(task_id="task_2")
run1_ti3 = run1.get_task_instance(task_id="task_3")

assert run1_ti1.state == State.NONE
assert run1_ti2.state == State.NONE
assert run1_ti3.state == State.SUCCESS

# Assert that task_1 of dag run 2 is left untouched
run2_ti1 = run2.get_task_instance(task_id="task_1")
assert run2_ti1.state == State.SUCCESS


def test_task_instance_clear_failure(admin_client):
rowid = '["12345"]' # F.A.B. crashes if the rowid is *too* invalid.
resp = admin_client.post(
Expand Down

0 comments on commit 5b0ce3d

Please sign in to comment.