Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP44 Fix DAG serialization #34042

Merged
merged 1 commit into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/serialization/pydantic/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class DagModelPydantic(BaseModelPydantic):
is_paused_at_creation: bool = airflow_conf.getboolean("core", "dags_are_paused_at_creation")
is_paused: bool = is_paused_at_creation
is_subdag: Optional[bool] = False
is_active: bool = False
is_active: Optional[bool] = False
last_parsed_time: Optional[datetime]
last_pickled: Optional[datetime]
last_expired: Optional[datetime]
Expand Down
2 changes: 1 addition & 1 deletion airflow/serialization/pydantic/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class DagRunPydantic(BaseModelPydantic):
data_interval_end: Optional[datetime]
last_scheduling_decision: Optional[datetime]
dag_hash: Optional[str]
updated_at: datetime
updated_at: Optional[datetime]
dag: Optional[PydanticDag]
consumed_dataset_events: List[DatasetEventPydantic] # noqa

Expand Down
2 changes: 1 addition & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def serialize(
json_pod = PodGenerator.serialize_pod(var)
return cls._encode(json_pod, type_=DAT.POD)
elif isinstance(var, DAG):
return SerializedDAG.serialize_dag(var)
return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to encode this, we expect it to have a top level "dag" unique key as shown in

Similarly, we don't encode tasks or each operator, because we know the structure that it would be within ["dag"]["tasks"]

Check L445 and L447

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When building DAG s10n it was designed to using Airflow internal knowledge to not inflate the final blob and optimize wherever possible. This is the reason we don't store defaults and objects that are None for DAG and operator objects

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, be careful about backwards-incompatible changes for this.

And if for AIP-44 we absolutely need this sort of change, that will probably be more involved than just the DAG object

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but:
if I serialize dag (with BaseSerialization.serialize) then try to deserialize it with BaseSerialization.deserialize it fails - as it doesn't get there
https://github.com/apache/airflow/blob/main/airflow/serialization/serialized_objects.py#L551
because

        var = encoded_var[Encoding.VAR]
        type_ = encoded_var[Encoding.TYPE]

failes (key not exists).

In code I am not able to find any usages of BaseSerialization.serialize on DAG (at least when looking for something like BaseSerialization.serialize.*dag

We are currently working on making sure all other objects types are serializable.

Copy link
Member

@potiuk potiuk Sep 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I think @mhenc is right - this part of code seems to be not used before - it's not DAG serialization that gets afffected here, it's just serializing the whole dag as part of bigger structure. I have not seen any place in the code where we'd do that before.

elif isinstance(var, Resources):
return var.to_dict()
elif isinstance(var, MappedOperator):
Expand Down
8 changes: 4 additions & 4 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2525,7 +2525,7 @@ def tg(a: str) -> None:
tg.expand(a=[".", ".."])

ser_dag = SerializedBaseOperator.serialize(dag)
assert ser_dag["_task_group"]["children"]["tg"] == (
assert ser_dag[Encoding.VAR]["_task_group"]["children"]["tg"] == (
"taskgroup",
{
"_group_id": "tg",
Expand All @@ -2549,7 +2549,7 @@ def tg(a: str) -> None:
},
)

serde_dag = SerializedDAG.deserialize_dag(ser_dag)
serde_dag = SerializedDAG.deserialize_dag(ser_dag[Encoding.VAR])
serde_tg = serde_dag.task_group.children["tg"]
assert isinstance(serde_tg, MappedTaskGroup)
assert serde_tg._expand_input == DictOfListsExpandInput({"a": [".", ".."]})
Expand All @@ -2568,7 +2568,7 @@ def operator_extra_links(self):
with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
_DummyOperator.partial(task_id="task").expand(inputs=[1, 2, 3])
serialized_dag = SerializedBaseOperator.serialize(dag)
assert serialized_dag["tasks"][0] == {
assert serialized_dag[Encoding.VAR]["tasks"][0] == {
"task_id": "task",
"expand_input": {
"type": "dict-of-lists",
Expand All @@ -2589,5 +2589,5 @@ def operator_extra_links(self):
"_is_empty": False,
"_is_mapped": True,
}
deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag)
deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR])
assert deserialized_dag.task_dict["task"].operator_extra_links == [AirflowLink2()]
217 changes: 193 additions & 24 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,34 @@
from __future__ import annotations

import json
from datetime import datetime
from datetime import datetime, timedelta

import pytest
from dateutil import relativedelta
from kubernetes.client import models as k8s
from pendulum.tz.timezone import Timezone

from airflow.datasets import Dataset
from airflow.exceptions import SerializationError
from airflow.models.taskinstance import TaskInstance
from airflow.jobs.job import Job
from airflow.models.connection import Connection
from airflow.models.dag import DAG, DagModel
from airflow.models.dagrun import DagRun
from airflow.models.param import Param
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.serialization.enums import DagAttributeTypes as DAT
from airflow.serialization.pydantic.dag import DagModelPydantic
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.job import JobPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.settings import _ENABLE_AIP_44
from airflow.utils.state import State
from airflow.utils.operator_resources import Resources
from airflow.utils.state import DagRunState, State
from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunType
from tests import REPO_ROOT


Expand Down Expand Up @@ -82,31 +100,182 @@ class Test:
BaseSerialization.serialize(obj, strict=True) # now raises


TI = TaskInstance(
task=EmptyOperator(task_id="test-task"),
run_id="fake_run",
state=State.RUNNING,
)

TI_WITH_START_DAY = TaskInstance(
task=EmptyOperator(task_id="test-task"),
run_id="fake_run",
state=State.RUNNING,
)
TI_WITH_START_DAY.start_date = datetime.utcnow()

DAG_RUN = DagRun(
dag_id="test_dag_id",
run_id="test_dag_run_id",
run_type=DagRunType.MANUAL,
execution_date=datetime.utcnow(),
start_date=datetime.utcnow(),
external_trigger=True,
state=DagRunState.SUCCESS,
)
DAG_RUN.id = 1


def equals(a, b) -> bool:
return a == b


def equal_time(a: datetime, b: datetime) -> bool:
return a.strftime("%s") == b.strftime("%s")


@pytest.mark.parametrize(
"input, encoded_type, cmp_func",
[
("test_str", None, equals),
(1, None, equals),
(datetime.utcnow(), DAT.DATETIME, equal_time),
(timedelta(minutes=2), DAT.TIMEDELTA, equals),
(Timezone("UTC"), DAT.TIMEZONE, lambda a, b: a.name == b.name),
(relativedelta.relativedelta(hours=+1), DAT.RELATIVEDELTA, lambda a, b: a.hours == b.hours),
({"test": "dict", "test-1": 1}, None, equals),
(["array_item", 2], None, equals),
(("tuple_item", 3), DAT.TUPLE, equals),
(set(["set_item", 3]), DAT.SET, equals),
(
k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
name="test", annotations={"test": "annotation"}, creation_timestamp=datetime.utcnow()
)
),
DAT.POD,
equals,
),
(
DAG(
"fake-dag",
schedule="*/10 * * * *",
default_args={"depends_on_past": True},
start_date=datetime.utcnow(),
catchup=False,
),
DAT.DAG,
lambda a, b: a.dag_id == b.dag_id and equal_time(a.start_date, b.start_date),
),
(Resources(cpus=0.1, ram=2048), None, None),
(EmptyOperator(task_id="test-task"), None, None),
(TaskGroup(group_id="test-group", dag=DAG(dag_id="test_dag", start_date=datetime.now())), None, None),
(
Param("test", "desc"),
DAT.PARAM,
lambda a, b: a.value == b.value and a.description == b.description,
),
(
XComArg(
operator=PythonOperator(
python_callable=int,
task_id="test_xcom_op",
do_xcom_push=True,
)
),
DAT.XCOM_REF,
None,
),
(Dataset(uri="test"), DAT.DATASET, equals),
(SimpleTaskInstance.from_ti(ti=TI), DAT.SIMPLE_TASK_INSTANCE, equals),
(
Connection(conn_id="TEST_ID", uri="mysql://"),
DAT.CONNECTION,
lambda a, b: a.get_uri() == b.get_uri(),
),
],
)
def test_serialize_deserialize(input, encoded_type, cmp_func):
from airflow.serialization.serialized_objects import BaseSerialization

