Skip to content

Commit

Permalink
🔨Fix airflow scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
luisglft committed Oct 21, 2024
1 parent b73d0fa commit 5514406
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 167 deletions.
321 changes: 156 additions & 165 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from datetime import timedelta
from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple

from sqlalchemy import func, not_, or_, text, select
from sqlalchemy import func, and_, or_, text, select, desc
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import load_only, selectinload
from sqlalchemy.orm.session import Session, make_transient
Expand Down Expand Up @@ -330,198 +330,189 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
# dag and task ids that can't be queued because of concurrency limits
starved_dags: Set[str] = self._get_starved_dags(session=session)
starved_tasks: Set[Tuple[str, str]] = set()

pool_num_starving_tasks: DefaultDict[str, int] = defaultdict(int)

# Subquery to get the current active task count for each DAG
# Only considering tasks from running DAG runs
current_active_tasks = (
session.query(
TI.dag_id,
func.count().label('active_count')
)
.join(DR, and_(DR.dag_id == TI.dag_id, DR.run_id == TI.run_id))
.filter(DR.state == DagRunState.RUNNING)
.filter(TI.state.in_([TaskInstanceState.RUNNING, TaskInstanceState.QUEUED]))
.group_by(TI.dag_id)
.subquery()
)

for loop_count in itertools.count(start=1):

num_starved_pools = len(starved_pools)
num_starved_dags = len(starved_dags)
num_starved_tasks = len(starved_tasks)

# Get task instances associated with scheduled
# DagRuns which are not backfilled, in the given states,
# and the dag is not paused
query = (
session.query(TI)
.with_hint(TI, 'USE INDEX (ti_state)', dialect_name='mysql')
.join(TI.dag_run)
.filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state == DagRunState.RUNNING)
.join(TI.dag_model)
.filter(not_(DM.is_paused))
.filter(TI.state == TaskInstanceState.SCHEDULED)
.options(selectinload('dag_model'))
.order_by(-TI.priority_weight, DR.execution_date)
# Get the limit for each DAG
dag_limit_subquery = (
session.query(
DM.dag_id,
func.greatest(DM.max_active_tasks - func.coalesce(current_active_tasks.c.active_count, 0), 0).label('dag_limit')
)
.outerjoin(current_active_tasks, DM.dag_id == current_active_tasks.c.dag_id)
.subquery()
)

if starved_pools:
query = query.filter(not_(TI.pool.in_(starved_pools)))
# Subquery to rank tasks within each DAG
ranked_tis = (
session.query(
TI,
func.row_number().over(
partition_by=TI.dag_id,
order_by=[desc(TI.priority_weight), TI.start_date]
).label('row_number'),
dag_limit_subquery.c.dag_limit
)
.join(TI.dag_run)
.join(DM, TI.dag_id == DM.dag_id)
.join(dag_limit_subquery, TI.dag_id == dag_limit_subquery.c.dag_id)
.filter(
DR.state == DagRunState.RUNNING,
DR.run_type != DagRunType.BACKFILL_JOB,
~DM.is_paused,
~TI.dag_id.in_(starved_dags),
~TI.pool.in_(starved_pools),
TI.state == TaskInstanceState.SCHEDULED,
)
).subquery()

if starved_tasks:
ranked_tis = ranked_tis.filter(
~func.concat(TI.dag_id, TI.task_id).in_([f"{dag_id}{task_id}" for dag_id, task_id in starved_tasks])
)

final_query = (
session.query(TI)
.join(
ranked_tis,
and_(
TI.task_id == ranked_tis.c.task_id,
TI.dag_id == ranked_tis.c.dag_id,
TI.run_id == ranked_tis.c.run_id
)
)
.filter(ranked_tis.c.row_number <= ranked_tis.c.dag_limit)
.order_by(desc(ranked_tis.c.priority_weight), ranked_tis.c.start_date)
.limit(max_tis)
)

# Execute the query with row locks
task_instances_to_examine: List[TI] = with_row_locks(
final_query,
of=TI,
session=session,
**skip_locked(session=session),
).all()

if starved_dags:
query = query.filter(not_(TI.dag_id.in_(starved_dags)))

