Skip to content

Commit

Permalink
Refactor Sqlalchemy queries to 2.0 style (Part 6) (#32645)
Browse files Browse the repository at this point in the history
  • Loading branch information
phanikumv authored Jul 21, 2023
1 parent 15d42b4 commit c7c0dee
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 61 deletions.
50 changes: 26 additions & 24 deletions airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing import Collection

import sqlalchemy_jsonfield
from sqlalchemy import BigInteger, Column, Index, LargeBinary, String, and_, or_
from sqlalchemy import BigInteger, Column, Index, LargeBinary, String, and_, or_, select
from sqlalchemy.orm import Session, backref, foreign, relationship
from sqlalchemy.sql.expression import func, literal

Expand Down Expand Up @@ -144,23 +144,19 @@ def write_dag(
# If Yes, does nothing
# If No or the DAG does not exists, updates / writes Serialized DAG to DB
if min_update_interval is not None:
if (
session.query(literal(True))
.filter(
and_(
cls.dag_id == dag.dag_id,
(timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated,
)
if session.scalar(
select(literal(True)).where(
cls.dag_id == dag.dag_id,
(timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated,
)
.scalar()
):
return False

log.debug("Checking if DAG (%s) changed", dag.dag_id)
new_serialized_dag = cls(dag, processor_subdir)
serialized_dag_db = (
session.query(cls.dag_hash, cls.processor_subdir).filter(cls.dag_id == dag.dag_id).first()
)
serialized_dag_db = session.execute(
select(cls.dag_hash, cls.processor_subdir).where(cls.dag_id == dag.dag_id)
).first()

if (
serialized_dag_db is not None
Expand All @@ -183,7 +179,7 @@ def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDA
:param session: ORM Session
:returns: a dict of DAGs read from database
"""
serialized_dags = session.query(cls)
serialized_dags = session.scalars(select(cls))

dags = {}
for row in serialized_dags:
Expand Down Expand Up @@ -275,7 +271,7 @@ def has_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> bool:
:param dag_id: the DAG to check
:param session: ORM Session
"""
return session.query(literal(True)).filter(cls.dag_id == dag_id).first() is not None
return session.scalar(select(literal(True)).where(cls.dag_id == dag_id).limit(1)) is not None

@classmethod
@provide_session
Expand All @@ -296,15 +292,15 @@ def get(cls, dag_id: str, session: Session = NEW_SESSION) -> SerializedDagModel
:param dag_id: the DAG to fetch
:param session: ORM Session
"""
row = session.query(cls).filter(cls.dag_id == dag_id).one_or_none()
row = session.scalar(select(cls).where(cls.dag_id == dag_id))
if row:
return row

# If we didn't find a matching DAG id then ask the DAG table to find
# out the root dag
root_dag_id = session.query(DagModel.root_dag_id).filter(DagModel.dag_id == dag_id).scalar()
root_dag_id = session.scalar(select(DagModel.root_dag_id).where(DagModel.dag_id == dag_id))

return session.query(cls).filter(cls.dag_id == root_dag_id).one_or_none()
return session.scalar(select(cls).where(cls.dag_id == root_dag_id))

@staticmethod
@provide_session
Expand Down Expand Up @@ -340,7 +336,7 @@ def get_last_updated_datetime(cls, dag_id: str, session: Session = NEW_SESSION)
:param dag_id: DAG ID
:param session: ORM Session
"""
return session.query(cls.last_updated).filter(cls.dag_id == dag_id).scalar()
return session.scalar(select(cls.last_updated).where(cls.dag_id == dag_id))

@classmethod
@provide_session
Expand All @@ -350,7 +346,7 @@ def get_max_last_updated_datetime(cls, session: Session = NEW_SESSION) -> dateti
:param session: ORM Session
"""
return session.query(func.max(cls.last_updated)).scalar()
return session.scalar(select(func.max(cls.last_updated)))

@classmethod
@provide_session
Expand All @@ -362,7 +358,7 @@ def get_latest_version_hash(cls, dag_id: str, session: Session = NEW_SESSION) ->
:param session: ORM Session
:return: DAG Hash, or None if the DAG is not found
"""
return session.query(cls.dag_hash).filter(cls.dag_id == dag_id).scalar()
return session.scalar(select(cls.dag_hash).where(cls.dag_id == dag_id))

@classmethod
def get_latest_version_hash_and_updated_datetime(
Expand All @@ -379,7 +375,9 @@ def get_latest_version_hash_and_updated_datetime(
:param session: ORM Session
:return: A tuple of DAG Hash and last updated datetime, or None if the DAG is not found
"""
return session.query(cls.dag_hash, cls.last_updated).filter(cls.dag_id == dag_id).one_or_none()
return session.execute(
select(cls.dag_hash, cls.last_updated).where(cls.dag_id == dag_id)
).one_or_none()

@classmethod
@provide_session
Expand All @@ -390,11 +388,15 @@ def get_dag_dependencies(cls, session: Session = NEW_SESSION) -> dict[str, list[
:param session: ORM Session
"""
if session.bind.dialect.name in ["sqlite", "mysql"]:
query = session.query(cls.dag_id, func.json_extract(cls._data, "$.dag.dag_dependencies"))
query = session.execute(
select(cls.dag_id, func.json_extract(cls._data, "$.dag.dag_dependencies"))
)
iterator = ((dag_id, json.loads(deps_data) if deps_data else []) for dag_id, deps_data in query)
elif session.bind.dialect.name == "mssql":
query = session.query(cls.dag_id, func.json_query(cls._data, "$.dag.dag_dependencies"))
query = session.execute(select(cls.dag_id, func.json_query(cls._data, "$.dag.dag_dependencies")))
iterator = ((dag_id, json.loads(deps_data) if deps_data else []) for dag_id, deps_data in query)
else:
iterator = session.query(cls.dag_id, func.json_extract_path(cls._data, "dag", "dag_dependencies"))
iterator = session.execute(
select(cls.dag_id, func.json_extract_path(cls._data, "dag", "dag_dependencies"))
)
return {dag_id: [DagDependency(**d) for d in (deps_data or [])] for dag_id, deps_data in iterator}
75 changes: 42 additions & 33 deletions airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from traceback import format_exception
from typing import Any, Iterable

from sqlalchemy import Column, Integer, String, delete, func, or_
from sqlalchemy import Column, Integer, String, delete, func, or_, select, update
from sqlalchemy.orm import Session, joinedload, relationship

from airflow.api_internal.internal_api_call import internal_api_call
Expand Down Expand Up @@ -93,9 +93,9 @@ def from_object(cls, trigger: BaseTrigger) -> Trigger:
@provide_session
def bulk_fetch(cls, ids: Iterable[int], session: Session = NEW_SESSION) -> dict[int, Trigger]:
"""Fetches all the Triggers by ID and returns a dict mapping ID -> Trigger instance."""
query = (
session.query(cls)
.filter(cls.id.in_(ids))
query = session.scalars(
select(cls)
.where(cls.id.in_(ids))
.options(
joinedload("task_instance"),
joinedload("task_instance.trigger"),
Expand All @@ -117,19 +117,21 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None:
# Update all task instances with trigger IDs that are not DEFERRED to remove them
for attempt in run_with_db_retries():
with attempt:
session.query(TaskInstance).filter(
TaskInstance.state != TaskInstanceState.DEFERRED, TaskInstance.trigger_id.isnot(None)
).update({TaskInstance.trigger_id: None})
session.execute(
update(TaskInstance)
.where(
TaskInstance.state != TaskInstanceState.DEFERRED, TaskInstance.trigger_id.is_not(None)
)
.values(trigger_id=None)
)

# Get all triggers that have no task instances depending on them...
ids = [
trigger_id
for (trigger_id,) in (
session.query(cls.id)
.join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True)
.group_by(cls.id)
.having(func.count(TaskInstance.trigger_id) == 0)
)
]
ids = session.scalars(
select(cls.id)
.join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True)
.group_by(cls.id)
.having(func.count(TaskInstance.trigger_id) == 0)
).all()
# ...and delete them (we can't do this in one query due to MySQL)
session.execute(
delete(Trigger).where(Trigger.id.in_(ids)).execution_options(synchronize_session=False)
Expand All @@ -140,8 +142,10 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None:
@provide_session
def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION) -> None:
"""Takes an event from an instance of itself, and triggers all dependent tasks to resume."""
for task_instance in session.query(TaskInstance).filter(
TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED
for task_instance in session.scalars(
select(TaskInstance).where(
TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED
)
):
# Add the event's payload into the kwargs for the task
next_kwargs = task_instance.next_kwargs or {}
Expand Down Expand Up @@ -171,8 +175,10 @@ def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) ->
workers as first-class concepts, we can run the failure code here
in-process, but we can't do that right now.
"""
for task_instance in session.query(TaskInstance).filter(
TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED
for task_instance in session.scalars(
select(TaskInstance).where(
TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED
)
):
# Add the error and set the next_method to the fail state
traceback = format_exception(type(exc), exc, exc.__traceback__) if exc else None
Expand All @@ -188,7 +194,7 @@ def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) ->
@provide_session
def ids_for_triggerer(cls, triggerer_id, session: Session = NEW_SESSION) -> list[int]:
"""Retrieves a list of triggerer_ids."""
return [row[0] for row in session.query(cls.id).filter(cls.triggerer_id == triggerer_id)]
return session.scalars(select(cls.id).where(cls.triggerer_id == triggerer_id)).all()

@classmethod
@internal_api_call
Expand All @@ -203,42 +209,45 @@ def assign_unassigned(cls, triggerer_id, capacity, heartrate, session: Session =
"""
from airflow.jobs.job import Job # To avoid circular import

count = session.query(func.count(cls.id)).filter(cls.triggerer_id == triggerer_id).scalar()
count = session.scalar(select(func.count(cls.id)).filter(cls.triggerer_id == triggerer_id))
capacity -= count

if capacity <= 0:
return
# we multiply heartrate by a grace_multiplier to give the triggerer
# a chance to heartbeat before we consider it dead
health_check_threshold = heartrate * 2.1
alive_triggerer_ids = [
row[0]
for row in session.query(Job.id).filter(
alive_triggerer_ids = session.scalars(
select(Job.id).where(
Job.end_date.is_(None),
Job.latest_heartbeat > timezone.utcnow() - datetime.timedelta(seconds=health_check_threshold),
Job.job_type == "TriggererJob",
)
]
).all()

# Find triggers who do NOT have an alive triggerer_id, and then assign
# up to `capacity` of those to us.
trigger_ids_query = cls.get_sorted_triggers(
capacity=capacity, alive_triggerer_ids=alive_triggerer_ids, session=session
)
if trigger_ids_query:
session.query(cls).filter(cls.id.in_([i.id for i in trigger_ids_query])).update(
{cls.triggerer_id: triggerer_id},
synchronize_session=False,
session.execute(
update(cls)
.where(cls.id.in_([i.id for i in trigger_ids_query]))
.values(triggerer_id=triggerer_id)
.execution_options(synchronize_session=False)
)

session.commit()

@classmethod
def get_sorted_triggers(cls, capacity, alive_triggerer_ids, session):
return with_row_locks(
session.query(cls.id)
.filter(or_(cls.triggerer_id.is_(None), cls.triggerer_id.notin_(alive_triggerer_ids)))
query = with_row_locks(
select(cls.id)
.where(or_(cls.triggerer_id.is_(None), cls.triggerer_id.not_in(alive_triggerer_ids)))
.order_by(cls.created_date)
.limit(capacity),
session,
skip_locked=True,
).all()
)
return session.execute(query).all()
3 changes: 2 additions & 1 deletion airflow/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from typing import TYPE_CHECKING, Callable, TypeVar, cast

import re2
from sqlalchemy import select
from sqlalchemy.orm import Session

from airflow import settings
Expand Down Expand Up @@ -267,7 +268,7 @@ def get_dag_by_pickle(pickle_id: int, session: Session = NEW_SESSION) -> DAG:
"""Fetch DAG from the database using pickling."""
from airflow.models import DagPickle

dag_pickle = session.query(DagPickle).filter(DagPickle.id == pickle_id).first()
dag_pickle = session.scalar(select(DagPickle).where(DagPickle.id == pickle_id)).first()
if not dag_pickle:
raise AirflowException(f"pickle_id could not be found in DagPickle.id list: {pickle_id}")
pickle_dag = dag_pickle.pickle
Expand Down
6 changes: 3 additions & 3 deletions airflow/utils/db_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from typing import Any

from pendulum import DateTime
from sqlalchemy import and_, column, false, func, inspect, table, text
from sqlalchemy import and_, column, false, func, inspect, select, table, text
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Query, Session, aliased
Expand Down Expand Up @@ -179,7 +179,7 @@ def _do_delete(*, query, orm_model, skip_archive, session):
pk_cols = source_table.primary_key.columns
delete = source_table.delete().where(
tuple_(*pk_cols).in_(
session.query(*[target_table.c[x.name] for x in source_table.primary_key.columns]).subquery()
select(*[target_table.c[x.name] for x in source_table.primary_key.columns]).subquery()
)
)
else:
Expand All @@ -196,7 +196,7 @@ def _do_delete(*, query, orm_model, skip_archive, session):


def _subquery_keep_last(*, recency_column, keep_last_filters, group_by_columns, max_date_colname, session):
subquery = session.query(*group_by_columns, func.max(recency_column).label(max_date_colname))
subquery = select(*group_by_columns, func.max(recency_column).label(max_date_colname))

if keep_last_filters is not None:
for entry in keep_last_filters:
Expand Down

0 comments on commit c7c0dee

Please sign in to comment.