Skip to content

Commit

Permalink
Move dataset information into timetable
Browse files Browse the repository at this point in the history
This allows us to remove a bunch of conditionals regarding whether a DAG
is backed by a timetable or dataset condition, and a weird edge case in
serialization where we don't actually deserialize datasets in a
timetable.

Now datasets are always serialized as a part of the timetable, and we
always evaluate the timetable and datasets. Timetables that do not
actually contain datasets (most of them) simply always evaluate to
False.
  • Loading branch information
uranusjr committed Apr 30, 2024
1 parent e979ecc commit 2dbc042
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 110 deletions.
11 changes: 7 additions & 4 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,15 @@ class BaseDataset:
:meta private:
"""

def __or__(self, other: BaseDataset) -> DatasetAny:
def __bool__(self) -> bool:
return True

def __or__(self, other: BaseDataset) -> BaseDataset:
if not isinstance(other, BaseDataset):
return NotImplemented
return DatasetAny(self, other)

def __and__(self, other: BaseDataset) -> DatasetAll:
def __and__(self, other: BaseDataset) -> BaseDataset:
if not isinstance(other, BaseDataset):
return NotImplemented
return DatasetAll(self, other)
Expand Down Expand Up @@ -203,7 +206,7 @@ class DatasetAny(_DatasetBooleanCondition):

agg_func = any

def __or__(self, other: BaseDataset) -> DatasetAny:
def __or__(self, other: BaseDataset) -> BaseDataset:
if not isinstance(other, BaseDataset):
return NotImplemented
# Optimization: X | (Y | Z) is equivalent to X | Y | Z.
Expand All @@ -225,7 +228,7 @@ class DatasetAll(_DatasetBooleanCondition):

agg_func = all

def __and__(self, other: BaseDataset) -> DatasetAll:
def __and__(self, other: BaseDataset) -> BaseDataset:
if not isinstance(other, BaseDataset):
return NotImplemented
# Optimization: X & (Y & Z) is equivalent to X & Y & Z.
Expand Down
49 changes: 18 additions & 31 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@
from airflow.settings import json
from airflow.stats import Stats
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
from airflow.timetables.datasets import DatasetOrTimeSchedule
from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable
from airflow.timetables.simple import (
ContinuousTimetable,
Expand Down Expand Up @@ -631,35 +630,26 @@ def __init__(
stacklevel=2,
)

self.timetable: Timetable
# Kept for compatibility. Do not use in new code.
self.schedule_interval: ScheduleInterval
self.dataset_triggers: BaseDataset | None = None
if isinstance(schedule, BaseDataset):
self.dataset_triggers = schedule
elif isinstance(schedule, Collection) and not isinstance(schedule, str):
if not all(isinstance(x, Dataset) for x in schedule):
raise ValueError("All elements in 'schedule' should be datasets")
self.dataset_triggers = DatasetAll(*schedule)
elif isinstance(schedule, Timetable):
timetable = schedule
elif schedule is not NOTSET and not isinstance(schedule, BaseDataset):
schedule_interval = schedule

if isinstance(schedule, DatasetOrTimeSchedule):
if isinstance(schedule, Timetable):
self.timetable = schedule
self.dataset_triggers = self.timetable.datasets
self.schedule_interval = self.timetable.summary
elif self.dataset_triggers:
self.timetable = DatasetTriggeredTimetable()
self.schedule_interval = schedule.summary
elif isinstance(schedule, BaseDataset):
self.timetable = DatasetTriggeredTimetable(schedule)
self.schedule_interval = self.timetable.summary
elif timetable:
self.timetable = timetable
elif isinstance(schedule, Collection) and not isinstance(schedule, str):
if not all(isinstance(x, Dataset) for x in schedule):
raise ValueError("All elements in 'schedule' should be datasets")
self.timetable = DatasetTriggeredTimetable(DatasetAll(*schedule))
self.schedule_interval = self.timetable.summary
elif isinstance(schedule, ArgNotSet):
self.timetable = create_timetable(schedule, self.timezone)
self.schedule_interval = DEFAULT_SCHEDULE_INTERVAL
else:
if isinstance(schedule_interval, ArgNotSet):
schedule_interval = DEFAULT_SCHEDULE_INTERVAL
self.schedule_interval = schedule_interval
self.timetable = create_timetable(schedule_interval, self.timezone)
self.timetable = create_timetable(schedule, self.timezone)
self.schedule_interval = schedule

if isinstance(template_searchpath, str):
template_searchpath = [template_searchpath]
Expand Down Expand Up @@ -3175,10 +3165,7 @@ def bulk_write_to_db(
)
orm_dag.schedule_interval = dag.schedule_interval
orm_dag.timetable_description = dag.timetable.description
if (dataset_triggers := dag.dataset_triggers) is None:
orm_dag.dataset_expression = None
else:
orm_dag.dataset_expression = dataset_triggers.as_expression()
orm_dag.dataset_expression = dag.timetable.dataset_condition.as_expression()

orm_dag.processor_subdir = processor_subdir

Expand Down Expand Up @@ -3234,11 +3221,11 @@ def bulk_write_to_db(
# later we'll persist them to the database.
for dag in dags:
curr_orm_dag = existing_dags.get(dag.dag_id)
if dag.dataset_triggers is None:
if not (dataset_condition := dag.timetable.dataset_condition):
if curr_orm_dag and curr_orm_dag.schedule_dataset_references:
curr_orm_dag.schedule_dataset_references = []
else:
for _, dataset in dag.dataset_triggers.iter_datasets():
for _, dataset in dataset_condition.iter_datasets():
dag_references[dag.dag_id].add(dataset.uri)
input_datasets[DatasetModel.from_public(dataset)] = None
curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references
Expand Down Expand Up @@ -3889,7 +3876,7 @@ def dag_ready(dag_id: str, cond: BaseDataset, statuses: dict) -> bool | None:
for ser_dag in ser_dags:
dag_id = ser_dag.dag_id
statuses = dag_statuses[dag_id]
if not dag_ready(dag_id, cond=ser_dag.dag.dataset_triggers, statuses=statuses):
if not dag_ready(dag_id, cond=ser_dag.dag.timetable.dataset_condition, statuses=statuses):
del by_dag[dag_id]
del dag_statuses[dag_id]
del dag_statuses
Expand Down
4 changes: 0 additions & 4 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,6 @@
{ "$ref": "#/definitions/typed_relativedelta" }
]
},
"dataset_triggers": {
"$ref": "#/definitions/typed_dataset_cond"

},
"owner_links": { "type": "object" },
"timetable": {
"type": "object",
Expand Down
56 changes: 31 additions & 25 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
if TYPE_CHECKING:
from inspect import Parameter

from airflow.datasets import BaseDataset
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.expandinput import ExpandInput
from airflow.models.operator import Operator
Expand Down Expand Up @@ -229,6 +230,35 @@ def __str__(self) -> str:
)


def encode_dataset_condition(var: BaseDataset) -> dict[str, Any]:
"""Encode a dataset condition.
:meta private:
"""
if isinstance(var, Dataset):
return {"__type": DAT.DATASET, "uri": var.uri, "extra": var.extra}
if isinstance(var, DatasetAll):
return {"__type": DAT.DATASET_ALL, "objects": [encode_dataset_condition(x) for x in var.objects]}
if isinstance(var, DatasetAny):
return {"__type": DAT.DATASET_ANY, "objects": [encode_dataset_condition(x) for x in var.objects]}
raise ValueError(f"serialization not implemented for {type(var).__name__!r}")


def decode_dataset_condition(var: dict[str, Any]) -> BaseDataset:
"""Decode a previously serialized dataset condition.
:meta private:
"""
dat = var["__type"]
if dat == DAT.DATASET:
return Dataset(var["uri"], extra=var["extra"])
if dat == DAT.DATASET_ALL:
return DatasetAll(*(decode_dataset_condition(x) for x in var["objects"]))
if dat == DAT.DATASET_ANY:
return DatasetAny(*(decode_dataset_condition(x) for x in var["objects"]))
raise ValueError(f"deserialization not implemented for DAT {dat!r}")


def encode_timetable(var: Timetable) -> dict[str, Any]:
"""
Encode a timetable instance.
Expand Down Expand Up @@ -487,8 +517,6 @@ def serialize_to_json(
serialized_object[key] = encode_timetable(value)
elif key == "weight_rule" and value is not None:
serialized_object[key] = encode_priority_weight_strategy(value)
elif key == "dataset_triggers":
serialized_object[key] = cls.serialize(value)
else:
value = cls.serialize(value)
if isinstance(value, dict) and Encoding.TYPE in value:
Expand Down Expand Up @@ -605,24 +633,6 @@ def serialize(
return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
elif isinstance(var, XComArg):
return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
elif isinstance(var, Dataset):
return cls._encode({"uri": var.uri, "extra": var.extra}, type_=DAT.DATASET)
elif isinstance(var, DatasetAll):
return cls._encode(
[
cls.serialize(x, strict=strict, use_pydantic_models=use_pydantic_models)
for x in var.objects
],
type_=DAT.DATASET_ALL,
)
elif isinstance(var, DatasetAny):
return cls._encode(
[
cls.serialize(x, strict=strict, use_pydantic_models=use_pydantic_models)
for x in var.objects
],
type_=DAT.DATASET_ANY,
)
elif isinstance(var, SimpleTaskInstance):
return cls._encode(
cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models),
Expand Down Expand Up @@ -917,9 +927,7 @@ def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]:
"""Detect dependencies set directly on the DAG object."""
if not dag:
return
if not dag.dataset_triggers:
return
for uri, _ in dag.dataset_triggers.iter_datasets():
for uri, _ in dag.timetable.dataset_condition.iter_datasets():
yield DagDependency(
source="dataset",
target=dag.dag_id,
Expand Down Expand Up @@ -1569,8 +1577,6 @@ def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG:
v = cls.deserialize(v)
elif k == "params":
v = cls._deserialize_params_dict(v)
elif k == "dataset_triggers":
v = cls.deserialize(v)
# else use v as it is

setattr(dag, k, v)
Expand Down
63 changes: 50 additions & 13 deletions airflow/timetables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,47 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, NamedTuple, Sequence
from typing import TYPE_CHECKING, Any, Iterator, NamedTuple, Sequence
from warnings import warn

from airflow.datasets import BaseDataset
from airflow.typing_compat import Protocol, runtime_checkable

if TYPE_CHECKING:
from pendulum import DateTime

from airflow.datasets import Dataset
from airflow.utils.types import DagRunType


class _NullDataset(BaseDataset):
"""Sentinel type that represents "no datasets".
This is only implemented to make typing easier in timetables, and not
expected to be used anywhere else.
:meta private:
"""

def __bool__(self) -> bool:
return False

def __or__(self, other: BaseDataset) -> BaseDataset:
return NotImplemented

def __and__(self, other: BaseDataset) -> BaseDataset:
return NotImplemented

def as_expression(self) -> Any:
return None

def evaluate(self, statuses: dict[str, bool]) -> bool:
return False

def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
return iter(())


class DataInterval(NamedTuple):
"""A data interval for a DagRun to operate over.
Expand Down Expand Up @@ -127,6 +157,12 @@ class Timetable(Protocol):

@property
def can_be_scheduled(self):
"""Whether this timetable can actually schedule runs in an automated manner.
This defaults to and should generally be *True* (including non periodic
execution types like *@once* and data triggered tables), but
``NullTimetable`` sets this to *False*.
"""
if hasattr(self, "can_run"):
warn(
'can_run class variable is deprecated. Use "can_be_scheduled" instead.',
Expand All @@ -136,25 +172,26 @@ def can_be_scheduled(self):
return self.can_run
return self._can_be_scheduled

"""Whether this timetable can actually schedule runs in an automated manner.
This defaults to and should generally be *True* (including non periodic
execution types like *@once* and data triggered tables), but
``NullTimetable`` sets this to *False*.
"""

run_ordering: Sequence[str] = ("data_interval_end", "execution_date")
"""How runs triggered from this timetable should be ordered in UI.
This should be a list of field names on the DAG run object.
"""

active_runs_limit: int | None = None
"""Override the max_active_runs parameter of any DAGs using this timetable.
This is called during DAG initializing, and will set the max_active_runs if
it returns a value. In most cases this should return None, but in some cases
(for example, the ContinuousTimetable) there are good reasons for limiting
the DAGRun parallelism.
"""Maximum active runs that can be active at one time for a DAG.
This is called during DAG initialization, and the return value is used as
the DAG's default ``max_active_runs``. This should generally return *None*,
but there are good reasons to limit DAG run parallelism in some cases, such
as for :class:`~airflow.timetable.simple.ContinuousTimetable`.
"""

dataset_condition: BaseDataset = _NullDataset()
"""The dataset condition that triggers a DAG using this timetable.
If this is not *None*, this should be a dataset, or a combination of, that
controls the DAG's dataset triggers.
"""

@classmethod
Expand Down
20 changes: 10 additions & 10 deletions airflow/timetables/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def __init__(
) -> None:
self.timetable = timetable
if isinstance(datasets, BaseDataset):
self.datasets = datasets
self.dataset_condition = datasets
else:
self.datasets = DatasetAll(*datasets)
self.dataset_condition = DatasetAll(*datasets)

self.description = f"Triggered by datasets or {timetable.description}"
self.periodic = timetable.periodic
Expand All @@ -55,25 +55,25 @@ def __init__(

@classmethod
def deserialize(cls, data: dict[str, typing.Any]) -> Timetable:
from airflow.serialization.serialized_objects import decode_timetable
from airflow.serialization.serialized_objects import decode_dataset_condition, decode_timetable

return cls(
datasets=decode_dataset_condition(data["dataset_condition"]),
timetable=decode_timetable(data["timetable"]),
# don't need the datasets after deserialization
# they are already stored on dataset_triggers attr on DAG
# and this is what scheduler looks at
datasets=[],
)

def serialize(self) -> dict[str, typing.Any]:
from airflow.serialization.serialized_objects import encode_timetable
from airflow.serialization.serialized_objects import encode_dataset_condition, encode_timetable

return {"timetable": encode_timetable(self.timetable)}
return {
"dataset_condition": encode_dataset_condition(self.dataset_condition),
"timetable": encode_timetable(self.timetable),
}

def validate(self) -> None:
if isinstance(self.timetable, DatasetTriggeredSchedule):
raise AirflowTimetableInvalid("cannot nest dataset timetables")
if not isinstance(self.datasets, BaseDataset):
if not isinstance(self.dataset_condition, BaseDataset):
raise AirflowTimetableInvalid("all elements in 'datasets' must be datasets")

@property
Expand Down

0 comments on commit 2dbc042

Please sign in to comment.