if len(task_instances_to_examine) == 0:
self.log.debug("No tasks to consider for execution.")
return []
# else:
# print("---dag_limit_subquery")
# print(str(dag_limit_subquery.select().compile(compile_kwargs={"literal_binds": True})))
# print("---ranked_tis-query")
# print(str(ranked_tis.select().compile(compile_kwargs={"literal_binds": True})))
# print("---FINAL QUERY")
# print(str(final_query.statement.compile(compile_kwargs={"literal_binds": True})))

# Put one task instance on each line
task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine)
self.log.info(
"%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str
)

pool_slot_tracker = {pool_name: stats['open'] for pool_name, stats in pools.items()}

if starved_tasks:
task_filter = tuple_in_condition((TaskInstance.dag_id, TaskInstance.task_id), starved_tasks)
query = query.filter(not_(task_filter))
for task_instance in task_instances_to_examine:
pool_name = task_instance.pool

query = query.limit(max_tis)
pool_stats = pools.get(pool_name)
if not pool_stats:
self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool_name)
starved_pools.add(pool_name)
continue

task_instances_to_examine: List[TI] = with_row_locks(
query,
of=TI,
session=session,
**skip_locked(session=session),
).all()
# TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything.
# Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine))

# # Make sure to emit metrics if pool has no starving tasks
# # pool_num_starving_tasks.setdefault(pool_name, 0)
# pool_total = pool_stats["total"]
open_slots = pool_stats["open"]

if len(task_instances_to_examine) == 0:
self.log.debug("No tasks to consider for execution.")
break
# Check to make sure that the task max_active_tasks of the DAG hasn't been
# reached.
# This shoulnd't happen anymore but still leaving it here for debugging purposes
dag_id = task_instance.dag_id

# Put one task instance on each line
task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine)
current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks
self.log.info(
"%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str
"DAG %s has %s/%s running and queued tasks",
dag_id,
current_active_tasks_per_dag,
max_active_tasks_per_dag_limit,
)

for task_instance in task_instances_to_examine:
pool_name = task_instance.pool

pool_stats = pools.get(pool_name)
if not pool_stats:
self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool_name)
starved_pools.add(pool_name)
continue

# Make sure to emit metrics if pool has no starving tasks
pool_num_starving_tasks.setdefault(pool_name, 0)

pool_total = pool_stats["total"]
open_slots = pool_stats["open"]

if open_slots <= 0:
self.log.info(
"Not scheduling since there are %s open slots in pool %s", open_slots, pool_name
)
# Can't schedule any more since there are no more open slots.
pool_num_starving_tasks[pool_name] += 1
num_starving_tasks_total += 1
starved_pools.add(pool_name)
continue

if task_instance.pool_slots > pool_total:
self.log.warning(
"Not executing %s. Requested pool slots (%s) are greater than "
"total pool slots: '%s' for pool: %s.",
task_instance,
task_instance.pool_slots,
pool_total,
pool_name,
)

pool_num_starving_tasks[pool_name] += 1
num_starving_tasks_total += 1
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
continue

if task_instance.pool_slots > open_slots:
self.log.info(
"Not executing %s since it requires %s slots "
"but there are %s open slots in the pool %s.",
task_instance,
task_instance.pool_slots,
open_slots,
pool_name,
)
pool_num_starving_tasks[pool_name] += 1
num_starving_tasks_total += 1
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
# Though we can execute tasks with lower priority if there's enough room
continue

# Check to make sure that the task max_active_tasks of the DAG hasn't been
# reached.
dag_id = task_instance.dag_id

