Skip to content

Commit ab91b29

Browse files
Merge pull request #222 from stacklok/non-block-dashboard-reqs
Use `def` in FastAPI dashboard calls.
2 parents 6184edd + 9271327 commit ab91b29

File tree

5 files changed

+53
-27
lines changed

5 files changed

+53
-27
lines changed

src/codegate/dashboard/dashboard.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
from fastapi import APIRouter
66

77
from codegate.dashboard.post_processing import (
8-
match_conversations,
98
parse_get_alert_conversation,
10-
parse_get_prompt_with_output,
9+
parse_messages_in_conversations,
1110
)
1211
from codegate.dashboard.request_models import AlertConversation, Conversation
1312
from codegate.db.connection import DbReader
@@ -19,31 +18,19 @@
1918

2019

2120
@dashboard_router.get("/dashboard/messages")
22-
async def get_messages() -> List[Conversation]:
21+
def get_messages() -> List[Conversation]:
2322
"""
2423
Get all the messages from the database and return them as a list of conversations.
2524
"""
26-
prompts_outputs = await db_reader.get_prompts_with_output()
25+
prompts_outputs = asyncio.run(db_reader.get_prompts_with_output())
2726

28-
# Parse the prompts and outputs in parallel
29-
async with asyncio.TaskGroup() as tg:
30-
tasks = [tg.create_task(parse_get_prompt_with_output(row)) for row in prompts_outputs]
31-
partial_conversations = [task.result() for task in tasks]
32-
33-
conversations = await match_conversations(partial_conversations)
34-
return conversations
27+
return asyncio.run(parse_messages_in_conversations(prompts_outputs))
3528

3629

3730
@dashboard_router.get("/dashboard/alerts")
38-
async def get_alerts() -> List[AlertConversation]:
31+
def get_alerts() -> List[AlertConversation]:
3932
"""
4033
Get all the messages from the database and return them as a list of conversations.
4134
"""
42-
alerts_prompt_output = await db_reader.get_alerts_with_prompt_and_output()
43-
44-
# Parse the prompts and outputs in parallel
45-
async with asyncio.TaskGroup() as tg:
46-
tasks = [tg.create_task(parse_get_alert_conversation(row)) for row in alerts_prompt_output]
47-
alert_conversations = [task.result() for task in tasks if task.result() is not None]
48-
49-
return alert_conversations
35+
alerts_prompt_output = asyncio.run(db_reader.get_alerts_with_prompt_and_output())
36+
return asyncio.run(parse_get_alert_conversation(alerts_prompt_output))

src/codegate/dashboard/post_processing.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,23 @@ async def match_conversations(
200200
return conversations
201201

202202

203-
async def parse_get_alert_conversation(
203+
async def parse_messages_in_conversations(
204+
prompts_outputs: List[GetPromptWithOutputsRow],
205+
) -> List[Conversation]:
206+
"""
207+
Get all the messages from the database and return them as a list of conversations.
208+
"""
209+
210+
# Parse the prompts and outputs in parallel
211+
async with asyncio.TaskGroup() as tg:
212+
tasks = [tg.create_task(parse_get_prompt_with_output(row)) for row in prompts_outputs]
213+
partial_conversations = [task.result() for task in tasks]
214+
215+
conversations = await match_conversations(partial_conversations)
216+
return conversations
217+
218+
219+
async def parse_row_alert_conversation(
204220
row: GetAlertsWithPromptAndOutputRow,
205221
) -> Optional[AlertConversation]:
206222
"""
@@ -220,12 +236,33 @@ async def parse_get_alert_conversation(
220236
conversation_timestamp=row.timestamp,
221237
)
222238
code_snippet = json.loads(row.code_snippet) if row.code_snippet else None
239+
trigger_string = None
240+
if row.trigger_string:
241+
try:
242+
trigger_string = json.loads(row.trigger_string)
243+
except Exception:
244+
trigger_string = row.trigger_string
245+
223246
return AlertConversation(
224247
conversation=conversation,
225248
alert_id=row.id,
226249
code_snippet=code_snippet,
227-
trigger_string=row.trigger_string,
250+
trigger_string=trigger_string,
228251
trigger_type=row.trigger_type,
229252
trigger_category=row.trigger_category,
230253
timestamp=row.timestamp,
231254
)
255+
256+
257+
async def parse_get_alert_conversation(
258+
alerts_conversations: List[GetAlertsWithPromptAndOutputRow],
259+
) -> List[AlertConversation]:
260+
"""
261+
Parse a list of rows from the get_alerts_with_prompt_and_output query and return a list of
262+
AlertConversation
263+
264+
The rows contain the raw request and output strings from the pipeline.
265+
"""
266+
async with asyncio.TaskGroup() as tg:
267+
tasks = [tg.create_task(parse_row_alert_conversation(row)) for row in alerts_conversations]
268+
return [task.result() for task in tasks if task.result() is not None]

src/codegate/dashboard/request_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import datetime
2-
from typing import List, Optional
2+
from typing import List, Optional, Union
33

44
from pydantic import BaseModel
55

@@ -57,7 +57,7 @@ class AlertConversation(BaseModel):
5757
conversation: Conversation
5858
alert_id: str
5959
code_snippet: Optional[CodeSnippet]
60-
trigger_string: Optional[str]
60+
trigger_string: Optional[Union[str, dict]]
6161
trigger_type: str
6262
trigger_category: Optional[str]
6363
timestamp: datetime.datetime

src/codegate/pipeline/codegate_context_retriever/codegate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ async def process(
8686
# Look for matches in vector DB using list of packages as filter
8787
searched_objects = await self.get_objects_from_search(last_user_message_str, packages)
8888

89+
logger.info(
90+
f"Found {len(searched_objects)} matches in the database",
91+
searched_objects=searched_objects,
92+
)
8993
# If matches are found, add the matched content to context
9094
if len(searched_objects) > 0:
9195
# Remove searched objects that are not in packages. This is needed

src/codegate/pipeline/system_prompt/codegate.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,4 @@ async def process(
5656
context.add_alert(self.name, trigger_string=prepended_message)
5757
request_system_message["content"] = prepended_message
5858

59-
return PipelineResult(
60-
request=new_request,
61-
)
59+
return PipelineResult(request=new_request, context=context)

0 commit comments

Comments
 (0)