diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 401a97262c..aa106e2f47 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,11 +7,11 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.2 + rev: v0.9.3 hooks: - id: ruff - repo: https://github.com/psf/black - rev: 24.10.0 + rev: 25.1.0 hooks: - id: black - repo: https://github.com/pre-commit/mirrors-prettier diff --git a/app/backend/chat_history/cosmosdb.py b/app/backend/chat_history/cosmosdb.py index 49760970f7..fca1585b93 100644 --- a/app/backend/chat_history/cosmosdb.py +++ b/app/backend/chat_history/cosmosdb.py @@ -4,12 +4,13 @@ from azure.cosmos.aio import ContainerProxy, CosmosClient from azure.identity.aio import AzureDeveloperCliCredential, ManagedIdentityCredential -from quart import Blueprint, current_app, jsonify, request +from quart import Blueprint, current_app, jsonify, make_response, request from config import ( CONFIG_CHAT_HISTORY_COSMOS_ENABLED, CONFIG_COSMOS_HISTORY_CLIENT, CONFIG_COSMOS_HISTORY_CONTAINER, + CONFIG_COSMOS_HISTORY_VERSION, CONFIG_CREDENTIAL, ) from decorators import authenticated @@ -34,23 +35,50 @@ async def post_chat_history(auth_claims: Dict[str, Any]): try: request_json = await request.get_json() - id = request_json.get("id") - answers = request_json.get("answers") - title = answers[0][0][:50] + "..." if len(answers[0][0]) > 50 else answers[0][0] + session_id = request_json.get("id") + message_pairs = request_json.get("answers") + first_question = message_pairs[0][0] + title = first_question + "..." if len(first_question) > 50 else first_question timestamp = int(time.time() * 1000) - await container.upsert_item( - {"id": id, "entra_oid": entra_oid, "title": title, "answers": answers, "timestamp": timestamp} - ) + # Insert the session item: + session_item = { + "id": session_id, + "version": current_app.config[CONFIG_COSMOS_HISTORY_VERSION], + "session_id": session_id, + "entra_oid": entra_oid, + "type": "session", + "title": title, + "timestamp": timestamp, + } + + message_pair_items = [] + # Now insert a message item for each question/response pair: + for ind, message_pair in enumerate(message_pairs): + message_pair_items.append( + { + "id": f"{session_id}-{ind}", + "version": current_app.config[CONFIG_COSMOS_HISTORY_VERSION], + "session_id": session_id, + "entra_oid": entra_oid, + "type": "message_pair", + "question": message_pair[0], + "response": message_pair[1], + } + ) + batch_operations = [("upsert", (session_item,))] + [ + ("upsert", (message_pair_item,)) for message_pair_item in message_pair_items + ] + await container.execute_item_batch(batch_operations=batch_operations, partition_key=[entra_oid, session_id]) return jsonify({}), 201 except Exception as error: return error_response(error, "/chat_history") -@chat_history_cosmosdb_bp.post("/chat_history/items") +@chat_history_cosmosdb_bp.get("/chat_history/sessions") @authenticated -async def get_chat_history(auth_claims: Dict[str, Any]): +async def get_chat_history_sessions(auth_claims: Dict[str, Any]): if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]: return jsonify({"error": "Chat history not enabled"}), 400 @@ -63,27 +91,26 @@ async def get_chat_history(auth_claims: Dict[str, Any]): return jsonify({"error": "User OID not found"}), 401 try: - request_json = await request.get_json() - count = request_json.get("count", 20) - continuation_token = request_json.get("continuation_token") + count = int(request.args.get("count", 10)) + continuation_token = request.args.get("continuation_token") res = container.query_items( - query="SELECT c.id, c.entra_oid, c.title, c.timestamp FROM c WHERE c.entra_oid = @entra_oid ORDER BY c.timestamp DESC", - parameters=[dict(name="@entra_oid", value=entra_oid)], + query="SELECT c.id, c.entra_oid, c.title, c.timestamp FROM c WHERE c.entra_oid = @entra_oid AND c.type = @type ORDER BY c.timestamp DESC", + parameters=[dict(name="@entra_oid", value=entra_oid), dict(name="@type", value="session")], + partition_key=[entra_oid], max_item_count=count, ) - # set the continuation token for the next page pager = res.by_page(continuation_token) # Get the first page, and the continuation token + sessions = [] try: page = await pager.__anext__() continuation_token = pager.continuation_token # type: ignore - items = [] async for item in page: - items.append( + sessions.append( { "id": item.get("id"), "entra_oid": item.get("entra_oid"), @@ -94,18 +121,17 @@ async def get_chat_history(auth_claims: Dict[str, Any]): # If there are no more pages, StopAsyncIteration is raised except StopAsyncIteration: - items = [] continuation_token = None - return jsonify({"items": items, "continuation_token": continuation_token}), 200 + return jsonify({"sessions": sessions, "continuation_token": continuation_token}), 200 except Exception as error: - return error_response(error, "/chat_history/items") + return error_response(error, "/chat_history/sessions") -@chat_history_cosmosdb_bp.get("/chat_history/items/") +@chat_history_cosmosdb_bp.get("/chat_history/sessions/") @authenticated -async def get_chat_history_session(auth_claims: Dict[str, Any], item_id: str): +async def get_chat_history_session(auth_claims: Dict[str, Any], session_id: str): if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]: return jsonify({"error": "Chat history not enabled"}), 400 @@ -118,26 +144,34 @@ async def get_chat_history_session(auth_claims: Dict[str, Any], item_id: str): return jsonify({"error": "User OID not found"}), 401 try: - res = await container.read_item(item=item_id, partition_key=entra_oid) + res = container.query_items( + query="SELECT * FROM c WHERE c.session_id = @session_id AND c.type = @type", + parameters=[dict(name="@session_id", value=session_id), dict(name="@type", value="message_pair")], + partition_key=[entra_oid, session_id], + ) + + message_pairs = [] + async for page in res.by_page(): + async for item in page: + message_pairs.append([item["question"], item["response"]]) + return ( jsonify( { - "id": res.get("id"), - "entra_oid": res.get("entra_oid"), - "title": res.get("title", "untitled"), - "timestamp": res.get("timestamp"), - "answers": res.get("answers", []), + "id": session_id, + "entra_oid": entra_oid, + "answers": message_pairs, } ), 200, ) except Exception as error: - return error_response(error, f"/chat_history/items/{item_id}") + return error_response(error, f"/chat_history/sessions/{session_id}") -@chat_history_cosmosdb_bp.delete("/chat_history/items/") +@chat_history_cosmosdb_bp.delete("/chat_history/sessions/") @authenticated -async def delete_chat_history_session(auth_claims: Dict[str, Any], item_id: str): +async def delete_chat_history_session(auth_claims: Dict[str, Any], session_id: str): if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]: return jsonify({"error": "Chat history not enabled"}), 400 @@ -150,10 +184,22 @@ async def delete_chat_history_session(auth_claims: Dict[str, Any], item_id: str) return jsonify({"error": "User OID not found"}), 401 try: - await container.delete_item(item=item_id, partition_key=entra_oid) - return jsonify({}), 204 + res = container.query_items( + query="SELECT c.id FROM c WHERE c.session_id = @session_id", + parameters=[dict(name="@session_id", value=session_id)], + partition_key=[entra_oid, session_id], + ) + + ids_to_delete = [] + async for page in res.by_page(): + async for item in page: + ids_to_delete.append(item["id"]) + + batch_operations = [("delete", (id,)) for id in ids_to_delete] + await container.execute_item_batch(batch_operations=batch_operations, partition_key=[entra_oid, session_id]) + return await make_response("", 204) except Exception as error: - return error_response(error, f"/chat_history/items/{item_id}") + return error_response(error, f"/chat_history/sessions/{session_id}") @chat_history_cosmosdb_bp.before_app_serving @@ -183,6 +229,7 @@ async def setup_clients(): current_app.config[CONFIG_COSMOS_HISTORY_CLIENT] = cosmos_client current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER] = cosmos_container + current_app.config[CONFIG_COSMOS_HISTORY_VERSION] = os.environ["AZURE_CHAT_HISTORY_VERSION"] @chat_history_cosmosdb_bp.after_app_serving diff --git a/app/backend/config.py b/app/backend/config.py index eaba154116..a9315df6c0 100644 --- a/app/backend/config.py +++ b/app/backend/config.py @@ -26,3 +26,4 @@ CONFIG_CHAT_HISTORY_COSMOS_ENABLED = "chat_history_cosmos_enabled" CONFIG_COSMOS_HISTORY_CLIENT = "cosmos_history_client" CONFIG_COSMOS_HISTORY_CONTAINER = "cosmos_history_container" +CONFIG_COSMOS_HISTORY_VERSION = "cosmos_history_version" diff --git a/app/backend/requirements.txt b/app/backend/requirements.txt index 59277f9941..22194d07cd 100644 --- a/app/backend/requirements.txt +++ b/app/backend/requirements.txt @@ -47,7 +47,7 @@ azure-core==1.30.2 # msrest azure-core-tracing-opentelemetry==1.0.0b11 # via azure-monitor-opentelemetry -azure-cosmos==4.7.0 +azure-cosmos==4.9.0 # via -r requirements.in azure-identity==1.17.1 # via diff --git a/app/frontend/src/api/api.ts b/app/frontend/src/api/api.ts index 76636d4d05..df95f801b5 100644 --- a/app/frontend/src/api/api.ts +++ b/app/frontend/src/api/api.ts @@ -1,6 +1,6 @@ const BACKEND_URI = ""; -import { ChatAppResponse, ChatAppResponseOrError, ChatAppRequest, Config, SimpleAPIResponse, HistoryListApiResponse, HistroyApiResponse } from "./models"; +import { ChatAppResponse, ChatAppResponseOrError, ChatAppRequest, Config, SimpleAPIResponse, HistoryListApiResponse, HistoryApiResponse } from "./models"; import { useLogin, getToken, isUsingAppServicesLogin } from "../authConfig"; export async function getHeaders(idToken: string | undefined): Promise> { @@ -145,10 +145,14 @@ export async function postChatHistoryApi(item: any, idToken: string): Promise { const headers = await getHeaders(idToken); - const response = await fetch("/chat_history/items", { - method: "POST", - headers: { ...headers, "Content-Type": "application/json" }, - body: JSON.stringify({ count: count, continuation_token: continuationToken }) + let url = `${BACKEND_URI}/chat_history/sessions?count=${count}`; + if (continuationToken) { + url += `&continuationToken=${continuationToken}`; + } + + const response = await fetch(url.toString(), { + method: "GET", + headers: { ...headers, "Content-Type": "application/json" } }); if (!response.ok) { @@ -159,9 +163,9 @@ export async function getChatHistoryListApi(count: number, continuationToken: st return dataResponse; } -export async function getChatHistoryApi(id: string, idToken: string): Promise { +export async function getChatHistoryApi(id: string, idToken: string): Promise { const headers = await getHeaders(idToken); - const response = await fetch(`/chat_history/items/${id}`, { + const response = await fetch(`/chat_history/sessions/${id}`, { method: "GET", headers: { ...headers, "Content-Type": "application/json" } }); @@ -170,13 +174,13 @@ export async function getChatHistoryApi(id: string, idToken: string): Promise { const headers = await getHeaders(idToken); - const response = await fetch(`/chat_history/items/${id}`, { + const response = await fetch(`/chat_history/sessions/${id}`, { method: "DELETE", headers: { ...headers, "Content-Type": "application/json" } }); @@ -184,7 +188,4 @@ export async function deleteChatHistoryApi(id: string, idToken: string): Promise if (!response.ok) { throw new Error(`Deleting chat history failed: ${response.statusText}`); } - - const dataResponse: any = await response.json(); - return dataResponse; } diff --git a/app/frontend/src/api/models.ts b/app/frontend/src/api/models.ts index ef1fa154b0..f560271325 100644 --- a/app/frontend/src/api/models.ts +++ b/app/frontend/src/api/models.ts @@ -107,7 +107,7 @@ export interface SpeechConfig { } export type HistoryListApiResponse = { - items: { + sessions: { id: string; entra_oid: string; title: string; @@ -116,10 +116,8 @@ export type HistoryListApiResponse = { continuation_token?: string; }; -export type HistroyApiResponse = { +export type HistoryApiResponse = { id: string; entra_oid: string; - title: string; answers: any; - timestamp: number; }; diff --git a/app/frontend/src/components/HistoryProviders/CosmosDB.ts b/app/frontend/src/components/HistoryProviders/CosmosDB.ts index 4d613b28a8..5da9df5493 100644 --- a/app/frontend/src/components/HistoryProviders/CosmosDB.ts +++ b/app/frontend/src/components/HistoryProviders/CosmosDB.ts @@ -23,10 +23,10 @@ export class CosmosDBProvider implements IHistoryProvider { if (!this.continuationToken) { this.isItemEnd = true; } - return response.items.map(item => ({ - id: item.id, - title: item.title, - timestamp: item.timestamp + return response.sessions.map(session => ({ + id: session.id, + title: session.title, + timestamp: session.timestamp })); } catch (e) { console.error(e); diff --git a/infra/main.bicep b/infra/main.bicep index 214e41767c..bae630d6f0 100644 --- a/infra/main.bicep +++ b/infra/main.bicep @@ -70,7 +70,8 @@ param cosmosDbLocation string = '' param cosmosDbAccountName string = '' param cosmosDbThroughput int = 400 param chatHistoryDatabaseName string = 'chat-database' -param chatHistoryContainerName string = 'chat-history' +param chatHistoryContainerName string = 'chat-history-v2' +param chatHistoryVersion string = 'cosmosdb-v2' // https://learn.microsoft.com/azure/ai-services/openai/concepts/models?tabs=python-secure%2Cstandard%2Cstandard-chat-completions#standard-deployment-model-availability @description('Location for the OpenAI resource group') @@ -375,6 +376,7 @@ var appEnvVariables = { AZURE_COSMOSDB_ACCOUNT: (useAuthentication && useChatHistoryCosmos) ? cosmosDb.outputs.name : '' AZURE_CHAT_HISTORY_DATABASE: chatHistoryDatabaseName AZURE_CHAT_HISTORY_CONTAINER: chatHistoryContainerName + AZURE_CHAT_HISTORY_VERSION: chatHistoryVersion // Shared by all OpenAI deployments OPENAI_HOST: openAiHost AZURE_OPENAI_EMB_MODEL_NAME: embedding.modelName @@ -799,26 +801,31 @@ module cosmosDb 'br/public:avm/res/document-db/database-account:0.6.1' = if (use containers: [ { name: chatHistoryContainerName + kind: 'MultiHash' paths: [ '/entra_oid' + '/session_id' ] indexingPolicy: { indexingMode: 'consistent' automatic: true includedPaths: [ { - path: '/*' + path: '/entra_oid/?' + } + { + path: '/session_id/?' } - ] - excludedPaths: [ { - path: '/title/?' + path: '/timestamp/?' } { - path: '/answers/*' + path: '/type/?' } + ] + excludedPaths: [ { - path: '/"_etag"/?' + path: '/*' } ] } @@ -1210,6 +1217,7 @@ output AZURE_SEARCH_SERVICE_ASSIGNED_USERID string = searchService.outputs.princ output AZURE_COSMOSDB_ACCOUNT string = (useAuthentication && useChatHistoryCosmos) ? cosmosDb.outputs.name : '' output AZURE_CHAT_HISTORY_DATABASE string = chatHistoryDatabaseName output AZURE_CHAT_HISTORY_CONTAINER string = chatHistoryContainerName +output AZURE_CHAT_HISTORY_VERSION string = chatHistoryVersion output AZURE_STORAGE_ACCOUNT string = storage.outputs.name output AZURE_STORAGE_CONTAINER string = storageContainerName diff --git a/tests/conftest.py b/tests/conftest.py index 90d6e112aa..431ac4b9bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -423,6 +423,7 @@ async def auth_public_documents_client( monkeypatch.setenv("AZURE_COSMOSDB_ACCOUNT", "test-cosmosdb-account") monkeypatch.setenv("AZURE_CHAT_HISTORY_DATABASE", "test-cosmosdb-database") monkeypatch.setenv("AZURE_CHAT_HISTORY_CONTAINER", "test-cosmosdb-container") + monkeypatch.setenv("AZURE_CHAT_HISTORY_VERSION", "cosmosdb-v2") for key, value in request.param.items(): monkeypatch.setenv(key, value) diff --git a/tests/snapshots/test_cosmosdb/test_chathistory_getitem/auth_public_documents_client0/result.json b/tests/snapshots/test_cosmosdb/test_chathistory_getitem/auth_public_documents_client0/result.json index 62f6f0f0f5..ab59969d83 100644 --- a/tests/snapshots/test_cosmosdb/test_chathistory_getitem/auth_public_documents_client0/result.json +++ b/tests/snapshots/test_cosmosdb/test_chathistory_getitem/auth_public_documents_client0/result.json @@ -1,11 +1,19 @@ { "answers": [ [ - "This is a test message" + "What does a Product Manager do?", + { + "delta": { + "role": "assistant" + }, + "message": { + "content": "A Product Manager is responsible for leading the product management team and providing guidance on product strategy, design, development, and launch. They collaborate with internal teams and external partners to ensure successful product execution. They also develop and implement product life-cycle management processes, monitor industry trends, develop product marketing plans, research customer needs, collaborate with internal teams, develop pricing strategies, oversee product portfolio, analyze product performance, and identify areas for improvement [role_library.pdf#page=29][role_library.pdf#page=12][role_library.pdf#page=23].", + "role": "assistant" + }, + "session_state": "143c0240-b2ee-4090-8e90-2a1c58124894" + } ] ], "entra_oid": "OID_X", - "id": "123", - "timestamp": 123456789, - "title": "This is a test message" + "id": "123" } \ No newline at end of file diff --git a/tests/snapshots/test_cosmosdb/test_chathistory_query/auth_public_documents_client0/result.json b/tests/snapshots/test_cosmosdb/test_chathistory_query/auth_public_documents_client0/result.json index 1c3422fbca..1d78dee8b3 100644 --- a/tests/snapshots/test_cosmosdb/test_chathistory_query/auth_public_documents_client0/result.json +++ b/tests/snapshots/test_cosmosdb/test_chathistory_query/auth_public_documents_client0/result.json @@ -1,6 +1,6 @@ { "continuation_token": "next", - "items": [ + "sessions": [ { "entra_oid": "OID_X", "id": "123", diff --git a/tests/snapshots/test_cosmosdb/test_chathistory_query_continuation/auth_public_documents_client0/result.json b/tests/snapshots/test_cosmosdb/test_chathistory_query_continuation/auth_public_documents_client0/result.json index 2983ef0689..9dca87c8e0 100644 --- a/tests/snapshots/test_cosmosdb/test_chathistory_query_continuation/auth_public_documents_client0/result.json +++ b/tests/snapshots/test_cosmosdb/test_chathistory_query_continuation/auth_public_documents_client0/result.json @@ -1,4 +1,4 @@ { "continuation_token": null, - "items": [] + "sessions": [] } \ No newline at end of file diff --git a/tests/test_cosmosdb.py b/tests/test_cosmosdb.py index 5495d8447a..6efc9b7f3a 100644 --- a/tests/test_cosmosdb.py +++ b/tests/test_cosmosdb.py @@ -1,3 +1,4 @@ +import copy import json import pytest @@ -5,23 +6,77 @@ from .mocks import MockAsyncPageIterator +for_sessions_query = [ + [ + { + "id": "123", + "session_id": "123", + "entra_oid": "OID_X", + "title": "This is a test message", + "timestamp": 123456789, + "type": "session", + } + ] +] + +for_deletion_query = [ + [ + { + "id": "123", + "session_id": "123", + "entra_oid": "OID_X", + "title": "This is a test message", + "timestamp": 123456789, + "type": "session", + }, + { + "id": "123-0", + "version": "cosmosdb-v2", + "session_id": "123", + "entra_oid": "OID_X", + "type": "message_pair", + "question": "What does a Product Manager do?", + "response": { + "delta": {"role": "assistant"}, + "session_state": "143c0240-b2ee-4090-8e90-2a1c58124894", + "message": { + "content": "A Product Manager is responsible for leading the product management team and providing guidance on product strategy, design, development, and launch. They collaborate with internal teams and external partners to ensure successful product execution. They also develop and implement product life-cycle management processes, monitor industry trends, develop product marketing plans, research customer needs, collaborate with internal teams, develop pricing strategies, oversee product portfolio, analyze product performance, and identify areas for improvement [role_library.pdf#page=29][role_library.pdf#page=12][role_library.pdf#page=23].", + "role": "assistant", + }, + }, + "order": 0, + "timestamp": None, + }, + ] +] + +for_message_pairs_query = [ + [ + { + "id": "123-0", + "version": "cosmosdb-v2", + "session_id": "123", + "entra_oid": "OID_X", + "type": "message_pair", + "question": "What does a Product Manager do?", + "response": { + "delta": {"role": "assistant"}, + "session_state": "143c0240-b2ee-4090-8e90-2a1c58124894", + "message": { + "content": "A Product Manager is responsible for leading the product management team and providing guidance on product strategy, design, development, and launch. They collaborate with internal teams and external partners to ensure successful product execution. They also develop and implement product life-cycle management processes, monitor industry trends, develop product marketing plans, research customer needs, collaborate with internal teams, develop pricing strategies, oversee product portfolio, analyze product performance, and identify areas for improvement [role_library.pdf#page=29][role_library.pdf#page=12][role_library.pdf#page=23].", + "role": "assistant", + }, + }, + "order": 0, + "timestamp": None, + }, + ] +] + class MockCosmosDBResultsIterator: - def __init__(self, empty=False): - if empty: - self.data = [] - else: - self.data = [ - [ - { - "id": "123", - "entra_oid": "OID_X", - "title": "This is a test message", - "timestamp": 123456789, - "answers": [["This is a test message"]], - } - ] - ] + def __init__(self, data=[]): + self.data = copy.deepcopy(data) def __aiter__(self): return self @@ -45,20 +100,33 @@ def by_page(self, continuation_token=None): @pytest.mark.asyncio async def test_chathistory_newitem(auth_public_documents_client, monkeypatch): - async def mock_upsert_item(container_proxy, item, **kwargs): - assert item["id"] == "123" - assert item["answers"] == [["This is a test message"]] - assert item["entra_oid"] == "OID_X" - assert item["title"] == "This is a test message" - - monkeypatch.setattr(ContainerProxy, "upsert_item", mock_upsert_item) + async def mock_execute_item_batch(container_proxy, **kwargs): + partition_key = kwargs["partition_key"] + assert partition_key == ["OID_X", "123"] + operations = kwargs["batch_operations"] + assert len(operations) == 2 + assert operations[0][0] == "upsert" + assert operations[1][0] == "upsert" + session = operations[0][1][0] + assert session["id"] == "123" + assert session["session_id"] == "123" + assert session["entra_oid"] == "OID_X" + assert session["title"] == "This is a test message" + message = operations[1][1][0] + assert message["id"] == "123-0" + assert message["session_id"] == "123" + assert message["entra_oid"] == "OID_X" + assert message["question"] == "This is a test message" + assert message["response"] == "This is a test answer" + + monkeypatch.setattr(ContainerProxy, "execute_item_batch", mock_execute_item_batch) response = await auth_public_documents_client.post( "/chat_history", headers={"Authorization": "Bearer MockToken"}, json={ "id": "123", - "answers": [["This is a test message"]], + "answers": [["This is a test message", "This is a test answer"]], }, ) assert response.status_code == 201 @@ -122,7 +190,7 @@ async def mock_upsert_item(container_proxy, item, **kwargs): ) assert response.status_code == 500 assert (await response.get_json()) == { - "error": "The app encountered an error processing your request.\nIf you are an administrator of the app, view the full error in the logs. See aka.ms/appservice-logs for more information.\nError type: \n" + "error": "The app encountered an error processing your request.\nIf you are an administrator of the app, view the full error in the logs. See aka.ms/appservice-logs for more information.\nError type: \n" } @@ -130,14 +198,12 @@ async def mock_upsert_item(container_proxy, item, **kwargs): async def test_chathistory_query(auth_public_documents_client, monkeypatch, snapshot): def mock_query_items(container_proxy, query, **kwargs): - return MockCosmosDBResultsIterator() + return MockCosmosDBResultsIterator(for_sessions_query) monkeypatch.setattr(ContainerProxy, "query_items", mock_query_items) - response = await auth_public_documents_client.post( - "/chat_history/items", - headers={"Authorization": "Bearer MockToken"}, - json={"count": 20}, + response = await auth_public_documents_client.get( + "/chat_history/sessions?count=20", headers={"Authorization": "Bearer MockToken"} ) assert response.status_code == 200 result = await response.get_json() @@ -148,14 +214,12 @@ def mock_query_items(container_proxy, query, **kwargs): async def test_chathistory_query_continuation(auth_public_documents_client, monkeypatch, snapshot): def mock_query_items(container_proxy, query, **kwargs): - return MockCosmosDBResultsIterator(empty=True) + return MockCosmosDBResultsIterator() monkeypatch.setattr(ContainerProxy, "query_items", mock_query_items) - response = await auth_public_documents_client.post( - "/chat_history/items", - headers={"Authorization": "Bearer MockToken"}, - json={"count": 20}, + response = await auth_public_documents_client.get( + "/chat_history/sessions?count=20&continuation_token=123", headers={"Authorization": "Bearer MockToken"} ) assert response.status_code == 200 result = await response.get_json() @@ -165,40 +229,22 @@ def mock_query_items(container_proxy, query, **kwargs): @pytest.mark.asyncio async def test_chathistory_query_error_disabled(client, monkeypatch): - response = await client.post( - "/chat_history/items", - headers={"Authorization": "Bearer MockToken"}, - json={ - "id": "123", - "answers": [["This is a test message"]], - }, - ) + response = await client.get("/chat_history/sessions", headers={"Authorization": "Bearer MockToken"}) assert response.status_code == 400 @pytest.mark.asyncio async def test_chathistory_query_error_container(auth_public_documents_client, monkeypatch): auth_public_documents_client.app.config["cosmos_history_container"] = None - response = await auth_public_documents_client.post( - "/chat_history/items", - headers={"Authorization": "Bearer MockToken"}, - json={ - "id": "123", - "answers": [["This is a test message"]], - }, + response = await auth_public_documents_client.get( + "/chat_history/sessions", headers={"Authorization": "Bearer MockToken"} ) assert response.status_code == 400 @pytest.mark.asyncio async def test_chathistory_query_error_entra(auth_public_documents_client, monkeypatch): - response = await auth_public_documents_client.post( - "/chat_history/items", - json={ - "id": "123", - "answers": [["This is a test message"]], - }, - ) + response = await auth_public_documents_client.get("/chat_history/sessions") assert response.status_code == 401 @@ -210,10 +256,8 @@ def mock_query_items(container_proxy, query, **kwargs): monkeypatch.setattr(ContainerProxy, "query_items", mock_query_items) - response = await auth_public_documents_client.post( - "/chat_history/items", - headers={"Authorization": "Bearer MockToken"}, - json={"count": 20}, + response = await auth_public_documents_client.get( + "/chat_history/sessions?count=20", headers={"Authorization": "Bearer MockToken"} ) assert response.status_code == 500 assert (await response.get_json()) == { @@ -225,19 +269,13 @@ def mock_query_items(container_proxy, query, **kwargs): @pytest.mark.asyncio async def test_chathistory_getitem(auth_public_documents_client, monkeypatch, snapshot): - async def mock_read_item(container_proxy, item, partition_key, **kwargs): - return { - "id": "123", - "entra_oid": "OID_X", - "title": "This is a test message", - "timestamp": 123456789, - "answers": [["This is a test message"]], - } + def mock_query_items(container_proxy, query, **kwargs): + return MockCosmosDBResultsIterator(for_message_pairs_query) - monkeypatch.setattr(ContainerProxy, "read_item", mock_read_item) + monkeypatch.setattr(ContainerProxy, "query_items", mock_query_items) response = await auth_public_documents_client.get( - "/chat_history/items/123", + "/chat_history/sessions/123", headers={"Authorization": "Bearer MockToken"}, ) assert response.status_code == 200 @@ -250,7 +288,7 @@ async def mock_read_item(container_proxy, item, partition_key, **kwargs): async def test_chathistory_getitem_error_disabled(client, monkeypatch): response = await client.get( - "/chat_history/items/123", + "/chat_history/sessions/123", headers={"Authorization": "BearerMockToken"}, ) assert response.status_code == 400 @@ -260,7 +298,7 @@ async def test_chathistory_getitem_error_disabled(client, monkeypatch): async def test_chathistory_getitem_error_container(auth_public_documents_client, monkeypatch): auth_public_documents_client.app.config["cosmos_history_container"] = None response = await auth_public_documents_client.get( - "/chat_history/items/123", + "/chat_history/sessions/123", headers={"Authorization": "BearerMockToken"}, ) assert response.status_code == 400 @@ -269,7 +307,7 @@ async def test_chathistory_getitem_error_container(auth_public_documents_client, @pytest.mark.asyncio async def test_chathistory_getitem_error_entra(auth_public_documents_client, monkeypatch): response = await auth_public_documents_client.get( - "/chat_history/items/123", + "/chat_history/sessions/123", ) assert response.status_code == 401 @@ -283,7 +321,7 @@ async def mock_read_item(container_proxy, item, partition_key, **kwargs): monkeypatch.setattr(ContainerProxy, "read_item", mock_read_item) response = await auth_public_documents_client.get( - "/chat_history/items/123", + "/chat_history/sessions/123", headers={"Authorization": "Bearer MockToken"}, ) assert response.status_code == 500 @@ -293,14 +331,26 @@ async def mock_read_item(container_proxy, item, partition_key, **kwargs): @pytest.mark.asyncio async def test_chathistory_deleteitem(auth_public_documents_client, monkeypatch): - async def mock_delete_item(container_proxy, item, partition_key, **kwargs): - assert item == "123" - assert partition_key == "OID_X" + def mock_query_items(container_proxy, query, **kwargs): + return MockCosmosDBResultsIterator(for_deletion_query) - monkeypatch.setattr(ContainerProxy, "delete_item", mock_delete_item) + monkeypatch.setattr(ContainerProxy, "query_items", mock_query_items) + + # mock the batch delete operation + async def mock_execute_item_batch(container_proxy, **kwargs): + partition_key = kwargs["partition_key"] + assert partition_key == ["OID_X", "123"] + operations = kwargs["batch_operations"] + assert len(operations) == 2 + assert operations[0][0] == "delete" + assert operations[1][0] == "delete" + assert operations[0][1][0] == "123" + assert operations[1][1][0] == "123-0" + + monkeypatch.setattr(ContainerProxy, "execute_item_batch", mock_execute_item_batch) response = await auth_public_documents_client.delete( - "/chat_history/items/123", + "/chat_history/sessions/123", headers={"Authorization": "Bearer MockToken"}, ) assert response.status_code == 204 @@ -310,7 +360,7 @@ async def mock_delete_item(container_proxy, item, partition_key, **kwargs): async def test_chathistory_deleteitem_error_disabled(client, monkeypatch): response = await client.delete( - "/chat_history/items/123", + "/chat_history/sessions/123", headers={"Authorization": "Bearer MockToken"}, ) assert response.status_code == 400 @@ -320,7 +370,7 @@ async def test_chathistory_deleteitem_error_disabled(client, monkeypatch): async def test_chathistory_deleteitem_error_container(auth_public_documents_client, monkeypatch): auth_public_documents_client.app.config["cosmos_history_container"] = None response = await auth_public_documents_client.delete( - "/chat_history/items/123", + "/chat_history/sessions/123", headers={"Authorization": "Bearer MockToken"}, ) assert response.status_code == 400 @@ -329,7 +379,7 @@ async def test_chathistory_deleteitem_error_container(auth_public_documents_clie @pytest.mark.asyncio async def test_chathistory_deleteitem_error_entra(auth_public_documents_client, monkeypatch): response = await auth_public_documents_client.delete( - "/chat_history/items/123", + "/chat_history/sessions/123", ) assert response.status_code == 401 @@ -343,7 +393,7 @@ async def mock_delete_item(container_proxy, item, partition_key, **kwargs): monkeypatch.setattr(ContainerProxy, "delete_item", mock_delete_item) response = await auth_public_documents_client.delete( - "/chat_history/items/123", + "/chat_history/sessions/123", headers={"Authorization": "Bearer MockToken"}, ) assert response.status_code == 500 diff --git a/tests/test_listfilestrategy.py b/tests/test_listfilestrategy.py index c0247f55ae..4937a3aa26 100644 --- a/tests/test_listfilestrategy.py +++ b/tests/test_listfilestrategy.py @@ -41,7 +41,7 @@ def test_file_filename_to_id(): # test ascii filename assert File(empty).filename_to_id() == "file-foo_pdf-666F6F2E706466" # test filename containing unicode - empty.name = "foo\u00A9.txt" + empty.name = "foo\u00a9.txt" assert File(empty).filename_to_id() == "file-foo__txt-666F6FC2A92E747874" # test filenaming starting with unicode empty.name = "ファイル名.pdf"