Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 48 additions & 38 deletions airflow-core/src/airflow/serialization/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,11 +720,17 @@ def set_task_instance_state(
else:
tasks_to_set_state = [(task, map_index) for map_index in map_indexes]

# Only set the state on the targeted task instances here. We do not pass
# ``downstream`` through to ``set_state`` because that helper would mark
# downstream task instances as the same state, which would then prevent the
# explicit downstream-clearing block below from finding them in the
# FAILED/UPSTREAM_FAILED state. Downstream handling is therefore done
# explicitly in the ``if downstream:`` block below.
altered = set_state(
tasks=tasks_to_set_state,
run_id=run_id,
upstream=upstream,
downstream=downstream,
downstream=False,
future=future,
past=past,
state=state,
Expand All @@ -736,45 +742,49 @@ def set_task_instance_state(
return altered

# Clear downstream tasks that are in failed/upstream_failed state to resume them.
# Flush the session so that the tasks marked success are reflected in the db.
session.flush()
subset = self.partial_subset(
task_ids={task_id},
include_downstream=True,
include_upstream=False,
)

# Raises an error if not found
dr_id, logical_date = session.execute(
select(DagRun.id, DagRun.logical_date).where(
DagRun.run_id == run_id, DagRun.dag_id == self.dag_id
# Only clear downstreams when ``downstream=True`` is explicitly passed, so that
# marking a single task instance as success (e.g. from the Task Instances view)
# does not unexpectedly resume downstream tasks — restoring Airflow 2 behavior.
if downstream:
# Flush the session so that the tasks marked success are reflected in the db.
session.flush()
subset = self.partial_subset(
task_ids={task_id},
include_downstream=True,
include_upstream=False,
)
).one()

# Now we want to clear downstreams of tasks that had their state set...
clear_kwargs = {
"only_failed": True,
"session": session,
# Exclude the task itself from being cleared.
"exclude_task_ids": frozenset((task_id,)),
}
if not future and not past: # Simple case 1: we're only dealing with exactly one run.
clear_kwargs["run_id"] = run_id
subset.clear(**clear_kwargs)
elif future and past: # Simple case 2: we're clearing ALL runs.
subset.clear(**clear_kwargs)
else: # Complex cases: we may have more than one run, based on a date range.
# Make 'future' and 'past' make some sense when multiple runs exist
# for the same logical date. We order runs by their id and only
# clear runs have larger/smaller ids.
exclude_run_id_stmt = select(DagRun.run_id).where(DagRun.logical_date == logical_date)
if future:
clear_kwargs["start_date"] = logical_date
exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id > dr_id)
else:
clear_kwargs["end_date"] = logical_date
exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id < dr_id)
subset.clear(exclude_run_ids=frozenset(session.scalars(exclude_run_id_stmt)), **clear_kwargs)
# Raises an error if not found
dr_id, logical_date = session.execute(
select(DagRun.id, DagRun.logical_date).where(
DagRun.run_id == run_id, DagRun.dag_id == self.dag_id
)
).one()

# Now we want to clear downstreams of tasks that had their state set...
clear_kwargs = {
"only_failed": True,
"session": session,
# Exclude the task itself from being cleared.
"exclude_task_ids": frozenset((task_id,)),
}
if not future and not past: # Simple case 1: we're only dealing with exactly one run.
clear_kwargs["run_id"] = run_id
subset.clear(**clear_kwargs)
elif future and past: # Simple case 2: we're clearing ALL runs.
subset.clear(**clear_kwargs)
else: # Complex cases: we may have more than one run, based on a date range.
# Make 'future' and 'past' make some sense when multiple runs exist
# for the same logical date. We order runs by their id and only
# clear runs have larger/smaller ids.
exclude_run_id_stmt = select(DagRun.run_id).where(DagRun.logical_date == logical_date)
if future:
clear_kwargs["start_date"] = logical_date
exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id > dr_id)
else:
clear_kwargs["end_date"] = logical_date
exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id < dr_id)
subset.clear(exclude_run_ids=frozenset(session.scalars(exclude_run_id_stmt)), **clear_kwargs)
return altered

