Skip to content

Commit

Permalink
fix: reduce query memory usage in DatasetExampleRevisionsDataLoader (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang authored Jan 18, 2025
1 parent d9303c2 commit 7412bb9
Showing 1 changed file with 26 additions and 50 deletions.
76 changes: 26 additions & 50 deletions src/phoenix/server/api/dataloaders/dataset_example_revisions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, Union

from sqlalchemy import and_, case, func, null, or_, select
from sqlalchemy import Integer, case, func, or_, select, union
from sqlalchemy.sql.expression import literal
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias
Expand All @@ -20,74 +20,50 @@ class DatasetExampleRevisionsDataLoader(DataLoader[Key, Result]):
def __init__(self, db: DbSessionFactory) -> None:
super().__init__(
load_fn=self._load_fn,
max_batch_size=200, # needed to prevent the size of the query from getting too large
# Setting max_batch_size to prevent the size of the query from getting too large.
# The maximum number of terms is SQLITE_MAX_COMPOUND_SELECT which defaults to 500.
# This is needed because of the compound select query below used in transferring
# the input data to the database. SQLite in fact has better ways to transfer data,
# but unfortunately they're not made available in sqlalchemy yet.
max_batch_size=200,
)
self._db = db

async def _load_fn(self, keys: list[Key]) -> list[Union[Result, NotFound]]:
example_and_version_ids = tuple(
set(
(example_id, version_id)
for example_id, version_id in keys
if version_id is not None
)
)
versionless_example_ids = tuple(
set(example_id for example_id, version_id in keys if version_id is None)
)
resolved_example_and_version_ids = (
(
# sqlalchemy has limited SQLite support for VALUES, so use UNION ALL instead.
# For details, see https://github.com/sqlalchemy/sqlalchemy/issues/7228
keys_subquery = union(
*(
select(
models.DatasetExample.id.label("example_id"),
models.DatasetVersion.id.label("version_id"),
)
.select_from(models.DatasetExample)
.join(
models.DatasetVersion,
onclause=literal(True), # cross join
)
.where(
or_(
*(
and_(
models.DatasetExample.id == example_id,
models.DatasetVersion.id == version_id,
)
for example_id, version_id in example_and_version_ids
)
)
literal(example_id, Integer).label("example_id"),
literal(version_id, Integer).label("version_id"),
)
for example_id, version_id in keys
)
.union(
select(
models.DatasetExample.id.label("example_id"), null().label("version_id")
).where(models.DatasetExample.id.in_(versionless_example_ids))
)
.subquery()
)
).subquery()
revision_ids = (
select(
resolved_example_and_version_ids.c.example_id,
resolved_example_and_version_ids.c.version_id,
keys_subquery.c.example_id,
keys_subquery.c.version_id,
func.max(models.DatasetExampleRevision.id).label("revision_id"),
)
.select_from(resolved_example_and_version_ids)
.select_from(keys_subquery)
.join(
models.DatasetExampleRevision,
onclause=resolved_example_and_version_ids.c.example_id
onclause=keys_subquery.c.example_id
== models.DatasetExampleRevision.dataset_example_id,
)
.where(
or_(
resolved_example_and_version_ids.c.version_id.is_(None),
models.DatasetExampleRevision.dataset_version_id
<= resolved_example_and_version_ids.c.version_id,
# This query gets the latest `revision_id` for each example:
# - If `version_id` is NOT given, it finds the maximum `revision_id`.
# - If `version_id` is given, it finds the highest `revision_id` whose
# `version_id` is less than or equal to the one specified.
keys_subquery.c.version_id.is_(None),
models.DatasetExampleRevision.dataset_version_id <= keys_subquery.c.version_id,
)
)
.group_by(
resolved_example_and_version_ids.c.example_id,
resolved_example_and_version_ids.c.version_id,
)
.group_by(keys_subquery.c.example_id, keys_subquery.c.version_id)
).subquery()
query = (
select(
Expand Down

0 comments on commit 7412bb9

Please sign in to comment.