Skip to content

Commit

Permalink
Swap Dag Parsing to use the TaskSDK machinery.
Browse files Browse the repository at this point in the history
As part of Airflow 3 DAG definition files will have to use the Task SDK for
all their classes, and anything involving running user code will need to be
de-coupled from the database in the user-code process.

This change moves all of the "serialization" change up to the
DagFileProcessorManager, using the new function introduced in #44898 and the
"subprocess" machinery introduced in #44874.

**Important Note**: this change does not remove the ability for dag processes
to access the DB for Variables etc. That will come in a future change.

Some key parts of this change:

- It builds upon the WatchedSubprocess from the TaskSDK. Right now this puts a
  nasty/unwanted depenednecy between the Dag Parsing code upon the TaskSDK.
  This will be addressed before release (we have talked about introducing a
  new "apache-airflow-base-executor" dist where this subprocess+supervisor
  could live, as the "execution_time" folder in the Task SDK is more a feature
  of the executor, not of the TaskSDK itself.)
- A number of classes that we need to send between processes have been
  converted to Pydantic for ease of serialization.
- In order to not have to serialize everything in the subprocess and deserialize everything
  in the parent Manager process, we have created a `LazyDeserializedDAG` class
  that provides lazy access to much of the properties needed to create update
  the DAG related DB objects, without needing to fully deserialize the entire
  DAG structure.
- Classes switched to attrs based for less boilerplate in constructors.
- Internal timers convert to `time.monotonic` where possible, and `time.time`
  where not, we only need second diff between two points, not datetime
  objects.
- With the earlier removal of "sync mode" for SQLite in #44839 the need for
  separate TERMIANTE and END messages over the control socket can go.

Co-authored-by: Jed Cunningham <[email protected]>
Co-authored-by: Daniel Imberman <[email protected]>
  • Loading branch information
3 people committed Dec 19, 2024
1 parent 3c11168 commit 1e703c2
Show file tree
Hide file tree
Showing 25 changed files with 1,136 additions and 2,054 deletions.
105 changes: 25 additions & 80 deletions airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,49 +16,38 @@
# under the License.
from __future__ import annotations

import json
from typing import TYPE_CHECKING

from pydantic import BaseModel

from airflow.api_fastapi.execution_api.datamodels import taskinstance as ti_datamodel # noqa: TC001
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.typing_compat import Self


class CallbackRequest:
class CallbackRequest(BaseModel):
"""
Base Class with information about the callback to be executed.
:param full_filepath: File Path to use to run the callback
:param msg: Additional Message that can be used for logging
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
"""

def __init__(
self,
full_filepath: str,
processor_subdir: str | None = None,
msg: str | None = None,
):
self.full_filepath = full_filepath
self.processor_subdir = processor_subdir
self.msg = msg

def __eq__(self, other):
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return NotImplemented

def __repr__(self):
return str(self.__dict__)

def to_json(self) -> str:
return json.dumps(self.__dict__)
full_filepath: str
"""File Path to use to run the callback"""
processor_subdir: str | None = None
"""Directory used by Dag Processor when parsed the dag"""
msg: str | None = None
"""Additional Message that can be used for logging to determine failure/zombie"""

@classmethod
def from_json(cls, json_str: str):
json_object = json.loads(json_str)
return cls(**json_object)
def from_json(cls, data: str | bytes | bytearray) -> Self:
return cls.model_validate_json(data)

def to_json(self, **kwargs) -> str:
return self.model_dump_json(**kwargs)


