diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 80bf67067a707..5ef5cccd3c592 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -28,7 +28,7 @@ from datetime import timedelta from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple -from sqlalchemy import func, not_, or_, text, select +from sqlalchemy import func, and_, or_, text, select, desc from sqlalchemy.exc import OperationalError from sqlalchemy.orm import load_only, selectinload from sqlalchemy.orm.session import Session, make_transient @@ -330,198 +330,189 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = # dag and task ids that can't be queued because of concurrency limits starved_dags: Set[str] = self._get_starved_dags(session=session) starved_tasks: Set[Tuple[str, str]] = set() - pool_num_starving_tasks: DefaultDict[str, int] = defaultdict(int) + + # Subquery to get the current active task count for each DAG + # Only considering tasks from running DAG runs + current_active_tasks = ( + session.query( + TI.dag_id, + func.count().label('active_count') + ) + .join(DR, and_(DR.dag_id == TI.dag_id, DR.run_id == TI.run_id)) + .filter(DR.state == DagRunState.RUNNING) + .filter(TI.state.in_([TaskInstanceState.RUNNING, TaskInstanceState.QUEUED])) + .group_by(TI.dag_id) + .subquery() + ) - for loop_count in itertools.count(start=1): - - num_starved_pools = len(starved_pools) - num_starved_dags = len(starved_dags) - num_starved_tasks = len(starved_tasks) - - # Get task instances associated with scheduled - # DagRuns which are not backfilled, in the given states, - # and the dag is not paused - query = ( - session.query(TI) - .with_hint(TI, 'USE INDEX (ti_state)', dialect_name='mysql') - .join(TI.dag_run) - .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state == DagRunState.RUNNING) - .join(TI.dag_model) - .filter(not_(DM.is_paused)) - .filter(TI.state == TaskInstanceState.SCHEDULED) - .options(selectinload('dag_model')) - .order_by(-TI.priority_weight, DR.execution_date) + # Get the limit for each DAG + dag_limit_subquery = ( + session.query( + DM.dag_id, + func.greatest(DM.max_active_tasks - func.coalesce(current_active_tasks.c.active_count, 0), 0).label('dag_limit') ) + .outerjoin(current_active_tasks, DM.dag_id == current_active_tasks.c.dag_id) + .subquery() + ) - if starved_pools: - query = query.filter(not_(TI.pool.in_(starved_pools))) + # Subquery to rank tasks within each DAG + ranked_tis = ( + session.query( + TI, + func.row_number().over( + partition_by=TI.dag_id, + order_by=[desc(TI.priority_weight), TI.start_date] + ).label('row_number'), + dag_limit_subquery.c.dag_limit + ) + .join(TI.dag_run) + .join(DM, TI.dag_id == DM.dag_id) + .join(dag_limit_subquery, TI.dag_id == dag_limit_subquery.c.dag_id) + .filter( + DR.state == DagRunState.RUNNING, + DR.run_type != DagRunType.BACKFILL_JOB, + ~DM.is_paused, + ~TI.dag_id.in_(starved_dags), + ~TI.pool.in_(starved_pools), + TI.state == TaskInstanceState.SCHEDULED, + ) + ).subquery() + + if starved_tasks: + ranked_tis = ranked_tis.filter( + ~func.concat(TI.dag_id, TI.task_id).in_([f"{dag_id}{task_id}" for dag_id, task_id in starved_tasks]) + ) + + final_query = ( + session.query(TI) + .join( + ranked_tis, + and_( + TI.task_id == ranked_tis.c.task_id, + TI.dag_id == ranked_tis.c.dag_id, + TI.run_id == ranked_tis.c.run_id + ) + ) + .filter(ranked_tis.c.row_number <= ranked_tis.c.dag_limit) + .order_by(desc(ranked_tis.c.priority_weight), ranked_tis.c.start_date) + .limit(max_tis) + ) + + # Execute the query with row locks + task_instances_to_examine: List[TI] = with_row_locks( + final_query, + of=TI, + session=session, + **skip_locked(session=session), + ).all() - if starved_dags: - query = query.filter(not_(TI.dag_id.in_(starved_dags))) + + if len(task_instances_to_examine) == 0: + self.log.debug("No tasks to consider for execution.") + return [] + # else: + # print("---dag_limit_subquery") + # print(str(dag_limit_subquery.select().compile(compile_kwargs={"literal_binds": True}))) + # print("---ranked_tis-query") + # print(str(ranked_tis.select().compile(compile_kwargs={"literal_binds": True}))) + # print("---FINAL QUERY") + # print(str(final_query.statement.compile(compile_kwargs={"literal_binds": True}))) + + # Put one task instance on each line + task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine) + self.log.info( + "%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str + ) + + pool_slot_tracker = {pool_name: stats['open'] for pool_name, stats in pools.items()} - if starved_tasks: - task_filter = tuple_in_condition((TaskInstance.dag_id, TaskInstance.task_id), starved_tasks) - query = query.filter(not_(task_filter)) + for task_instance in task_instances_to_examine: + pool_name = task_instance.pool - query = query.limit(max_tis) + pool_stats = pools.get(pool_name) + if not pool_stats: + self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool_name) + starved_pools.add(pool_name) + continue - task_instances_to_examine: List[TI] = with_row_locks( - query, - of=TI, - session=session, - **skip_locked(session=session), - ).all() - # TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything. - # Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine)) + + # # Make sure to emit metrics if pool has no starving tasks + # # pool_num_starving_tasks.setdefault(pool_name, 0) + # pool_total = pool_stats["total"] + open_slots = pool_stats["open"] - if len(task_instances_to_examine) == 0: - self.log.debug("No tasks to consider for execution.") - break + # Check to make sure that the task max_active_tasks of the DAG hasn't been + # reached. + # This shoulnd't happen anymore but still leaving it here for debugging purposes + dag_id = task_instance.dag_id - # Put one task instance on each line - task_instance_str = "\n\t".join(repr(x) for x in task_instances_to_examine) + current_active_tasks_per_dag = dag_active_tasks_map[dag_id] + max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks self.log.info( - "%s tasks up for execution:\n\t%s", len(task_instances_to_examine), task_instance_str + "DAG %s has %s/%s running and queued tasks", + dag_id, + current_active_tasks_per_dag, + max_active_tasks_per_dag_limit, ) - - for task_instance in task_instances_to_examine: - pool_name = task_instance.pool - - pool_stats = pools.get(pool_name) - if not pool_stats: - self.log.warning("Tasks using non-existent pool '%s' will not be scheduled", pool_name) - starved_pools.add(pool_name) - continue - - # Make sure to emit metrics if pool has no starving tasks - pool_num_starving_tasks.setdefault(pool_name, 0) - - pool_total = pool_stats["total"] - open_slots = pool_stats["open"] - - if open_slots <= 0: - self.log.info( - "Not scheduling since there are %s open slots in pool %s", open_slots, pool_name - ) - # Can't schedule any more since there are no more open slots. - pool_num_starving_tasks[pool_name] += 1 - num_starving_tasks_total += 1 - starved_pools.add(pool_name) - continue - - if task_instance.pool_slots > pool_total: - self.log.warning( - "Not executing %s. Requested pool slots (%s) are greater than " - "total pool slots: '%s' for pool: %s.", - task_instance, - task_instance.pool_slots, - pool_total, - pool_name, - ) - - pool_num_starving_tasks[pool_name] += 1 - num_starving_tasks_total += 1 - starved_tasks.add((task_instance.dag_id, task_instance.task_id)) - continue - - if task_instance.pool_slots > open_slots: - self.log.info( - "Not executing %s since it requires %s slots " - "but there are %s open slots in the pool %s.", - task_instance, - task_instance.pool_slots, - open_slots, - pool_name, - ) - pool_num_starving_tasks[pool_name] += 1 - num_starving_tasks_total += 1 - starved_tasks.add((task_instance.dag_id, task_instance.task_id)) - # Though we can execute tasks with lower priority if there's enough room - continue - - # Check to make sure that the task max_active_tasks of the DAG hasn't been - # reached. - dag_id = task_instance.dag_id - - current_active_tasks_per_dag = dag_active_tasks_map[dag_id] - max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks + if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit: self.log.info( - "DAG %s has %s/%s running and queued tasks", + "Not executing %s since the number of tasks running or queued " + "from DAG %s is >= to the DAG's max_active_tasks limit of %s", + task_instance, dag_id, - current_active_tasks_per_dag, max_active_tasks_per_dag_limit, ) - if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit: - self.log.info( - "Not executing %s since the number of tasks running or queued " - "from DAG %s is >= to the DAG's max_active_tasks limit of %s", - task_instance, + starved_dags.add(dag_id) + + if task_instance.dag_model.has_task_concurrency_limits: + # Many dags don't have a task_concurrency, so where we can avoid loading the full + # serialized DAG the better. + serialized_dag = self.dagbag.get_dag(dag_id, session=session) + # If the dag is missing, fail the task and continue to the next task. + if not serialized_dag: + self.log.error( + "DAG '%s' for task instance %s not found in serialized_dag table", dag_id, - max_active_tasks_per_dag_limit, + task_instance, ) - starved_dags.add(dag_id) - continue + session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update( + {TI.state: State.FAILED}, synchronize_session='fetch' + ) + # continue + + task_concurrency_limit: Optional[int] = None + if serialized_dag.has_task(task_instance.task_id): + task_concurrency_limit = serialized_dag.get_task( + task_instance.task_id + ).max_active_tis_per_dag - if task_instance.dag_model.has_task_concurrency_limits: - # Many dags don't have a task_concurrency, so where we can avoid loading the full - # serialized DAG the better. - serialized_dag = self.dagbag.get_dag(dag_id, session=session) - # If the dag is missing, fail the task and continue to the next task. - if not serialized_dag: - self.log.error( - "DAG '%s' for task instance %s not found in serialized_dag table", - dag_id, + if task_concurrency_limit is not None: + current_task_concurrency = task_concurrency_map[ + (task_instance.dag_id, task_instance.task_id) + ] + + if current_task_concurrency >= task_concurrency_limit: + self.log.info( + "Not executing %s since the task concurrency for" + " this task has been reached.", task_instance, ) - session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update( - {TI.state: State.FAILED}, synchronize_session='fetch' - ) + starved_tasks.add((task_instance.dag_id, task_instance.task_id)) continue - - task_concurrency_limit: Optional[int] = None - if serialized_dag.has_task(task_instance.task_id): - task_concurrency_limit = serialized_dag.get_task( - task_instance.task_id - ).max_active_tis_per_dag - - if task_concurrency_limit is not None: - current_task_concurrency = task_concurrency_map[ - (task_instance.dag_id, task_instance.task_id) - ] - - if current_task_concurrency >= task_concurrency_limit: - self.log.info( - "Not executing %s since the task concurrency for" - " this task has been reached.", - task_instance, - ) - starved_tasks.add((task_instance.dag_id, task_instance.task_id)) - continue - + + # Check pool-specific slot availability + if (pool_slot_tracker.get(pool_name, 0) >= task_instance.pool_slots): executable_tis.append(task_instance) open_slots -= task_instance.pool_slots dag_active_tasks_map[dag_id] += 1 task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1 - pool_stats["open"] = open_slots + else: + starved_tasks.add((task_instance.dag_id, task_instance.task_id)) + pool_num_starving_tasks[pool_name] += 1 + num_starving_tasks_total += 1 - is_done = executable_tis or len(task_instances_to_examine) < max_tis - # Check this to avoid accidental infinite loops - found_new_filters = ( - len(starved_pools) > num_starved_pools - or len(starved_dags) > num_starved_dags - or len(starved_tasks) > num_starved_tasks - ) - - if is_done or not found_new_filters: - break - - self.log.debug( - "Found no task instances to queue on the %s. iteration " - "but there could be more candidate task instances to check.", - loop_count, - ) for pool_name, num_starving_tasks in pool_num_starving_tasks.items(): Stats.gauge(f'pool.starving_tasks.{pool_name}', num_starving_tasks) diff --git a/airflow/www/static/js/ti_log.js b/airflow/www/static/js/ti_log.js index d4b7ca031c371..5919b03b5d844 100644 --- a/airflow/www/static/js/ti_log.js +++ b/airflow/www/static/js/ti_log.js @@ -140,7 +140,7 @@ function autoTailingLog(tryNumber, metadata = null, autoTailing = false) { } recurse().then(() => autoTailingLog(tryNumber, res.metadata, autoTailing)); }).catch((error) => { - console.error(`Error while retrieving log: ${error}`); + console.error(`Error while retrieving log`, error); const externalLogUrl = getMetaValue('external_log_url'); const fullExternalUrl = `${externalLogUrl @@ -151,7 +151,7 @@ function autoTailingLog(tryNumber, metadata = null, autoTailing = false) { document.getElementById(`loading-${tryNumber}`).style.display = 'none'; - const logBlockElementId = `try-${tryNumber}-${item[0]}`; + const logBlockElementId = `try-${tryNumber}-error`; let logBlock = document.getElementById(logBlockElementId); if (!logBlock) { const logDivBlock = document.createElement('div'); @@ -164,6 +164,7 @@ function autoTailingLog(tryNumber, metadata = null, autoTailing = false) { logBlock.innerHTML += "There was an error while retrieving the log from S3. Please use Kibana to view the logs."; logBlock.innerHTML += `View logs in Kibana`; + }); }