Skip to content

Commit

Permalink
Order triggers by - TI priority_weight when assign unassigned triggers (
Browse files Browse the repository at this point in the history
apache#32318)

* Order triggers by - TI priority_weight when assign unassigned triggers

Signed-off-by: Hussein Awala <hussein@awala.fr>

* Update airflow/models/trigger.py

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>

* Replace outer join by inner join and use coalesce to handle None values

* fix unit tests

---------

Signed-off-by: Hussein Awala <hussein@awala.fr>
Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
Co-authored-by: eladkal <45845474+eladkal@users.noreply.github.com>
  • Loading branch information
3 people authored and ferruzzi committed Aug 17, 2023
1 parent 6e3d9e4 commit ae3fa6e
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 16 deletions.
4 changes: 3 additions & 1 deletion airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from sqlalchemy import Column, Integer, String, delete, func, or_, select, update
from sqlalchemy.orm import Session, joinedload, relationship
from sqlalchemy.sql.functions import coalesce

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.models.base import Base
Expand Down Expand Up @@ -244,8 +245,9 @@ def assign_unassigned(
def get_sorted_triggers(cls, capacity, alive_triggerer_ids, session):
query = with_row_locks(
select(cls.id)
.join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=False)
.where(or_(cls.triggerer_id.is_(None), cls.triggerer_id.not_in(alive_triggerer_ids)))
.order_by(cls.created_date)
.order_by(coalesce(TaskInstance.priority_weight, 0).desc(), cls.created_date)
.limit(capacity),
session,
skip_locked=True,
Expand Down
18 changes: 16 additions & 2 deletions tests/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def handle_events(self):
assert len(instances) == 1


def test_trigger_from_dead_triggerer(session):
def test_trigger_from_dead_triggerer(session, create_task_instance):
"""
Checks that the triggerer will correctly claim a Trigger that is assigned to a
triggerer that does not exist.
Expand All @@ -425,6 +425,13 @@ def test_trigger_from_dead_triggerer(session):
trigger_orm.id = 1
trigger_orm.triggerer_id = 999 # Non-existent triggerer
session.add(trigger_orm)
ti_orm = create_task_instance(
task_id="ti_orm",
execution_date=datetime.datetime.utcnow(),
run_id="orm_run_id",
)
ti_orm.trigger_id = trigger_orm.id
session.add(trigger_orm)
session.commit()
# Make a TriggererJobRunner and have it retrieve DB tasks
job = Job()
Expand All @@ -434,7 +441,7 @@ def test_trigger_from_dead_triggerer(session):
assert [x for x, y in job_runner.trigger_runner.to_create] == [1]


def test_trigger_from_expired_triggerer(session):
def test_trigger_from_expired_triggerer(session, create_task_instance):
"""
Checks that the triggerer will correctly claim a Trigger that is assigned to a
triggerer that has an expired heartbeat.
Expand All @@ -445,6 +452,13 @@ def test_trigger_from_expired_triggerer(session):
trigger_orm.id = 1
trigger_orm.triggerer_id = 42
session.add(trigger_orm)
ti_orm = create_task_instance(
task_id="ti_orm",
execution_date=datetime.datetime.utcnow(),
run_id="orm_run_id",
)
ti_orm.trigger_id = trigger_orm.id
session.add(trigger_orm)
# Use a TriggererJobRunner with an expired heartbeat
triggerer_job_orm = Job(TriggererJobRunner.job_type)
triggerer_job_orm.id = 42
Expand Down
124 changes: 111 additions & 13 deletions tests/models/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,47 @@ def test_assign_unassigned(session, create_task_instance):
trigger_on_healthy_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
trigger_on_healthy_triggerer.id = 1
trigger_on_healthy_triggerer.triggerer_id = healthy_triggerer.id
session.add(trigger_on_healthy_triggerer)
ti_trigger_on_healthy_triggerer = create_task_instance(
task_id="ti_trigger_on_healthy_triggerer",
execution_date=time_now,
run_id="trigger_on_healthy_triggerer_run_id",
)
ti_trigger_on_healthy_triggerer.trigger_id = trigger_on_healthy_triggerer.id
session.add(ti_trigger_on_healthy_triggerer)
trigger_on_unhealthy_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
trigger_on_unhealthy_triggerer.id = 2
trigger_on_unhealthy_triggerer.triggerer_id = unhealthy_triggerer.id
session.add(trigger_on_unhealthy_triggerer)
ti_trigger_on_unhealthy_triggerer = create_task_instance(
task_id="ti_trigger_on_unhealthy_triggerer",
execution_date=time_now + datetime.timedelta(hours=1),
run_id="trigger_on_unhealthy_triggerer_run_id",
)
ti_trigger_on_unhealthy_triggerer.trigger_id = trigger_on_unhealthy_triggerer.id
session.add(ti_trigger_on_unhealthy_triggerer)
trigger_on_killed_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
trigger_on_killed_triggerer.id = 3
trigger_on_killed_triggerer.triggerer_id = finished_triggerer.id
session.add(trigger_on_killed_triggerer)
ti_trigger_on_killed_triggerer = create_task_instance(
task_id="ti_trigger_on_killed_triggerer",
execution_date=time_now + datetime.timedelta(hours=2),
run_id="trigger_on_killed_triggerer_run_id",
)
ti_trigger_on_killed_triggerer.trigger_id = trigger_on_killed_triggerer.id
session.add(ti_trigger_on_killed_triggerer)
trigger_unassigned_to_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
trigger_unassigned_to_triggerer.id = 4
assert trigger_unassigned_to_triggerer.triggerer_id is None
session.add(trigger_on_healthy_triggerer)
session.add(trigger_on_unhealthy_triggerer)
session.add(trigger_on_killed_triggerer)
session.add(trigger_unassigned_to_triggerer)
ti_trigger_unassigned_to_triggerer = create_task_instance(
task_id="ti_trigger_unassigned_to_triggerer",
execution_date=time_now + datetime.timedelta(hours=3),
run_id="trigger_unassigned_to_triggerer_run_id",
)
ti_trigger_unassigned_to_triggerer.trigger_id = trigger_unassigned_to_triggerer.id
session.add(ti_trigger_unassigned_to_triggerer)
assert trigger_unassigned_to_triggerer.triggerer_id is None
session.commit()
assert session.query(Trigger).count() == 4
Trigger.assign_unassigned(new_triggerer.id, 100, health_check_threshold=30)
Expand All @@ -209,31 +237,101 @@ def test_assign_unassigned(session, create_task_instance):
)


def test_get_sorted_triggers(session, create_task_instance):
def test_get_sorted_triggers_same_priority_weight(session, create_task_instance):
"""
Tests that triggers are sorted by the creation_date.
Tests that triggers are sorted by the creation_date if they have the same priority.
"""
old_execution_date = datetime.datetime(
2023, 5, 9, 12, 16, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
)
trigger_old = Trigger(
classpath="airflow.triggers.testing.SuccessTrigger",
kwargs={},
created_date=datetime.datetime(
2023, 5, 9, 12, 16, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
),
created_date=old_execution_date + datetime.timedelta(seconds=30),
)
trigger_old.id = 1
session.add(trigger_old)
TI_old = create_task_instance(
task_id="old",
execution_date=old_execution_date,
run_id="old_run_id",
)
TI_old.priority_weight = 1
TI_old.trigger_id = trigger_old.id
session.add(TI_old)

new_execution_date = datetime.datetime(
2023, 5, 9, 12, 17, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
)
trigger_new = Trigger(
classpath="airflow.triggers.testing.SuccessTrigger",
kwargs={},
created_date=datetime.datetime(
2023, 5, 9, 12, 17, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
),
created_date=new_execution_date + datetime.timedelta(seconds=30),
)
trigger_new.id = 2
session.add(trigger_old)
session.add(trigger_new)
TI_new = create_task_instance(
task_id="new",
execution_date=new_execution_date,
run_id="new_run_id",
)
TI_new.priority_weight = 1
TI_new.trigger_id = trigger_new.id
session.add(TI_new)

session.commit()
assert session.query(Trigger).count() == 2

trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[], session=session)

