Skip to content

Commit

Permalink
Swap Dag Parsing to use the TaskSDK machinery.
Browse files Browse the repository at this point in the history
As part of Airflow 3 DAG definition files will have to use the Task SDK for
all their classes, and anything involving running user code will need to be
de-coupled from the database in the user-code process.

This change moves all of the "serialization" change up to the
DagFileProcessorManager, using the new function introduced in #44898 and the
"subprocess" machinery introduced in #44874.

**Important Note**: this change does not remove the ability for dag processes
to access the DB for Variables etc. That will come in a future change.

Some key parts of this change:

- It builds upon the WatchedSubprocess from the TaskSDK. Right now this puts a
  nasty/unwanted depenednecy between the Dag Parsing code upon the TaskSDK.
  This will be addressed before release (we have talked about introducing a
  new "apache-airflow-base-executor" dist where this subprocess+supervisor
  could live, as the "execution_time" folder in the Task SDK is more a feature
  of the executor, not of the TaskSDK itself)
- A number of classes that we need to send between processes have been
  converted to Pydantic for ease of serialization.
- In order to not have to serialize everything in the subprocess and deserialize everything
  in the parent Manager process, we have created a `LazyDeserializedDAG` class
  that provides lazy access to much of the properties needed to create update
  the DAG related DB objects, without needing to fully deserialize the entire
  DAG structure.
- Classes switched to attrs based for less boilerplate in constructors.
- Internal timers convert to `time.monotonic` where possible, and `time.time`
  where not, we only need second diff between two points, not datetime objects
- With the earlier removal of "sync mode" for SQLite in #44839 the need for
  separate TERMIANTE and END messages over the control socket can go

Co-authored-by: Jed Cunningham <[email protected]>
Co-authored-by: Daniel Imberman <[email protected]>
  • Loading branch information
3 people committed Dec 16, 2024
1 parent 0228832 commit 32176b9
Show file tree
Hide file tree
Showing 18 changed files with 913 additions and 1,764 deletions.
102 changes: 23 additions & 79 deletions airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,49 +16,37 @@
# under the License.
from __future__ import annotations

import json
from typing import TYPE_CHECKING

from pydantic import BaseModel

from airflow.api_fastapi.execution_api.datamodels import taskinstance as ti_datamodel # noqa: TC001
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.typing_compat import Self


class CallbackRequest:
class CallbackRequest(BaseModel):
"""
Base Class with information about the callback to be executed.
:param full_filepath: File Path to use to run the callback
:param msg: Additional Message that can be used for logging
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
"""

def __init__(
self,
full_filepath: str,
processor_subdir: str | None = None,
msg: str | None = None,
):
self.full_filepath = full_filepath
self.processor_subdir = processor_subdir
self.msg = msg

def __eq__(self, other):
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return NotImplemented
full_filepath: str
"""File Path to use to run the callback"""
processor_subdir: str | None = None
"""Directory used by Dag Processor when parsed the dag"""
msg: str | None = None
"""Additional Message that can be used for logging to determine failure/zombie"""

def __repr__(self):
return str(self.__dict__)

def to_json(self) -> str:
return json.dumps(self.__dict__)
to_json = BaseModel.model_dump_json

@classmethod
def from_json(cls, json_str: str):
json_object = json.loads(json_str)
return cls(**json_object)
def from_json(cls, data: str | bytes | bytearray) -> Self:
return cls.model_validate_json(data)


