Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DAG.dataset_triggers into the timetable class #39321

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 7 additions & 4 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,15 @@ class BaseDataset:
:meta private:
"""

def __or__(self, other: BaseDataset) -> DatasetAny:
def __bool__(self) -> bool:
return True

def __or__(self, other: BaseDataset) -> BaseDataset:
if not isinstance(other, BaseDataset):
return NotImplemented
return DatasetAny(self, other)

def __and__(self, other: BaseDataset) -> DatasetAll:
def __and__(self, other: BaseDataset) -> BaseDataset:
if not isinstance(other, BaseDataset):
return NotImplemented
return DatasetAll(self, other)
Expand Down Expand Up @@ -203,7 +206,7 @@ class DatasetAny(_DatasetBooleanCondition):

agg_func = any

def __or__(self, other: BaseDataset) -> DatasetAny:
def __or__(self, other: BaseDataset) -> BaseDataset:
if not isinstance(other, BaseDataset):
return NotImplemented
# Optimization: X | (Y | Z) is equivalent to X | Y | Z.
Expand All @@ -225,7 +228,7 @@ class DatasetAll(_DatasetBooleanCondition):

agg_func = all

def __and__(self, other: BaseDataset) -> DatasetAll:
def __and__(self, other: BaseDataset) -> BaseDataset:
if not isinstance(other, BaseDataset):
return NotImplemented
# Optimization: X & (Y & Z) is equivalent to X & Y & Z.
Expand Down
54 changes: 23 additions & 31 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@
from airflow.settings import json
from airflow.stats import Stats
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
from airflow.timetables.datasets import DatasetOrTimeSchedule
from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable
from airflow.timetables.simple import (
ContinuousTimetable,
Expand Down Expand Up @@ -631,35 +630,31 @@ def __init__(
stacklevel=2,
)

self.timetable: Timetable
if timetable is not None:
schedule = timetable
elif schedule_interval is not NOTSET:
schedule = schedule_interval

# Kept for compatibility. Do not use in new code.
self.schedule_interval: ScheduleInterval
self.dataset_triggers: BaseDataset | None = None
if isinstance(schedule, BaseDataset):
self.dataset_triggers = schedule
elif isinstance(schedule, Collection) and not isinstance(schedule, str):
if not all(isinstance(x, Dataset) for x in schedule):
raise ValueError("All elements in 'schedule' should be datasets")
self.dataset_triggers = DatasetAll(*schedule)
elif isinstance(schedule, Timetable):
timetable = schedule
elif schedule is not NOTSET and not isinstance(schedule, BaseDataset):
schedule_interval = schedule

if isinstance(schedule, DatasetOrTimeSchedule):
if isinstance(schedule, Timetable):
self.timetable = schedule
self.dataset_triggers = self.timetable.datasets
self.schedule_interval = schedule.summary
elif isinstance(schedule, BaseDataset):
self.timetable = DatasetTriggeredTimetable(schedule)
self.schedule_interval = self.timetable.summary
elif self.dataset_triggers:
self.timetable = DatasetTriggeredTimetable()
self.schedule_interval = self.timetable.summary
elif timetable:
self.timetable = timetable
elif isinstance(schedule, Collection) and not isinstance(schedule, str):
if not all(isinstance(x, Dataset) for x in schedule):
raise ValueError("All elements in 'schedule' should be datasets")
self.timetable = DatasetTriggeredTimetable(DatasetAll(*schedule))
self.schedule_interval = self.timetable.summary
elif isinstance(schedule, ArgNotSet):
self.timetable = create_timetable(schedule, self.timezone)
self.schedule_interval = DEFAULT_SCHEDULE_INTERVAL
else:
if isinstance(schedule_interval, ArgNotSet):
schedule_interval = DEFAULT_SCHEDULE_INTERVAL
self.schedule_interval = schedule_interval
self.timetable = create_timetable(schedule_interval, self.timezone)
self.timetable = create_timetable(schedule, self.timezone)
self.schedule_interval = schedule

if isinstance(template_searchpath, str):
template_searchpath = [template_searchpath]
Expand Down Expand Up @@ -3175,10 +3170,7 @@ def bulk_write_to_db(
)
orm_dag.schedule_interval = dag.schedule_interval
orm_dag.timetable_description = dag.timetable.description
if (dataset_triggers := dag.dataset_triggers) is None:
orm_dag.dataset_expression = None
else:
orm_dag.dataset_expression = dataset_triggers.as_expression()
orm_dag.dataset_expression = dag.timetable.dataset_condition.as_expression()

orm_dag.processor_subdir = processor_subdir

Expand Down Expand Up @@ -3234,11 +3226,11 @@ def bulk_write_to_db(
# later we'll persist them to the database.
for dag in dags:
curr_orm_dag = existing_dags.get(dag.dag_id)
if dag.dataset_triggers is None:
if not (dataset_condition := dag.timetable.dataset_condition):
if curr_orm_dag and curr_orm_dag.schedule_dataset_references:
curr_orm_dag.schedule_dataset_references = []
else:
for _, dataset in dag.dataset_triggers.iter_datasets():
for _, dataset in dataset_condition.iter_datasets():
dag_references[dag.dag_id].add(dataset.uri)
input_datasets[DatasetModel.from_public(dataset)] = None
curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references
Expand Down Expand Up @@ -3889,7 +3881,7 @@ def dag_ready(dag_id: str, cond: BaseDataset, statuses: dict) -> bool | None:
for ser_dag in ser_dags:
dag_id = ser_dag.dag_id
statuses = dag_statuses[dag_id]
if not dag_ready(dag_id, cond=ser_dag.dag.dataset_triggers, statuses=statuses):
if not dag_ready(dag_id, cond=ser_dag.dag.timetable.dataset_condition, statuses=statuses):
del by_dag[dag_id]
del dag_statuses[dag_id]
del dag_statuses
Expand Down
4 changes: 0 additions & 4 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,6 @@
{ "$ref": "#/definitions/typed_relativedelta" }
]
},
"dataset_triggers": {
"$ref": "#/definitions/typed_dataset_cond"

},
"owner_links": { "type": "object" },
"timetable": {
"type": "object",
Expand Down
64 changes: 36 additions & 28 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

from airflow.compat.functools import cache
from airflow.configuration import conf
from airflow.datasets import Dataset, DatasetAll, DatasetAny
from airflow.datasets import BaseDataset, Dataset, DatasetAll, DatasetAny
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError, TaskDeferred
from airflow.jobs.job import Job
from airflow.models.baseoperator import BaseOperator
Expand Down Expand Up @@ -229,6 +229,35 @@ def __str__(self) -> str:
)


def encode_dataset_condition(var: BaseDataset) -> dict[str, Any]:
"""Encode a dataset condition.

