Skip to content

Commit 5514406

Browse files
committed
🔨Fix airflow scheduler
1 parent b73d0fa commit 5514406

File tree

2 files changed

+159
-167
lines changed

2 files changed

+159
-167
lines changed

‎airflow/jobs/scheduler_job.py

Lines changed: 156 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from datetime import timedelta
2929
from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple
3030

31-
from sqlalchemy import func, not_, or_, text, select
31+
from sqlalchemy import func, and_, or_, text, select, desc
3232
from sqlalchemy.exc import OperationalError
3333
from sqlalchemy.orm import load_only, selectinload
3434
from sqlalchemy.orm.session import Session, make_transient
@@ -330,198 +330,189 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
330330
# dag and task ids that can't be queued because of concurrency limits
331331
starved_dags: Set[str] = self._get_starved_dags(session=session)
332332
starved_tasks: Set[Tuple[str, str]] = set()
333-
334333
pool_num_starving_tasks: DefaultDict[str, int] = defaultdict(int)
334+
335+
# Subquery to get the current active task count for each DAG
336+
# Only considering tasks from running DAG runs
337+
current_active_tasks = (
338+
session.query(
339+
TI.dag_id,
340+
func.count().label('active_count')
341+
)
342+
.join(DR, and_(DR.dag_id == TI.dag_id, DR.run_id == TI.run_id))
343+
.filter(DR.state == DagRunState.RUNNING)
344+
.filter(TI.state.in_([TaskInstanceState.RUNNING, TaskInstanceState.QUEUED]))
345+
.group_by(TI.dag_id)
346+
.subquery()
347+
)
335348

336-
for loop_count in itertools.count(start=1):
337-
338-
num_starved_pools = len(starved_pools)
339-
num_starved_dags = len(starved_dags)
340-
num_starved_tasks = len(starved_tasks)
341-
342-
# Get task instances associated with scheduled
343-
# DagRuns which are not backfilled, in the given states,
344-
# and the dag is not paused
345-
query = (
346-
session.query(TI)
347-
.with_hint(TI, 'USE INDEX (ti_state)', dialect_name='mysql')
348-
.join(TI.dag_run)
349-
.filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state == DagRunState.RUNNING)
350-
.join(TI.dag_model)
351-
.filter(not_(DM.is_paused))
352-
.filter(TI.state == TaskInstanceState.SCHEDULED)
353-
.options(selectinload('dag_model'))
354-
.order_by(-TI.priority_weight, DR.execution_date)
349+
# Get the limit for each DAG
350+
dag_limit_subquery = (
351+
session.query(
352+
DM.dag_id,
353+
func.greatest(DM.max_active_tasks - func.coalesce(current_active_tasks.c.active_count, 0), 0).label('dag_limit')
355354
)
355+
.outerjoin(current_active_tasks, DM.dag_id == current_active_tasks.c.dag_id)
356+
.subquery()
357+
)
356358

