Skip to content
This repository has been archived by the owner on Nov 3, 2024. It is now read-only.

Commit

Permalink
[HN-299/HN-305] feat: fetch all chats for specific user (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
ToJen authored Jul 3, 2024
1 parent c26b38a commit e0969b2
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 23 deletions.
2 changes: 1 addition & 1 deletion hive_agent_client/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .chat import send_chat_message, get_chat_history
from .chat import send_chat_message, get_chat_history, get_all_chats
64 changes: 60 additions & 4 deletions hive_agent_client/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ def get_log_level():


async def send_chat_message(
http_client: httpx.AsyncClient, base_url: str, user_id: str, session_id: str, content: str
http_client: httpx.AsyncClient,
base_url: str,
user_id: str,
session_id: str,
content: str,
) -> str:
"""
Sends a chat message to the Hive Agent API and returns the response.
Expand All @@ -41,7 +45,7 @@ async def send_chat_message(
payload = {
"user_id": user_id,
"session_id": session_id,
"chat_data": {"messages": [{"role": "user", "content": content}]}
"chat_data": {"messages": [{"role": "user", "content": content}]},
}

try:
Expand Down Expand Up @@ -94,7 +98,9 @@ async def get_chat_history(
response = await http_client.get(url, params=params)
response.raise_for_status()
chat_history = response.json()
logger.debug(f"Chat history for user {user_id} and session {session_id}: {chat_history}")
logger.debug(
f"Chat history for user {user_id} and session {session_id}: {chat_history}"
)
return chat_history
except httpx.HTTPStatusError as e:
logging.error(
Expand All @@ -104,7 +110,9 @@ async def get_chat_history(
f"HTTP error occurred when fetching chat history from the chat API: {e.response.status_code} - {e.response.text}"
)
except httpx.RequestError as e:
logging.error(f"Request error occurred when fetching chat history from {url}: {e}")
logging.error(
f"Request error occurred when fetching chat history from {url}: {e}"
)
raise Exception(
f"Request error occurred when fetching chat history from the chat API: {e}"
)
Expand All @@ -115,3 +123,51 @@ async def get_chat_history(
raise Exception(
f"An unexpected error occurred when fetching chat history from the chat API: {e}"
)


async def get_all_chats(
http_client: httpx.AsyncClient, base_url: str, user_id: str
) -> Dict[str, List[Dict]]:
"""
Retrieves all chat sessions for a specified user from the Hive Agent API.
:param http_client: An instance of httpx.AsyncClient to make HTTP requests.
:param base_url: The base URL of the Hive Agent API.
:param user_id: The user ID.
:return: All chat sessions as a dictionary with session IDs as keys and lists of messages as values.
:raises httpx.HTTPStatusError: If the request fails due to a network error or returns a 4xx/5xx response.
:raises Exception: For other types of errors.
"""

endpoint = "/all_chats"
url = f"{base_url}{endpoint}"
params = {"user_id": user_id}

try:
logging.debug(f"Fetching all chats from {url} with params: {params}")
response = await http_client.get(url, params=params)
response.raise_for_status()

all_chats = response.json()
logger.debug(f"All chats for user {user_id}: {all_chats}")

return all_chats
except httpx.HTTPStatusError as e:
logging.error(
f"HTTP error occurred when fetching all chats from {url}: {e.response.status_code} - {e.response.text}"
)
raise Exception(
f"HTTP error occurred when fetching all chats from the chat API: {e.response.status_code} - {e.response.text}"
)
except httpx.RequestError as e:
logging.error(f"Request error occurred when fetching all chats from {url}: {e}")
raise Exception(
f"Request error occurred when fetching all chats from the chat API: {e}"
)
except Exception as e:
logging.error(
f"An unexpected error occurred when fetching all chats from {url}: {e}"
)
raise Exception(
f"An unexpected error occurred when fetching all chats from the chat API: {e}"
)
28 changes: 24 additions & 4 deletions hive_agent_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from typing import Dict, List

from hive_agent_client.chat import send_chat_message, get_chat_history
from hive_agent_client.chat import send_chat_message, get_chat_history, get_all_chats
from hive_agent_client.database import (
create_table,
insert_data,
Expand Down Expand Up @@ -45,7 +45,9 @@ async def chat(self, user_id: str, session_id: str, content: str) -> str:
"""
try:
logger.debug(f"Sending message to chat endpoint: {content}")
return await send_chat_message(self.http_client, self.base_url, user_id, session_id, content)
return await send_chat_message(
self.http_client, self.base_url, user_id, session_id, content
)
except Exception as e:
logger.error(f"Failed to send chat message - {content}: {e}")
raise Exception(f"Failed to send chat message: {e}")
Expand All @@ -59,11 +61,29 @@ async def get_chat_history(self, user_id: str, session_id: str) -> List[Dict]:
:return: The chat history as a list of dictionaries.
"""
try:
return await get_chat_history(self.http_client, self.base_url, user_id, session_id)
return await get_chat_history(
self.http_client, self.base_url, user_id, session_id
)
except Exception as e:
logger.error(f"Failed to get chat history for user {user_id} and session {session_id}: {e}")
logger.error(
f"Failed to get chat history for user {user_id} and session {session_id}: {e}"
)
raise Exception(f"Failed to get chat history: {e}")

async def get_all_chats(self, user_id: str) -> Dict[str, List[Dict]]:
"""
Retrieve all chat sessions for a specified user.
:param user_id: The user ID.
:return: All chat sessions as a dictionary with session IDs as keys and lists of messages as values.
"""

try:
return await get_all_chats(self.http_client, self.base_url, user_id)
except Exception as e:
logger.error(f"Failed to get all chats for user {user_id}: {e}")
raise Exception(f"Failed to get all chats: {e}")

async def create_table(self, table_name: str, columns: dict) -> Dict:
"""
Create a new table in the database.
Expand Down
103 changes: 92 additions & 11 deletions tests/chat/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from unittest.mock import AsyncMock
from hive_agent_client.chat import send_chat_message, get_chat_history
from hive_agent_client.chat import send_chat_message, get_chat_history, get_all_chats


@pytest.mark.asyncio
Expand All @@ -18,15 +18,17 @@ async def test_send_chat_message_success():
session_id = "session123"
content = "Hello, how are you?"

result = await send_chat_message(mock_client, base_url, user_id, session_id, content)
result = await send_chat_message(
mock_client, base_url, user_id, session_id, content
)

assert result == "Hello, world!"
mock_client.post.assert_called_once_with(
"http://example.com/api/v1/chat",
json={
"user_id": user_id,
"session_id": session_id,
"chat_data": {"messages": [{"role": "user", "content": content}]}
"chat_data": {"messages": [{"role": "user", "content": content}]},
},
)

Expand Down Expand Up @@ -60,8 +62,8 @@ async def test_send_chat_message_http_error():
content = "Hello, how are you?"

with pytest.raises(
Exception,
match="HTTP error occurred when sending message to the chat API: 400 - Bad request",
Exception,
match="HTTP error occurred when sending message to the chat API: 400 - Bad request",
):
await send_chat_message(mock_client, base_url, user_id, session_id, content)

Expand All @@ -72,10 +74,20 @@ async def test_get_chat_history_success():
mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200
expected_history = [
{"user_id": "user123", "session_id": "session123", "message": "Hello", "role": "user",
"timestamp": "2023-01-01T00:00:00Z"},
{"user_id": "user123", "session_id": "session123", "message": "Hi there", "role": "assistant",
"timestamp": "2023-01-01T00:00:01Z"}
{
"user_id": "user123",
"session_id": "session123",
"message": "Hello",
"role": "user",
"timestamp": "2023-01-01T00:00:00Z",
},
{
"user_id": "user123",
"session_id": "session123",
"message": "Hi there",
"role": "assistant",
"timestamp": "2023-01-01T00:00:01Z",
},
]
mock_response.json.return_value = expected_history
mock_client.get.return_value = mock_response
Expand Down Expand Up @@ -109,7 +121,76 @@ async def test_get_chat_history_failure():
session_id = "session123"

with pytest.raises(
Exception,
match="HTTP error occurred when fetching chat history from the chat API: 400 - Bad request",
Exception,
match="HTTP error occurred when fetching chat history from the chat API: 400 - Bad request",
):
await get_chat_history(mock_client, base_url, user_id, session_id)


@pytest.mark.asyncio
async def test_get_all_chats_success():
mock_client = AsyncMock(spec=httpx.AsyncClient)
mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200

expected_all_chats = {
"session1": [
{
"message": "Hello in session1",
"role": "USER",
"timestamp": "2023-01-01T00:00:00Z",
},
{
"message": "Response in session1",
"role": "ASSISTANT",
"timestamp": "2023-01-01T00:00:01Z",
},
],
"session2": [
{
"message": "Hello in session2",
"role": "USER",
"timestamp": "2023-01-01T00:00:02Z",
},
{
"message": "Response in session2",
"role": "ASSISTANT",
"timestamp": "2023-01-01T00:00:03Z",
},
],
}

mock_response.json.return_value = expected_all_chats
mock_client.get.return_value = mock_response

base_url = "http://example.com/api/v1"
user_id = "user123"

result = await get_all_chats(mock_client, base_url, user_id)
assert result == expected_all_chats

mock_client.get.assert_called_once_with(
f"http://example.com/api/v1/all_chats",
params={"user_id": user_id},
)


@pytest.mark.asyncio
async def test_get_all_chats_failure():
mock_client = AsyncMock(spec=httpx.AsyncClient)
mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 400
mock_response.text = "Bad request"
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
message="Bad request", request=mock_response.request, response=mock_response
)
mock_client.get.return_value = mock_response