:meta private:
"""
if isinstance(var, Dataset):
return {"__type": DAT.DATASET, "uri": var.uri, "extra": var.extra}
if isinstance(var, DatasetAll):
return {"__type": DAT.DATASET_ALL, "objects": [encode_dataset_condition(x) for x in var.objects]}
if isinstance(var, DatasetAny):
return {"__type": DAT.DATASET_ANY, "objects": [encode_dataset_condition(x) for x in var.objects]}
raise ValueError(f"serialization not implemented for {type(var).__name__!r}")


def decode_dataset_condition(var: dict[str, Any]) -> BaseDataset:
"""Decode a previously serialized dataset condition.

:meta private:
"""
dat = var["__type"]
if dat == DAT.DATASET:
return Dataset(var["uri"], extra=var["extra"])
if dat == DAT.DATASET_ALL:
return DatasetAll(*(decode_dataset_condition(x) for x in var["objects"]))
if dat == DAT.DATASET_ANY:
return DatasetAny(*(decode_dataset_condition(x) for x in var["objects"]))
raise ValueError(f"deserialization not implemented for DAT {dat!r}")


def encode_timetable(var: Timetable) -> dict[str, Any]:
"""
Encode a timetable instance.
Expand Down Expand Up @@ -487,8 +516,6 @@ def serialize_to_json(
serialized_object[key] = encode_timetable(value)
elif key == "weight_rule" and value is not None:
serialized_object[key] = encode_priority_weight_strategy(value)
elif key == "dataset_triggers":
serialized_object[key] = cls.serialize(value)
else:
value = cls.serialize(value)
if isinstance(value, dict) and Encoding.TYPE in value:
Expand Down Expand Up @@ -605,24 +632,9 @@ def serialize(
return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
elif isinstance(var, XComArg):
return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
elif isinstance(var, Dataset):
return cls._encode({"uri": var.uri, "extra": var.extra}, type_=DAT.DATASET)
elif isinstance(var, DatasetAll):
return cls._encode(
[
cls.serialize(x, strict=strict, use_pydantic_models=use_pydantic_models)
for x in var.objects
],
type_=DAT.DATASET_ALL,
)
elif isinstance(var, DatasetAny):
return cls._encode(
[
cls.serialize(x, strict=strict, use_pydantic_models=use_pydantic_models)
for x in var.objects
],
type_=DAT.DATASET_ANY,
)
elif isinstance(var, BaseDataset):
serialized_dataset = encode_dataset_condition(var)
return cls._encode(serialized_dataset, type_=serialized_dataset.pop("__type"))
elif isinstance(var, SimpleTaskInstance):
return cls._encode(
cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models),
Expand Down Expand Up @@ -738,9 +750,9 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
elif type_ == DAT.DATASET:
return Dataset(**var)
elif type_ == DAT.DATASET_ANY:
return DatasetAny(*(cls.deserialize(x) for x in var))
return DatasetAny(*(decode_dataset_condition(x) for x in var["objects"]))
elif type_ == DAT.DATASET_ALL:
return DatasetAll(*(cls.deserialize(x) for x in var))
return DatasetAll(*(decode_dataset_condition(x) for x in var["objects"]))
elif type_ == DAT.SIMPLE_TASK_INSTANCE:
return SimpleTaskInstance(**cls.deserialize(var))
elif type_ == DAT.CONNECTION:
Expand Down Expand Up @@ -917,9 +929,7 @@ def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]:
"""Detect dependencies set directly on the DAG object."""
if not dag:
return
if not dag.dataset_triggers:
return
for uri, _ in dag.dataset_triggers.iter_datasets():
for uri, _ in dag.timetable.dataset_condition.iter_datasets():
yield DagDependency(
source="dataset",
target=dag.dag_id,
Expand Down Expand Up @@ -1569,8 +1579,6 @@ def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG:
v = cls.deserialize(v)
elif k == "params":
v = cls._deserialize_params_dict(v)
elif k == "dataset_triggers":
v = cls.deserialize(v)
# else use v as it is

setattr(dag, k, v)
Expand Down
63 changes: 50 additions & 13 deletions airflow/timetables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,47 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, NamedTuple, Sequence
from typing import TYPE_CHECKING, Any, Iterator, NamedTuple, Sequence
from warnings import warn

from airflow.datasets import BaseDataset
from airflow.typing_compat import Protocol, runtime_checkable

if TYPE_CHECKING:
from pendulum import DateTime

from airflow.datasets import Dataset
from airflow.utils.types import DagRunType


class _NullDataset(BaseDataset):
"""Sentinel type that represents "no datasets".

This is only implemented to make typing easier in timetables, and not
expected to be used anywhere else.

:meta private:
"""

def __bool__(self) -> bool:
return False

def __or__(self, other: BaseDataset) -> BaseDataset:
return NotImplemented

def __and__(self, other: BaseDataset) -> BaseDataset:
return NotImplemented

def as_expression(self) -> Any:
return None

def evaluate(self, statuses: dict[str, bool]) -> bool:
return False

def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
return iter(())


class DataInterval(NamedTuple):
"""A data interval for a DagRun to operate over.

Expand Down Expand Up @@ -127,6 +157,12 @@ class Timetable(Protocol):

@property
def can_be_scheduled(self):
"""Whether this timetable can actually schedule runs in an automated manner.

This defaults to and should generally be *True* (including non periodic
execution types like *@once* and data triggered tables), but
``NullTimetable`` sets this to *False*.
"""
if hasattr(self, "can_run"):
warn(
'can_run class variable is deprecated. Use "can_be_scheduled" instead.',
Expand All @@ -136,25 +172,26 @@ def can_be_scheduled(self):
return self.can_run
return self._can_be_scheduled

"""Whether this timetable can actually schedule runs in an automated manner.

This defaults to and should generally be *True* (including non periodic
execution types like *@once* and data triggered tables), but
``NullTimetable`` sets this to *False*.
"""

run_ordering: Sequence[str] = ("data_interval_end", "execution_date")
"""How runs triggered from this timetable should be ordered in UI.

This should be a list of field names on the DAG run object.
"""

active_runs_limit: int | None = None
"""Override the max_active_runs parameter of any DAGs using this timetable.
This is called during DAG initializing, and will set the max_active_runs if
it returns a value. In most cases this should return None, but in some cases
(for example, the ContinuousTimetable) there are good reasons for limiting
the DAGRun parallelism.
"""Maximum active runs that can be active at one time for a DAG.

This is called during DAG initialization, and the return value is used as
the DAG's default ``max_active_runs``. This should generally return *None*,
but there are good reasons to limit DAG run parallelism in some cases, such
as for :class:`~airflow.timetable.simple.ContinuousTimetable`.
"""

dataset_condition: BaseDataset = _NullDataset()
"""The dataset condition that triggers a DAG using this timetable.

If this is not *None*, this should be a dataset, or a combination of, that
controls the DAG's dataset triggers.
"""

@classmethod
Expand Down