@provide_session
Expand Down
79 changes: 75 additions & 4 deletions airflow-core/tests/unit/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2999,7 +2999,7 @@ def test_count_number_queries(self, tasks_count, testing_dag_bundle):
["test-run-id"],
)
def test_set_task_instance_state(run_id, session, dag_maker):
"""Test that set_task_instance_state updates the TaskInstance state and clear downstream failed"""
"""Test that set_task_instance_state updates the TaskInstance state"""
start_date = datetime_tz(2020, 1, 1)
with dag_maker(
"test_set_task_instance_state",
Expand Down Expand Up @@ -3037,6 +3037,8 @@ def get_ti_from_db(task):

session.flush()

# When downstream=False (default), only the selected TI state is changed -
# downstream failed/upstream_failed tasks are NOT cleared (Airflow 2 semantics).
altered = dag.set_task_instance_state(
task_id=task_1.task_id,
run_id=run_id,
Expand All @@ -3050,6 +3052,73 @@ def get_ti_from_db(task):
assert isinstance(inspect(ti1).attrs.dag_run.loaded_value, DagRun)
# task_2 remains as SUCCESS
assert get_ti_from_db(task_2).state == State.SUCCESS
# task_3 and task_4 remain in their FAILED/UPSTREAM_FAILED state because downstream=False
assert get_ti_from_db(task_3).state == State.UPSTREAM_FAILED
assert get_ti_from_db(task_4).state == State.FAILED
# task_5 remains as SKIPPED
assert get_ti_from_db(task_5).state == State.SKIPPED

assert {tuple(t.key) for t in altered} == {
("test_set_task_instance_state", "task_1", dagrun.run_id, 0, -1)
}


@pytest.mark.need_serialized_dag
@pytest.mark.parametrize(
"run_id",
["test-run-id"],
)
def test_set_task_instance_state_downstream_clears_failed(run_id, session, dag_maker):
"""Test that set_task_instance_state with downstream=True clears downstream failed/upstream_failed"""
start_date = datetime_tz(2020, 1, 1)
with dag_maker(
"test_set_task_instance_state_downstream",
start_date=start_date,
session=session,
serialized=True,
) as dag:
task_1 = EmptyOperator(task_id="task_1")
task_2 = EmptyOperator(task_id="task_2")
task_3 = EmptyOperator(task_id="task_3")
task_4 = EmptyOperator(task_id="task_4")
task_5 = EmptyOperator(task_id="task_5")
task_1 >> [task_2, task_3, task_4, task_5]

dagrun = dag_maker.create_dagrun(
run_id=run_id,
state=State.FAILED,
run_type=DagRunType.SCHEDULED,
)

def get_ti_from_db(task):
return session.scalar(
select(TI).where(
TI.dag_id == dag.dag_id,
TI.task_id == task.task_id,
TI.run_id == dagrun.run_id,
)
)

get_ti_from_db(task_1).state = State.FAILED
get_ti_from_db(task_2).state = State.SUCCESS
get_ti_from_db(task_3).state = State.UPSTREAM_FAILED
get_ti_from_db(task_4).state = State.FAILED
get_ti_from_db(task_5).state = State.SKIPPED

session.flush()

# When downstream=True, downstream failed/upstream_failed tasks ARE cleared.
altered = dag.set_task_instance_state(
task_id=task_1.task_id,
downstream=True,
run_id=run_id,
state=State.SUCCESS,
session=session,
)
ti1 = get_ti_from_db(task_1)
assert ti1.state == State.SUCCESS
# task_2 remains as SUCCESS
assert get_ti_from_db(task_2).state == State.SUCCESS
# task_3 and task_4 are cleared because they were in FAILED/UPSTREAM_FAILED state
assert get_ti_from_db(task_3).state == State.NONE
assert get_ti_from_db(task_4).state == State.NONE
Expand All @@ -3060,7 +3129,7 @@ def get_ti_from_db(task):
assert dagrun.get_state() == State.QUEUED

assert {tuple(t.key) for t in altered} == {
("test_set_task_instance_state", "task_1", dagrun.run_id, 0, -1)
("test_set_task_instance_state_downstream", "task_1", dagrun.run_id, 0, -1)
}


Expand Down Expand Up @@ -3119,6 +3188,8 @@ def consumer(value):
(task_id, 1, dr2.run_id, TaskInstanceState.FAILED),
]

# When ``downstream`` is not passed, only the selected TI state is changed —
# downstream failed/upstream_failed tasks are NOT cleared (Airflow 2 semantics).
dag.set_task_instance_state(
task_id=task_id,
map_indexes=[1],
Expand All @@ -3130,10 +3201,10 @@ def consumer(value):
assert dr1 in session, "Check session is passed down all the way"

assert session.execute(ti_query).all() == [
("downstream", -1, dr1.run_id, None),
("downstream", -1, dr1.run_id, TaskInstanceState.FAILED),
(task_id, 0, dr1.run_id, TaskInstanceState.FAILED),
(task_id, 1, dr1.run_id, TaskInstanceState.SUCCESS),
("downstream", -1, dr2.run_id, None),
("downstream", -1, dr2.run_id, TaskInstanceState.FAILED),
(task_id, 0, dr2.run_id, TaskInstanceState.FAILED),
(task_id, 1, dr2.run_id, TaskInstanceState.SUCCESS),
]
Expand Down
Loading