serialized = BaseSerialization.serialize(input) # does not raise
json.dumps(serialized) # does not raise
if encoded_type is not None:
assert serialized["__type"] == encoded_type
assert serialized["__var"] is not None
if cmp_func is not None:
deserialized = BaseSerialization.deserialize(serialized)
assert cmp_func(input, deserialized)

# Verify recursive behavior
obj = [[input]]
serialized = BaseSerialization.serialize(obj) # does not raise
# Verify the result is JSON-serializable
json.dumps(serialized) # does not raise


@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
def test_use_pydantic_models():
"""If use_pydantic_models=True the TaskInstance object should be serialized to TaskInstancePydantic."""
@pytest.mark.parametrize(
"input, pydantic_class, encoded_type, cmp_func",
[
(
Job(state=State.RUNNING, latest_heartbeat=datetime.utcnow()),
JobPydantic,
DAT.BASE_JOB,
lambda a, b: equal_time(a.latest_heartbeat, b.latest_heartbeat),
),
(
TI_WITH_START_DAY,
TaskInstancePydantic,
DAT.TASK_INSTANCE,
lambda a, b: equal_time(a.start_date, b.start_date),
),
(
DAG_RUN,
DagRunPydantic,
DAT.DAG_RUN,
lambda a, b: equal_time(a.execution_date, b.execution_date)
and equal_time(a.start_date, b.start_date),
),
# DataSet is already serialized by non-Pydantic serialization. Is DatasetPydantic needed then?
# (
# Dataset(
# uri="foo://bar",
# extra={"foo": "bar"},
# ),
# DatasetPydantic,
# DAT.DATA_SET,
# lambda a, b: a.uri == b.uri and a.extra == b.extra,
# ),
(
DagModel(
dag_id="TEST_DAG_1",
fileloc="/tmp/dag_1.py",
schedule_interval="2 2 * * *",
is_paused=True,
),
DagModelPydantic,
DAT.DAG_MODEL,
lambda a, b: a.fileloc == b.fileloc and a.schedule_interval == b.schedule_interval,
),
],
)
def test_serialize_deserialize_pydantic(input, pydantic_class, encoded_type, cmp_func):
"""If use_pydantic_models=True the objects should be serialized to Pydantic objects."""

