Skip to content

Commit

Permalink
Fix dag serialization (#34042)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhenc committed Oct 27, 2023
1 parent 9b538b7 commit 64a64ab
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 31 deletions.
2 changes: 1 addition & 1 deletion airflow/serialization/pydantic/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class DagModelPydantic(BaseModelPydantic):
is_paused_at_creation: bool = airflow_conf.getboolean("core", "dags_are_paused_at_creation")
is_paused: bool = is_paused_at_creation
is_subdag: Optional[bool] = False
is_active: bool = False
is_active: Optional[bool] = False
last_parsed_time: Optional[datetime]
last_pickled: Optional[datetime]
last_expired: Optional[datetime]
Expand Down
2 changes: 1 addition & 1 deletion airflow/serialization/pydantic/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class DagRunPydantic(BaseModelPydantic):
data_interval_end: Optional[datetime]
last_scheduling_decision: Optional[datetime]
dag_hash: Optional[str]
updated_at: datetime
updated_at: Optional[datetime]
dag: Optional[PydanticDag]
consumed_dataset_events: List[DatasetEventPydantic] # noqa

Expand Down
2 changes: 1 addition & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def serialize(
json_pod = PodGenerator.serialize_pod(var)
return cls._encode(json_pod, type_=DAT.POD)
elif isinstance(var, DAG):
return SerializedDAG.serialize_dag(var)
return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG)
elif isinstance(var, Resources):
return var.to_dict()
elif isinstance(var, MappedOperator):
Expand Down
8 changes: 4 additions & 4 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2525,7 +2525,7 @@ def tg(a: str) -> None:
tg.expand(a=[".", ".."])

ser_dag = SerializedBaseOperator.serialize(dag)
assert ser_dag["_task_group"]["children"]["tg"] == (
assert ser_dag[Encoding.VAR]["_task_group"]["children"]["tg"] == (
"taskgroup",
{
"_group_id": "tg",
Expand All @@ -2549,7 +2549,7 @@ def tg(a: str) -> None:
},
)

serde_dag = SerializedDAG.deserialize_dag(ser_dag)
serde_dag = SerializedDAG.deserialize_dag(ser_dag[Encoding.VAR])
serde_tg = serde_dag.task_group.children["tg"]
assert isinstance(serde_tg, MappedTaskGroup)
assert serde_tg._expand_input == DictOfListsExpandInput({"a": [".", ".."]})
Expand All @@ -2568,7 +2568,7 @@ def operator_extra_links(self):
with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
_DummyOperator.partial(task_id="task").expand(inputs=[1, 2, 3])
serialized_dag = SerializedBaseOperator.serialize(dag)
assert serialized_dag["tasks"][0] == {
assert serialized_dag[Encoding.VAR]["tasks"][0] == {
"task_id": "task",
"expand_input": {
"type": "dict-of-lists",
Expand All @@ -2589,5 +2589,5 @@ def operator_extra_links(self):
"_is_empty": False,
"_is_mapped": True,
}
deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag)
deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR])
assert deserialized_dag.task_dict["task"].operator_extra_links == [AirflowLink2()]
217 changes: 193 additions & 24 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,34 @@
from __future__ import annotations

import json
from datetime import datetime
from datetime import datetime, timedelta

import pytest
from dateutil import relativedelta
from kubernetes.client import models as k8s
from pendulum.tz.timezone import Timezone

from airflow.datasets import Dataset
from airflow.exceptions import SerializationError
from airflow.models.taskinstance import TaskInstance
from airflow.jobs.job import Job
from airflow.models.connection import Connection
from airflow.models.dag import DAG, DagModel
from airflow.models.dagrun import DagRun
from airflow.models.param import Param
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.serialization.enums import DagAttributeTypes as DAT
from airflow.serialization.pydantic.dag import DagModelPydantic
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.job import JobPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.settings import _ENABLE_AIP_44
from airflow.utils.state import State
from airflow.utils.operator_resources import Resources
from airflow.utils.state import DagRunState, State
from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunType
from tests import REPO_ROOT


Expand Down Expand Up @@ -82,31 +100,182 @@ class Test:
BaseSerialization.serialize(obj, strict=True) # now raises


TI = TaskInstance(
task=EmptyOperator(task_id="test-task"),
run_id="fake_run",
state=State.RUNNING,
)

TI_WITH_START_DAY = TaskInstance(
task=EmptyOperator(task_id="test-task"),
run_id="fake_run",
state=State.RUNNING,
)
TI_WITH_START_DAY.start_date = datetime.utcnow()

DAG_RUN = DagRun(
dag_id="test_dag_id",
run_id="test_dag_run_id",
run_type=DagRunType.MANUAL,
execution_date=datetime.utcnow(),
start_date=datetime.utcnow(),
external_trigger=True,
state=DagRunState.SUCCESS,
)
DAG_RUN.id = 1


def equals(a, b) -> bool:
return a == b


def equal_time(a: datetime, b: datetime) -> bool:
return a.strftime("%s") == b.strftime("%s")


@pytest.mark.parametrize(
"input, encoded_type, cmp_func",
[
("test_str", None, equals),
(1, None, equals),
(datetime.utcnow(), DAT.DATETIME, equal_time),
(timedelta(minutes=2), DAT.TIMEDELTA, equals),
(Timezone("UTC"), DAT.TIMEZONE, lambda a, b: a.name == b.name),
(relativedelta.relativedelta(hours=+1), DAT.RELATIVEDELTA, lambda a, b: a.hours == b.hours),
({"test": "dict", "test-1": 1}, None, equals),
(["array_item", 2], None, equals),
(("tuple_item", 3), DAT.TUPLE, equals),
(set(["set_item", 3]), DAT.SET, equals),
(
k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
name="test", annotations={"test": "annotation"}, creation_timestamp=datetime.utcnow()
)
),
DAT.POD,
equals,
),
(
DAG(
"fake-dag",
schedule="*/10 * * * *",
default_args={"depends_on_past": True},
start_date=datetime.utcnow(),
catchup=False,
),
DAT.DAG,
lambda a, b: a.dag_id == b.dag_id and equal_time(a.start_date, b.start_date),
),
(Resources(cpus=0.1, ram=2048), None, None),
(EmptyOperator(task_id="test-task"), None, None),
(TaskGroup(group_id="test-group", dag=DAG(dag_id="test_dag", start_date=datetime.now())), None, None),
(
Param("test", "desc"),
DAT.PARAM,
lambda a, b: a.value == b.value and a.description == b.description,
),
(
XComArg(
operator=PythonOperator(
python_callable=int,
task_id="test_xcom_op",
do_xcom_push=True,
)
),
DAT.XCOM_REF,
None,
),
(Dataset(uri="test"), DAT.DATASET, equals),
(SimpleTaskInstance.from_ti(ti=TI), DAT.SIMPLE_TASK_INSTANCE, equals),
(
Connection(conn_id="TEST_ID", uri="mysql://"),
DAT.CONNECTION,
lambda a, b: a.get_uri() == b.get_uri(),
),
],
)
def test_serialize_deserialize(input, encoded_type, cmp_func):
from airflow.serialization.serialized_objects import BaseSerialization

