Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]: Optimize SQL Queries in SQLAlchemy to Eliminate N+1 Problem #1902

Merged
merged 11 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 19 additions & 40 deletions agenta-backend/agenta_backend/models/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,13 @@ async def human_evaluation_db_to_simple_evaluation_output(
async def evaluation_db_to_pydantic(
evaluation_db: EvaluationDB,
) -> Evaluation:
variant = await db_manager.get_app_variant_instance_by_id(
str(evaluation_db.variant_id)
variant_name = (
evaluation_db.variant.variant_name
if evaluation_db.variant.variant_name
else str(evaluation_db.variant_id)
)
variant_name = variant.variant_name if variant else str(evaluation_db.variant_id)
variant_revision = await db_manager.get_app_variant_revision_by_id(
str(evaluation_db.variant_revision_id)
)
revision = str(variant_revision.revision)
aggregated_results = await aggregated_result_of_evaluation_to_pydantic(
str(evaluation_db.id)
aggregated_results = aggregated_result_of_evaluation_to_pydantic(
evaluation_db.aggregated_results
)

return Evaluation(
Expand All @@ -139,7 +136,7 @@ async def evaluation_db_to_pydantic(
status=evaluation_db.status,
variant_ids=[str(evaluation_db.variant_id)],
variant_revision_ids=[str(evaluation_db.variant_revision_id)],
revisions=[revision],
revisions=[str(evaluation_db.variant_revision.revision)],
variant_names=[variant_name],
testset_id=str(evaluation_db.testset_id),
testset_name=evaluation_db.testset.name,
Expand Down Expand Up @@ -213,12 +210,11 @@ def human_evaluation_scenario_db_to_pydantic(
)


async def aggregated_result_of_evaluation_to_pydantic(evaluation_id: str) -> List[dict]:
def aggregated_result_of_evaluation_to_pydantic(
evaluation_aggregated_results: List,
) -> List[dict]:
transformed_results = []
aggregated_results = await db_manager.fetch_eval_aggregated_results(
evaluation_id=evaluation_id
)
for aggregated_result in aggregated_results:
for aggregated_result in evaluation_aggregated_results:
evaluator_config_dict = (
{
"id": str(aggregated_result.evaluator_config.id),
Expand All @@ -242,27 +238,16 @@ async def aggregated_result_of_evaluation_to_pydantic(evaluation_id: str) -> Lis
return transformed_results


async def evaluation_scenarios_results_to_pydantic(
evaluation_scenario_id: str,
) -> List[dict]:
scenario_results = await db_manager.fetch_evaluation_scenario_results(
evaluation_scenario_id
)
return [
async def evaluation_scenario_db_to_pydantic(
evaluation_scenario_db: EvaluationScenarioDB, evaluation_id: str
) -> EvaluationScenario:
scenario_results = [
{
"evaluator_config": str(scenario_result.evaluator_config_id),
"result": scenario_result.result,
}
for scenario_result in scenario_results
for scenario_result in evaluation_scenario_db.results
]


async def evaluation_scenario_db_to_pydantic(
evaluation_scenario_db: EvaluationScenarioDB, evaluation_id: str
) -> EvaluationScenario:
scenario_results = await evaluation_scenarios_results_to_pydantic(
str(evaluation_scenario_db.id)
)
return EvaluationScenario(
id=str(evaluation_scenario_db.id),
evaluation_id=evaluation_id,
Expand Down Expand Up @@ -308,17 +293,11 @@ async def app_variant_db_to_output(app_variant_db: AppVariantDB) -> AppVariantRe
if isinstance(app_variant_db.base_id, uuid.UUID) and isinstance(
app_variant_db.base.deployment_id, uuid.UUID
):
deployment = await db_manager.get_deployment_by_id(
str(app_variant_db.base.deployment_id)
)
uri = deployment.uri
uri = app_variant_db.base.deployment.uri
else:
deployment = None
uri = None

logger.info(
f"uri: {uri} deployment: {str(app_variant_db.base.deployment_id)} {deployment}"
)
logger.info(f"uri: {uri} deployment: {str(app_variant_db.base.deployment_id)}")
variant_response = AppVariantResponse(
app_id=str(app_variant_db.app_id),
app_name=str(app_variant_db.app.app_name),
Expand All @@ -329,7 +308,7 @@ async def app_variant_db_to_output(app_variant_db: AppVariantDB) -> AppVariantRe
base_name=app_variant_db.base_name, # type: ignore
base_id=str(app_variant_db.base_id),
config_name=app_variant_db.config_name, # type: ignore
uri=uri,
uri=uri, # type: ignore
revision=app_variant_db.revision, # type: ignore
)

Expand Down
2 changes: 1 addition & 1 deletion agenta-backend/agenta_backend/models/db_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self) -> None:
self.mode = os.environ.get("DATABASE_MODE", "v2")
self.postgres_uri = os.environ.get("POSTGRES_URI")
self.mongo_uri = os.environ.get("MONGODB_URI")
self.engine = create_async_engine(url=self.postgres_uri) # type: ignore
self.engine = create_async_engine(url=self.postgres_uri, echo=True) # type: ignore
aybruhm marked this conversation as resolved.
Show resolved Hide resolved
self.async_session_maker = async_sessionmaker(
bind=self.engine, class_=AsyncSession, expire_on_commit=False
)
Expand Down
4 changes: 2 additions & 2 deletions agenta-backend/agenta_backend/routers/evaluation_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ async def fetch_evaluation_results(evaluation_id: str, request: Request):
status_code=403,
)

results = await converters.aggregated_result_of_evaluation_to_pydantic(
str(evaluation.id)
results = converters.aggregated_result_of_evaluation_to_pydantic(
evaluation.aggregated_results # type: ignore
)
return {"results": results, "evaluation_id": evaluation_id}
except Exception as exc:
Expand Down
92 changes: 35 additions & 57 deletions agenta-backend/agenta_backend/services/db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,10 @@ async def list_app_variants(app_id: str):
async with db_engine.get_session() as session:
result = await session.execute(
select(AppVariantDB)
.options(joinedload(AppVariantDB.app), joinedload(AppVariantDB.base))
.options(
joinedload(AppVariantDB.app).load_only(AppDB.id, AppDB.app_name), # type: ignore
joinedload(AppVariantDB.base).joinedload(VariantBaseDB.deployment).load_only(DeploymentDB.uri), # type: ignore
)
.filter_by(app_id=uuid.UUID(app_uuid))
)
app_variants = result.scalars().all()
Expand Down Expand Up @@ -1827,26 +1830,6 @@ async def get_app_variant_instance_by_id(variant_id: str) -> AppVariantDB:
return app_variant_db


async def get_app_variant_revision_by_id(
variant_revision_id: str,
) -> AppVariantRevisionsDB:
"""Get the app variant revision object from the database with the provided id.

Arguments:
variant_revision_id (str): The app variant revision unique identifier

Returns:
AppVariantDB: instance of app variant object
"""

async with db_engine.get_session() as session:
result = await session.execute(
select(AppVariantRevisionsDB).filter_by(id=uuid.UUID(variant_revision_id))
)
variant_revision_db = result.scalars().first()
return variant_revision_db


async def fetch_testset_by_id(testset_id: str) -> Optional[TestSetDB]:
"""Fetches a testset by its ID.
Args:
Expand Down Expand Up @@ -1959,8 +1942,16 @@ async def fetch_evaluation_by_id(evaluation_id: str) -> Optional[EvaluationDB]:
joinedload(EvaluationDB.user).load_only(UserDB.username), # type: ignore
joinedload(EvaluationDB.testset).load_only(TestSetDB.id, TestSetDB.name), # type: ignore
)
result = await session.execute(query)
evaluation = result.scalars().first()
result = await session.execute(
aybruhm marked this conversation as resolved.
Show resolved Hide resolved
query.options(
joinedload(EvaluationDB.variant).load_only(AppVariantDB.id, AppVariantDB.variant_name), # type: ignore
joinedload(EvaluationDB.variant_revision).load_only(AppVariantRevisionsDB.revision), # type: ignore
joinedload(EvaluationDB.aggregated_results).joinedload(
EvaluationAggregatedResultDB.evaluator_config
),
)
)
evaluation = result.unique().scalars().first()
return evaluation


Expand Down Expand Up @@ -2289,40 +2280,14 @@ async def fetch_evaluation_scenarios(evaluation_id: str):

async with db_engine.get_session() as session:
result = await session.execute(
select(EvaluationScenarioDB).filter_by(
evaluation_id=uuid.UUID(evaluation_id)
)
select(EvaluationScenarioDB)
.filter_by(evaluation_id=uuid.UUID(evaluation_id))
.options(joinedload(EvaluationScenarioDB.results))
)
evaluation_scenarios = result.scalars().all()
evaluation_scenarios = result.unique().scalars().all()
return evaluation_scenarios


async def fetch_evaluation_scenario_results(evaluation_scenario_id: str):
"""
Fetches evaluation scenario results.

Args:
evaluation_scenario_id (str): The evaluation scenario identifier

Returns:
The evaluation scenario results.
"""

async with db_engine.get_session() as session:
result = await session.execute(
select(EvaluationScenarioResultDB)
.options(
load_only(
EvaluationScenarioResultDB.evaluator_config_id, # type: ignore
EvaluationScenarioResultDB.result, # type: ignore
)
)
.filter_by(evaluation_scenario_id=uuid.UUID(evaluation_scenario_id))
)
scenario_results = result.scalars().all()
return scenario_results


async def fetch_evaluation_scenario_by_id(
evaluation_scenario_id: str,
) -> Optional[EvaluationScenarioDB]:
Expand Down Expand Up @@ -2691,7 +2656,14 @@ async def create_new_evaluation(
session.add(evaluation)
await session.commit()
await session.refresh(
evaluation, attribute_names=["user", "testset", "aggregated_results"]
evaluation,
attribute_names=[
"user",
"testset",
"variant",
"variant_revision",
"aggregated_results",
],
)

return evaluation
Expand All @@ -2718,7 +2690,13 @@ async def list_evaluations(app_id: str):
)

result = await session.execute(
query.options(joinedload(EvaluationDB.aggregated_results))
query.options(
joinedload(EvaluationDB.variant).load_only(AppVariantDB.id, AppVariantDB.variant_name), # type: ignore
joinedload(EvaluationDB.variant_revision).load_only(AppVariantRevisionsDB.revision), # type: ignore
joinedload(EvaluationDB.aggregated_results).joinedload(
EvaluationAggregatedResultDB.evaluator_config
),
)
)
evaluations = result.unique().scalars().all()
return evaluations
Expand Down Expand Up @@ -2807,7 +2785,7 @@ async def delete_evaluations(evaluation_ids: List[str]) -> None:

async def create_new_evaluation_scenario(
user_id: str,
evaluation: EvaluationDB,
evaluation_id: str,
variant_id: str,
inputs: List[EvaluationScenarioInput],
outputs: List[EvaluationScenarioOutput],
Expand All @@ -2827,7 +2805,7 @@ async def create_new_evaluation_scenario(
async with db_engine.get_session() as session:
evaluation_scenario = EvaluationScenarioDB(
user_id=uuid.UUID(user_id),
evaluation_id=evaluation.id,
evaluation_id=uuid.UUID(evaluation_id),
variant_id=uuid.UUID(variant_id),
inputs=[input.dict() for input in inputs],
outputs=[output.dict() for output in outputs],
Expand Down
22 changes: 7 additions & 15 deletions agenta-backend/agenta_backend/tasks/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@ def evaluate(
), f"App variant with id {variant_id} not found!"
app_variant_parameters = app_variant_db.config_parameters
testset_db = loop.run_until_complete(fetch_testset_by_id(testset_id))
new_evaluation_db = loop.run_until_complete(
fetch_evaluation_by_id(evaluation_id)
)
evaluator_config_dbs = []
for evaluator_config_id in evaluators_config_ids:
evaluator_config = loop.run_until_complete(
Expand Down Expand Up @@ -195,7 +192,7 @@ def evaluate(
loop.run_until_complete(
create_new_evaluation_scenario(
user_id=str(app.user_id),
evaluation=new_evaluation_db,
evaluation_id=evaluation_id,
variant_id=variant_id,
inputs=inputs,
outputs=[
Expand Down Expand Up @@ -274,7 +271,7 @@ def evaluate(
loop.run_until_complete(
create_new_evaluation_scenario(
user_id=str(app.user_id),
evaluation=new_evaluation_db,
evaluation_id=evaluation_id,
variant_id=variant_id,
inputs=inputs,
outputs=[
Expand Down Expand Up @@ -341,15 +338,11 @@ def evaluate(
)

loop.run_until_complete(
update_evaluation_with_aggregated_results(
str(new_evaluation_db.id), aggregated_results
)
update_evaluation_with_aggregated_results(evaluation_id, aggregated_results)
)

failed_evaluation_scenarios = loop.run_until_complete(
check_if_evaluation_contains_failed_evaluation_scenarios(
str(new_evaluation_db.id)
)
check_if_evaluation_contains_failed_evaluation_scenarios(evaluation_id)
)

evaluation_status = Result(
Expand All @@ -365,7 +358,7 @@ def evaluate(

loop.run_until_complete(
update_evaluation(
evaluation_id=str(new_evaluation_db.id),
evaluation_id=evaluation_id,
updates={"status": evaluation_status.model_dump()},
)
)
Expand All @@ -384,7 +377,7 @@ def evaluate(
message="Evaluation Aggregation Failed",
stacktrace=str(traceback.format_exc()),
),
)
).model_dump()
},
)
)
Expand Down Expand Up @@ -441,9 +434,8 @@ async def aggregate_evaluator_results(
type="error", value=None, error=Error(message="Aggregation failed")
)

evaluator_config = await fetch_evaluator_config(config_id)
aggregated_result = AggregatedResult(
evaluator_config=str(evaluator_config.id), # type: ignore
evaluator_config=config_id,
result=result,
)
aggregated_results.append(aggregated_result)
Expand Down
Loading