class TaskCallbackRequest(CallbackRequest):
Expand All @@ -67,25 +56,12 @@ class TaskCallbackRequest(CallbackRequest):
A Class with information about the success/failure TI callback to be executed. Currently, only failure
callbacks (when tasks are externally killed) and Zombies are run via DagFileProcessorProcess.
:param full_filepath: File Path to use to run the callback
:param simple_task_instance: Simplified Task Instance representation
:param msg: Additional Message that can be used for logging to determine failure/zombie
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
:param task_callback_type: e.g. whether on success, on failure, on retry.
"""

def __init__(
self,
full_filepath: str,
simple_task_instance: SimpleTaskInstance,
processor_subdir: str | None = None,
msg: str | None = None,
task_callback_type: TaskInstanceState | None = None,
):
super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.simple_task_instance = simple_task_instance
self.task_callback_type = task_callback_type
ti: ti_datamodel.TaskInstance
"""Simplified Task Instance representation"""
task_callback_type: TaskInstanceState | None = None
"""Whether on success, on failure, on retry"""

@property
def is_failure_callback(self) -> bool:
Expand All @@ -98,42 +74,11 @@ def is_failure_callback(self) -> bool:
TaskInstanceState.UPSTREAM_FAILED,
}

def to_json(self) -> str:
from airflow.serialization.serialized_objects import BaseSerialization

val = BaseSerialization.serialize(self.__dict__, strict=True)
return json.dumps(val)

@classmethod
def from_json(cls, json_str: str):
from airflow.serialization.serialized_objects import BaseSerialization

val = json.loads(json_str)
return cls(**BaseSerialization.deserialize(val))


class DagCallbackRequest(CallbackRequest):
"""
A Class with information about the success/failure DAG callback to be executed.
:param full_filepath: File Path to use to run the callback
:param dag_id: DAG ID
:param run_id: Run ID for the DagRun
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
:param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback
:param msg: Additional Message that can be used for logging
"""
"""A Class with information about the success/failure DAG callback to be executed."""

def __init__(
self,
full_filepath: str,
dag_id: str,
run_id: str,
processor_subdir: str | None,
is_failure_callback: bool | None = True,
msg: str | None = None,
):
super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.dag_id = dag_id
self.run_id = run_id
self.is_failure_callback = is_failure_callback
dag_id: str
run_id: str
is_failure_callback: bool | None = True
"""Flag to determine whether it is a Failure Callback or Success Callback"""
8 changes: 1 addition & 7 deletions airflow/cli/commands/local_commands/dag_processor_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from __future__ import annotations

import logging
from datetime import timedelta
from typing import Any

from airflow.cli.commands.local_commands.daemon_utils import run_command_with_daemon_option
Expand All @@ -36,11 +35,10 @@
def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner:
"""Create DagFileProcessorProcess instance."""
processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout")
processor_timeout = timedelta(seconds=processor_timeout_seconds)
return DagProcessorJobRunner(
job=Job(),
processor=DagFileProcessorManager(
processor_timeout=processor_timeout,
processor_timeout=processor_timeout_seconds,
dag_directory=args.subdir,
max_runs=args.num_runs,
),
Expand All @@ -54,10 +52,6 @@ def dag_processor(args):
if not conf.getboolean("scheduler", "standalone_dag_processor"):
raise SystemExit("The option [scheduler/standalone_dag_processor] must be True.")

sql_conn: str = conf.get("database", "sql_alchemy_conn").lower()
if sql_conn.startswith("sqlite"):
raise SystemExit("Standalone DagProcessor is not supported when using sqlite.")

job_runner = _create_dag_processor_job_runner(args)

reload_configuration_for_dag_processing()
Expand Down
84 changes: 49 additions & 35 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

from __future__ import annotations

import itertools
import logging
import traceback
from typing import TYPE_CHECKING, NamedTuple
Expand Down Expand Up @@ -64,12 +63,13 @@
from sqlalchemy.sql import Select

from airflow.models.dagwarning import DagWarning
from airflow.serialization.serialized_objects import MaybeSerializedDAG
from airflow.typing_compat import Self

log = logging.getLogger(__name__)


def _create_orm_dags(dags: Iterable[DAG], *, session: Session) -> Iterator[DagModel]:
def _create_orm_dags(dags: Iterable[MaybeSerializedDAG], *, session: Session) -> Iterator[DagModel]:
for dag in dags:
orm_dag = DagModel(dag_id=dag.dag_id)
if dag.is_paused_upon_creation is not None:
Expand Down Expand Up @@ -124,7 +124,7 @@ class _RunInfo(NamedTuple):
num_active_runs: dict[str, int]

@classmethod
def calculate(cls, dags: dict[str, DAG], *, session: Session) -> Self:
def calculate(cls, dags: dict[str, MaybeSerializedDAG], *, session: Session) -> Self:
"""
Query the the run counts from the db.
Expand Down Expand Up @@ -169,7 +169,7 @@ def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, se
)


