From de2b82e2a3115c4c39dc5458ffe1c8d5018df6dd Mon Sep 17 00:00:00 2001 From: Amogh Date: Fri, 20 Dec 2024 15:27:59 +0530 Subject: [PATCH 1/8] AIP-72: Handling task retries in task SDK + execution API --- .../execution_api/datamodels/taskinstance.py | 2 ++ .../execution_api/routes/task_instances.py | 29 +++++++++++++++++-- task_sdk/src/airflow/sdk/api/client.py | 5 +++- .../airflow/sdk/api/datamodels/_generated.py | 1 + .../src/airflow/sdk/execution_time/comms.py | 1 + .../airflow/sdk/execution_time/supervisor.py | 16 ++++++++-- .../airflow/sdk/execution_time/task_runner.py | 9 ++++-- 7 files changed, 56 insertions(+), 7 deletions(-) diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index c1bf588c2bbd4..002f88bfa7ebb 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -59,6 +59,8 @@ class TITerminalStatePayload(BaseModel): end_date: UtcDateTime """When the task completed executing""" + task_retries: int | None + class TITargetStatePayload(BaseModel): """Schema for updating TaskInstance to a target state, excluding terminal and running states.""" diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 016f5222c79d8..cb160cc8bb12d 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -199,17 +199,22 @@ def ti_update_state( ) # We exclude_unset to avoid updating fields that are not set in the payload - data = ti_patch_payload.model_dump(exclude_unset=True) + # We do not need to deserialize "task_retries" -- it is used for dynamic decision making within failed state + data = ti_patch_payload.model_dump(exclude_unset=True, exclude={"task_retries"}) query = update(TI).where(TI.id == ti_id_str).values(data) if isinstance(ti_patch_payload, TITerminalStatePayload): query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) query = query.values(state=ti_patch_payload.state) + updated_state = ti_patch_payload.state if ti_patch_payload.state == State.FAILED: # clear the next_method and next_kwargs query = query.values(next_method=None, next_kwargs=None) - updated_state = State.FAILED + task_instance = session.get(TI, ti_id_str) + if _is_eligible_to_retry(task_instance, ti_patch_payload.task_retries): + query = query.values(state=State.UP_FOR_RETRY) + updated_state = State.UP_FOR_RETRY elif isinstance(ti_patch_payload, TIDeferredStatePayload): # Calculate timeout if it was passed timeout = None @@ -359,3 +364,23 @@ def ti_put_rtif( _update_rtif(task_instance, put_rtif_payload, session) return {"message": "Rendered task instance fields successfully set"} + + +def _is_eligible_to_retry(task_instance: TI, task_retries: int | None): + """ + Is task instance is eligible for retry. + + :param task_instance: the task instance + + :meta private: + """ + if task_instance.state == State.RESTARTING: + # If a task is cleared when running, it goes into RESTARTING state and is always + # eligible for retry + return True + + if task_retries == -1: + # task_runner indicated that it doesn't know number of retries, guess it from the table + return task_instance.try_number <= task_instance.max_tries + + return task_retries and task_instance.try_number <= task_instance.max_tries diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index da91c2bd98dd2..7d48968721258 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -124,11 +124,14 @@ def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext: resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json()) return TIRunContext.model_validate_json(resp.read()) - def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime): + def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime, task_retries: int | None): """Tell the API server that this TI has reached a terminal state.""" # TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing. body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state)) + if task_retries: + body.task_retries = task_retries + self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) def heartbeat(self, id: uuid.UUID, pid: int): diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 00187364c8669..0e190792c7ed0 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -217,3 +217,4 @@ class TITerminalStatePayload(BaseModel): state: TerminalTIState end_date: Annotated[datetime, Field(title="End Date")] + task_retries: Annotated[int | None, Field(title="Task Retries")] = None diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 31690815af4f6..236be09c9fe2f 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -121,6 +121,7 @@ class TaskState(BaseModel): state: TerminalTIState end_date: datetime | None = None + task_retries: int | None = None type: Literal["TaskState"] = "TaskState" diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 932e6ead37947..b219d62379782 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -291,6 +291,8 @@ class WatchedSubprocess: _exit_code: int | None = attrs.field(default=None, init=False) _terminal_state: str | None = attrs.field(default=None, init=False) _final_state: str | None = attrs.field(default=None, init=False) + # denotes if a task `has` retries defined or not, helpful to send signals between the handle_requests and wait + _should_retry: bool = attrs.field(default=False, init=False) _last_successful_heartbeat: float = attrs.field(default=0, init=False) _last_heartbeat_attempt: float = attrs.field(default=0, init=False) @@ -515,13 +517,15 @@ def wait(self) -> int: # If it hasn't, assume it's failed self._exit_code = self._exit_code if self._exit_code is not None else 1 + print("The exit code is", self._exit_code) + # If the process has finished in a terminal state, update the state of the TaskInstance # to reflect the final state of the process. # For states like `deferred`, the process will exit with 0, but the state will be updated # by the subprocess in the `handle_requests` method. - if self.final_state in TerminalTIState: + if self.final_state in TerminalTIState and not self._should_retry: self.client.task_instances.finish( - id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc) + id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc), task_retries=None ) return self._exit_code @@ -710,6 +714,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): if isinstance(msg, TaskState): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() + if msg.task_retries: + self.client.task_instances.finish( + id=self.id, + state=self.final_state, + when=datetime.now(tz=timezone.utc), + task_retries=msg.task_retries, + ) + self._should_retry = True elif isinstance(msg, GetConnection): conn = self.client.connections.get(msg.conn_id) if isinstance(conn, ConnectionResponse): diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 78eab44df83c3..c4a2e9df19436 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -320,8 +320,13 @@ def run(ti: RuntimeTaskInstance, log: Logger): except SystemExit: ... except BaseException: - # TODO: Handle TI handle failure - raise + msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc)) + if not getattr(ti, "task", None): + # We do not know about retries, let's mark it -1, so that the execution api can make a guess + msg.task_retries = -1 + else: + # `None` indicates no retries provided, the default is anyway 0 which evaluates to false + msg.task_retries = ti.task.retries or None if msg: SUPERVISOR_COMMS.send_request(msg=msg, log=log) From b596ab7958c8413f05dde06b563a8c9abb359540 Mon Sep 17 00:00:00 2001 From: Amogh Date: Fri, 20 Dec 2024 16:19:57 +0530 Subject: [PATCH 2/8] fixing tests --- task_sdk/tests/api/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index 16b8d6c9bfe15..5492918e25774 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -137,7 +137,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: return httpx.Response(status_code=400, json={"detail": "Bad Request"}) client = make_client(transport=httpx.MockTransport(handle_request)) - client.task_instances.finish(ti_id, state=state, when="2024-10-31T12:00:00Z") + client.task_instances.finish(ti_id, state=state, when="2024-10-31T12:00:00Z", task_retries=None) def test_task_instance_heartbeat(self): # Simulate a successful response from the server that sends a heartbeat for a ti From 320a090bb3e92ae578bb4df4e1d02894473c0d4e Mon Sep 17 00:00:00 2001 From: Amogh Date: Mon, 23 Dec 2024 13:00:44 +0530 Subject: [PATCH 3/8] adding default to datamodel --- airflow/api_fastapi/execution_api/datamodels/taskinstance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 002f88bfa7ebb..0137f5ac88b91 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -59,7 +59,7 @@ class TITerminalStatePayload(BaseModel): end_date: UtcDateTime """When the task completed executing""" - task_retries: int | None + task_retries: int | None = None class TITargetStatePayload(BaseModel): From 1c51e38ce3c99aca9533d49d47e92a41e1a14c01 Mon Sep 17 00:00:00 2001 From: Amogh Date: Mon, 23 Dec 2024 13:32:05 +0530 Subject: [PATCH 4/8] adding test coverage --- task_sdk/tests/api/test_client.py | 20 +++++ .../tests/execution_time/test_supervisor.py | 13 +++ .../tests/execution_time/test_task_runner.py | 81 +++++++++++++++++++ .../routes/test_task_instances.py | 65 +++++++++++++++ 4 files changed, 179 insertions(+) diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index 5492918e25774..15ed0054158f1 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -139,6 +139,26 @@ def handle_request(request: httpx.Request) -> httpx.Response: client = make_client(transport=httpx.MockTransport(handle_request)) client.task_instances.finish(ti_id, state=state, when="2024-10-31T12:00:00Z", task_retries=None) + def test_task_instance_finish_with_retries(self): + # Simulate a successful response from the server that finishes (moved to terminal state) a task when retries are present + ti_id = uuid6.uuid7() + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == f"/task-instances/{ti_id}/state": + actual_body = json.loads(request.read()) + assert actual_body["end_date"] == "2024-10-31T12:00:00Z" + assert actual_body["state"] == TerminalTIState.FAILED + assert actual_body["task_retries"] == 2 + return httpx.Response( + status_code=204, + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + + client = make_client(transport=httpx.MockTransport(handle_request)) + client.task_instances.finish( + ti_id, state=TerminalTIState.FAILED, when="2024-10-31T12:00:00Z", task_retries=2 + ) + def test_task_instance_heartbeat(self): # Simulate a successful response from the server that sends a heartbeat for a ti ti_id = uuid6.uuid7() diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 73ac6dea630ba..49b67b4c07df5 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -882,6 +882,19 @@ def watched_subprocess(self, mocker): "", id="patch_task_instance_to_skipped", ), + # checking if we are capable of handling if task_retries is passed + pytest.param( + TaskState( + state=TerminalTIState.FAILED, + end_date=timezone.parse("2024-10-31T12:00:00Z"), + task_retries=2, + ), + b"", + "", + (), + "", + id="patch_task_instance_to_failed_with_retries", + ), ], ) def test_handle_requests( diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 96ac89db5cd9d..bb055ec96865b 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -248,6 +248,87 @@ def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context): ) +@pytest.mark.parametrize( + ["retries", "expected_msg"], + [ + # No retries configured + pytest.param(None, TaskState(state=TerminalTIState.FAILED, task_retries=None)), + # Retries configured + pytest.param(2, TaskState(state=TerminalTIState.FAILED, task_retries=2)), + # Retries configured but with 0 + pytest.param(0, TaskState(state=TerminalTIState.FAILED, task_retries=None)), + ], +) +def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context, retries, expected_msg): + """Test running a basic task that raises a base exception.""" + from airflow.providers.standard.operators.python import PythonOperator + + task = PythonOperator( + task_id="zero_division_error", + retries=retries, + python_callable=lambda: 1 / 0, + ) + + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id="zero_division_error", + dag_id="basic_dag_base_exception", + run_id="c", + try_number=1, + ), + file="", + requests_fd=0, + ti_context=make_ti_context(), + ) + + ti = mocked_parse(what, "basic_dag_base_exception", task) + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + run(ti, log=mock.MagicMock()) + expected_msg.end_date = instant + mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_msg, log=mock.ANY) + + +def test_run_raises_missing_task(time_machine, mocked_parse, make_ti_context): + """Test running a basic dag with missing ti.task.""" + from airflow.providers.standard.operators.python import PythonOperator + + task = PythonOperator( + task_id="missing_task", + python_callable=lambda: 1 / 0, + ) + + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), task_id="missing_task", dag_id="basic_dag_missing_task", run_id="c", try_number=1 + ), + file="", + requests_fd=0, + ti_context=make_ti_context(), + ) + + ti = mocked_parse(what, "basic_dag_missing_task", task) + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + # set ti.task as None + ti.task = None + run(ti, log=mock.MagicMock()) + mock_supervisor_comms.send_request.assert_called_once_with( + msg=TaskState(state=TerminalTIState.FAILED, task_retries=-1, end_date=instant), log=mock.ANY + ) + + def test_startup_basic_templated_dag(mocked_parse, make_ti_context): """Test running a DAG with templated task.""" from airflow.providers.standard.operators.bash import BashOperator diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index 4ed5f8f1598f3..9e21c40e38ba5 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -340,6 +340,71 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan assert trs[0].map_index == -1 assert trs[0].duration == 129600 + @pytest.mark.parametrize( + ("retries", "expected_state"), + [ + # retries given + (2, State.UP_FOR_RETRY), + # retries not given + (None, State.FAILED), + # retries given but as 0 + (0, State.FAILED), + # retries not known, given as -1, calculates on table default + (-1, State.UP_FOR_RETRY), + ], + ) + def test_ti_update_state_to_retry(self, client, session, create_task_instance, retries, expected_state): + ti = create_task_instance( + task_id="test_ti_update_state_to_retry", + state=State.RUNNING, + ) + ti.retries = retries + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={ + "state": State.FAILED, + "end_date": DEFAULT_END_DATE.isoformat(), + "task_retries": retries, + }, + ) + + assert response.status_code == 204 + assert response.text == "" + + session.expire_all() + + ti = session.get(TaskInstance, ti.id) + assert ti.state == expected_state + assert ti.next_method is None + assert ti.next_kwargs is None + + def test_ti_update_state_to_retry_when_restarting(self, client, session, create_task_instance): + ti = create_task_instance( + task_id="test_ti_update_state_to_retry_when_restarting", + state=State.RESTARTING, + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={ + "state": State.FAILED, + "end_date": DEFAULT_END_DATE.isoformat(), + }, + ) + + assert response.status_code == 204 + assert response.text == "" + + session.expire_all() + + ti = session.get(TaskInstance, ti.id) + assert ti.state == State.UP_FOR_RETRY + assert ti.next_method is None + assert ti.next_kwargs is None + class TestTIHealthEndpoint: def setup_method(self): From 8fb5fbe30f9176a68c2fb59862b835f4c663592e Mon Sep 17 00:00:00 2001 From: Amogh Date: Thu, 26 Dec 2024 16:51:37 +0530 Subject: [PATCH 5/8] reworking --- .../execution_api/datamodels/taskinstance.py | 3 +- .../execution_api/routes/task_instances.py | 17 ++--- task_sdk/src/airflow/sdk/api/client.py | 9 +-- .../airflow/sdk/api/datamodels/_generated.py | 2 +- .../src/airflow/sdk/execution_time/comms.py | 18 +++++- .../airflow/sdk/execution_time/supervisor.py | 29 +++++---- .../airflow/sdk/execution_time/task_runner.py | 16 ++--- task_sdk/tests/api/test_client.py | 14 ++--- .../tests/execution_time/test_supervisor.py | 17 +++-- .../tests/execution_time/test_task_runner.py | 62 +++++-------------- .../routes/test_task_instances.py | 12 ++-- 11 files changed, 92 insertions(+), 107 deletions(-) diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 0137f5ac88b91..76400a43bade2 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -59,7 +59,8 @@ class TITerminalStatePayload(BaseModel): end_date: UtcDateTime """When the task completed executing""" - task_retries: int | None = None + """Indicates if the task should retry before failing or not.""" + should_retry: bool = False class TITargetStatePayload(BaseModel): diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index cb160cc8bb12d..5f7aeadbad162 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -200,7 +200,7 @@ def ti_update_state( # We exclude_unset to avoid updating fields that are not set in the payload # We do not need to deserialize "task_retries" -- it is used for dynamic decision making within failed state - data = ti_patch_payload.model_dump(exclude_unset=True, exclude={"task_retries"}) + data = ti_patch_payload.model_dump(exclude_unset=True, exclude={"should_retry"}) query = update(TI).where(TI.id == ti_id_str).values(data) @@ -212,7 +212,7 @@ def ti_update_state( # clear the next_method and next_kwargs query = query.values(next_method=None, next_kwargs=None) task_instance = session.get(TI, ti_id_str) - if _is_eligible_to_retry(task_instance, ti_patch_payload.task_retries): + if _is_eligible_to_retry(task_instance, ti_patch_payload.should_retry): query = query.values(state=State.UP_FOR_RETRY) updated_state = State.UP_FOR_RETRY elif isinstance(ti_patch_payload, TIDeferredStatePayload): @@ -366,7 +366,7 @@ def ti_put_rtif( return {"message": "Rendered task instance fields successfully set"} -def _is_eligible_to_retry(task_instance: TI, task_retries: int | None): +def _is_eligible_to_retry(task_instance: TI, should_retry: bool) -> bool: """ Is task instance is eligible for retry. @@ -374,13 +374,14 @@ def _is_eligible_to_retry(task_instance: TI, task_retries: int | None): :meta private: """ + if not should_retry: + return False + if task_instance.state == State.RESTARTING: # If a task is cleared when running, it goes into RESTARTING state and is always # eligible for retry return True - if task_retries == -1: - # task_runner indicated that it doesn't know number of retries, guess it from the table - return task_instance.try_number <= task_instance.max_tries - - return task_retries and task_instance.try_number <= task_instance.max_tries + # max_tries is initialised with the retries defined at task level, we do not need to explicitly ask for + # retries from the task SDK now, we can handle using max_tries + return task_instance.max_tries and task_instance.try_number <= task_instance.max_tries diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index 7d48968721258..001b8a7d2874e 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -124,14 +124,15 @@ def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext: resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json()) return TIRunContext.model_validate_json(resp.read()) - def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime, task_retries: int | None): + def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime): """Tell the API server that this TI has reached a terminal state.""" # TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing. body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state)) + self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) - if task_retries: - body.task_retries = task_retries - + def fail(self, id: uuid.UUID, when: datetime, should_retry: bool): + """Tell the API server that this TI has to fail, with or without retries.""" + body = TITerminalStatePayload(end_date=when, state=TerminalTIState.FAILED, should_retry=should_retry) self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) def heartbeat(self, id: uuid.UUID, pid: int): diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 0e190792c7ed0..d047c201d3847 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -217,4 +217,4 @@ class TITerminalStatePayload(BaseModel): state: TerminalTIState end_date: Annotated[datetime, Field(title="End Date")] - task_retries: Annotated[int | None, Field(title="Task Retries")] = None + should_retry: Annotated[bool | None, Field(title="Should Retry")] = False diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 58754a69663ff..e9458e619a7ad 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -145,12 +145,25 @@ class TaskState(BaseModel): - anything else = FAILED """ - state: TerminalTIState + state: Literal[TerminalTIState.SUCCESS, TerminalTIState.REMOVED, TerminalTIState.SKIPPED] end_date: datetime | None = None - task_retries: int | None = None type: Literal["TaskState"] = "TaskState" +class FailState(BaseModel): + """ + Update a task's state to FAILED. + + Contains attributes specific to FAILING a state like + ability to retry. + """ + + should_retry: bool = True + end_date: datetime | None = None + state: Literal[TerminalTIState.FAILED] = TerminalTIState.FAILED + type: Literal["FailState"] = "FailState" + + class DeferTask(TIDeferredStatePayload): """Update a task instance state to deferred.""" @@ -233,6 +246,7 @@ class SetRenderedFields(BaseModel): ToSupervisor = Annotated[ Union[ TaskState, + FailState, GetXCom, GetConnection, GetVariable, diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 0587ce8c110cb..bc5abb610af93 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -63,6 +63,7 @@ ConnectionResult, DeferTask, ErrorResponse, + FailState, GetConnection, GetVariable, GetXCom, @@ -294,8 +295,9 @@ class WatchedSubprocess: _exit_code: int | None = attrs.field(default=None, init=False) _terminal_state: str | None = attrs.field(default=None, init=False) _final_state: str | None = attrs.field(default=None, init=False) - # denotes if a task `has` retries defined or not, helpful to send signals between the handle_requests and wait - _should_retry: bool = attrs.field(default=False, init=False) + # denotes if a request to `fail` has been sent from the _handle_requests or not, or it will be handled in wait() + # useful to synchronise the API requests for `fail` between handle_requests and wait + _fail_request_sent: bool = attrs.field(default=False, init=False) _last_successful_heartbeat: float = attrs.field(default=0, init=False) _last_heartbeat_attempt: float = attrs.field(default=0, init=False) @@ -520,15 +522,13 @@ def wait(self) -> int: # If it hasn't, assume it's failed self._exit_code = self._exit_code if self._exit_code is not None else 1 - print("The exit code is", self._exit_code) - # If the process has finished in a terminal state, update the state of the TaskInstance # to reflect the final state of the process. # For states like `deferred`, the process will exit with 0, but the state will be updated # by the subprocess in the `handle_requests` method. - if self.final_state in TerminalTIState and not self._should_retry: + if (not self._fail_request_sent) and self.final_state in TerminalTIState: self.client.task_instances.finish( - id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc), task_retries=None + id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc) ) return self._exit_code @@ -715,17 +715,16 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): log.debug("Received message from task runner", msg=msg) resp = None - if isinstance(msg, TaskState): + if isinstance(msg, FailState): + self._terminal_state = TerminalTIState.FAILED + self._task_end_time_monotonic = time.monotonic() + self._fail_request_sent = True + log.debug("IN SIDE FAILSTATE.") + self.client.task_instances.fail(self.id, datetime.now(tz=timezone.utc), msg.should_retry) + elif isinstance(msg, TaskState): + log.debug("IN SIDE TaskState.") self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() - if msg.task_retries: - self.client.task_instances.finish( - id=self.id, - state=self.final_state, - when=datetime.now(tz=timezone.utc), - task_retries=msg.task_retries, - ) - self._should_retry = True elif isinstance(msg, GetConnection): conn = self.client.connections.get(msg.conn_id) if isinstance(conn, ConnectionResponse): diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 8e138fc18a2ed..baf913e4a5015 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -33,6 +33,7 @@ from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.execution_time.comms import ( DeferTask, + FailState, RescheduleTask, SetRenderedFields, StartupDetails, @@ -300,9 +301,10 @@ def run(ti: RuntimeTaskInstance, log: Logger): # TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951 # TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952 - msg = TaskState( + msg = FailState( state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc), + should_retry=False, ) # TODO: Run task failure callbacks here @@ -313,22 +315,16 @@ def run(ti: RuntimeTaskInstance, log: Logger): # External state updates are already handled with `ti_heartbeat` and will be # updated already be another UI API. So, these exceptions should ideally never be thrown. # If these are thrown, we should mark the TI state as failed. - msg = TaskState( + msg = FailState( state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc), + should_retry=False, ) # TODO: Run task failure callbacks here except SystemExit: ... except BaseException: - msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc)) - if not getattr(ti, "task", None): - # We do not know about retries, let's mark it -1, so that the execution api can make a guess - msg.task_retries = -1 - else: - # `None` indicates no retries provided, the default is anyway 0 which evaluates to false - msg.task_retries = ti.task.retries or None - + msg = FailState(should_retry=True, end_date=datetime.now(tz=timezone.utc)) if msg: SUPERVISOR_COMMS.send_request(msg=msg, log=log) diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index 15ed0054158f1..2e9035e8df3aa 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -121,7 +121,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: resp = client.task_instances.start(ti_id, 100, start_date) assert resp == ti_context - @pytest.mark.parametrize("state", [state for state in TerminalTIState]) + @pytest.mark.parametrize("state", [state for state in TerminalTIState if state != TerminalTIState.FAILED]) def test_task_instance_finish(self, state): # Simulate a successful response from the server that finishes (moved to terminal state) a task ti_id = uuid6.uuid7() @@ -137,10 +137,10 @@ def handle_request(request: httpx.Request) -> httpx.Response: return httpx.Response(status_code=400, json={"detail": "Bad Request"}) client = make_client(transport=httpx.MockTransport(handle_request)) - client.task_instances.finish(ti_id, state=state, when="2024-10-31T12:00:00Z", task_retries=None) + client.task_instances.finish(ti_id, state=state, when="2024-10-31T12:00:00Z") - def test_task_instance_finish_with_retries(self): - # Simulate a successful response from the server that finishes (moved to terminal state) a task when retries are present + def test_task_instance_fail(self): + # Simulate a successful response from the server that fails a task with retry. ti_id = uuid6.uuid7() def handle_request(request: httpx.Request) -> httpx.Response: @@ -148,16 +148,14 @@ def handle_request(request: httpx.Request) -> httpx.Response: actual_body = json.loads(request.read()) assert actual_body["end_date"] == "2024-10-31T12:00:00Z" assert actual_body["state"] == TerminalTIState.FAILED - assert actual_body["task_retries"] == 2 + assert actual_body["should_retry"] is True return httpx.Response( status_code=204, ) return httpx.Response(status_code=400, json={"detail": "Bad Request"}) client = make_client(transport=httpx.MockTransport(handle_request)) - client.task_instances.finish( - ti_id, state=TerminalTIState.FAILED, when="2024-10-31T12:00:00Z", task_retries=2 - ) + client.task_instances.fail(ti_id, when="2024-10-31T12:00:00Z", should_retry=True) def test_task_instance_heartbeat(self): # Simulate a successful response from the server that sends a heartbeat for a ti diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 6f03782768cd4..4c1af64c3931d 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -41,6 +41,7 @@ from airflow.sdk.execution_time.comms import ( ConnectionResult, DeferTask, + FailState, GetConnection, GetVariable, GetXCom, @@ -884,18 +885,18 @@ def watched_subprocess(self, mocker): "", id="patch_task_instance_to_skipped", ), - # checking if we are capable of handling if task_retries is passed + # testing to see if supervisor can handle FailState message pytest.param( - TaskState( + FailState( state=TerminalTIState.FAILED, end_date=timezone.parse("2024-10-31T12:00:00Z"), - task_retries=2, + should_retry=False, ), b"", + "task_instances.fail", + (TI_ID, timezone.parse("2024-11-7T12:00:00Z"), False), "", - (), - "", - id="patch_task_instance_to_failed_with_retries", + id="patch_task_instance_to_failed", ), pytest.param( SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}), @@ -916,6 +917,7 @@ def test_handle_requests( client_attr_path, method_arg, mock_response, + time_machine, ): """ Test handling of different messages to the subprocess. For any new message type, add a @@ -929,6 +931,9 @@ def test_handle_requests( 4. Verifies that the response is correctly decoded. """ + instant = tz.datetime(2024, 11, 7, 12, 0, 0, 0) + time_machine.move_to(instant, tick=False) + # Mock the client method. E.g. `client.variables.get` or `client.connections.get` mock_client_method = attrgetter(client_attr_path)(watched_subprocess.client) mock_client_method.return_value = mock_response diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 2fd2b1832dc38..a1af2745175e0 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -37,6 +37,7 @@ from airflow.sdk.execution_time.comms import ( ConnectionResult, DeferTask, + FailState, GetConnection, SetRenderedFields, StartupDetails, @@ -257,25 +258,19 @@ def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context): @pytest.mark.parametrize( - ["retries", "expected_msg"], - [ - # No retries configured - pytest.param(None, TaskState(state=TerminalTIState.FAILED, task_retries=None)), - # Retries configured - pytest.param(2, TaskState(state=TerminalTIState.FAILED, task_retries=2)), - # Retries configured but with 0 - pytest.param(0, TaskState(state=TerminalTIState.FAILED, task_retries=None)), - ], + "retries", + [None, 0, 3], ) -def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context, retries, expected_msg): +def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context, retries): """Test running a basic task that raises a base exception.""" from airflow.providers.standard.operators.python import PythonOperator task = PythonOperator( task_id="zero_division_error", - retries=retries, python_callable=lambda: 1 / 0, ) + if retries is not None: + task.retries = retries what = StartupDetails( ti=TaskInstance( @@ -299,41 +294,14 @@ def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context, "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True ) as mock_supervisor_comms: run(ti, log=mock.MagicMock()) - expected_msg.end_date = instant - mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_msg, log=mock.ANY) - - -def test_run_raises_missing_task(time_machine, mocked_parse, make_ti_context): - """Test running a basic dag with missing ti.task.""" - from airflow.providers.standard.operators.python import PythonOperator - - task = PythonOperator( - task_id="missing_task", - python_callable=lambda: 1 / 0, - ) - - what = StartupDetails( - ti=TaskInstance( - id=uuid7(), task_id="missing_task", dag_id="basic_dag_missing_task", run_id="c", try_number=1 - ), - file="", - requests_fd=0, - ti_context=make_ti_context(), - ) - - ti = mocked_parse(what, "basic_dag_missing_task", task) - - instant = timezone.datetime(2024, 12, 3, 10, 0) - time_machine.move_to(instant, tick=False) - with mock.patch( - "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True - ) as mock_supervisor_comms: - # set ti.task as None - ti.task = None - run(ti, log=mock.MagicMock()) mock_supervisor_comms.send_request.assert_called_once_with( - msg=TaskState(state=TerminalTIState.FAILED, task_retries=-1, end_date=instant), log=mock.ANY + msg=FailState( + should_retry=True, + state=TerminalTIState.FAILED, + end_date=instant, + ), + log=mock.ANY, ) @@ -467,7 +435,9 @@ def execute(self, context): ), ], ) -def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context): +def test_run_basic_failed_without_retries( + time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context +): """Test running a basic task that marks itself as failed by raising exception.""" class CustomOperator(BaseOperator): @@ -499,7 +469,7 @@ def execute(self, context): run(ti, log=mock.MagicMock()) mock_supervisor_comms.send_request.assert_called_once_with( - msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), log=mock.ANY + msg=FailState(state=TerminalTIState.FAILED, end_date=instant, should_retry=False), log=mock.ANY ) diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index 9e21c40e38ba5..1e5e8ede318e7 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -353,12 +353,12 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan (-1, State.UP_FOR_RETRY), ], ) - def test_ti_update_state_to_retry(self, client, session, create_task_instance, retries, expected_state): + def test_ti_update_state_to_retry(self, client, session, create_task_instance): ti = create_task_instance( task_id="test_ti_update_state_to_retry", state=State.RUNNING, ) - ti.retries = retries + ti.max_tries = 3 session.commit() response = client.patch( @@ -366,7 +366,7 @@ def test_ti_update_state_to_retry(self, client, session, create_task_instance, r json={ "state": State.FAILED, "end_date": DEFAULT_END_DATE.isoformat(), - "task_retries": retries, + "should_retry": True, }, ) @@ -376,9 +376,9 @@ def test_ti_update_state_to_retry(self, client, session, create_task_instance, r session.expire_all() ti = session.get(TaskInstance, ti.id) - assert ti.state == expected_state - assert ti.next_method is None - assert ti.next_kwargs is None + # assert ti.state == expected_state + # assert ti.next_method is None + # assert ti.next_kwargs is None def test_ti_update_state_to_retry_when_restarting(self, client, session, create_task_instance): ti = create_task_instance( From e6529d2b7c34cd5c79dccef7b107ae2772d4a328 Mon Sep 17 00:00:00 2001 From: Amogh Date: Thu, 26 Dec 2024 16:57:14 +0530 Subject: [PATCH 6/8] fixup --- airflow/api_fastapi/execution_api/routes/task_instances.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 5f7aeadbad162..7d45d5f4071c2 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -199,7 +199,7 @@ def ti_update_state( ) # We exclude_unset to avoid updating fields that are not set in the payload - # We do not need to deserialize "task_retries" -- it is used for dynamic decision making within failed state + # We do not need to deserialize "should_retry" -- it is used for dynamic decision-making within failed state data = ti_patch_payload.model_dump(exclude_unset=True, exclude={"should_retry"}) query = update(TI).where(TI.id == ti_id_str).values(data) From 980dfa318cdfd2782e20cdd905e08193acad638b Mon Sep 17 00:00:00 2001 From: Amogh Date: Fri, 27 Dec 2024 17:11:22 +0530 Subject: [PATCH 7/8] failed = fail after retrying, fail_without_retry = just fail --- .../execution_api/datamodels/taskinstance.py | 3 -- .../execution_api/routes/task_instances.py | 46 +++++++++---------- airflow/utils/state.py | 5 +- task_sdk/src/airflow/sdk/api/client.py | 5 -- .../airflow/sdk/api/datamodels/_generated.py | 6 ++- .../src/airflow/sdk/execution_time/comms.py | 17 +------ .../airflow/sdk/execution_time/supervisor.py | 15 +----- .../airflow/sdk/execution_time/task_runner.py | 14 +++--- task_sdk/tests/api/test_client.py | 20 +------- .../tests/execution_time/test_supervisor.py | 14 +++--- .../tests/execution_time/test_task_runner.py | 16 ++----- .../routes/test_task_instances.py | 22 ++++----- 12 files changed, 60 insertions(+), 123 deletions(-) diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 76400a43bade2..c1bf588c2bbd4 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -59,9 +59,6 @@ class TITerminalStatePayload(BaseModel): end_date: UtcDateTime """When the task completed executing""" - """Indicates if the task should retry before failing or not.""" - should_retry: bool = False - class TITargetStatePayload(BaseModel): """Schema for updating TaskInstance to a target state, excluding terminal and running states.""" diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 7d45d5f4071c2..4956466ca707a 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -44,7 +44,7 @@ from airflow.models.taskreschedule import TaskReschedule from airflow.models.trigger import Trigger from airflow.utils import timezone -from airflow.utils.state import State +from airflow.utils.state import State, TerminalTIState # TODO: Add dependency on JWT token router = AirflowRouter() @@ -185,9 +185,13 @@ def ti_update_state( # We only use UUID above for validation purposes ti_id_str = str(task_instance_id) - old = select(TI.state).where(TI.id == ti_id_str).with_for_update() + old = select(TI.state, TI.try_number, TI.max_tries).where(TI.id == ti_id_str).with_for_update() try: - (previous_state,) = session.execute(old).one() + ( + previous_state, + try_number, + max_tries, + ) = session.execute(old).one() except NoResultFound: log.error("Task Instance %s not found", ti_id_str) raise HTTPException( @@ -199,22 +203,23 @@ def ti_update_state( ) # We exclude_unset to avoid updating fields that are not set in the payload - # We do not need to deserialize "should_retry" -- it is used for dynamic decision-making within failed state - data = ti_patch_payload.model_dump(exclude_unset=True, exclude={"should_retry"}) + data = ti_patch_payload.model_dump(exclude_unset=True) query = update(TI).where(TI.id == ti_id_str).values(data) if isinstance(ti_patch_payload, TITerminalStatePayload): query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) - query = query.values(state=ti_patch_payload.state) updated_state = ti_patch_payload.state - if ti_patch_payload.state == State.FAILED: - # clear the next_method and next_kwargs - query = query.values(next_method=None, next_kwargs=None) - task_instance = session.get(TI, ti_id_str) - if _is_eligible_to_retry(task_instance, ti_patch_payload.should_retry): - query = query.values(state=State.UP_FOR_RETRY) + # if we get failed, we should attempt to retry, as it is a more + # normal state. Tasks with retries are more frequent than without retries. + if ti_patch_payload.state == TerminalTIState.FAIL_WITHOUT_RETRY: + updated_state = State.FAILED + elif ti_patch_payload.state == State.FAILED: + if _is_eligible_to_retry(previous_state, try_number, max_tries): updated_state = State.UP_FOR_RETRY + else: + updated_state = State.FAILED + query = query.values(state=updated_state) elif isinstance(ti_patch_payload, TIDeferredStatePayload): # Calculate timeout if it was passed timeout = None @@ -366,22 +371,13 @@ def ti_put_rtif( return {"message": "Rendered task instance fields successfully set"} -def _is_eligible_to_retry(task_instance: TI, should_retry: bool) -> bool: - """ - Is task instance is eligible for retry. - - :param task_instance: the task instance - - :meta private: - """ - if not should_retry: - return False - - if task_instance.state == State.RESTARTING: +def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool: + """Is task instance is eligible for retry.""" + if state == State.RESTARTING: # If a task is cleared when running, it goes into RESTARTING state and is always # eligible for retry return True # max_tries is initialised with the retries defined at task level, we do not need to explicitly ask for # retries from the task SDK now, we can handle using max_tries - return task_instance.max_tries and task_instance.try_number <= task_instance.max_tries + return max_tries != 0 and try_number <= max_tries diff --git a/airflow/utils/state.py b/airflow/utils/state.py index e4e2e9db8a587..cfdd015137aaa 100644 --- a/airflow/utils/state.py +++ b/airflow/utils/state.py @@ -36,9 +36,12 @@ class TerminalTIState(str, Enum): """States that a Task Instance can be in that indicate it has reached a terminal state.""" SUCCESS = "success" - FAILED = "failed" + FAILED = "failed" # This state indicates that we attempt to retry. SKIPPED = "skipped" # A user can raise a AirflowSkipException from a task & it will be marked as skipped REMOVED = "removed" + FAIL_WITHOUT_RETRY = ( + "fail_without_retry" # This state is useful for when we want to terminate a task, without retrying. + ) def __str__(self) -> str: return self.value diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index ba3b66eca90d8..907485902a76c 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -130,11 +130,6 @@ def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime): body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state)) self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) - def fail(self, id: uuid.UUID, when: datetime, should_retry: bool): - """Tell the API server that this TI has to fail, with or without retries.""" - body = TITerminalStatePayload(end_date=when, state=TerminalTIState.FAILED, should_retry=should_retry) - self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) - def heartbeat(self, id: uuid.UUID, pid: int): body = TIHeartbeatInfo(pid=pid, hostname=get_hostname()) self.client.put(f"task-instances/{id}/heartbeat", content=body.model_dump_json()) diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index d047c201d3847..6bd9067804d5d 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -126,9 +126,12 @@ class TerminalTIState(str, Enum): """ SUCCESS = "success" - FAILED = "failed" + FAILED = "failed" # This state indicates that we attempt to retry. SKIPPED = "skipped" REMOVED = "removed" + FAIL_WITHOUT_RETRY = ( + "fail_without_retry" # This state is useful for when we want to terminate a task, without retrying. + ) class ValidationError(BaseModel): @@ -217,4 +220,3 @@ class TITerminalStatePayload(BaseModel): state: TerminalTIState end_date: Annotated[datetime, Field(title="End Date")] - should_retry: Annotated[bool | None, Field(title="Should Retry")] = False diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 0c554260aaa29..b90787ca4cfc9 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -145,25 +145,11 @@ class TaskState(BaseModel): - anything else = FAILED """ - state: Literal[TerminalTIState.SUCCESS, TerminalTIState.REMOVED, TerminalTIState.SKIPPED] + state: TerminalTIState end_date: datetime | None = None type: Literal["TaskState"] = "TaskState" -class FailState(BaseModel): - """ - Update a task's state to FAILED. - - Contains attributes specific to FAILING a state like - ability to retry. - """ - - should_retry: bool = True - end_date: datetime | None = None - state: Literal[TerminalTIState.FAILED] = TerminalTIState.FAILED - type: Literal["FailState"] = "FailState" - - class DeferTask(TIDeferredStatePayload): """Update a task instance state to deferred.""" @@ -246,7 +232,6 @@ class SetRenderedFields(BaseModel): ToSupervisor = Annotated[ Union[ TaskState, - FailState, GetXCom, GetConnection, GetVariable, diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index ccfdce01be301..811d1ce86a60d 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -62,7 +62,6 @@ from airflow.sdk.execution_time.comms import ( ConnectionResult, DeferTask, - FailState, GetConnection, GetVariable, GetXCom, @@ -294,9 +293,6 @@ class WatchedSubprocess: _exit_code: int | None = attrs.field(default=None, init=False) _terminal_state: str | None = attrs.field(default=None, init=False) _final_state: str | None = attrs.field(default=None, init=False) - # denotes if a request to `fail` has been sent from the _handle_requests or not, or it will be handled in wait() - # useful to synchronise the API requests for `fail` between handle_requests and wait - _fail_request_sent: bool = attrs.field(default=False, init=False) _last_successful_heartbeat: float = attrs.field(default=0, init=False) _last_heartbeat_attempt: float = attrs.field(default=0, init=False) @@ -525,7 +521,7 @@ def wait(self) -> int: # to reflect the final state of the process. # For states like `deferred`, the process will exit with 0, but the state will be updated # by the subprocess in the `handle_requests` method. - if (not self._fail_request_sent) and self.final_state in TerminalTIState: + if self.final_state in TerminalTIState: self.client.task_instances.finish( id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc) ) @@ -714,14 +710,7 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): log.debug("Received message from task runner", msg=msg) resp = None - if isinstance(msg, FailState): - self._terminal_state = TerminalTIState.FAILED - self._task_end_time_monotonic = time.monotonic() - self._fail_request_sent = True - log.debug("IN SIDE FAILSTATE.") - self.client.task_instances.fail(self.id, datetime.now(tz=timezone.utc), msg.should_retry) - elif isinstance(msg, TaskState): - log.debug("IN SIDE TaskState.") + if isinstance(msg, TaskState): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() elif isinstance(msg, GetConnection): diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 863ea66b8b36b..9379826ff22ef 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -34,7 +34,6 @@ from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.execution_time.comms import ( DeferTask, - FailState, GetXCom, RescheduleTask, SetRenderedFields, @@ -409,10 +408,9 @@ def run(ti: RuntimeTaskInstance, log: Logger): # TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951 # TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952 - msg = FailState( - state=TerminalTIState.FAILED, + msg = TaskState( + state=TerminalTIState.FAIL_WITHOUT_RETRY, end_date=datetime.now(tz=timezone.utc), - should_retry=False, ) # TODO: Run task failure callbacks here @@ -423,16 +421,16 @@ def run(ti: RuntimeTaskInstance, log: Logger): # External state updates are already handled with `ti_heartbeat` and will be # updated already be another UI API. So, these exceptions should ideally never be thrown. # If these are thrown, we should mark the TI state as failed. - msg = FailState( - state=TerminalTIState.FAILED, + msg = TaskState( + state=TerminalTIState.FAIL_WITHOUT_RETRY, end_date=datetime.now(tz=timezone.utc), - should_retry=False, ) # TODO: Run task failure callbacks here except SystemExit: ... except BaseException: - msg = FailState(should_retry=True, end_date=datetime.now(tz=timezone.utc)) + # TODO: Run task failure callbacks here + msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc)) if msg: SUPERVISOR_COMMS.send_request(msg=msg, log=log) diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index 1e912468fe424..279502793ee23 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -121,7 +121,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: resp = client.task_instances.start(ti_id, 100, start_date) assert resp == ti_context - @pytest.mark.parametrize("state", [state for state in TerminalTIState if state != TerminalTIState.FAILED]) + @pytest.mark.parametrize("state", [state for state in TerminalTIState]) def test_task_instance_finish(self, state): # Simulate a successful response from the server that finishes (moved to terminal state) a task ti_id = uuid6.uuid7() @@ -139,24 +139,6 @@ def handle_request(request: httpx.Request) -> httpx.Response: client = make_client(transport=httpx.MockTransport(handle_request)) client.task_instances.finish(ti_id, state=state, when="2024-10-31T12:00:00Z") - def test_task_instance_fail(self): - # Simulate a successful response from the server that fails a task with retry. - ti_id = uuid6.uuid7() - - def handle_request(request: httpx.Request) -> httpx.Response: - if request.url.path == f"/task-instances/{ti_id}/state": - actual_body = json.loads(request.read()) - assert actual_body["end_date"] == "2024-10-31T12:00:00Z" - assert actual_body["state"] == TerminalTIState.FAILED - assert actual_body["should_retry"] is True - return httpx.Response( - status_code=204, - ) - return httpx.Response(status_code=400, json={"detail": "Bad Request"}) - - client = make_client(transport=httpx.MockTransport(handle_request)) - client.task_instances.fail(ti_id, when="2024-10-31T12:00:00Z", should_retry=True) - def test_task_instance_heartbeat(self): # Simulate a successful response from the server that sends a heartbeat for a ti ti_id = uuid6.uuid7() diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index ee96bb74ff7b7..7f181c4e2de6d 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -41,7 +41,6 @@ from airflow.sdk.execution_time.comms import ( ConnectionResult, DeferTask, - FailState, GetConnection, GetVariable, GetXCom, @@ -885,18 +884,17 @@ def watched_subprocess(self, mocker): "", id="patch_task_instance_to_skipped", ), - # testing to see if supervisor can handle FailState message + # testing to see if supervisor can handle TaskState message with state as fail_with_retry pytest.param( - FailState( - state=TerminalTIState.FAILED, + TaskState( + state=TerminalTIState.FAIL_WITHOUT_RETRY, end_date=timezone.parse("2024-10-31T12:00:00Z"), - should_retry=False, ), b"", - "task_instances.fail", - (TI_ID, timezone.parse("2024-11-7T12:00:00Z"), False), "", - id="patch_task_instance_to_failed", + (), + "", + id="patch_task_instance_to_failed_with_retries", ), pytest.param( SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}), diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index a1af2745175e0..0d1531ca2ef65 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -37,7 +37,6 @@ from airflow.sdk.execution_time.comms import ( ConnectionResult, DeferTask, - FailState, GetConnection, SetRenderedFields, StartupDetails, @@ -257,20 +256,14 @@ def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context): ) -@pytest.mark.parametrize( - "retries", - [None, 0, 3], -) -def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context, retries): - """Test running a basic task that raises a base exception.""" +def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context): + """Test running a basic task that raises a base exception which should send fail_with_retry state.""" from airflow.providers.standard.operators.python import PythonOperator task = PythonOperator( task_id="zero_division_error", python_callable=lambda: 1 / 0, ) - if retries is not None: - task.retries = retries what = StartupDetails( ti=TaskInstance( @@ -296,8 +289,7 @@ def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context, run(ti, log=mock.MagicMock()) mock_supervisor_comms.send_request.assert_called_once_with( - msg=FailState( - should_retry=True, + msg=TaskState( state=TerminalTIState.FAILED, end_date=instant, ), @@ -469,7 +461,7 @@ def execute(self, context): run(ti, log=mock.MagicMock()) mock_supervisor_comms.send_request.assert_called_once_with( - msg=FailState(state=TerminalTIState.FAILED, end_date=instant, should_retry=False), log=mock.ANY + msg=TaskState(state=TerminalTIState.FAIL_WITHOUT_RETRY, end_date=instant), log=mock.ANY ) diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index 681ac7b5ef985..379c879f3e05c 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -28,7 +28,7 @@ from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger from airflow.models.taskinstance import TaskInstance from airflow.utils import timezone -from airflow.utils.state import State, TaskInstanceState +from airflow.utils.state import State, TaskInstanceState, TerminalTIState from tests_common.test_utils.db import clear_db_runs, clear_rendered_ti_fields @@ -234,7 +234,7 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta with mock.patch( "airflow.api_fastapi.common.db.common.Session.execute", side_effect=[ - mock.Mock(one=lambda: ("running",)), # First call returns "queued" + mock.Mock(one=lambda: ("running", 1, 0)), # First call returns "queued" SQLAlchemyError("Database error"), # Second call raises an error ], ): @@ -341,30 +341,30 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan assert trs[0].duration == 129600 @pytest.mark.parametrize( - ("should_retry", "expected_state"), + ("retries", "expected_state"), [ - # retries given - (True, State.UP_FOR_RETRY), - # retries not given - (False, State.FAILED), + (0, State.FAILED), + (None, State.FAILED), + (3, State.UP_FOR_RETRY), ], ) def test_ti_update_state_to_failed_with_retries( - self, client, session, create_task_instance, should_retry, expected_state + self, client, session, create_task_instance, retries, expected_state ): ti = create_task_instance( task_id="test_ti_update_state_to_retry", state=State.RUNNING, ) - ti.max_tries = 3 + + if retries is not None: + ti.max_tries = retries session.commit() response = client.patch( f"/execution/task-instances/{ti.id}/state", json={ - "state": State.FAILED, + "state": TerminalTIState.FAILED, "end_date": DEFAULT_END_DATE.isoformat(), - "should_retry": should_retry, }, ) From b230ff61144ec3ebcf64edf005b4ea9ef931bef4 Mon Sep 17 00:00:00 2001 From: Amogh Date: Mon, 30 Dec 2024 13:06:20 +0530 Subject: [PATCH 8/8] handling review comments from ash and kaxil --- airflow/utils/state.py | 6 +- .../airflow/sdk/api/datamodels/_generated.py | 6 +- .../tests/execution_time/test_supervisor.py | 16 ----- .../routes/test_task_instances.py | 64 ++++++++++++++++++- 4 files changed, 66 insertions(+), 26 deletions(-) diff --git a/airflow/utils/state.py b/airflow/utils/state.py index cfdd015137aaa..dca2c8fc93f31 100644 --- a/airflow/utils/state.py +++ b/airflow/utils/state.py @@ -36,12 +36,10 @@ class TerminalTIState(str, Enum): """States that a Task Instance can be in that indicate it has reached a terminal state.""" SUCCESS = "success" - FAILED = "failed" # This state indicates that we attempt to retry. + FAILED = "failed" SKIPPED = "skipped" # A user can raise a AirflowSkipException from a task & it will be marked as skipped REMOVED = "removed" - FAIL_WITHOUT_RETRY = ( - "fail_without_retry" # This state is useful for when we want to terminate a task, without retrying. - ) + FAIL_WITHOUT_RETRY = "fail_without_retry" def __str__(self) -> str: return self.value diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 6bd9067804d5d..ff4cc588ff564 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -126,12 +126,10 @@ class TerminalTIState(str, Enum): """ SUCCESS = "success" - FAILED = "failed" # This state indicates that we attempt to retry. + FAILED = "failed" SKIPPED = "skipped" REMOVED = "removed" - FAIL_WITHOUT_RETRY = ( - "fail_without_retry" # This state is useful for when we want to terminate a task, without retrying. - ) + FAIL_WITHOUT_RETRY = "fail_without_retry" class ValidationError(BaseModel): diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 7f181c4e2de6d..9cfe456962bb9 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -884,18 +884,6 @@ def watched_subprocess(self, mocker): "", id="patch_task_instance_to_skipped", ), - # testing to see if supervisor can handle TaskState message with state as fail_with_retry - pytest.param( - TaskState( - state=TerminalTIState.FAIL_WITHOUT_RETRY, - end_date=timezone.parse("2024-10-31T12:00:00Z"), - ), - b"", - "", - (), - "", - id="patch_task_instance_to_failed_with_retries", - ), pytest.param( SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}), b"", @@ -928,10 +916,6 @@ def test_handle_requests( 3. Checks that the buffer is updated with the expected response. 4. Verifies that the response is correctly decoded. """ - - instant = tz.datetime(2024, 11, 7, 12, 0, 0, 0) - time_machine.move_to(instant, tick=False) - # Mock the client method. E.g. `client.variables.get` or `client.connections.get` mock_client_method = attrgetter(client_attr_path)(watched_subprocess.client) mock_client_method.return_value = mock_response diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index 379c879f3e05c..497c5fbaf3f44 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -378,7 +378,67 @@ def test_ti_update_state_to_failed_with_retries( assert ti.next_method is None assert ti.next_kwargs is None - def test_ti_update_state_to_failed_table_check(self, client, session, create_task_instance): + def test_ti_update_state_when_ti_is_restarting(self, client, session, create_task_instance): + ti = create_task_instance( + task_id="test_ti_update_state_when_ti_is_restarting", + state=State.RUNNING, + ) + # update state to restarting + ti.state = State.RESTARTING + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={ + "state": TerminalTIState.FAILED, + "end_date": DEFAULT_END_DATE.isoformat(), + }, + ) + + assert response.status_code == 204 + assert response.text == "" + + session.expire_all() + + ti = session.get(TaskInstance, ti.id) + # restarting is always retried + assert ti.state == State.UP_FOR_RETRY + assert ti.next_method is None + assert ti.next_kwargs is None + + def test_ti_update_state_when_ti_has_higher_tries_than_retries( + self, client, session, create_task_instance + ): + ti = create_task_instance( + task_id="test_ti_update_state_when_ti_has_higher_tries_than_retries", + state=State.RUNNING, + ) + # two maximum tries defined, but third try going on + ti.max_tries = 2 + ti.try_number = 3 + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={ + "state": TerminalTIState.FAILED, + "end_date": DEFAULT_END_DATE.isoformat(), + }, + ) + + assert response.status_code == 204 + assert response.text == "" + + session.expire_all() + + ti = session.get(TaskInstance, ti.id) + # all retries exhausted, marking as failed + assert ti.state == State.FAILED + assert ti.next_method is None + assert ti.next_kwargs is None + + def test_ti_update_state_to_failed_without_retry_table_check(self, client, session, create_task_instance): + # we just want to fail in this test, no need to retry ti = create_task_instance( task_id="test_ti_update_state_to_failed_table_check", state=State.RUNNING, @@ -389,7 +449,7 @@ def test_ti_update_state_to_failed_table_check(self, client, session, create_tas response = client.patch( f"/execution/task-instances/{ti.id}/state", json={ - "state": State.FAILED, + "state": TerminalTIState.FAIL_WITHOUT_RETRY, "end_date": DEFAULT_END_DATE.isoformat(), }, )