Skip to content

Commit

Permalink
Fix type of "moment" when running an e2e example for deferred TI (apa…
Browse files Browse the repository at this point in the history
…che#45030)

While trying to run an e2e example of a task that defers and then launches a trigger:
```
from airflow import DAG

from airflow.providers.standard.sensors.date_time import DateTimeSensorAsync
from airflow.utils import timezone
import datetime

with DAG(
    dag_id="demo_deferred",
    schedule=None,
    catchup=False,
) as dag:
    DateTimeSensorAsync(
            task_id="async",
            target_time=str(timezone.utcnow() + datetime.timedelta(seconds=3)),
            poke_interval=60,
            timeout=600,
        )
```

I realised that the "moment" inside "trigger_kwargs" is of `pendulum.DateTime` type, and since we have a "dict[str, ANY]`, defined here: https://github.com/apache/airflow/blob/main/airflow/api_fastapi/execution_api/datamodels/taskinstance.py#L82
 
on its datamodel (we cant really have a `UtcDateTime` for one specific field, like we do [here](https://github.com/apache/airflow/blob/main/airflow/api_fastapi/execution_api/datamodels/taskinstance.py#L57C15-L57C26)), it fails to match the type defined in the `Trigger` table which is datetime. 

So, I have added a "before" validator that checks for the type being string and if it is a string, translates it to a datetime object.
  • Loading branch information
amoghrajesh authored Dec 18, 2024
1 parent ef004de commit fc7d983
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
10 changes: 9 additions & 1 deletion airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from datetime import timedelta
from typing import Annotated, Any, Literal, Union

from pydantic import Discriminator, Field, Tag, WithJsonSchema
from pydantic import AwareDatetime, Discriminator, Field, Tag, TypeAdapter, WithJsonSchema, field_validator

from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.core_api.base import BaseModel
Expand All @@ -30,6 +30,8 @@
from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState
from airflow.utils.types import DagRunType

AwareDatetimeAdapter = TypeAdapter(AwareDatetime)


class TIEnterRunningPayload(BaseModel):
"""Schema for updating TaskInstance to 'RUNNING' state with minimal required fields."""
Expand Down Expand Up @@ -83,6 +85,12 @@ class TIDeferredStatePayload(BaseModel):
next_method: str
trigger_timeout: timedelta | None = None

@field_validator("trigger_kwargs")
def validate_moment(cls, v):
if "moment" in v:
v["moment"] = AwareDatetimeAdapter.validate_strings(v["moment"])
return v


class TIRescheduleStatePayload(BaseModel):
"""Schema for updating TaskInstance to a up_for_reschedule state."""
Expand Down
1 change: 1 addition & 0 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def ti_update_state(
kwargs=ti_patch_payload.trigger_kwargs,
)
session.add(trigger_row)
session.flush()

# TODO: HANDLE execution timeout later as it requires a call to the DB
# either get it from the serialised DAG or get it from the API
Expand Down
12 changes: 9 additions & 3 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def test_ti_update_state_to_deferred(self, client, session, create_task_instance

payload = {
"state": "deferred",
"trigger_kwargs": {"key": "value"},
"trigger_kwargs": {"key": "value", "moment": "2024-12-18T00:00:00Z"},
"classpath": "my-classpath",
"next_method": "execute_callback",
"trigger_timeout": "P1D", # 1 day
Expand All @@ -277,14 +277,20 @@ def test_ti_update_state_to_deferred(self, client, session, create_task_instance

assert tis[0].state == TaskInstanceState.DEFERRED
assert tis[0].next_method == "execute_callback"
assert tis[0].next_kwargs == {"key": "value"}
assert tis[0].next_kwargs == {
"key": "value",
"moment": datetime(2024, 12, 18, 00, 00, 00, tzinfo=timezone.utc),
}
assert tis[0].trigger_timeout == timezone.make_aware(datetime(2024, 11, 23), timezone=timezone.utc)

t = session.query(Trigger).all()
assert len(t) == 1
assert t[0].created_date == instant
assert t[0].classpath == "my-classpath"
assert t[0].kwargs == {"key": "value"}
assert t[0].kwargs == {
"key": "value",
"moment": datetime(2024, 12, 18, 00, 00, 00, tzinfo=timezone.utc),
}

def test_ti_update_state_to_reschedule(self, client, session, create_task_instance, time_machine):
"""
Expand Down

0 comments on commit fc7d983

Please sign in to comment.