def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir: str | None):
def _serialize_dag_capturing_errors(dag: MaybeSerializedDAG, session: Session, processor_subdir: str | None):
"""
Try to serialize the dag to the DB, but make a note of any errors.
Expand All @@ -192,7 +192,7 @@ def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir
_sync_dag_perms(dag, session=session)
else:
# Check and update DagCode
DagCode.update_source_code(dag)
DagCode.update_source_code(dag.dag_id, dag.fileloc)
return []
except OperationalError:
raise
Expand All @@ -202,7 +202,7 @@ def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir
return [(dag.fileloc, traceback.format_exc(limit=-dagbag_import_error_traceback_depth))]


def _sync_dag_perms(dag: DAG, session: Session):
def _sync_dag_perms(dag: MaybeSerializedDAG, session: Session):
"""Sync DAG specific permissions."""
dag_id = dag.dag_id

Expand Down Expand Up @@ -270,7 +270,7 @@ def _update_import_errors(


def update_dag_parsing_results_in_db(
dags: Collection[DAG],
dags: Collection[MaybeSerializedDAG],
import_errors: dict[str, str],
processor_subdir: str | None,
warnings: set[DagWarning],
Expand Down Expand Up @@ -347,7 +347,7 @@ def update_dag_parsing_results_in_db(
class DagModelOperation(NamedTuple):
"""Collect DAG objects and perform database operations for them."""

dags: dict[str, DAG]
dags: dict[str, MaybeSerializedDAG]

def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]:
"""Find existing DagModel objects from DAG objects."""
Expand Down Expand Up @@ -380,6 +380,8 @@ def update_dags(
processor_subdir: str | None = None,
session: Session,
) -> None:
from airflow.configuration import conf

# we exclude backfill from active run counts since their concurrency is separate
run_info = _RunInfo.calculate(
dags=self.dags,
Expand All @@ -393,19 +395,41 @@ def update_dags(
dm.is_active = True
dm.has_import_errors = False
dm.last_parsed_time = utcnow()
dm.default_view = dag.default_view
dm.default_view = dag.default_view or conf.get("webserver", "dag_default_view").lower()
if hasattr(dag, "_dag_display_property_value"):
dm._dag_display_property_value = dag._dag_display_property_value
elif dag.dag_display_name != dag.dag_id:
dm._dag_display_property_value = dag.dag_display_name
dm.description = dag.description
dm.max_active_tasks = dag.max_active_tasks
dm.max_active_runs = dag.max_active_runs
dm.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs
dm.has_task_concurrency_limits = any(
t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None
for t in dag.tasks
)

# These "is not None" checks are because with a LazySerializedDag object where the user hasn't
# specified an explicit value, we don't get the default values from the config in the lazy
# serialized ver
# we just
if dag.max_active_tasks is not None:
dm.max_active_tasks = dag.max_active_tasks
elif dag.max_active_tasks is None and dm.max_active_tasks is None:
dm.max_active_tasks = conf.getint("core", "max_active_tasks_per_dag")

if dag.max_active_runs is not None:
dm.max_active_runs = dag.max_active_runs
elif dag.max_active_runs is None and dm.max_active_runs is None:
dm.max_active_runs = conf.getint("core", "max_active_runs_per_dag")

if dag.max_consecutive_failed_dag_runs is not None:
dm.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs
elif dag.max_consecutive_failed_dag_runs is None and dm.max_consecutive_failed_dag_runs is None:
dm.max_consecutive_failed_dag_runs = conf.getint(
"core", "max_consecutive_failed_dag_runs_per_dag"
)

if hasattr(dag, "has_task_concurrency_limits"):
dm.has_task_concurrency_limits = dag.has_task_concurrency_limits
else:
dm.has_task_concurrency_limits = any(
t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None
for t in dag.tasks
)
dm.timetable_summary = dag.timetable.summary
dm.timetable_description = dag.timetable.description
dm.asset_expression = dag.timetable.asset_condition.as_expression()
Expand All @@ -419,7 +443,7 @@ def update_dags(
if run_info.num_active_runs.get(dag.dag_id, 0) >= dm.max_active_runs:
dm.next_dagrun_create_after = None
else:
dm.calculate_dagrun_date_fields(dag, last_automated_data_interval)
dm.calculate_dagrun_date_fields(dag, last_automated_data_interval) # type: ignore[arg-type]

if not dag.timetable.asset_condition:
dm.schedule_asset_references = []
Expand All @@ -436,24 +460,20 @@ def update_dags(
dm.dag_owner_links = []


def _find_all_assets(dags: Iterable[DAG]) -> Iterator[Asset]:
def _find_all_assets(dags: Iterable[MaybeSerializedDAG]) -> Iterator[Asset]:
for dag in dags:
for _, asset in dag.timetable.asset_condition.iter_assets():
yield asset
for task in dag.task_dict.values():
for obj in itertools.chain(task.inlets, task.outlets):
if isinstance(obj, Asset):
yield obj
for _, alias in dag.get_task_assets(of_type=Asset):
yield alias


def _find_all_asset_aliases(dags: Iterable[DAG]) -> Iterator[AssetAlias]:
def _find_all_asset_aliases(dags: Iterable[MaybeSerializedDAG]) -> Iterator[AssetAlias]:
for dag in dags:
for _, alias in dag.timetable.asset_condition.iter_asset_aliases():
yield alias
for task in dag.task_dict.values():
for obj in itertools.chain(task.inlets, task.outlets):
if isinstance(obj, AssetAlias):
yield obj
for _, alias in dag.get_task_assets(of_type=AssetAlias):
yield alias


def _find_active_assets(name_uri_assets, session: Session):
Expand Down Expand Up @@ -500,7 +520,7 @@ class AssetModelOperation(NamedTuple):
asset_aliases: dict[str, AssetAlias]

@classmethod
def collect(cls, dags: dict[str, DAG]) -> Self:
def collect(cls, dags: dict[str, MaybeSerializedDAG]) -> Self:
coll = cls(
schedule_asset_references={
dag_id: [asset for _, asset in dag.timetable.asset_condition.iter_assets()]
Expand All @@ -511,13 +531,7 @@ def collect(cls, dags: dict[str, DAG]) -> Self:
for dag_id, dag in dags.items()
},
outlet_references={
dag_id: [
(task_id, outlet)
for task_id, task in dag.task_dict.items()
for outlet in task.outlets
if isinstance(outlet, Asset)
]
for dag_id, dag in dags.items()
dag_id: list(dag.get_task_assets(inlets=False, outlets=True)) for dag_id, dag in dags.items()
},
assets={(asset.name, asset.uri): asset for asset in _find_all_assets(dags.values())},
asset_aliases={alias.name: alias for alias in _find_all_asset_aliases(dags.values())},
Expand Down
Loading

0 comments on commit 1e703c2

Please sign in to comment.