diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 90480d8a3343b..82d4c6e31c3cb 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -24,7 +24,7 @@ from typing import Collection import sqlalchemy_jsonfield -from sqlalchemy import BigInteger, Column, Index, LargeBinary, String, and_, or_ +from sqlalchemy import BigInteger, Column, Index, LargeBinary, String, and_, or_, select from sqlalchemy.orm import Session, backref, foreign, relationship from sqlalchemy.sql.expression import func, literal @@ -144,23 +144,19 @@ def write_dag( # If Yes, does nothing # If No or the DAG does not exists, updates / writes Serialized DAG to DB if min_update_interval is not None: - if ( - session.query(literal(True)) - .filter( - and_( - cls.dag_id == dag.dag_id, - (timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated, - ) + if session.scalar( + select(literal(True)).where( + cls.dag_id == dag.dag_id, + (timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated, ) - .scalar() ): return False log.debug("Checking if DAG (%s) changed", dag.dag_id) new_serialized_dag = cls(dag, processor_subdir) - serialized_dag_db = ( - session.query(cls.dag_hash, cls.processor_subdir).filter(cls.dag_id == dag.dag_id).first() - ) + serialized_dag_db = session.execute( + select(cls.dag_hash, cls.processor_subdir).where(cls.dag_id == dag.dag_id) + ).first() if ( serialized_dag_db is not None @@ -183,7 +179,7 @@ def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDA :param session: ORM Session :returns: a dict of DAGs read from database """ - serialized_dags = session.query(cls) + serialized_dags = session.scalars(select(cls)) dags = {} for row in serialized_dags: @@ -275,7 +271,7 @@ def has_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> bool: :param dag_id: the DAG to check :param session: ORM Session """ - return session.query(literal(True)).filter(cls.dag_id == dag_id).first() is not None + return session.scalar(select(literal(True)).where(cls.dag_id == dag_id).limit(1)) is not None @classmethod @provide_session @@ -296,15 +292,15 @@ def get(cls, dag_id: str, session: Session = NEW_SESSION) -> SerializedDagModel :param dag_id: the DAG to fetch :param session: ORM Session """ - row = session.query(cls).filter(cls.dag_id == dag_id).one_or_none() + row = session.scalar(select(cls).where(cls.dag_id == dag_id)) if row: return row # If we didn't find a matching DAG id then ask the DAG table to find # out the root dag - root_dag_id = session.query(DagModel.root_dag_id).filter(DagModel.dag_id == dag_id).scalar() + root_dag_id = session.scalar(select(DagModel.root_dag_id).where(DagModel.dag_id == dag_id)) - return session.query(cls).filter(cls.dag_id == root_dag_id).one_or_none() + return session.scalar(select(cls).where(cls.dag_id == root_dag_id)) @staticmethod @provide_session @@ -340,7 +336,7 @@ def get_last_updated_datetime(cls, dag_id: str, session: Session = NEW_SESSION) :param dag_id: DAG ID :param session: ORM Session """ - return session.query(cls.last_updated).filter(cls.dag_id == dag_id).scalar() + return session.scalar(select(cls.last_updated).where(cls.dag_id == dag_id)) @classmethod @provide_session @@ -350,7 +346,7 @@ def get_max_last_updated_datetime(cls, session: Session = NEW_SESSION) -> dateti :param session: ORM Session """ - return session.query(func.max(cls.last_updated)).scalar() + return session.scalar(select(func.max(cls.last_updated))) @classmethod @provide_session @@ -362,7 +358,7 @@ def get_latest_version_hash(cls, dag_id: str, session: Session = NEW_SESSION) -> :param session: ORM Session :return: DAG Hash, or None if the DAG is not found """ - return session.query(cls.dag_hash).filter(cls.dag_id == dag_id).scalar() + return session.scalar(select(cls.dag_hash).where(cls.dag_id == dag_id)) @classmethod def get_latest_version_hash_and_updated_datetime( @@ -379,7 +375,9 @@ def get_latest_version_hash_and_updated_datetime( :param session: ORM Session :return: A tuple of DAG Hash and last updated datetime, or None if the DAG is not found """ - return session.query(cls.dag_hash, cls.last_updated).filter(cls.dag_id == dag_id).one_or_none() + return session.execute( + select(cls.dag_hash, cls.last_updated).where(cls.dag_id == dag_id) + ).one_or_none() @classmethod @provide_session @@ -390,11 +388,15 @@ def get_dag_dependencies(cls, session: Session = NEW_SESSION) -> dict[str, list[ :param session: ORM Session """ if session.bind.dialect.name in ["sqlite", "mysql"]: - query = session.query(cls.dag_id, func.json_extract(cls._data, "$.dag.dag_dependencies")) + query = session.execute( + select(cls.dag_id, func.json_extract(cls._data, "$.dag.dag_dependencies")) + ) iterator = ((dag_id, json.loads(deps_data) if deps_data else []) for dag_id, deps_data in query) elif session.bind.dialect.name == "mssql": - query = session.query(cls.dag_id, func.json_query(cls._data, "$.dag.dag_dependencies")) + query = session.execute(select(cls.dag_id, func.json_query(cls._data, "$.dag.dag_dependencies"))) iterator = ((dag_id, json.loads(deps_data) if deps_data else []) for dag_id, deps_data in query) else: - iterator = session.query(cls.dag_id, func.json_extract_path(cls._data, "dag", "dag_dependencies")) + iterator = session.execute( + select(cls.dag_id, func.json_extract_path(cls._data, "dag", "dag_dependencies")) + ) return {dag_id: [DagDependency(**d) for d in (deps_data or [])] for dag_id, deps_data in iterator} diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index 4f50918a09d48..bba60c94493f2 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -20,7 +20,7 @@ from traceback import format_exception from typing import Any, Iterable -from sqlalchemy import Column, Integer, String, delete, func, or_ +from sqlalchemy import Column, Integer, String, delete, func, or_, select, update from sqlalchemy.orm import Session, joinedload, relationship from airflow.api_internal.internal_api_call import internal_api_call @@ -93,9 +93,9 @@ def from_object(cls, trigger: BaseTrigger) -> Trigger: @provide_session def bulk_fetch(cls, ids: Iterable[int], session: Session = NEW_SESSION) -> dict[int, Trigger]: """Fetches all the Triggers by ID and returns a dict mapping ID -> Trigger instance.""" - query = ( - session.query(cls) - .filter(cls.id.in_(ids)) + query = session.scalars( + select(cls) + .where(cls.id.in_(ids)) .options( joinedload("task_instance"), joinedload("task_instance.trigger"), @@ -117,19 +117,21 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None: # Update all task instances with trigger IDs that are not DEFERRED to remove them for attempt in run_with_db_retries(): with attempt: - session.query(TaskInstance).filter( - TaskInstance.state != TaskInstanceState.DEFERRED, TaskInstance.trigger_id.isnot(None) - ).update({TaskInstance.trigger_id: None}) + session.execute( + update(TaskInstance) + .where( + TaskInstance.state != TaskInstanceState.DEFERRED, TaskInstance.trigger_id.is_not(None) + ) + .values(trigger_id=None) + ) + # Get all triggers that have no task instances depending on them... - ids = [ - trigger_id - for (trigger_id,) in ( - session.query(cls.id) - .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True) - .group_by(cls.id) - .having(func.count(TaskInstance.trigger_id) == 0) - ) - ] + ids = session.scalars( + select(cls.id) + .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True) + .group_by(cls.id) + .having(func.count(TaskInstance.trigger_id) == 0) + ).all() # ...and delete them (we can't do this in one query due to MySQL) session.execute( delete(Trigger).where(Trigger.id.in_(ids)).execution_options(synchronize_session=False) @@ -140,8 +142,10 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None: @provide_session def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION) -> None: """Takes an event from an instance of itself, and triggers all dependent tasks to resume.""" - for task_instance in session.query(TaskInstance).filter( - TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED + for task_instance in session.scalars( + select(TaskInstance).where( + TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED + ) ): # Add the event's payload into the kwargs for the task next_kwargs = task_instance.next_kwargs or {} @@ -171,8 +175,10 @@ def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> workers as first-class concepts, we can run the failure code here in-process, but we can't do that right now. """ - for task_instance in session.query(TaskInstance).filter( - TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED + for task_instance in session.scalars( + select(TaskInstance).where( + TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED + ) ): # Add the error and set the next_method to the fail state traceback = format_exception(type(exc), exc, exc.__traceback__) if exc else None @@ -188,7 +194,7 @@ def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> @provide_session def ids_for_triggerer(cls, triggerer_id, session: Session = NEW_SESSION) -> list[int]: """Retrieves a list of triggerer_ids.""" - return [row[0] for row in session.query(cls.id).filter(cls.triggerer_id == triggerer_id)] + return session.scalars(select(cls.id).where(cls.triggerer_id == triggerer_id)).all() @classmethod @internal_api_call @@ -203,7 +209,7 @@ def assign_unassigned(cls, triggerer_id, capacity, heartrate, session: Session = """ from airflow.jobs.job import Job # To avoid circular import - count = session.query(func.count(cls.id)).filter(cls.triggerer_id == triggerer_id).scalar() + count = session.scalar(select(func.count(cls.id)).filter(cls.triggerer_id == triggerer_id)) capacity -= count if capacity <= 0: @@ -211,14 +217,13 @@ def assign_unassigned(cls, triggerer_id, capacity, heartrate, session: Session = # we multiply heartrate by a grace_multiplier to give the triggerer # a chance to heartbeat before we consider it dead health_check_threshold = heartrate * 2.1 - alive_triggerer_ids = [ - row[0] - for row in session.query(Job.id).filter( + alive_triggerer_ids = session.scalars( + select(Job.id).where( Job.end_date.is_(None), Job.latest_heartbeat > timezone.utcnow() - datetime.timedelta(seconds=health_check_threshold), Job.job_type == "TriggererJob", ) - ] + ).all() # Find triggers who do NOT have an alive triggerer_id, and then assign # up to `capacity` of those to us. @@ -226,19 +231,23 @@ def assign_unassigned(cls, triggerer_id, capacity, heartrate, session: Session = capacity=capacity, alive_triggerer_ids=alive_triggerer_ids, session=session ) if trigger_ids_query: - session.query(cls).filter(cls.id.in_([i.id for i in trigger_ids_query])).update( - {cls.triggerer_id: triggerer_id}, - synchronize_session=False, + session.execute( + update(cls) + .where(cls.id.in_([i.id for i in trigger_ids_query])) + .values(triggerer_id=triggerer_id) + .execution_options(synchronize_session=False) ) + session.commit() @classmethod def get_sorted_triggers(cls, capacity, alive_triggerer_ids, session): - return with_row_locks( - session.query(cls.id) - .filter(or_(cls.triggerer_id.is_(None), cls.triggerer_id.notin_(alive_triggerer_ids))) + query = with_row_locks( + select(cls.id) + .where(or_(cls.triggerer_id.is_(None), cls.triggerer_id.not_in(alive_triggerer_ids))) .order_by(cls.created_date) .limit(capacity), session, skip_locked=True, - ).all() + ) + return session.execute(query).all() diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index 354c2b236c07f..b601c4c07e2c3 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -32,6 +32,7 @@ from typing import TYPE_CHECKING, Callable, TypeVar, cast import re2 +from sqlalchemy import select from sqlalchemy.orm import Session from airflow import settings @@ -267,7 +268,7 @@ def get_dag_by_pickle(pickle_id: int, session: Session = NEW_SESSION) -> DAG: """Fetch DAG from the database using pickling.""" from airflow.models import DagPickle - dag_pickle = session.query(DagPickle).filter(DagPickle.id == pickle_id).first() + dag_pickle = session.scalar(select(DagPickle).where(DagPickle.id == pickle_id)).first() if not dag_pickle: raise AirflowException(f"pickle_id could not be found in DagPickle.id list: {pickle_id}") pickle_dag = dag_pickle.pickle diff --git a/airflow/utils/db_cleanup.py b/airflow/utils/db_cleanup.py index 36bced947348d..f9dae372429da 100644 --- a/airflow/utils/db_cleanup.py +++ b/airflow/utils/db_cleanup.py @@ -30,7 +30,7 @@ from typing import Any from pendulum import DateTime -from sqlalchemy import and_, column, false, func, inspect, table, text +from sqlalchemy import and_, column, false, func, inspect, select, table, text from sqlalchemy.exc import OperationalError, ProgrammingError from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import Query, Session, aliased @@ -179,7 +179,7 @@ def _do_delete(*, query, orm_model, skip_archive, session): pk_cols = source_table.primary_key.columns delete = source_table.delete().where( tuple_(*pk_cols).in_( - session.query(*[target_table.c[x.name] for x in source_table.primary_key.columns]).subquery() + select(*[target_table.c[x.name] for x in source_table.primary_key.columns]).subquery() ) ) else: @@ -196,7 +196,7 @@ def _do_delete(*, query, orm_model, skip_archive, session): def _subquery_keep_last(*, recency_column, keep_last_filters, group_by_columns, max_date_colname, session): - subquery = session.query(*group_by_columns, func.max(recency_column).label(max_date_colname)) + subquery = select(*group_by_columns, func.max(recency_column).label(max_date_colname)) if keep_last_filters is not None: for entry in keep_last_filters: