diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index c1bf588c2bbd4..1228e11cd510c 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -24,7 +24,7 @@ from pydantic import AwareDatetime, Discriminator, Field, Tag, TypeAdapter, WithJsonSchema, field_validator from airflow.api_fastapi.common.types import UtcDateTime -from airflow.api_fastapi.core_api.base import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel, ConfigDict from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState @@ -36,6 +36,8 @@ class TIEnterRunningPayload(BaseModel): """Schema for updating TaskInstance to 'RUNNING' state with minimal required fields.""" + model_config = ConfigDict(extra="forbid") + state: Annotated[ Literal[TIState.RUNNING], # Specify a default in the schema, but not in code. @@ -54,6 +56,8 @@ class TIEnterRunningPayload(BaseModel): class TITerminalStatePayload(BaseModel): """Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED).""" + model_config = ConfigDict(extra="forbid") + state: TerminalTIState end_date: UtcDateTime @@ -63,12 +67,16 @@ class TITerminalStatePayload(BaseModel): class TITargetStatePayload(BaseModel): """Schema for updating TaskInstance to a target state, excluding terminal and running states.""" + model_config = ConfigDict(extra="forbid") + state: IntermediateTIState class TIDeferredStatePayload(BaseModel): """Schema for updating TaskInstance to a deferred state.""" + model_config = ConfigDict(extra="forbid") + state: Annotated[ Literal[IntermediateTIState.DEFERRED], # Specify a default in the schema, but not in code, so Pydantic marks it as required. @@ -148,6 +156,8 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: class TIHeartbeatInfo(BaseModel): """Schema for TaskInstance heartbeat endpoint.""" + model_config = ConfigDict(extra="forbid") + hostname: str pid: int diff --git a/airflow/api_fastapi/execution_api/datamodels/variable.py b/airflow/api_fastapi/execution_api/datamodels/variable.py index 6c597524763aa..546a06f09b83e 100644 --- a/airflow/api_fastapi/execution_api/datamodels/variable.py +++ b/airflow/api_fastapi/execution_api/datamodels/variable.py @@ -25,6 +25,8 @@ class VariableResponse(BaseModel): """Variable schema for responses with fields that are needed for Runtime.""" + model_config = ConfigDict(extra="forbid") + key: str val: str | None = Field(alias="value") diff --git a/airflow/api_fastapi/execution_api/datamodels/xcom.py b/airflow/api_fastapi/execution_api/datamodels/xcom.py index 1f913f9ac380e..6f897aa6966b9 100644 --- a/airflow/api_fastapi/execution_api/datamodels/xcom.py +++ b/airflow/api_fastapi/execution_api/datamodels/xcom.py @@ -19,12 +19,14 @@ from typing import Any -from airflow.api_fastapi.core_api.base import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel, ConfigDict class XComResponse(BaseModel): """XCom schema for responses with fields that are needed for Runtime.""" + model_config = ConfigDict(extra="forbid") + key: str value: Any """The returned XCom value in a JSON-compatible format.""" diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 00187364c8669..a7a89ee7f4663 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -74,6 +74,9 @@ class TIDeferredStatePayload(BaseModel): Schema for updating TaskInstance to a deferred state. """ + model_config = ConfigDict( + extra="forbid", + ) state: Annotated[Literal["deferred"] | None, Field(title="State")] = "deferred" classpath: Annotated[str, Field(title="Classpath")] trigger_kwargs: Annotated[dict[str, Any] | None, Field(title="Trigger Kwargs")] = None @@ -86,6 +89,9 @@ class TIEnterRunningPayload(BaseModel): Schema for updating TaskInstance to 'RUNNING' state with minimal required fields. """ + model_config = ConfigDict( + extra="forbid", + ) state: Annotated[Literal["running"] | None, Field(title="State")] = "running" hostname: Annotated[str, Field(title="Hostname")] unixname: Annotated[str, Field(title="Unixname")] @@ -98,6 +104,9 @@ class TIHeartbeatInfo(BaseModel): Schema for TaskInstance heartbeat endpoint. """ + model_config = ConfigDict( + extra="forbid", + ) hostname: Annotated[str, Field(title="Hostname")] pid: Annotated[int, Field(title="Pid")] @@ -117,6 +126,9 @@ class TITargetStatePayload(BaseModel): Schema for updating TaskInstance to a target state, excluding terminal and running states. """ + model_config = ConfigDict( + extra="forbid", + ) state: IntermediateTIState @@ -154,6 +166,9 @@ class VariableResponse(BaseModel): Variable schema for responses with fields that are needed for Runtime. """ + model_config = ConfigDict( + extra="forbid", + ) key: Annotated[str, Field(title="Key")] value: Annotated[str | None, Field(title="Value")] = None @@ -163,6 +178,9 @@ class XComResponse(BaseModel): XCom schema for responses with fields that are needed for Runtime. """ + model_config = ConfigDict( + extra="forbid", + ) key: Annotated[str, Field(title="Key")] value: Annotated[Any, Field(title="Value")] @@ -215,5 +233,8 @@ class TITerminalStatePayload(BaseModel): Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED). """ + model_config = ConfigDict( + extra="forbid", + ) state: TerminalTIState end_date: Annotated[datetime, Field(title="End Date")] diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index 4ed5f8f1598f3..cd7a49971cdb6 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -136,6 +136,39 @@ def test_ti_run_state_conflict_if_not_queued( assert session.scalar(select(TaskInstance.state).where(TaskInstance.id == ti.id)) == initial_ti_state + def test_ti_run_failed_with_extra(self, client, session, create_task_instance, time_machine): + """ + Test that a 422 error is returned when extra fields are included in the payload. + """ + instant_str = "2024-12-19T00:00:00Z" + instant = timezone.parse(instant_str) + time_machine.move_to(instant, tick=False) + + ti = create_task_instance( + task_id="test_ti_run_failed_with_extra", + state=State.QUEUED, + session=session, + start_date=instant, + ) + + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": instant_str, + "foo": "bar", + }, + ) + + assert response.status_code == 422 + assert response.json()["detail"][0]["type"] == "extra_forbidden" + assert response.json()["detail"][0]["msg"] == "Extra inputs are not permitted" + class TestTIUpdateState: def setup_method(self): @@ -340,6 +373,27 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan assert trs[0].map_index == -1 assert trs[0].duration == 129600 + def test_ti_update_state_failed_with_extra(self, client, session, create_task_instance, time_machine): + """ + Test that a 422 error is returned when extra fields are included in the payload. + """ + ti = create_task_instance( + task_id="test_ti_update_state_failed_with_extra", + state=State.RUNNING, + session=session, + start_date=DEFAULT_START_DATE, + ) + + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/state", json={"state": "scheduled", "foo": "bar"} + ) + + assert response.status_code == 422 + assert response.json()["detail"][0]["type"] == "extra_forbidden" + assert response.json()["detail"][0]["msg"] == "Extra inputs are not permitted" + class TestTIHealthEndpoint: def setup_method(self): @@ -536,6 +590,35 @@ def test_ti_update_state_to_failed_table_check(self, client, session, create_tas assert ti.next_kwargs is None assert ti.duration == 3600.00 + def test_ti_heartbeat_with_extra( + self, + client, + session, + create_task_instance, + time_machine, + ): + """ + Test that a 422 error is returned when extra fields are included in the payload. + """ + ti = create_task_instance( + task_id="test_ti_heartbeat_when_task_not_running", + state=State.RUNNING, + hostname="random-hostname", + pid=1547, + session=session, + ) + session.commit() + task_instance_id = ti.id + + response = client.put( + f"/execution/task-instances/{task_instance_id}/heartbeat", + json={"hostname": "random-hostname", "pid": 1547, "foo": "bar"}, + ) + + assert response.status_code == 422 + assert response.json()["detail"][0]["type"] == "extra_forbidden" + assert response.json()["detail"][0]["msg"] == "Extra inputs are not permitted" + class TestTIPutRTIF: def setup_method(self):