base_url = "http://example.com/api/v1"
user_id = "user123"

with pytest.raises(
Exception,
match="HTTP error occurred when fetching all chats from the chat API: 400 - Bad request",
):
await get_all_chats(mock_client, base_url, user_id)
81 changes: 78 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,20 @@ async def test_get_chat_history_success():
user_id = "user123"
session_id = "session123"
expected_history = [
{"user_id": user_id, "session_id": session_id, "message": "Hello", "role": "user", "timestamp": "2023-01-01T00:00:00Z"},
{"user_id": user_id, "session_id": session_id, "message": "Hi there", "role": "assistant", "timestamp": "2023-01-01T00:00:01Z"}
{
"user_id": user_id,
"session_id": session_id,
"message": "Hello",
"role": "user",
"timestamp": "2023-01-01T00:00:00Z",
},
{
"user_id": user_id,
"session_id": session_id,
"message": "Hi there",
"role": "assistant",
"timestamp": "2023-01-01T00:00:01Z",
},
]

with respx.mock() as mock:
Expand All @@ -90,7 +102,70 @@ async def test_get_chat_history_failure():
client = HiveAgentClient(base_url, version)
with pytest.raises(Exception) as excinfo:
await client.get_chat_history(user_id, session_id)
assert "HTTP error occurred when fetching chat history from the chat API: 400" in str(excinfo.value)
assert (
"HTTP error occurred when fetching chat history from the chat API: 400"
in str(excinfo.value)
)