from airflow.serialization.serialized_objects import BaseSerialization

ti = TaskInstance(
task=EmptyOperator(task_id="task"),
run_id="run_id",
state=State.RUNNING,
)
start_date = datetime.utcnow()
ti.start_date = start_date
obj = [[ti]] # nested to verify recursive behavior

serialized = BaseSerialization.serialize(obj, use_pydantic_models=True) # does not raise
deserialized = BaseSerialization.deserialize(serialized, use_pydantic_models=True) # does not raise
assert isinstance(deserialized[0][0], TaskInstancePydantic)

serialized_json = json.dumps(serialized) # does not raise
deserialized_from_json = BaseSerialization.deserialize(
json.loads(serialized_json), use_pydantic_models=True
) # does not raise
assert isinstance(deserialized_from_json[0][0], TaskInstancePydantic)
assert deserialized_from_json[0][0].start_date == start_date
serialized = BaseSerialization.serialize(input, use_pydantic_models=True) # does not raise
# Verify the result is JSON-serializable
json.dumps(serialized) # does not raise
assert serialized["__type"] == encoded_type
assert serialized["__var"] is not None
deserialized = BaseSerialization.deserialize(serialized, use_pydantic_models=True)
assert isinstance(deserialized, pydantic_class)
assert cmp_func(input, deserialized)

# Verify recursive behavior
obj = [[input]]
BaseSerialization.serialize(obj, use_pydantic_models=True) # does not raise


def test_serialized_mapped_operator_unmap(dag_maker):
Expand Down