|
28 | 28 | from datetime import timedelta
|
29 | 29 | from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple
|
30 | 30 |
|
31 |
| -from sqlalchemy import func, not_, or_, text, select |
| 31 | +from sqlalchemy import func, and_, or_, text, select, desc |
32 | 32 | from sqlalchemy.exc import OperationalError
|
33 | 33 | from sqlalchemy.orm import load_only, selectinload
|
34 | 34 | from sqlalchemy.orm.session import Session, make_transient
|
@@ -330,198 +330,189 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
|
330 | 330 | # dag and task ids that can't be queued because of concurrency limits
|
331 | 331 | starved_dags: Set[str] = self._get_starved_dags(session=session)
|
332 | 332 | starved_tasks: Set[Tuple[str, str]] = set()
|
333 |
| - |
334 | 333 | pool_num_starving_tasks: DefaultDict[str, int] = defaultdict(int)
|
| 334 | + |
| 335 | + # Subquery to get the current active task count for each DAG |
| 336 | + # Only considering tasks from running DAG runs |
| 337 | + current_active_tasks = ( |
| 338 | + session.query( |
| 339 | + TI.dag_id, |
| 340 | + func.count().label('active_count') |
| 341 | + ) |
| 342 | + .join(DR, and_(DR.dag_id == TI.dag_id, DR.run_id == TI.run_id)) |
| 343 | + .filter(DR.state == DagRunState.RUNNING) |
| 344 | + .filter(TI.state.in_([TaskInstanceState.RUNNING, TaskInstanceState.QUEUED])) |
| 345 | + .group_by(TI.dag_id) |
| 346 | + .subquery() |
| 347 | + ) |
335 | 348 |
|
336 |
| - for loop_count in itertools.count(start=1): |
337 |
| - |
338 |
| - num_starved_pools = len(starved_pools) |
339 |
| - num_starved_dags = len(starved_dags) |
340 |
| - num_starved_tasks = len(starved_tasks) |
341 |
| - |
342 |
| - # Get task instances associated with scheduled |
343 |
| - # DagRuns which are not backfilled, in the given states, |
344 |
| - # and the dag is not paused |
345 |
| - query = ( |
346 |
| - session.query(TI) |
347 |
| - .with_hint(TI, 'USE INDEX (ti_state)', dialect_name='mysql') |
348 |
| - .join(TI.dag_run) |
349 |
| - .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state == DagRunState.RUNNING) |
350 |
| - .join(TI.dag_model) |
351 |
| - .filter(not_(DM.is_paused)) |
352 |
| - .filter(TI.state == TaskInstanceState.SCHEDULED) |
353 |
| - .options(selectinload('dag_model')) |
354 |
| - .order_by(-TI.priority_weight, DR.execution_date) |
| 349 | + # Get the limit for each DAG |
| 350 | + dag_limit_subquery = ( |
| 351 | + session.query( |
| 352 | + DM.dag_id, |
| 353 | + func.greatest(DM.max_active_tasks - func.coalesce(current_active_tasks.c.active_count, 0), 0).label('dag_limit') |
355 | 354 | )
|
| 355 | + .outerjoin(current_active_tasks, DM.dag_id == current_active_tasks.c.dag_id) |
| 356 | + .subquery() |
| 357 | + ) |
356 | 358 |
|
357 |
| - if starved_pools: |
358 |
| - query = query.filter(not_(TI.pool.in_(starved_pools))) |
| 359 | + # Subquery to rank tasks within each DAG |
| 360 | + ranked_tis = ( |
| 361 | + session.query( |
| 362 | + TI, |
| 363 | + func.row_number().over( |
| 364 | + partition_by=TI.dag_id, |
| 365 | + order_by=[desc(TI.priority_weight), TI.start_date] |
| 366 | + ).label('row_number'), |
| 367 | + dag_limit_subquery.c.dag_limit |
| 368 | + ) |
| 369 | + .join(TI.dag_run) |
| 370 | + .join(DM, TI.dag_id == DM.dag_id) |
| 371 | + .join(dag_limit_subquery, TI.dag_id == dag_limit_subquery.c.dag_id) |
| 372 | + .filter( |
| 373 | + DR.state == DagRunState.RUNNING, |
| 374 | + DR.run_type != DagRunType.BACKFILL_JOB, |
| 375 | + ~DM.is_paused, |
| 376 | + ~TI.dag_id.in_(starved_dags), |
| 377 | + ~TI.pool.in_(starved_pools), |
| 378 | + TI.state == TaskInstanceState.SCHEDULED, |
| 379 | + ) |
| 380 | + ).subquery() |
| 381 | + |
| 382 | + if starved_tasks: |
| 383 | + ranked_tis = ranked_tis.filter( |
| 384 | + ~func.concat(TI.dag_id, TI.task_id).in_([f"{dag_id}{task_id}" for dag_id, task_id in starved_tasks]) |
| 385 | + ) |
| 386 | + |
| 387 | + final_query = ( |
| 388 | + session.query(TI) |
| 389 | + .join( |
| 390 | + ranked_tis, |
| 391 | + and_( |
| 392 | + TI.task_id == ranked_tis.c.task_id, |
| 393 | + TI.dag_id == ranked_tis.c.dag_id, |
| 394 | + TI.run_id == ranked_tis.c.run_id |
| 395 | + ) |
| 396 | + ) |
| 397 | + .filter(ranked_tis.c.row_number <= ranked_tis.c.dag_limit) |
| 398 | + .order_by(desc(ranked_tis.c.priority_weight), ranked_tis.c.start_date) |
| 399 | + .limit(max_tis) |
| 400 | + ) |
| 401 | + |
| 402 | + # Execute the query with row locks |
| 403 | + task_instances_to_examine: List[TI] = with_row_locks( |
| 404 | + final_query, |
| 405 | + of=TI, |
| 406 | + session=session, |
| 407 | + **skip_locked(session=session), |
| 408 | + ).all() |
359 | 409 |
|
360 |
| - if starved_dags: |
361 |
| - query = query.filter(not_(TI.dag_id.in_(starved_dags))) |
| 410 | + |
| 411 | + if len(task_instances_to_examine) == 0: |
| 412 | + self.log.debug("No tasks to consider for execution.") |
| 413 | + return [] |
| 414 | + # else: |
| 415 | + # print("---dag_limit_subquery") |
| 416 | + # print(str(dag_limit_subquery.select().compile(compile_kwargs={"literal_binds": True}))) |
| 417 | + # print("---ranked_tis-query") |
| 418 | + # print(str(ranked_tis.select().compile(compile_kwargs={"literal_binds": True}))) |
| 419 | + # print("---FINAL QUERY") |
| 420 | + # print(str(final_query.statement.compile(compile_kwargs={"literal_binds": True}))) |
| 421 | + |
| 422 | + # Put one task instance on each line |
| 423 | + task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine) |
| 424 | + self.log.info( |
| 425 | + "%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str |
| 426 | + ) |
| 427 | + |
| 428 | + pool_slot_tracker = {pool_name: stats['open'] for pool_name, stats in pools.items()} |
362 | 429 |
|
363 |
| - if starved_tasks: |
364 |
| - task_filter = tuple_in_condition((TaskInstance.dag_id, TaskInstance.task_id), starved_tasks) |
365 |
| - query = query.filter(not_(task_filter)) |
| 430 | + for task_instance in task_instances_to_examine: |
| 431 | + pool_name = task_instance.pool |
366 | 432 |
|
367 |
| - query = query.limit(max_tis) |
| 433 | + pool_stats = pools.get(pool_name) |
| 434 | + if not pool_stats: |
| 435 | + self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool_name) |
| 436 | + starved_pools.add(pool_name) |
| 437 | + continue |
368 | 438 |
|
369 |
| - task_instances_to_examine: List[TI] = with_row_locks( |
370 |
| - query, |
371 |
| - of=TI, |
372 |
| - session=session, |
373 |
| - **skip_locked(session=session), |
374 |
| - ).all() |
375 |
| - # TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything. |
376 |
| - # Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine)) |
| 439 | + |
| 440 | + # # Make sure to emit metrics if pool has no starving tasks |
| 441 | + # # pool_num_starving_tasks.setdefault(pool_name, 0) |
| 442 | + # pool_total = pool_stats["total"] |
| 443 | + open_slots = pool_stats["open"] |
377 | 444 |
|
378 |
| - if len(task_instances_to_examine) == 0: |
379 |
| - self.log.debug("No tasks to consider for execution.") |
380 |
| - break |
| 445 | + # Check to make sure that the task max_active_tasks of the DAG hasn't been |
| 446 | + # reached. |
| 447 | + # This shoulnd't happen anymore but still leaving it here for debugging purposes |
| 448 | + dag_id = task_instance.dag_id |
381 | 449 |
|
382 |
| - # Put one task instance on each line |
383 |
| - task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine) |
| 450 | + current_active_tasks_per_dag = dag_active_tasks_map[dag_id] |
| 451 | + max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks |
384 | 452 | self.log.info(
|
385 |
| - "%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str |
| 453 | + "DAG %s has %s/%s running and queued tasks", |
| 454 | + dag_id, |
| 455 | + current_active_tasks_per_dag, |
| 456 | + max_active_tasks_per_dag_limit, |
386 | 457 | )
|
387 |
| - |
388 |
| - for task_instance in task_instances_to_examine: |
389 |
| - pool_name = task_instance.pool |
390 |
| - |
391 |
| - pool_stats = pools.get(pool_name) |
392 |
| - if not pool_stats: |
393 |
| - self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool_name) |
394 |
| - starved_pools.add(pool_name) |
395 |
| - continue |
396 |
| - |
397 |
| - # Make sure to emit metrics if pool has no starving tasks |
398 |
| - pool_num_starving_tasks.setdefault(pool_name, 0) |
399 |
| - |
400 |
| - pool_total = pool_stats["total"] |
401 |
| - open_slots = pool_stats["open"] |
402 |
| - |
403 |
| - if open_slots <= 0: |
404 |
| - self.log.info( |
405 |
| - "Not scheduling since there are %s open slots in pool %s", open_slots, pool_name |
406 |
| - ) |
407 |
| - # Can't schedule any more since there are no more open slots. |
408 |
| - pool_num_starving_tasks[pool_name] += 1 |
409 |
| - num_starving_tasks_total += 1 |
410 |
| - starved_pools.add(pool_name) |
411 |
| - continue |
412 |
| - |
413 |
| - if task_instance.pool_slots > pool_total: |
414 |
| - self.log.warning( |
415 |
| - "Not executing %s. Requested pool slots (%s) are greater than " |
416 |
| - "total pool slots: '%s' for pool: %s.", |
417 |
| - task_instance, |
418 |
| - task_instance.pool_slots, |
419 |
| - pool_total, |
420 |
| - pool_name, |
421 |
| - ) |
422 |
| - |
423 |
| - pool_num_starving_tasks[pool_name] += 1 |
424 |
| - num_starving_tasks_total += 1 |
425 |
| - starved_tasks.add((task_instance.dag_id, task_instance.task_id)) |
426 |
| - continue |
427 |
| - |
428 |
| - if task_instance.pool_slots > open_slots: |
429 |
| - self.log.info( |
430 |
| - "Not executing %s since it requires %s slots " |
431 |
| - "but there are %s open slots in the pool %s.", |
432 |
| - task_instance, |
433 |
| - task_instance.pool_slots, |
434 |
| - open_slots, |
435 |
| - pool_name, |
436 |
| - ) |
437 |
| - pool_num_starving_tasks[pool_name] += 1 |
438 |
| - num_starving_tasks_total += 1 |
439 |
| - starved_tasks.add((task_instance.dag_id, task_instance.task_id)) |
440 |
| - # Though we can execute tasks with lower priority if there's enough room |
441 |
| - continue |
442 |
| - |
443 |
| - # Check to make sure that the task max_active_tasks of the DAG hasn't been |
444 |
| - # reached. |
445 |
| - dag_id = task_instance.dag_id |
446 |
| - |
447 |
| - current_active_tasks_per_dag = dag_active_tasks_map[dag_id] |
448 |
| - max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks |
| 458 | + if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit: |
449 | 459 | self.log.info(
|
450 |
| - "DAG %s has %s/%s running and queued tasks", |
| 460 | + "Not executing %s since the number of tasks running or queued " |
| 461 | + "from DAG %s is >= to the DAG's max_active_tasks limit of %s", |
| 462 | + task_instance, |
451 | 463 | dag_id,
|
452 |
| - current_active_tasks_per_dag, |
453 | 464 | max_active_tasks_per_dag_limit,
|
454 | 465 | )
|
455 |
| - if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit: |
456 |
| - self.log.info( |
457 |
| - "Not executing %s since the number of tasks running or queued " |
458 |
| - "from DAG %s is >= to the DAG's max_active_tasks limit of %s", |
459 |
| - task_instance, |
| 466 | + starved_dags.add(dag_id) |
| 467 | + |
| 468 | + if task_instance.dag_model.has_task_concurrency_limits: |
| 469 | + # Many dags don't have a task_concurrency, so where we can avoid loading the full |
| 470 | + # serialized DAG the better. |
| 471 | + serialized_dag = self.dagbag.get_dag(dag_id, session=session) |
| 472 | + # If the dag is missing, fail the task and continue to the next task. |
| 473 | + if not serialized_dag: |
| 474 | + self.log.error( |
| 475 | + "DAG '%s' for task instance %s not found in serialized_dag table", |
460 | 476 | dag_id,
|
461 |
| - max_active_tasks_per_dag_limit, |
| 477 | + task_instance, |
462 | 478 | )
|
463 |
| - starved_dags.add(dag_id) |
464 |
| - continue |
| 479 | + session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update( |
| 480 | + {TI.state: State.FAILED}, synchronize_session='fetch' |
| 481 | + ) |
| 482 | + # continue |
| 483 | + |
| 484 | + task_concurrency_limit: Optional[int] = None |
| 485 | + if serialized_dag.has_task(task_instance.task_id): |
| 486 | + task_concurrency_limit = serialized_dag.get_task( |
| 487 | + task_instance.task_id |
| 488 | + ).max_active_tis_per_dag |
465 | 489 |
|
466 |
| - if task_instance.dag_model.has_task_concurrency_limits: |
467 |
| - # Many dags don't have a task_concurrency, so where we can avoid loading the full |
468 |
| - # serialized DAG the better. |
469 |
| - serialized_dag = self.dagbag.get_dag(dag_id, session=session) |
470 |
| - # If the dag is missing, fail the task and continue to the next task. |
471 |
| - if not serialized_dag: |
472 |
| - self.log.error( |
473 |
| - "DAG '%s' for task instance %s not found in serialized_dag table", |
474 |
| - dag_id, |
| 490 | + if task_concurrency_limit is not None: |
| 491 | + current_task_concurrency = task_concurrency_map[ |
| 492 | + (task_instance.dag_id, task_instance.task_id) |
| 493 | + ] |
| 494 | + |
| 495 | + if current_task_concurrency >= task_concurrency_limit: |
| 496 | + self.log.info( |
| 497 | + "Not executing %s since the task concurrency for" |
| 498 | + " this task has been reached.", |
475 | 499 | task_instance,
|
476 | 500 | )
|
477 |
| - session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update( |
478 |
| - {TI.state: State.FAILED}, synchronize_session='fetch' |
479 |
| - ) |
| 501 | + starved_tasks.add((task_instance.dag_id, task_instance.task_id)) |
480 | 502 | continue
|
481 |
| - |
482 |
| - task_concurrency_limit: Optional[int] = None |
483 |
| - if serialized_dag.has_task(task_instance.task_id): |
484 |
| - task_concurrency_limit = serialized_dag.get_task( |
485 |
| - task_instance.task_id |
486 |
| - ).max_active_tis_per_dag |
487 |
| - |
488 |
| - if task_concurrency_limit is not None: |
489 |
| - current_task_concurrency = task_concurrency_map[ |
490 |
| - (task_instance.dag_id, task_instance.task_id) |
491 |
| - ] |
492 |
| - |
493 |
| - if current_task_concurrency >= task_concurrency_limit: |
494 |
| - self.log.info( |
495 |
| - "Not executing %s since the task concurrency for" |
496 |
| - " this task has been reached.", |
497 |
| - task_instance, |
498 |
| - ) |
499 |
| - starved_tasks.add((task_instance.dag_id, task_instance.task_id)) |
500 |
| - continue |
501 |
| - |
| 503 | + |
| 504 | + # Check pool-specific slot availability |
| 505 | + if (pool_slot_tracker.get(pool_name, 0) >= task_instance.pool_slots): |
502 | 506 | executable_tis.append(task_instance)
|
503 | 507 | open_slots -= task_instance.pool_slots
|
504 | 508 | dag_active_tasks_map[dag_id] += 1
|
505 | 509 | task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1
|
506 |
| - |
507 | 510 | pool_stats["open"] = open_slots
|
| 511 | + else: |
| 512 | + starved_tasks.add((task_instance.dag_id, task_instance.task_id)) |
| 513 | + pool_num_starving_tasks[pool_name] += 1 |
| 514 | + num_starving_tasks_total += 1 |
508 | 515 |
|
509 |
| - is_done = executable_tis or len(task_instances_to_examine) < max_tis |
510 |
| - # Check this to avoid accidental infinite loops |
511 |
| - found_new_filters = ( |
512 |
| - len(starved_pools) > num_starved_pools |
513 |
| - or len(starved_dags) > num_starved_dags |
514 |
| - or len(starved_tasks) > num_starved_tasks |
515 |
| - ) |
516 |
| - |
517 |
| - if is_done or not found_new_filters: |
518 |
| - break |
519 |
| - |
520 |
| - self.log.debug( |
521 |
| - "Found no task instances to queue on the %s. iteration " |
522 |
| - "but there could be more candidate task instances to check.", |
523 |
| - loop_count, |
524 |
| - ) |
525 | 516 |
|
526 | 517 | for pool_name, num_starving_tasks in pool_num_starving_tasks.items():
|
527 | 518 | Stats.gauge(f'pool.starving_tasks.{pool_name}', num_starving_tasks)
|
|
0 commit comments