Skip to content

Commit

Permalink
Add combined TreeManager stats endpoint (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf authored Jan 17, 2023
1 parent acaa56e commit 718faa0
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 1 deletion.
32 changes: 32 additions & 0 deletions backend/oasst_backend/api/v1/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.tree_manager import TreeManager, TreeManagerStats, TreeMessageCountStats
from oasst_shared.schemas import protocol
from sqlmodel import Session

Expand All @@ -15,3 +16,34 @@ def get_message_stats(
):
pr = PromptRepository(db, api_client)
return pr.get_stats()


@router.get("/tree_manager/state_counts", response_model=dict[str, int])
def get_tree_manager__state_counts(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client)
tm = TreeManager(db, pr)
return tm.tree_counts_by_state()


@router.get("/tree_manager/message_counts", response_model=list[TreeMessageCountStats])
def get_tree_manager__message_counts(
only_active: bool = True,
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client)
tm = TreeManager(db, pr)
return tm.tree_message_count_stats(only_active=only_active)


@router.get("/tree_manager", response_model=TreeManagerStats)
def get_tree_manager__stats(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client)
tm = TreeManager(db, pr)
return tm.stats()
56 changes: 55 additions & 1 deletion backend/oasst_backend/tree_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from datetime import datetime
from enum import Enum
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -69,6 +70,25 @@ class Config:
orm_mode = True


class TreeMessageCountStats(pydantic.BaseModel):
message_tree_id: UUID
state: str
depth: int
oldest: datetime
youngest: datetime
count: int
goal_tree_size: int

@property
def completed(self) -> int:
return self.count / self.goal_tree_size


class TreeManagerStats(pydantic.BaseModel):
state_counts: dict[str, int]
message_counts: list[TreeMessageCountStats]


class TreeManager:
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))

Expand Down Expand Up @@ -924,6 +944,40 @@ def _insert_default_state(
active=True,
)

def tree_counts_by_state(self) -> dict[str, int]:
qry = self.db.query(
MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count")
).group_by(MessageTreeState.state)
return {x["state"]: x["count"] for x in qry}

def tree_message_count_stats(self, only_active: bool = True) -> list[TreeMessageCountStats]:
qry = (
self.db.query(
MessageTreeState.message_tree_id,
func.max(Message.depth).label("depth"),
func.min(Message.created_date).label("oldest"),
func.max(Message.created_date).label("youngest"),
func.count(Message.id).label("count"),
MessageTreeState.goal_tree_size,
MessageTreeState.state,
)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
.filter(not_(Message.deleted))
.group_by(MessageTreeState.message_tree_id)
)

if only_active:
qry.filter(MessageTreeState.active)

return [TreeMessageCountStats(**x) for x in qry]

def stats(self) -> TreeManagerStats:
return TreeManagerStats(
state_counts=self.tree_counts_by_state(),
message_counts=self.tree_message_count_stats(only_active=True),
)


if __name__ == "__main__":
from oasst_backend.api.deps import api_auth
Expand All @@ -942,7 +996,7 @@ def _insert_default_state(

# print("query_num_active_trees", tm.query_num_active_trees())
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
print("query_replies_need_review", tm.query_replies_need_review())
# print("query_replies_need_review", tm.query_replies_need_review())
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
# print("query_extendible_trees", tm.query_extendible_trees())
# print("query_extendible_parents", tm.query_extendible_parents())
Expand Down

0 comments on commit 718faa0

Please sign in to comment.