From 823a6f0f9de3cca58b8b48cfc95a1bededd69e24 Mon Sep 17 00:00:00 2001 From: jx2lee Date: Wed, 18 Dec 2024 22:59:57 +0900 Subject: [PATCH] add mising --- .../execution_api/datamodels/taskinstance.py | 12 ++++++++++-- .../airflow/sdk/api/datamodels/_generated.py | 18 +++++++++++++++--- .../execution_api/routes/test_variables.py | 2 +- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 498d94fcc9de4..1228e11cd510c 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -36,6 +36,8 @@ class TIEnterRunningPayload(BaseModel): """Schema for updating TaskInstance to 'RUNNING' state with minimal required fields.""" + model_config = ConfigDict(extra="forbid") + state: Annotated[ Literal[TIState.RUNNING], # Specify a default in the schema, but not in code. @@ -54,6 +56,8 @@ class TIEnterRunningPayload(BaseModel): class TITerminalStatePayload(BaseModel): """Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED).""" + model_config = ConfigDict(extra="forbid") + state: TerminalTIState end_date: UtcDateTime @@ -63,12 +67,16 @@ class TITerminalStatePayload(BaseModel): class TITargetStatePayload(BaseModel): """Schema for updating TaskInstance to a target state, excluding terminal and running states.""" + model_config = ConfigDict(extra="forbid") + state: IntermediateTIState class TIDeferredStatePayload(BaseModel): """Schema for updating TaskInstance to a deferred state.""" + model_config = ConfigDict(extra="forbid") + state: Annotated[ Literal[IntermediateTIState.DEFERRED], # Specify a default in the schema, but not in code, so Pydantic marks it as required. @@ -148,6 +156,8 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: class TIHeartbeatInfo(BaseModel): """Schema for TaskInstance heartbeat endpoint.""" + model_config = ConfigDict(extra="forbid") + hostname: str pid: int @@ -187,8 +197,6 @@ class DagRun(BaseModel): class TIRunContext(BaseModel): """Response schema for TaskInstance run context.""" - model_config = ConfigDict(extra="forbid") - dag_run: DagRun """DAG run information for the task instance.""" diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 3ee086096c2e6..a7a89ee7f4663 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -74,6 +74,9 @@ class TIDeferredStatePayload(BaseModel): Schema for updating TaskInstance to a deferred state. """ + model_config = ConfigDict( + extra="forbid", + ) state: Annotated[Literal["deferred"] | None, Field(title="State")] = "deferred" classpath: Annotated[str, Field(title="Classpath")] trigger_kwargs: Annotated[dict[str, Any] | None, Field(title="Trigger Kwargs")] = None @@ -86,6 +89,9 @@ class TIEnterRunningPayload(BaseModel): Schema for updating TaskInstance to 'RUNNING' state with minimal required fields. """ + model_config = ConfigDict( + extra="forbid", + ) state: Annotated[Literal["running"] | None, Field(title="State")] = "running" hostname: Annotated[str, Field(title="Hostname")] unixname: Annotated[str, Field(title="Unixname")] @@ -98,6 +104,9 @@ class TIHeartbeatInfo(BaseModel): Schema for TaskInstance heartbeat endpoint. """ + model_config = ConfigDict( + extra="forbid", + ) hostname: Annotated[str, Field(title="Hostname")] pid: Annotated[int, Field(title="Pid")] @@ -117,6 +126,9 @@ class TITargetStatePayload(BaseModel): Schema for updating TaskInstance to a target state, excluding terminal and running states. """ + model_config = ConfigDict( + extra="forbid", + ) state: IntermediateTIState @@ -211,9 +223,6 @@ class TIRunContext(BaseModel): Response schema for TaskInstance run context. """ - model_config = ConfigDict( - extra="forbid", - ) dag_run: DagRun variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None @@ -224,5 +233,8 @@ class TITerminalStatePayload(BaseModel): Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED). """ + model_config = ConfigDict( + extra="forbid", + ) state: TerminalTIState end_date: Annotated[datetime, Field(title="End Date")] diff --git a/tests/api_fastapi/execution_api/routes/test_variables.py b/tests/api_fastapi/execution_api/routes/test_variables.py index 20a9b43c07ace..45868e2a6092e 100644 --- a/tests/api_fastapi/execution_api/routes/test_variables.py +++ b/tests/api_fastapi/execution_api/routes/test_variables.py @@ -54,7 +54,7 @@ def test_variable_get_from_db(self, client, session): {"AIRFLOW_VAR_KEY1": "VALUE"}, ) def test_variable_get_from_env_var(self, client, session): - response = client.get("/execution/variables/key1") + response = client.get("/execution/variables/key1", params={"foo": "bar"}) assert response.status_code == 200 assert response.json() == {"key": "key1", "value": "VALUE"}