diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index b4d494b343..4b17470de6 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -202,10 +202,12 @@ async def get_reply_parts( async def convert_message(self, data: tuple) -> AstrBotMessage: username, cid, payload = data + sender_id = str(payload.get("sender_id") or username) + sender_name = str(payload.get("sender_name") or username) abm = AstrBotMessage() abm.self_id = "webchat" - abm.sender = MessageMember(username, username) + abm.sender = MessageMember(sender_id, sender_name) abm.type = MessageType.FRIEND_MESSAGE diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 5ff1913b9e..cb6c009088 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -752,6 +752,19 @@ async def chat(self, post_data: dict | None = None): enable_streaming = post_data.get("enable_streaming", True) platform_history_id = post_data.get("_platform_history_id") or "webchat" thread_selected_text = post_data.get("_thread_selected_text") + use_internal_sender = request.path.startswith("/api/v1/") and bool( + g.get("api_key_id", None) + ) + sender_id = ( + str(post_data.get("_sender_id") or username) + if use_internal_sender + else username + ) + sender_name = ( + str(post_data.get("_sender_name") or username) + if use_internal_sender + else username + ) if not session_id: return Response().error("session_id is empty").__dict__ @@ -1012,6 +1025,8 @@ def build_attachment_saved_event(part: dict | None) -> str | None: "message_id": message_id, "llm_checkpoint_id": llm_checkpoint_id, "thread_selected_text": thread_selected_text, + "sender_id": sender_id, + "sender_name": sender_name, }, ), ) diff --git a/astrbot/dashboard/routes/open_api.py b/astrbot/dashboard/routes/open_api.py index 52b412b2b5..684862e9de 100644 --- a/astrbot/dashboard/routes/open_api.py +++ b/astrbot/dashboard/routes/open_api.py @@ -65,6 +65,11 @@ def _resolve_open_username( return None, "username is empty" return username, None + @staticmethod + def _build_openapi_sender_id(api_key_id: str | None, username: str) -> str: + key_id = str(api_key_id or "unknown").strip() or "unknown" + return f"openapi:{key_id}:{username}" + def _get_chat_config_list(self) -> list[dict]: conf_list = self.core_lifecycle.astrbot_config_mgr.get_conf_list() @@ -170,6 +175,11 @@ async def chat_send(self): original_username = g.get("username", "guest") g.username = effective_username + post_data["_sender_id"] = self._build_openapi_sender_id( + g.get("api_key_id", None), + effective_username, + ) + post_data["_sender_name"] = effective_username if config_id: umo = f"webchat:FriendMessage:webchat!{effective_username}!{session_id}" try: @@ -213,10 +223,12 @@ def _extract_ws_api_key() -> str | None: return auth_header.removeprefix("ApiKey ").strip() return None - async def _authenticate_chat_ws_api_key(self) -> tuple[bool, str | None]: + async def _authenticate_chat_ws_api_key( + self, + ) -> tuple[bool, str | None, str | None]: raw_key = self._extract_ws_api_key() if not raw_key: - return False, "Missing API key" + return False, "Missing API key", None key_hash = hashlib.pbkdf2_hmac( "sha256", @@ -226,7 +238,7 @@ async def _authenticate_chat_ws_api_key(self) -> tuple[bool, str | None]: ).hex() api_key = await self.db.get_active_api_key_by_hash(key_hash) if not api_key: - return False, "Invalid API key" + return False, "Invalid API key", None if isinstance(api_key.scopes, list): scopes = api_key.scopes @@ -234,10 +246,10 @@ async def _authenticate_chat_ws_api_key(self) -> tuple[bool, str | None]: scopes = list(ALL_OPEN_API_SCOPES) if "*" not in scopes and "chat" not in scopes: - return False, "Insufficient API key scope" + return False, "Insufficient API key scope", None await self.db.touch_api_key(api_key.key_id) - return True, None + return True, None, api_key.key_id async def _send_chat_ws_error(self, message: str, code: str) -> None: await websocket.send_json( @@ -277,7 +289,11 @@ async def _update_session_config_route( return f"Failed to update chat config route: {e}" return None - async def _handle_chat_ws_send(self, post_data: dict) -> None: + async def _handle_chat_ws_send( + self, + post_data: dict, + api_key_id: str | None, + ) -> None: effective_username, username_err = self._resolve_open_username( post_data.get("username") ) @@ -331,6 +347,7 @@ async def _handle_chat_ws_send(self, post_data: dict) -> None: selected_provider = post_data.get("selected_provider") selected_model = post_data.get("selected_model") enable_streaming = post_data.get("enable_streaming", True) + sender_id = self._build_openapi_sender_id(api_key_id, effective_username) back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id) try: @@ -345,6 +362,8 @@ async def _handle_chat_ws_send(self, post_data: dict) -> None: "selected_model": selected_model, "enable_streaming": enable_streaming, "message_id": message_id, + "sender_id": sender_id, + "sender_name": effective_username, }, ) ) @@ -493,7 +512,7 @@ async def _handle_chat_ws_send(self, post_data: dict) -> None: webchat_queue_mgr.remove_back_queue(message_id) async def chat_ws(self) -> None: - authed, auth_err = await self._authenticate_chat_ws_api_key() + authed, auth_err, api_key_id = await self._authenticate_chat_ws_api_key() if not authed: await self._send_chat_ws_error(auth_err or "Unauthorized", "UNAUTHORIZED") await websocket.close(1008, auth_err or "Unauthorized") @@ -520,7 +539,7 @@ async def chat_ws(self) -> None: ) continue - await self._handle_chat_ws_send(message) + await self._handle_chat_ws_send(message, api_key_id) except Exception as e: logger.debug("Open API WS connection closed: %s", e) diff --git a/tests/test_api_key_open_api.py b/tests/test_api_key_open_api.py index 8b90e2ff48..00dddea7a6 100644 --- a/tests/test_api_key_open_api.py +++ b/tests/test_api_key_open_api.py @@ -8,9 +8,18 @@ from quart import Quart, g, request from werkzeug.datastructures import FileStorage +import astrbot.dashboard.routes.chat as chat_module +import astrbot.dashboard.routes.open_api as open_api_module from astrbot.core import LogBroker +from astrbot.core.config.default import DEFAULT_CONFIG from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db.sqlite import SQLiteDatabase +from astrbot.core.pipeline.context import PipelineContext +from astrbot.core.pipeline.waking_check.stage import WakingCheckStage +from astrbot.core.platform import AstrBotMessage, MessageMember, MessageType +from astrbot.core.platform.platform_metadata import PlatformMetadata +from astrbot.core.platform.sources.webchat.webchat_adapter import WebChatAdapter +from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent from astrbot.core.utils.auth_password import ( hash_dashboard_password, hash_legacy_dashboard_password, @@ -208,7 +217,7 @@ async def test_open_chat_send_auto_session_id_and_username( ): test_client = app.test_client() - raw_key, _ = await _create_api_key( + raw_key, key_id = await _create_api_key( app, authenticated_header, scopes=["chat"], @@ -226,6 +235,8 @@ async def fake_chat(post_data: dict | None = None): data={ "session_id": payload.get("session_id"), "creator": g.get("username"), + "sender_id": payload.get("_sender_id"), + "sender_name": payload.get("_sender_name"), } ) .__dict__ @@ -237,7 +248,9 @@ async def fake_chat(post_data: dict | None = None): "/api/v1/chat", json={ "message": "hello", - "username": "alice_auto_session", + "username": "astrbot", + "_sender_id": "evil-admin", + "_sender_name": "evil-admin", "enable_streaming": False, }, headers={"X-API-Key": raw_key}, @@ -251,12 +264,15 @@ async def fake_chat(post_data: dict | None = None): created_session_id = send_data["data"]["session_id"] assert isinstance(created_session_id, str) uuid.UUID(created_session_id) - assert send_data["data"]["creator"] == "alice_auto_session" + assert send_data["data"]["creator"] == "astrbot" + assert send_data["data"]["sender_id"] == f"openapi:{key_id}:astrbot" + assert send_data["data"]["sender_id"] not in DEFAULT_CONFIG["admins_id"] + assert send_data["data"]["sender_name"] == "astrbot" created_session = await core_lifecycle_td.db.get_platform_session_by_id( created_session_id ) assert created_session is not None - assert created_session.creator == "alice_auto_session" + assert created_session.creator == "astrbot" assert created_session.platform_id == "webchat" await core_lifecycle_td.db.create_platform_session( @@ -291,6 +307,196 @@ async def fake_chat(post_data: dict | None = None): assert missing_username_data["message"] == "Missing key: username" +@pytest.mark.asyncio +async def test_openapi_sender_id_is_used_for_webchat_event_identity(): + adapter = WebChatAdapter({}, {}, asyncio.Queue()) + message = await adapter.convert_message( + ( + "astrbot", + "openapi-repro-session", + { + "message": [{"type": "plain", "text": "/provider"}], + "message_id": "openapi-repro-message", + "sender_id": "openapi:key-123:astrbot", + "sender_name": "astrbot", + }, + ) + ) + + assert message.sender.user_id == "openapi:key-123:astrbot" + assert message.sender.user_id not in DEFAULT_CONFIG["admins_id"] + assert message.sender.nickname == "astrbot" + assert message.session_id == "webchat!astrbot!openapi-repro-session" + + +@pytest.mark.asyncio +async def test_openapi_sender_id_does_not_match_default_admin_in_waking_stage(): + stage = WakingCheckStage() + await stage.initialize( + PipelineContext( + astrbot_config={ + "admins_id": DEFAULT_CONFIG["admins_id"], + "wake_prefix": ["/"], + "platform_settings": {"friend_message_needs_wake_prefix": True}, + "disable_builtin_commands": False, + }, + plugin_manager=None, + astrbot_config_id="test", + ) + ) + platform_meta = PlatformMetadata( + name="webchat", + description="webchat", + id="webchat", + ) + + async def role_for(sender_id: str) -> str: + message_obj = AstrBotMessage() + message_obj.type = MessageType.FRIEND_MESSAGE + message_obj.self_id = "webchat" + message_obj.sender = MessageMember(sender_id, "astrbot") + message_obj.message = [] + message_obj.message_str = "hello" + message_obj.session_id = "webchat!astrbot!waking-repro" + event = WebChatMessageEvent( + message_str="hello", + message_obj=message_obj, + platform_meta=platform_meta, + session_id=message_obj.session_id, + ) + await stage.process(event) + return event.role + + assert await role_for("astrbot") == "admin" + assert await role_for("openapi:key-123:astrbot") == "member" + + +@pytest.mark.asyncio +async def test_open_chat_ws_send_uses_openapi_sender_id( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + open_api_route = _get_open_api_route(app) + session_id = f"openapi_ws_sender_{uuid.uuid4().hex[:8]}" + message_id = f"openapi-ws-message-{uuid.uuid4().hex[:8]}" + chat_queue_items = [] + websocket_messages = [] + + class CaptureQueue: + async def put(self, item): + chat_queue_items.append(item) + + back_queue = asyncio.Queue() + await back_queue.put( + { + "type": "end", + "data": "", + "message_id": message_id, + } + ) + + class FakeWebsocket: + async def send_json(self, payload): + websocket_messages.append(payload) + + monkeypatch.setattr( + open_api_module.webchat_queue_mgr, + "get_or_create_back_queue", + lambda *_args, **_kwargs: back_queue, + ) + monkeypatch.setattr( + open_api_module.webchat_queue_mgr, + "get_or_create_queue", + lambda *_args, **_kwargs: CaptureQueue(), + ) + monkeypatch.setattr( + open_api_module.webchat_queue_mgr, + "remove_back_queue", + lambda *_args, **_kwargs: None, + ) + monkeypatch.setattr(open_api_module, "websocket", FakeWebsocket()) + + await open_api_route._handle_chat_ws_send( + { + "message": "hello", + "username": "astrbot", + "session_id": session_id, + "message_id": message_id, + "enable_streaming": False, + "sender_id": "astrbot", + }, + "ws-key-123", + ) + + assert len(chat_queue_items) == 1 + queued_username, queued_session_id, queued_payload = chat_queue_items[0] + assert queued_username == "astrbot" + assert queued_session_id == session_id + assert queued_payload["sender_id"] == "openapi:ws-key-123:astrbot" + assert queued_payload["sender_id"] not in DEFAULT_CONFIG["admins_id"] + assert queued_payload["sender_name"] == "astrbot" + assert any(message["type"] == "session_id" for message in websocket_messages) + + +@pytest.mark.asyncio +async def test_dashboard_chat_send_ignores_internal_sender_fields( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch: pytest.MonkeyPatch, +): + test_client = app.test_client() + session_id = f"dashboard_sender_{uuid.uuid4().hex[:8]}" + chat_queue_items = [] + + class CaptureQueue: + async def put(self, item): + chat_queue_items.append(item) + + def fake_back_queue(*_args, **_kwargs): + queue = asyncio.Queue() + queue.put_nowait({"type": "end", "data": ""}) + return queue + + monkeypatch.setattr( + chat_module.webchat_queue_mgr, + "get_or_create_back_queue", + fake_back_queue, + ) + monkeypatch.setattr( + chat_module.webchat_queue_mgr, + "get_or_create_queue", + lambda *_args, **_kwargs: CaptureQueue(), + ) + monkeypatch.setattr( + chat_module.webchat_queue_mgr, + "remove_back_queue", + lambda *_args, **_kwargs: None, + ) + + send_res = await test_client.post( + "/api/chat/send", + json={ + "message": "hello", + "session_id": session_id, + "_sender_id": "astrbot", + "_sender_name": "astrbot", + }, + headers=authenticated_header, + ) + + assert send_res.status_code == 200 + body = await send_res.get_data(as_text=True) + assert '"type": "session_id"' in body + assert len(chat_queue_items) == 1 + queued_username, queued_session_id, queued_payload = chat_queue_items[0] + assert queued_username == core_lifecycle_td.astrbot_config["dashboard"]["username"] + assert queued_session_id == session_id + assert queued_payload["sender_id"] == queued_username + assert queued_payload["sender_name"] == queued_username + + @pytest.mark.asyncio async def test_open_chat_sessions_pagination( app: Quart,