serialized = BaseSerialization.serialize(input) # does not raise
json.dumps(serialized) # does not raise
if encoded_type is not None:
assert serialized["__type"] == encoded_type
assert serialized["__var"] is not None
if cmp_func is not None:
deserialized = BaseSerialization.deserialize(serialized)
assert cmp_func(input, deserialized)

# Verify recursive behavior
obj = [[input]]
serialized = BaseSerialization.serialize(obj) # does not raise
# Verify the result is JSON-serializable
json.dumps(serialized) # does not raise


@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
def test_use_pydantic_models():
"""If use_pydantic_models=True the TaskInstance object should be serialized to TaskInstancePydantic."""
@pytest.mark.parametrize(
"input, pydantic_class, encoded_type, cmp_func",
[
(
Job(state=State.RUNNING, latest_heartbeat=datetime.utcnow()),
JobPydantic,
DAT.BASE_JOB,
lambda a, b: equal_time(a.latest_heartbeat, b.latest_heartbeat),
),
(
TI_WITH_START_DAY,
TaskInstancePydantic,
DAT.TASK_INSTANCE,
lambda a, b: equal_time(a.start_date, b.start_date),
),
(
DAG_RUN,
DagRunPydantic,
DAT.DAG_RUN,
lambda a, b: equal_time(a.execution_date, b.execution_date)
and equal_time(a.start_date, b.start_date),
),
# DataSet is already serialized by non-Pydantic serialization. Is DatasetPydantic needed then?
# (
# Dataset(
# uri="foo://bar",
# extra={"foo": "bar"},
# ),
# DatasetPydantic,
# DAT.DATA_SET,
# lambda a, b: a.uri == b.uri and a.extra == b.extra,
# ),
(
DagModel(
dag_id="TEST_DAG_1",
fileloc="/tmp/dag_1.py",
schedule_interval="2 2 * * *",
is_paused=True,
),
DagModelPydantic,
DAT.DAG_MODEL,
lambda a, b: a.fileloc == b.fileloc and a.schedule_interval == b.schedule_interval,
),
],
)
def test_serialize_deserialize_pydantic(input, pydantic_class, encoded_type, cmp_func):
"""If use_pydantic_models=True the objects should be serialized to Pydantic objects."""

from airflow.serialization.serialized_objects import BaseSerialization

ti = TaskInstance(
task=EmptyOperator(task_id="task"),
run_id="run_id",
state=State.RUNNING,
)
start_date = datetime.utcnow()
ti.start_date = start_date
obj = [[ti]] # nested to verify recursive behavior

serialized = BaseSerialization.serialize(obj, use_pydantic_models=True) # does not raise
deserialized = BaseSerialization.deserialize(serialized, use_pydantic_models=True) # does not raise
assert isinstance(deserialized[0][0], TaskInstancePydantic)

serialized_json = json.dumps(serialized) # does not raise
deserialized_from_json = BaseSerialization.deserialize(
json.loads(serialized_json), use_pydantic_models=True
) # does not raise
assert isinstance(deserialized_from_json[0][0], TaskInstancePydantic)
assert deserialized_from_json[0][0].start_date == start_date
serialized = BaseSerialization.serialize(input, use_pydantic_models=True) # does not raise
# Verify the result is JSON-serializable
json.dumps(serialized) # does not raise
assert serialized["__type"] == encoded_type
assert serialized["__var"] is not None
deserialized = BaseSerialization.deserialize(serialized, use_pydantic_models=True)
assert isinstance(deserialized, pydantic_class)
assert cmp_func(input, deserialized)

# Verify recursive behavior
obj = [[input]]
BaseSerialization.serialize(obj, use_pydantic_models=True) # does not raise


def test_serialized_mapped_operator_unmap(dag_maker):
Expand Down

0 comments on commit 64a64ab

Please sign in to comment.