Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forbid extra fields on execution api #44986

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pydantic import AwareDatetime, Discriminator, Field, Tag, TypeAdapter, WithJsonSchema, field_validator

from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, ConfigDict
from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse
from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse
from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState
Expand All @@ -36,6 +36,8 @@
class TIEnterRunningPayload(BaseModel):
"""Schema for updating TaskInstance to 'RUNNING' state with minimal required fields."""

model_config = ConfigDict(extra="forbid")

state: Annotated[
Literal[TIState.RUNNING],
# Specify a default in the schema, but not in code.
Expand All @@ -54,6 +56,8 @@ class TIEnterRunningPayload(BaseModel):
class TITerminalStatePayload(BaseModel):
"""Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED)."""

model_config = ConfigDict(extra="forbid")

state: TerminalTIState

end_date: UtcDateTime
Expand All @@ -63,12 +67,16 @@ class TITerminalStatePayload(BaseModel):
class TITargetStatePayload(BaseModel):
"""Schema for updating TaskInstance to a target state, excluding terminal and running states."""

model_config = ConfigDict(extra="forbid")

state: IntermediateTIState


class TIDeferredStatePayload(BaseModel):
"""Schema for updating TaskInstance to a deferred state."""

model_config = ConfigDict(extra="forbid")

state: Annotated[
Literal[IntermediateTIState.DEFERRED],
# Specify a default in the schema, but not in code, so Pydantic marks it as required.
Expand Down Expand Up @@ -148,6 +156,8 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
class TIHeartbeatInfo(BaseModel):
"""Schema for TaskInstance heartbeat endpoint."""

model_config = ConfigDict(extra="forbid")

hostname: str
pid: int

Expand Down
2 changes: 2 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
class VariableResponse(BaseModel):
"""Variable schema for responses with fields that are needed for Runtime."""

model_config = ConfigDict(extra="forbid")

key: str
val: str | None = Field(alias="value")

Expand Down
4 changes: 3 additions & 1 deletion airflow/api_fastapi/execution_api/datamodels/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

from typing import Any

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, ConfigDict


class XComResponse(BaseModel):
"""XCom schema for responses with fields that are needed for Runtime."""

model_config = ConfigDict(extra="forbid")

key: str
value: Any
"""The returned XCom value in a JSON-compatible format."""
21 changes: 21 additions & 0 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class TIDeferredStatePayload(BaseModel):
Schema for updating TaskInstance to a deferred state.
"""

model_config = ConfigDict(
extra="forbid",
)
state: Annotated[Literal["deferred"] | None, Field(title="State")] = "deferred"
classpath: Annotated[str, Field(title="Classpath")]
trigger_kwargs: Annotated[dict[str, Any] | None, Field(title="Trigger Kwargs")] = None
Expand All @@ -86,6 +89,9 @@ class TIEnterRunningPayload(BaseModel):
Schema for updating TaskInstance to 'RUNNING' state with minimal required fields.
"""

model_config = ConfigDict(
extra="forbid",
)
state: Annotated[Literal["running"] | None, Field(title="State")] = "running"
hostname: Annotated[str, Field(title="Hostname")]
unixname: Annotated[str, Field(title="Unixname")]
Expand All @@ -98,6 +104,9 @@ class TIHeartbeatInfo(BaseModel):
Schema for TaskInstance heartbeat endpoint.
"""

model_config = ConfigDict(
extra="forbid",
)
hostname: Annotated[str, Field(title="Hostname")]
pid: Annotated[int, Field(title="Pid")]

Expand All @@ -117,6 +126,9 @@ class TITargetStatePayload(BaseModel):
Schema for updating TaskInstance to a target state, excluding terminal and running states.
"""

model_config = ConfigDict(
extra="forbid",
)
state: IntermediateTIState


Expand Down Expand Up @@ -154,6 +166,9 @@ class VariableResponse(BaseModel):
Variable schema for responses with fields that are needed for Runtime.
"""

model_config = ConfigDict(
extra="forbid",
)
key: Annotated[str, Field(title="Key")]
value: Annotated[str | None, Field(title="Value")] = None

Expand All @@ -163,6 +178,9 @@ class XComResponse(BaseModel):
XCom schema for responses with fields that are needed for Runtime.
"""

model_config = ConfigDict(
extra="forbid",
)
key: Annotated[str, Field(title="Key")]
value: Annotated[Any, Field(title="Value")]

Expand Down Expand Up @@ -215,5 +233,8 @@ class TITerminalStatePayload(BaseModel):
Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED).
"""

model_config = ConfigDict(
extra="forbid",
)
state: TerminalTIState
end_date: Annotated[datetime, Field(title="End Date")]
83 changes: 83 additions & 0 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,39 @@ def test_ti_run_state_conflict_if_not_queued(

assert session.scalar(select(TaskInstance.state).where(TaskInstance.id == ti.id)) == initial_ti_state

def test_ti_run_failed_with_extra(self, client, session, create_task_instance, time_machine):
"""
Test that a 422 error is returned when extra fields are included in the payload.
"""
instant_str = "2024-12-19T00:00:00Z"
instant = timezone.parse(instant_str)
time_machine.move_to(instant, tick=False)

ti = create_task_instance(
task_id="test_ti_run_failed_with_extra",
state=State.QUEUED,
session=session,
start_date=instant,
)

session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
"unixname": "random-unixname",
"pid": 100,
"start_date": instant_str,
"foo": "bar",
},
)

assert response.status_code == 422
assert response.json()["detail"][0]["type"] == "extra_forbidden"
assert response.json()["detail"][0]["msg"] == "Extra inputs are not permitted"


class TestTIUpdateState:
def setup_method(self):
Expand Down Expand Up @@ -340,6 +373,27 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan
assert trs[0].map_index == -1
assert trs[0].duration == 129600

def test_ti_update_state_failed_with_extra(self, client, session, create_task_instance, time_machine):
"""
Test that a 422 error is returned when extra fields are included in the payload.
"""
ti = create_task_instance(
task_id="test_ti_update_state_failed_with_extra",
state=State.RUNNING,
session=session,
start_date=DEFAULT_START_DATE,
)

session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state", json={"state": "scheduled", "foo": "bar"}
)

assert response.status_code == 422
assert response.json()["detail"][0]["type"] == "extra_forbidden"
assert response.json()["detail"][0]["msg"] == "Extra inputs are not permitted"


class TestTIHealthEndpoint:
def setup_method(self):
Expand Down Expand Up @@ -536,6 +590,35 @@ def test_ti_update_state_to_failed_table_check(self, client, session, create_tas
assert ti.next_kwargs is None
assert ti.duration == 3600.00

def test_ti_heartbeat_with_extra(
self,
client,
session,
create_task_instance,
time_machine,
):
"""
Test that a 422 error is returned when extra fields are included in the payload.
"""
ti = create_task_instance(
task_id="test_ti_heartbeat_when_task_not_running",
state=State.RUNNING,
hostname="random-hostname",
pid=1547,
session=session,
)
session.commit()
task_instance_id = ti.id

response = client.put(
f"/execution/task-instances/{task_instance_id}/heartbeat",
json={"hostname": "random-hostname", "pid": 1547, "foo": "bar"},
)

assert response.status_code == 422
assert response.json()["detail"][0]["type"] == "extra_forbidden"
assert response.json()["detail"][0]["msg"] == "Extra inputs are not permitted"


class TestTIPutRTIF:
def setup_method(self):
Expand Down