current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks
if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit:
self.log.info(
"DAG %s has %s/%s running and queued tasks",
"Not executing %s since the number of tasks running or queued "
"from DAG %s is >= to the DAG's max_active_tasks limit of %s",
task_instance,
dag_id,
current_active_tasks_per_dag,
max_active_tasks_per_dag_limit,
)
if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit:
self.log.info(
"Not executing %s since the number of tasks running or queued "
"from DAG %s is >= to the DAG's max_active_tasks limit of %s",
task_instance,
starved_dags.add(dag_id)

if task_instance.dag_model.has_task_concurrency_limits:
# Many dags don't have a task_concurrency, so where we can avoid loading the full
# serialized DAG the better.
serialized_dag = self.dagbag.get_dag(dag_id, session=session)
# If the dag is missing, fail the task and continue to the next task.
if not serialized_dag:
self.log.error(
"DAG '%s' for task instance %s not found in serialized_dag table",
dag_id,
max_active_tasks_per_dag_limit,
task_instance,
)
starved_dags.add(dag_id)
continue
session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update(
{TI.state: State.FAILED}, synchronize_session='fetch'
)
# continue

task_concurrency_limit: Optional[int] = None
if serialized_dag.has_task(task_instance.task_id):
task_concurrency_limit = serialized_dag.get_task(
task_instance.task_id
).max_active_tis_per_dag

if task_instance.dag_model.has_task_concurrency_limits:
# Many dags don't have a task_concurrency, so where we can avoid loading the full
# serialized DAG the better.
serialized_dag = self.dagbag.get_dag(dag_id, session=session)
# If the dag is missing, fail the task and continue to the next task.
if not serialized_dag:
self.log.error(
"DAG '%s' for task instance %s not found in serialized_dag table",
dag_id,
if task_concurrency_limit is not None:
current_task_concurrency = task_concurrency_map[
(task_instance.dag_id, task_instance.task_id)
]

if current_task_concurrency >= task_concurrency_limit:
self.log.info(
"Not executing %s since the task concurrency for"
" this task has been reached.",
task_instance,
)
session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update(
{TI.state: State.FAILED}, synchronize_session='fetch'
)
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
continue

task_concurrency_limit: Optional[int] = None
if serialized_dag.has_task(task_instance.task_id):
task_concurrency_limit = serialized_dag.get_task(
task_instance.task_id
).max_active_tis_per_dag

if task_concurrency_limit is not None:
current_task_concurrency = task_concurrency_map[
(task_instance.dag_id, task_instance.task_id)
]

if current_task_concurrency >= task_concurrency_limit:
self.log.info(
"Not executing %s since the task concurrency for"
" this task has been reached.",
task_instance,
)
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
continue


# Check pool-specific slot availability
if (pool_slot_tracker.get(pool_name, 0) >= task_instance.pool_slots):
executable_tis.append(task_instance)
open_slots -= task_instance.pool_slots
dag_active_tasks_map[dag_id] += 1
task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1

pool_stats["open"] = open_slots
else:
starved_tasks.add((task_instance.dag_id, task_instance.task_id))
pool_num_starving_tasks[pool_name] += 1
num_starving_tasks_total += 1

is_done = executable_tis or len(task_instances_to_examine) < max_tis
# Check this to avoid accidental infinite loops
found_new_filters = (
len(starved_pools) > num_starved_pools
or len(starved_dags) > num_starved_dags
or len(starved_tasks) > num_starved_tasks
)

if is_done or not found_new_filters:
break

self.log.debug(
"Found no task instances to queue on the %s. iteration "
"but there could be more candidate task instances to check.",
loop_count,
)

for pool_name, num_starving_tasks in pool_num_starving_tasks.items():
Stats.gauge(f'pool.starving_tasks.{pool_name}', num_starving_tasks)
Expand Down
5 changes: 3 additions & 2 deletions airflow/www/static/js/ti_log.js
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ function autoTailingLog(tryNumber, metadata = null, autoTailing = false) {
}
recurse().then(() => autoTailingLog(tryNumber, res.metadata, autoTailing));
}).catch((error) => {
console.error(`Error while retrieving log: ${error}`);
console.error(`Error while retrieving log`, error);

const externalLogUrl = getMetaValue('external_log_url');
const fullExternalUrl = `${externalLogUrl
Expand All @@ -151,7 +151,7 @@ function autoTailingLog(tryNumber, metadata = null, autoTailing = false) {

document.getElementById(`loading-${tryNumber}`).style.display = 'none';

const logBlockElementId = `try-${tryNumber}-${item[0]}`;
const logBlockElementId = `try-${tryNumber}-error`;
let logBlock = document.getElementById(logBlockElementId);
if (!logBlock) {
const logDivBlock = document.createElement('div');
Expand All @@ -164,6 +164,7 @@ function autoTailingLog(tryNumber, metadata = null, autoTailing = false) {

logBlock.innerHTML += "There was an error while retrieving the log from S3. Please use Kibana to view the logs.";
logBlock.innerHTML += `<a href="${fullExternalUrl}" target="_blank">View logs in Kibana</a>`;

});
}

Expand Down

0 comments on commit 5514406

Please sign in to comment.