Skip to content

Commit 6a20ae7

Browse files
committed
feat(sdk): add learning status api
1 parent 9bd29ab commit 6a20ae7

File tree

10 files changed

+454
-321
lines changed

10 files changed

+454
-321
lines changed

src/client/acontext-py/src/acontext/resources/async_sessions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ..types.session import (
1212
GetMessagesOutput,
1313
GetTasksOutput,
14+
LearningStatus,
1415
ListSessionsOutput,
1516
Message,
1617
Session,
@@ -303,3 +304,20 @@ async def flush(self, session_id: str) -> dict[str, Any]:
303304
"""
304305
data = await self._requester.request("POST", f"/session/{session_id}/flush")
305306
return data # type: ignore
307+
308+
async def get_learning_status(self, session_id: str) -> LearningStatus:
309+
"""Get learning status for a session.
310+
311+
Returns the count of space digested tasks and not space digested tasks.
312+
If the session is not connected to a space, returns 0 and 0.
313+
314+
Args:
315+
session_id: The UUID of the session.
316+
317+
Returns:
318+
LearningStatus object containing space_digested_count and not_space_digested_count.
319+
"""
320+
data = await self._requester.request(
321+
"GET", f"/session/{session_id}/get_learning_status"
322+
)
323+
return LearningStatus.model_validate(data)

src/client/acontext-py/src/acontext/resources/sessions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ..types.session import (
1212
GetMessagesOutput,
1313
GetTasksOutput,
14+
LearningStatus,
1415
ListSessionsOutput,
1516
Message,
1617
Session,
@@ -303,3 +304,20 @@ def flush(self, session_id: str) -> dict[str, Any]:
303304
"""
304305
data = self._requester.request("POST", f"/session/{session_id}/flush")
305306
return data # type: ignore
307+
308+
def get_learning_status(self, session_id: str) -> LearningStatus:
309+
"""Get learning status for a session.
310+
311+
Returns the count of space digested tasks and not space digested tasks.
312+
If the session is not connected to a space, returns 0 and 0.
313+
314+
Args:
315+
session_id: The UUID of the session.
316+
317+
Returns:
318+
LearningStatus object containing space_digested_count and not_space_digested_count.
319+
"""
320+
data = self._requester.request(
321+
"GET", f"/session/{session_id}/get_learning_status"
322+
)
323+
return LearningStatus.model_validate(data)

src/client/acontext-py/src/acontext/types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Asset,
1414
GetMessagesOutput,
1515
GetTasksOutput,
16+
LearningStatus,
1617
ListSessionsOutput,
1718
Message,
1819
Part,
@@ -47,6 +48,7 @@
4748
"Asset",
4849
"GetMessagesOutput",
4950
"GetTasksOutput",
51+
"LearningStatus",
5052
"ListSessionsOutput",
5153
"Message",
5254
"Part",

src/client/acontext-py/src/acontext/types/session.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,14 @@ class GetTasksOutput(BaseModel):
122122
items: list[Task] = Field(..., description="List of tasks")
123123
next_cursor: str | None = Field(None, description="Cursor for pagination")
124124
has_more: bool = Field(..., description="Whether there are more items")
125+
126+
127+
class LearningStatus(BaseModel):
128+
"""Response model for learning status."""
129+
130+
space_digested_count: int = Field(
131+
..., description="Number of tasks that are space digested"
132+
)
133+
not_space_digested_count: int = Field(
134+
..., description="Number of tasks that are not space digested"
135+
)

src/client/acontext-py/tests/test_async_client.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,30 @@ async def test_async_sessions_get_tasks_with_filters(
438438
assert hasattr(result, "has_more")
439439

440440

441+
@patch("acontext.async_client.AcontextAsyncClient.request", new_callable=AsyncMock)
442+
@pytest.mark.asyncio
443+
async def test_async_sessions_get_learning_status(
444+
mock_request, async_client: AcontextAsyncClient
445+
) -> None:
446+
mock_request.return_value = {
447+
"space_digested_count": 5,
448+
"not_space_digested_count": 3,
449+
}
450+
451+
result = await async_client.sessions.get_learning_status("session-id")
452+
453+
mock_request.assert_called_once()
454+
args, kwargs = mock_request.call_args
455+
method, path = args
456+
assert method == "GET"
457+
assert path == "/session/session-id/get_learning_status"
458+
# Verify it returns a Pydantic model
459+
assert hasattr(result, "space_digested_count")
460+
assert hasattr(result, "not_space_digested_count")
461+
assert result.space_digested_count == 5
462+
assert result.not_space_digested_count == 3
463+
464+
441465
@patch("acontext.async_client.AcontextAsyncClient.request", new_callable=AsyncMock)
442466
@pytest.mark.asyncio
443467
async def test_async_blocks_list_without_filters(

src/client/acontext-py/tests/test_client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,27 @@ def test_sessions_get_tasks_with_filters(mock_request, client: AcontextClient) -
448448
assert hasattr(result, "has_more")
449449

450450

451+
@patch("acontext.client.AcontextClient.request")
452+
def test_sessions_get_learning_status(mock_request, client: AcontextClient) -> None:
453+
mock_request.return_value = {
454+
"space_digested_count": 5,
455+
"not_space_digested_count": 3,
456+
}
457+
458+
result = client.sessions.get_learning_status("session-id")
459+
460+
mock_request.assert_called_once()
461+
args, kwargs = mock_request.call_args
462+
method, path = args
463+
assert method == "GET"
464+
assert path == "/session/session-id/get_learning_status"
465+
# Verify it returns a Pydantic model
466+
assert hasattr(result, "space_digested_count")
467+
assert hasattr(result, "not_space_digested_count")
468+
assert result.space_digested_count == 5
469+
assert result.not_space_digested_count == 3
470+
471+
451472
@patch("acontext.client.AcontextClient.request")
452473
def test_blocks_list_without_filters(mock_request, client: AcontextClient) -> None:
453474
mock_request.return_value = []

0 commit comments

Comments
 (0)