class TaskCallbackRequest(CallbackRequest):
Expand All @@ -67,25 +55,12 @@ class TaskCallbackRequest(CallbackRequest):
A Class with information about the success/failure TI callback to be executed. Currently, only failure
callbacks (when tasks are externally killed) and Zombies are run via DagFileProcessorProcess.
:param full_filepath: File Path to use to run the callback
:param simple_task_instance: Simplified Task Instance representation
:param msg: Additional Message that can be used for logging to determine failure/zombie
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
:param task_callback_type: e.g. whether on success, on failure, on retry.
"""

def __init__(
self,
full_filepath: str,
simple_task_instance: SimpleTaskInstance,
processor_subdir: str | None = None,
msg: str | None = None,
task_callback_type: TaskInstanceState | None = None,
):
super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.simple_task_instance = simple_task_instance
self.task_callback_type = task_callback_type
ti: ti_datamodel.TaskInstance
"""Simplified Task Instance representation"""
task_callback_type: TaskInstanceState | None = None
"""Whether on success, on failure, on retry"""

@property
def is_failure_callback(self) -> bool:
Expand All @@ -98,42 +73,11 @@ def is_failure_callback(self) -> bool:
TaskInstanceState.UPSTREAM_FAILED,
}

def to_json(self) -> str:
from airflow.serialization.serialized_objects import BaseSerialization

val = BaseSerialization.serialize(self.__dict__, strict=True)
return json.dumps(val)

@classmethod
def from_json(cls, json_str: str):
from airflow.serialization.serialized_objects import BaseSerialization

val = json.loads(json_str)
return cls(**BaseSerialization.deserialize(val))


class DagCallbackRequest(CallbackRequest):
"""
A Class with information about the success/failure DAG callback to be executed.
:param full_filepath: File Path to use to run the callback
:param dag_id: DAG ID
:param run_id: Run ID for the DagRun
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
:param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback
:param msg: Additional Message that can be used for logging
"""
"""A Class with information about the success/failure DAG callback to be executed."""

def __init__(
self,
full_filepath: str,
dag_id: str,
run_id: str,
processor_subdir: str | None,
is_failure_callback: bool | None = True,
msg: str | None = None,
):
super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.dag_id = dag_id
self.run_id = run_id
self.is_failure_callback = is_failure_callback
dag_id: str
run_id: str
is_failure_callback: bool | None = True
"""Flag to determine whether it is a Failure Callback or Success Callback"""
8 changes: 1 addition & 7 deletions airflow/cli/commands/local_commands/dag_processor_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from __future__ import annotations

import logging
from datetime import timedelta
from typing import Any

from airflow.cli.commands.local_commands.daemon_utils import run_command_with_daemon_option
Expand All @@ -36,11 +35,10 @@
def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner:
"""Create DagFileProcessorProcess instance."""
processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout")
processor_timeout = timedelta(seconds=processor_timeout_seconds)
return DagProcessorJobRunner(
job=Job(),
processor=DagFileProcessorManager(
processor_timeout=processor_timeout,
processor_timeout=processor_timeout_seconds,
dag_directory=args.subdir,
max_runs=args.num_runs,
),
Expand All @@ -54,10 +52,6 @@ def dag_processor(args):
if not conf.getboolean("scheduler", "standalone_dag_processor"):
raise SystemExit("The option [scheduler/standalone_dag_processor] must be True.")

sql_conn: str = conf.get("database", "sql_alchemy_conn").lower()
if sql_conn.startswith("sqlite"):
raise SystemExit("Standalone DagProcessor is not supported when using sqlite.")

job_runner = _create_dag_processor_job_runner(args)

reload_configuration_for_dag_processing()
Expand Down
32 changes: 20 additions & 12 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from sqlalchemy.sql import Select

from airflow.models.dagwarning import DagWarning
from airflow.serialization.serialized_objects import LazyDeserializedDAG
from airflow.typing_compat import Self

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -169,7 +170,7 @@ def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, se
)


def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir: str | None):
def _serialize_dag_capturing_errors(dag: LazyDeserializedDAG, session: Session, processor_subdir: str | None):
"""
Try to serialize the dag to the DB, but make a note of any errors.
Expand All @@ -192,7 +193,7 @@ def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir
_sync_dag_perms(dag, session=session)
else:
# Check and update DagCode
DagCode.update_source_code(dag)
DagCode.update_source_code(dag.dag_id, dag.fileloc)
return []
except OperationalError:
raise
Expand All @@ -202,7 +203,7 @@ def _serialize_dag_capturing_errors(dag: DAG, session: Session, processor_subdir
return [(dag.fileloc, traceback.format_exc(limit=-dagbag_import_error_traceback_depth))]


def _sync_dag_perms(dag: DAG, session: Session):
def _sync_dag_perms(dag: LazyDeserializedDAG, session: Session):
"""Sync DAG specific permissions."""
dag_id = dag.dag_id

Expand Down Expand Up @@ -270,7 +271,7 @@ def _update_import_errors(


def update_dag_parsing_results_in_db(
dags: Collection[DAG],
dags: Collection[LazyDeserializedDAG],
import_errors: dict[str, str],
processor_subdir: str | None,
warnings: set[DagWarning],
Expand Down Expand Up @@ -393,19 +394,26 @@ def update_dags(
dm.is_active = True
dm.has_import_errors = False
dm.last_parsed_time = utcnow()
dm.default_view = dag.default_view
if hasattr(dag, "_dag_display_property_value"):
dm._dag_display_property_value = dag._dag_display_property_value
elif dag.dag_display_name != dag.dag_id:
dm._dag_display_property_value = dag.dag_display_name
dm.description = dag.description
dm.max_active_tasks = dag.max_active_tasks
dm.max_active_runs = dag.max_active_runs
dm.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs
dm.has_task_concurrency_limits = any(
t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None
for t in dag.tasks
)
# TODO: this `if is not None` is maybe not the best. It's convient though
if dag.max_active_tasks is not None:
dm.max_active_tasks = dag.max_active_tasks
if dag.max_active_runs is not None:
dm.max_active_runs = dag.max_active_runs
if dag.max_consecutive_failed_dag_runs is not None:
dm.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs

if hasattr(dag, "has_task_concurrency_limits"):
dm.has_task_concurrency_limits = dag.has_task_concurrency_limits
else:
dm.has_task_concurrency_limits = any(
t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None
for t in dag.tasks
)
dm.timetable_summary = dag.timetable.summary
dm.timetable_description = dag.timetable.description
dm.asset_expression = dag.timetable.asset_condition.as_expression()
Expand Down
Loading

0 comments on commit 32176b9

Please sign in to comment.