Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Handling task retries in task SDK + execution API #45106

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class TITerminalStatePayload(BaseModel):
end_date: UtcDateTime
"""When the task completed executing"""

task_retries: int | None


class TITargetStatePayload(BaseModel):
"""Schema for updating TaskInstance to a target state, excluding terminal and running states."""
Expand Down
29 changes: 27 additions & 2 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,17 +199,22 @@ def ti_update_state(
)

# We exclude_unset to avoid updating fields that are not set in the payload
data = ti_patch_payload.model_dump(exclude_unset=True)
# We do not need to deserialize "task_retries" -- it is used for dynamic decision making within failed state
data = ti_patch_payload.model_dump(exclude_unset=True, exclude={"task_retries"})

query = update(TI).where(TI.id == ti_id_str).values(data)

if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=ti_patch_payload.state)
updated_state = ti_patch_payload.state
if ti_patch_payload.state == State.FAILED:
# clear the next_method and next_kwargs
query = query.values(next_method=None, next_kwargs=None)
updated_state = State.FAILED
task_instance = session.get(TI, ti_id_str)
if _is_eligible_to_retry(task_instance, ti_patch_payload.task_retries):
query = query.values(state=State.UP_FOR_RETRY)
updated_state = State.UP_FOR_RETRY
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down Expand Up @@ -359,3 +364,23 @@ def ti_put_rtif(
_update_rtif(task_instance, put_rtif_payload, session)

return {"message": "Rendered task instance fields successfully set"}


def _is_eligible_to_retry(task_instance: TI, task_retries: int | None):
"""
Is task instance is eligible for retry.

:param task_instance: the task instance

:meta private:
"""
if task_instance.state == State.RESTARTING:
# If a task is cleared when running, it goes into RESTARTING state and is always
# eligible for retry
return True

if task_retries == -1:
# task_runner indicated that it doesn't know number of retries, guess it from the table
return task_instance.try_number <= task_instance.max_tries

return task_retries and task_instance.try_number <= task_instance.max_tries
5 changes: 4 additions & 1 deletion task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,14 @@ def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext:
resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json())
return TIRunContext.model_validate_json(resp.read())

def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime, task_retries: int | None):
"""Tell the API server that this TI has reached a terminal state."""
# TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing.
body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state))

if task_retries:
body.task_retries = task_retries

self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def heartbeat(self, id: uuid.UUID, pid: int):
Expand Down
1 change: 1 addition & 0 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,4 @@ class TITerminalStatePayload(BaseModel):

state: TerminalTIState
end_date: Annotated[datetime, Field(title="End Date")]
task_retries: Annotated[int | None, Field(title="Task Retries")] = None
1 change: 1 addition & 0 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class TaskState(BaseModel):

state: TerminalTIState
end_date: datetime | None = None
task_retries: int | None = None
type: Literal["TaskState"] = "TaskState"


Expand Down
16 changes: 14 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ class WatchedSubprocess:
_exit_code: int | None = attrs.field(default=None, init=False)
_terminal_state: str | None = attrs.field(default=None, init=False)
_final_state: str | None = attrs.field(default=None, init=False)
# denotes if a task `has` retries defined or not, helpful to send signals between the handle_requests and wait
_should_retry: bool = attrs.field(default=False, init=False)

_last_successful_heartbeat: float = attrs.field(default=0, init=False)
_last_heartbeat_attempt: float = attrs.field(default=0, init=False)
Expand Down Expand Up @@ -515,13 +517,15 @@ def wait(self) -> int:
# If it hasn't, assume it's failed
self._exit_code = self._exit_code if self._exit_code is not None else 1

print("The exit code is", self._exit_code)

# If the process has finished in a terminal state, update the state of the TaskInstance
# to reflect the final state of the process.
# For states like `deferred`, the process will exit with 0, but the state will be updated
# by the subprocess in the `handle_requests` method.
if self.final_state in TerminalTIState:
if self.final_state in TerminalTIState and not self._should_retry:
self.client.task_instances.finish(
id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc)
id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc), task_retries=None
)
return self._exit_code

Expand Down Expand Up @@ -710,6 +714,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
if msg.task_retries:
self.client.task_instances.finish(
id=self.id,
state=self.final_state,
when=datetime.now(tz=timezone.utc),
task_retries=msg.task_retries,
)
self._should_retry = True
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
if isinstance(conn, ConnectionResponse):
Expand Down
9 changes: 7 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,13 @@ def run(ti: RuntimeTaskInstance, log: Logger):
except SystemExit:
...
except BaseException:
# TODO: Handle TI handle failure
raise
msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc))
if not getattr(ti, "task", None):
# We do not know about retries, let's mark it -1, so that the execution api can make a guess
msg.task_retries = -1
else:
# `None` indicates no retries provided, the default is anyway 0 which evaluates to false
msg.task_retries = ti.task.retries or None

if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def handle_request(request: httpx.Request) -> httpx.Response:
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))
client.task_instances.finish(ti_id, state=state, when="2024-10-31T12:00:00Z")
client.task_instances.finish(ti_id, state=state, when="2024-10-31T12:00:00Z", task_retries=None)

def test_task_instance_heartbeat(self):
# Simulate a successful response from the server that sends a heartbeat for a ti
Expand Down
Loading