357-
if starved_pools:
358-
query = query.filter(not_(TI.pool.in_(starved_pools)))
359+
# Subquery to rank tasks within each DAG
360+
ranked_tis = (
361+
session.query(
362+
TI,
363+
func.row_number().over(
364+
partition_by=TI.dag_id,
365+
order_by=[desc(TI.priority_weight), TI.start_date]
366+
).label('row_number'),
367+
dag_limit_subquery.c.dag_limit
368+
)
369+
.join(TI.dag_run)
370+
.join(DM, TI.dag_id == DM.dag_id)
371+
.join(dag_limit_subquery, TI.dag_id == dag_limit_subquery.c.dag_id)
372+
.filter(
373+
DR.state == DagRunState.RUNNING,
374+
DR.run_type != DagRunType.BACKFILL_JOB,
375+
~DM.is_paused,
376+
~TI.dag_id.in_(starved_dags),
377+
~TI.pool.in_(starved_pools),
378+
TI.state == TaskInstanceState.SCHEDULED,
379+
)
380+
).subquery()
381+
382+
if starved_tasks:
383+
ranked_tis = ranked_tis.filter(
384+
~func.concat(TI.dag_id, TI.task_id).in_([f"{dag_id}{task_id}" for dag_id, task_id in starved_tasks])
385+
)
386+
387+
final_query = (
388+
session.query(TI)
389+
.join(
390+
ranked_tis,
391+
and_(
392+
TI.task_id == ranked_tis.c.task_id,
393+
TI.dag_id == ranked_tis.c.dag_id,
394+
TI.run_id == ranked_tis.c.run_id
395+
)
396+
)
397+
.filter(ranked_tis.c.row_number <= ranked_tis.c.dag_limit)
398+
.order_by(desc(ranked_tis.c.priority_weight), ranked_tis.c.start_date)
399+
.limit(max_tis)
400+
)
401+
402+
# Execute the query with row locks
403+
task_instances_to_examine: List[TI] = with_row_locks(
404+
final_query,
405+
of=TI,
406+
session=session,
407+
**skip_locked(session=session),
408+
).all()
359409

360-
if starved_dags:
361-
query = query.filter(not_(TI.dag_id.in_(starved_dags)))
410+
411+
if len(task_instances_to_examine) == 0:
412+
self.log.debug("No tasks to consider for execution.")
413+
return []
414+
# else:
415+
# print("---dag_limit_subquery")
416+
# print(str(dag_limit_subquery.select().compile(compile_kwargs={"literal_binds": True})))
417+
# print("---ranked_tis-query")
418+
# print(str(ranked_tis.select().compile(compile_kwargs={"literal_binds": True})))
419+
# print("---FINAL QUERY")
420+
# print(str(final_query.statement.compile(compile_kwargs={"literal_binds": True})))
421+
422+
# Put one task instance on each line
423+
task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine)
424+
self.log.info(
425+
"%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str
426+
)
427+
428+
pool_slot_tracker = {pool_name: stats['open'] for pool_name, stats in pools.items()}
362429

363-
if starved_tasks:
364-
task_filter = tuple_in_condition((TaskInstance.dag_id, TaskInstance.task_id), starved_tasks)
365-
query = query.filter(not_(task_filter))
430+
for task_instance in task_instances_to_examine:
431+
pool_name = task_instance.pool
366432

367-
query = query.limit(max_tis)
433+
pool_stats = pools.get(pool_name)
434+
if not pool_stats:
435+
self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool_name)
436+
starved_pools.add(pool_name)
437+
continue
368438

369-
task_instances_to_examine: List[TI] = with_row_locks(
370-
query,
371-
of=TI,
372-
session=session,
373-
**skip_locked(session=session),
374-
).all()
375-
# TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything.
376-
# Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine))
439+
440+
# # Make sure to emit metrics if pool has no starving tasks
441+
# # pool_num_starving_tasks.setdefault(pool_name, 0)
442+
# pool_total = pool_stats["total"]
443+
open_slots = pool_stats["open"]
377444

378-
if len(task_instances_to_examine) == 0:
379-
self.log.debug("No tasks to consider for execution.")
380-
break
445+
# Check to make sure that the task max_active_tasks of the DAG hasn't been
446+
# reached.
447+
# This shoulnd't happen anymore but still leaving it here for debugging purposes
448+
dag_id = task_instance.dag_id
381449