@pytest.mark.asyncio
async def test_get_all_chats_success():
user_id = "user123"

expected_all_chats = {
"session1": [
{
"message": "Hello in session1",
"role": "USER",
"timestamp": "2023-01-01T00:00:00Z",
},
{
"message": "Response in session1",
"role": "ASSISTANT",
"timestamp": "2023-01-01T00:00:01Z",
},
],
"session2": [
{
"message": "Hello in session2",
"role": "USER",
"timestamp": "2023-01-01T00:00:02Z",
},
{
"message": "Response in session2",
"role": "ASSISTANT",
"timestamp": "2023-01-01T00:00:03Z",
},
],
}

with respx.mock() as mock:
mock.get(f"{base_url}/v1/all_chats").mock(
return_value=httpx.Response(200, json=expected_all_chats)
)

client = HiveAgentClient(base_url, version)

all_chats = await client.get_all_chats(user_id)
assert all_chats == expected_all_chats


@pytest.mark.asyncio
async def test_get_all_chats_failure():
user_id = "user123"

with respx.mock() as mock:
mock.get(f"{base_url}/v1/all_chats").mock(return_value=httpx.Response(400))

client = HiveAgentClient(base_url, version)

with pytest.raises(Exception) as excinfo:
await client.get_all_chats(user_id)

assert (
"HTTP error occurred when fetching all chats from the chat API: 400"
in str(excinfo.value)
)


@pytest.mark.asyncio
Expand Down
Loading

0 comments on commit e0969b2

Please sign in to comment.