assert trigger_ids_query == [(1,), (2,)]


def test_get_sorted_triggers_different_priority_weights(session, create_task_instance):
"""
Tests that triggers are sorted by the priority_weight.
"""
old_execution_date = datetime.datetime(
2023, 5, 9, 12, 16, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
)
trigger_old = Trigger(
classpath="airflow.triggers.testing.SuccessTrigger",
kwargs={},
created_date=old_execution_date + datetime.timedelta(seconds=30),
)
trigger_old.id = 1
session.add(trigger_old)
TI_old = create_task_instance(
task_id="old",
execution_date=old_execution_date,
run_id="old_run_id",
)
TI_old.priority_weight = 1
TI_old.trigger_id = trigger_old.id
session.add(TI_old)

new_execution_date = datetime.datetime(
2023, 5, 9, 12, 17, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan")
)
trigger_new = Trigger(
classpath="airflow.triggers.testing.SuccessTrigger",
kwargs={},
created_date=new_execution_date + datetime.timedelta(seconds=30),
)
trigger_new.id = 2
session.add(trigger_new)
TI_new = create_task_instance(
task_id="new",
execution_date=new_execution_date,
run_id="new_run_id",
)
TI_new.priority_weight = 2
TI_new.trigger_id = trigger_new.id
session.add(TI_new)

session.commit()
assert session.query(Trigger).count() == 2

trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[], session=session)

assert trigger_ids_query == [(2,), (1,)]

0 comments on commit ae3fa6e

Please sign in to comment.