Skip to content

Commit

Permalink
fix: group_by application_ids in GET /topology
Browse files Browse the repository at this point in the history
  • Loading branch information
Kiryous committed Sep 26, 2024
1 parent a4af7fc commit 53bb99b
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 17 deletions.
21 changes: 18 additions & 3 deletions keep/api/core/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
import pymysql
from dotenv import find_dotenv, load_dotenv
from google.cloud.sql.connector import Connector
from sqlalchemy import func
from sqlmodel import create_engine
from sqlalchemy import String, cast, func
from sqlalchemy.sql.elements import Label
from sqlalchemy.sql.sqltypes import String
from sqlmodel import Session, create_engine

# This import is required to create the tables
from keep.api.consts import RUNNING_IN_CLOUD_RUN
Expand Down Expand Up @@ -161,10 +163,23 @@ def create_db_engine():


def get_json_extract_field(session, base_field, key):

if session.bind.dialect.name == "postgresql":
return func.json_extract_path_text(base_field, key)
elif session.bind.dialect.name == "mysql":
return func.json_unquote(func.json_extract(base_field, "$.{}".format(key)))
else:
return func.json_extract(base_field, "$.{}".format(key))


def get_aggreated_field(session: Session, column_name: str, alias: str):
if session.bind is None:
raise ValueError("Session is not bound to a database")

if session.bind.dialect.name == "postgresql":
return func.array_agg(column_name).label(alias)
elif session.bind.dialect.name == "mysql":
return func.json_arrayagg(column_name).label(alias)
elif session.bind.dialect.name == "sqlite":
return func.group_concat(column_name).label(alias)
else:
return func.array_agg(column_name).label(alias)
44 changes: 30 additions & 14 deletions keep/topologies/topologies_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sqlmodel import Session, select

from keep.api.core.db_utils import get_aggreated_field
from keep.api.models.db.topology import (
TopologyApplication,
TopologyApplicationDtoIn,
Expand Down Expand Up @@ -39,6 +40,33 @@ class ServiceNotFoundException(TopologyException):
"""Raised when a service is not found"""


def get_service_application_ids_dict(
session: Session, service_ids: List[int]
) -> dict[int, List[UUID]]:
# TODO: add proper types
query = (
select(
TopologyServiceApplication.service_id,
get_aggreated_field(
session,
TopologyServiceApplication.application_id, # type: ignore
"application_ids",
),
) # type: ignore
.where(TopologyServiceApplication.service_id.in_(service_ids))
.group_by(TopologyServiceApplication.service_id)
)
results = session.exec(query).all()
if session.bind is None:
raise ValueError("Session is not bound to a database")
if session.bind.dialect.name == "sqlite":
result = {}
for service_id, application_ids in results:
result[service_id] = [UUID(app_id) for app_id in application_ids.split(",")]
return result
return {service_id: application_ids for service_id, application_ids in results}


class TopologiesService:
@staticmethod
def get_all_topology_data(
Expand Down Expand Up @@ -83,20 +111,8 @@ def get_all_topology_data(
).all()

# Fetch application IDs for all services in a single query
service_ids = [service.id for service in services]
application_service_query = (
select(TopologyServiceApplication)
.where(TopologyServiceApplication.service_id.in_(service_ids))
.group_by(TopologyServiceApplication.service_id)
)
application_service_results = session.exec(application_service_query).all()

# Create a dictionary mapping service IDs to application IDs
service_to_app_ids = {}
for result in application_service_results:
if result.service_id not in service_to_app_ids:
service_to_app_ids[result.service_id] = []
service_to_app_ids[result.service_id].append(result.application_id)
service_ids = [service.id for service in services if service.id is not None]
service_to_app_ids = get_service_application_ids_dict(session, service_ids)

logger.info(f"Service to app ids: {service_to_app_ids}")

Expand Down

0 comments on commit 53bb99b

Please sign in to comment.