382-
# Put one task instance on each line
383-
task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine)
450+
current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
451+
max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks
384452
self.log.info(
385-
"%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str
453+
"DAG %s has %s/%s running and queued tasks",
454+
dag_id,
455+
current_active_tasks_per_dag,
456+
max_active_tasks_per_dag_limit,
386457
)
387-
388-
for task_instance in task_instances_to_examine:
389-
pool_name = task_instance.pool
390-
391-
pool_stats = pools.get(pool_name)
392-
if not pool_stats:
393-
self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool_name)
394-
starved_pools.add(pool_name)
395-
continue
396-
397-
# Make sure to emit metrics if pool has no starving tasks
398-
pool_num_starving_tasks.setdefault(pool_name, 0)
399-
400-
pool_total = pool_stats["total"]
401-
open_slots = pool_stats["open"]
402-
403-
if open_slots <= 0:
404-
self.log.info(
405-
"Not scheduling since there are %s open slots in pool %s", open_slots, pool_name
406-
)
407-
# Can't schedule any more since there are no more open slots.
408-
pool_num_starving_tasks[pool_name] += 1
409-
num_starving_tasks_total += 1
410-
starved_pools.add(pool_name)
411-
continue
412-
413-
if task_instance.pool_slots > pool_total:
414-
self.log.warning(
415-
"Not executing %s. Requested pool slots (%s) are greater than "
416-
"total pool slots: '%s' for pool: %s.",
417-
task_instance,
418-
task_instance.pool_slots,
419-
pool_total,
420-
pool_name,
421-
)
422-
423-
pool_num_starving_tasks[pool_name] += 1
424-
num_starving_tasks_total += 1
425-
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
426-
continue
427-
428-
if task_instance.pool_slots > open_slots:
429-
self.log.info(
430-
"Not executing %s since it requires %s slots "
431-
"but there are %s open slots in the pool %s.",
432-
task_instance,
433-
task_instance.pool_slots,
434-
open_slots,
435-
pool_name,
436-
)
437-
pool_num_starving_tasks[pool_name] += 1
438-
num_starving_tasks_total += 1
439-
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
440-
# Though we can execute tasks with lower priority if there's enough room
441-
continue
442-
443-
# Check to make sure that the task max_active_tasks of the DAG hasn't been
444-
# reached.
445-
dag_id = task_instance.dag_id
446-
447-
current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
448-
max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks
458+
if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit:
449459
self.log.info(
450-
"DAG %s has %s/%s running and queued tasks",
460+
"Not executing %s since the number of tasks running or queued "
461+
"from DAG %s is >= to the DAG's max_active_tasks limit of %s",
462+
task_instance,
451463
dag_id,
452-
current_active_tasks_per_dag,
453464
max_active_tasks_per_dag_limit,
454465
)
455-
if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit:
456-
self.log.info(
457-
"Not executing %s since the number of tasks running or queued "
458-
"from DAG %s is >= to the DAG's max_active_tasks limit of %s",
459-
task_instance,
466+
starved_dags.add(dag_id)
467+
468+
if task_instance.dag_model.has_task_concurrency_limits:
469+
# Many dags don't have a task_concurrency, so where we can avoid loading the full
470+
# serialized DAG the better.
471+
serialized_dag = self.dagbag.get_dag(dag_id, session=session)
472+
# If the dag is missing, fail the task and continue to the next task.
473+
if not serialized_dag:
474+
self.log.error(
475+
"DAG '%s' for task instance %s not found in serialized_dag table",
460476
dag_id,
461-
max_active_tasks_per_dag_limit,
477+
task_instance,
462478
)
463-
starved_dags.add(dag_id)
464-
continue
479+
session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update(
480+
{TI.state: State.FAILED}, synchronize_session='fetch'
481+
)
482+
# continue
483+
484+
task_concurrency_limit: Optional[int] = None
485+
if serialized_dag.has_task(task_instance.task_id):
486+
task_concurrency_limit = serialized_dag.get_task(
487+
task_instance.task_id
488+
).max_active_tis_per_dag
465489

466-
if task_instance.dag_model.has_task_concurrency_limits:
467-
# Many dags don't have a task_concurrency, so where we can avoid loading the full
468-
# serialized DAG the better.
469-
serialized_dag = self.dagbag.get_dag(dag_id, session=session)
470-
# If the dag is missing, fail the task and continue to the next task.
471-
if not serialized_dag:
472-
self.log.error(
473-
"DAG '%s' for task instance %s not found in serialized_dag table",
474-
dag_id,
490+
if task_concurrency_limit is not None:
491+
current_task_concurrency = task_concurrency_map[
492+
(task_instance.dag_id, task_instance.task_id)
493+
]
494+
495+
if current_task_concurrency >= task_concurrency_limit:
496+
self.log.info(
497+
"Not executing %s since the task concurrency for"
498+
" this task has been reached.",
475499
task_instance,
476500
)
477-
session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update(
478-
{TI.state: State.FAILED}, synchronize_session='fetch'
479-
)
501+
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
480502
continue
481-
482-
task_concurrency_limit: Optional[int] = None
483-
if serialized_dag.has_task(task_instance.task_id):
484-
task_concurrency_limit = serialized_dag.get_task(
485-
task_instance.task_id
486-
).max_active_tis_per_dag
487-
488-
if task_concurrency_limit is not None:
489-
current_task_concurrency = task_concurrency_map[
490-
(task_instance.dag_id, task_instance.task_id)
491-
]
492-
493-
if current_task_concurrency >= task_concurrency_limit:
494-
self.log.info(
495-
"Not executing %s since the task concurrency for"
496-
" this task has been reached.",
497-
task_instance,
498-
)
499-
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
500-
continue
501-
503+
504+
# Check pool-specific slot availability
505+
if (pool_slot_tracker.get(pool_name, 0) >= task_instance.pool_slots):
502506
executable_tis.append(task_instance)
503507
open_slots -= task_instance.pool_slots
504508
dag_active_tasks_map[dag_id] += 1
505509
task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1
506-
507510
pool_stats["open"] = open_slots
511+
else:
512+
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
513+
pool_num_starving_tasks[pool_name] += 1
514+
num_starving_tasks_total += 1
508515

509-
is_done = executable_tis or len(task_instances_to_examine) < max_tis
510-
# Check this to avoid accidental infinite loops
511-
found_new_filters = (
512-
len(starved_pools) > num_starved_pools
513-
or len(starved_dags) > num_starved_dags
514-
or len(starved_tasks) > num_starved_tasks
515-
)
516-
517-
if is_done or not found_new_filters:
518-
break
519-
520-
self.log.debug(
521-
"Found no task instances to queue on the %s. iteration "
522-
"but there could be more candidate task instances to check.",
523-
loop_count,
524-
)
525516

526517
for pool_name, num_starving_tasks in pool_num_starving_tasks.items():
527518
Stats.gauge(f'pool.starving_tasks.{pool_name}', num_starving_tasks)

‎airflow/www/static/js/ti_log.js

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ function autoTailingLog(tryNumber, metadata = null, autoTailing = false) {
140140
}
141141
recurse().then(() => autoTailingLog(tryNumber, res.metadata, autoTailing));
142142
}).catch((error) => {
143-
console.error(`Error while retrieving log: ${error}`);
143+
console.error(`Error while retrieving log`, error);
144144

145145
const externalLogUrl = getMetaValue('external_log_url');
146146
const fullExternalUrl = `${externalLogUrl
@@ -151,7 +151,7 @@ function autoTailingLog(tryNumber, metadata = null, autoTailing = false) {
151151

152152
document.getElementById(`loading-${tryNumber}`).style.display = 'none';
153153

154-
const logBlockElementId = `try-${tryNumber}-${item[0]}`;
154+
const logBlockElementId = `try-${tryNumber}-error`;
155155
let logBlock = document.getElementById(logBlockElementId);
156156
if (!logBlock) {
157157
const logDivBlock = document.createElement('div');
@@ -164,6 +164,7 @@ function autoTailingLog(tryNumber, metadata = null, autoTailing = false) {
164164

165165
logBlock.innerHTML += "There was an error while retrieving the log from S3. Please use Kibana to view the logs.";
166166
logBlock.innerHTML += `<a href="${fullExternalUrl}" target="_blank">View logs in Kibana</a>`;
167+
167168
});
168169
}
169170

0 commit comments

Comments
 (0)