Skip to content

美化输出 #1075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions astrbot/core/core_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map
from .exec_hook import ExtractException


class AstrBotCoreLifecycle:
Expand Down Expand Up @@ -106,9 +107,8 @@ async def _task_wrapper(self, task: asyncio.Task):
pass
except Exception as e:
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
for line in traceback.format_exc().split("\n"):
logger.error(f"| {line}")
logger.error("-------")
err_msg = ExtractException(type(e), e, e.__traceback__, rich_printable=True)
logger.error(err_msg)

async def start(self):
self._load()
Expand All @@ -124,8 +124,9 @@ async def start(self):
f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler()
except BaseException:
logger.error(traceback.format_exc())
except BaseException as e:
err_msg = ExtractException(type(e), e, e.__traceback__, rich_printable=True)
logger.error(err_msg)

await asyncio.gather(*self.curr_tasks, return_exceptions=True)

Expand Down
105 changes: 105 additions & 0 deletions astrbot/core/exec_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
""" exception hook """

import traceback
import multiprocessing
import threading
import inspect
import sys
from rich.console import Console
from rich.text import Text
from rich.panel import Panel

console = Console()

def format_stack_trace(exctype, value, tb, max_depth=15, nested=False) -> Text:
tb_list = traceback.extract_tb(tb)
exception_info = Text()

if nested:
exception_info.append(f"{exctype.__name__}: {value}\n", style="bold red")
else:
# 获取当前进程和线程名称
process_name = multiprocessing.current_process().name
thread_name = threading.current_thread().name
exception_info.append(
f"Exception in process: {process_name}, thread: {thread_name}; {exctype.__name__}: {value}\n",
style="bold red"
)
exception_info.append("Traceback (most recent call last):\n", style="bold")

# 限制堆栈跟踪的深度
limited_tb_list = tb_list[:max_depth]
more_frames = len(tb_list) - max_depth

for i, (filename, lineno, funcname, line) in enumerate(limited_tb_list):
# 获取函数所在的模块名
module_name = inspect.getmodulename(filename)
exception_info.append(
f" at {module_name}.{funcname} in ({filename}:{lineno})\n",
style="yellow"
)

if more_frames > 0:
exception_info.append(f" ... {more_frames} more ...\n", style="dim")

# 检查是否有原因和其他信息
cause = getattr(value, "__cause__", None)
context = getattr(value, "__context__", None)

if cause:
exception_info.append("Caused by: \n", style="bold red")
exception_info.append(format_stack_trace(type(cause), cause, cause.__traceback__, nested=True))
if context and not cause:
exception_info.append("Original exception: \n", style="bold red")
exception_info.append(format_stack_trace(type(context), context, context.__traceback__, nested=True))

return exception_info

def ExtractException(exctype, value, tb, panel: bool = True, rich_printable: bool = False) -> Text | Panel | str | None:
"""
- panel: 是否以Panel形式返回异常信息
- rich_printable: 是否以可打印的格式返回异常信息 (把rich转换为普通print或者 stdout | stderr等控制台输出有效果的格式)
"""
# 获取回溯信息并格式化为字符串
_exc_info = None
if all(x is None for x in (exctype, value, tb)):
return None
tb_str = format_stack_trace(exctype, value, tb)
# 返回异常信息
if panel:
_exc_info = Panel(tb_str, title="[bold red]Exception Occurred[/bold red]", border_style="red")
if rich_printable:
with console.capture() as capture:
console.print(_exc_info)
return capture.get()
return _exc_info

def sys_excepthook(exctype, value, tb):
# 获取异常信息并打印到控制台
exception_info = ExtractException(exctype, value, tb , panel=True)
if exception_info:
console.print(exception_info)

def set_exechook():
"""
设置全局异常处理函数
"""
sys.excepthook = sys_excepthook

def GetStackTrace(vokedepth: int = 1) -> Text:
"""
获取堆栈跟踪信息
"""
# 获取当前调用栈信息的前两层
stack = traceback.extract_stack(limit=vokedepth)
stack_info = Text("Stack Trace:\n", style="bold")
for frame in stack[:-vokedepth+1]:
Copy link
Preview

Copilot AI Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When vokedepth is 1, this slice returns an empty list, potentially omitting useful stack trace information. Consider adjusting the slicing logic to ensure at least one stack frame is captured.

Suggested change
for frame in stack[:-vokedepth+1]:
if vokedepth == 1:
frames_to_include = stack[-1:]
else:
frames_to_include = stack[:-vokedepth+1]
for frame in frames_to_include:

Copilot uses AI. Check for mistakes.

filename = frame.filename
line = frame.lineno
funcname = frame.name
stack_info.append(f" at {funcname} in ({filename}:{line})\n", style="yellow")
return stack_info

if __name__ == "__main__":

set_exechook()
4 changes: 2 additions & 2 deletions astrbot/core/initial_loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio
import traceback
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core import LogBroker
from astrbot.dashboard.server import AstrBotDashboard
from astrbot.core.exec_hook import ExtractException


class InitialLoader:
Expand All @@ -21,7 +21,7 @@ async def start(self):
await core_lifecycle.initialize()
core_task = core_lifecycle.start()
except Exception as e:
logger.critical(traceback.format_exc())
logger.critical(ExtractException(type(e), e, e.__traceback__, rich_printable=True))
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")

self.dashboard_server = AstrBotDashboard(
Expand Down
81 changes: 47 additions & 34 deletions astrbot/core/log.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
import logging
import colorlog
import asyncio
from rich.console import Console
from rich.logging import RichHandler
from rich.theme import Theme
import os
from collections import deque
from asyncio import Queue
from typing import List
import asyncio

CACHED_SIZE = 200
log_color_config = {
"DEBUG": "green",
"INFO": "bold_cyan",
"WARNING": "bold_yellow",
"ERROR": "red",
"CRITICAL": "bold_red",
"RESET": "reset",
"asctime": "green",
}


def is_plugin_path(pathname):
"""
Expand Down Expand Up @@ -46,22 +37,28 @@ def get_short_level_name(level_name):
class LogBroker:
def __init__(self):
self.log_cache = deque(maxlen=CACHED_SIZE)
self.subscribers: List[Queue] = []
self.subscribers: list[Queue] = []

def register(self) -> Queue:
"""给每个订阅者返回一个带有日志缓存的队列"""
"""
给每个订阅者返回一个带有日志缓存的队列
"""
q = Queue(maxsize=CACHED_SIZE + 10)
for log in self.log_cache:
q.put_nowait(log)
self.subscribers.append(q)
return q

def unregister(self, q: Queue):
"""取消订阅"""
"""
取消订阅
"""
self.subscribers.remove(q)

def publish(self, log_entry: str):
"""发布消息"""
"""
发布消息
"""
self.log_cache.append(log_entry)
for q in self.subscribers:
try:
Expand All @@ -81,18 +78,32 @@ def emit(self, record):


class LogManager:
_default_log_theme = Theme({
"log.time": "bold white",
"log.level": "bold",
"log.message": "white",
"log.level.debug": "bold blue",
"log.level.info": "bold cyan",
"log.level.warning": "bold yellow",
"log.level.error": "bold red",
"log.level.critical": "bold red"
})

_console = Console(theme=_default_log_theme)

@classmethod
def GetLogger(cls, log_name: str = "default"):
logger = logging.getLogger(log_name)
if logger.hasHandlers():
return logger
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)

console_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
datefmt="%H:%M:%S",
log_colors=log_color_config,
console_handler = RichHandler(console=cls._console)
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(
logging.Formatter(
datefmt="%X",
fmt="| %(plugin_tag)s | %(filename)s:%(lineno)d ===>> %(message)s"
)
)

class PluginFilter(logging.Filter):
Expand All @@ -103,7 +114,6 @@ def filter(self, record):
return True

class FileNameFilter(logging.Filter):
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
def filter(self, record):
dirname = os.path.dirname(record.pathname)
record.filename = (
Expand All @@ -114,12 +124,10 @@ def filter(self, record):
return True

class LevelNameFilter(logging.Filter):
# 添加短日志级别名称
def filter(self, record):
record.short_levelname = get_short_level_name(record.levelname)
return True

console_handler.setFormatter(console_formatter)
logger.addFilter(PluginFilter())
logger.addFilter(FileNameFilter())
logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器
Expand All @@ -132,13 +140,18 @@ def filter(self, record):
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
handler = LogQueueHandler(log_broker)
handler.setLevel(logging.DEBUG)
if logger.handlers:
handler.setFormatter(logger.handlers[0].formatter)
else:
# 为队列处理器设置相同格式的formatter
handler.setFormatter(
logging.Formatter(
"[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s"
)
handler.setFormatter(
logging.Formatter(
datefmt="%X",
fmt=" %(asctime)s | %(short_levelname)s | %(plugin_tag)s | %(filename)s:%(lineno)d ===>> %(message)s"
)
)
logger.addHandler(handler)

if __name__ == "__main__":
logger = LogManager.GetLogger("test")
logger.debug("这是一个调试信息")
logger.info("这是一个信息")
logger.warning("这是一个警告")
logger.error("这是一个错误")
logger.critical("这是一个严重错误")
25 changes: 15 additions & 10 deletions astrbot/core/provider/func_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import asyncio
import copy

from typing import Dict, List, Awaitable, Literal, Any
from typing import Awaitable, Literal, Any
from dataclasses import dataclass
from typing import Optional
from contextlib import AsyncExitStack
from astrbot import logger

try:
import mcp
import mcp # type: ignore
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")

Expand All @@ -34,8 +34,11 @@ class FuncTool:
"""

name: str
parameters: Dict
""" 调用函数的名字 """
parameters: dict
""" 调用的时候传入的参数 """
description: str
""" 函数工具的描述 """
handler: Awaitable = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str = None
Expand Down Expand Up @@ -86,7 +89,7 @@ def __init__(self):

self.name = None
self.active: bool = True
self.tools: List[mcp.Tool] = []
self.tools: list[mcp.Tool] = []

async def connect_to_server(self, mcp_server_config: dict):
"""Connect to an MCP server
Expand Down Expand Up @@ -124,15 +127,16 @@ async def cleanup(self):

class FuncCall:
def __init__(self) -> None:
self.func_list: List[FuncTool] = []
self.func_list: list[FuncTool] = []
"""内部加载的 func tools"""
self.mcp_client_dict: Dict[str, MCPClient] = {}
self.mcp_client_dict: dict[str, MCPClient] = {}
"""MCP 服务列表"""
self.mcp_service_queue = asyncio.Queue()
"""用于外部控制 MCP 服务的启停"""
self.mcp_client_event: Dict[str, asyncio.Event] = {}
self.mcp_client_event: dict[str, asyncio.Event] = {}

def empty(self) -> bool:
""" 是否为空 """
return len(self.func_list) == 0

def add_func(
Expand Down Expand Up @@ -179,7 +183,8 @@ def remove_func(self, name: str) -> None:
self.func_list.pop(i)
break

def get_func(self, name) -> FuncTool:
def get_func(self, name: str) -> FuncTool:
""" 获取函数工具 """
for f in self.func_list:
if f.name == name:
return f
Expand Down Expand Up @@ -215,7 +220,7 @@ async def _init_mcp_clients(self) -> None:
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
return

mcp_server_json_obj: Dict[str, Dict] = json.load(
mcp_server_json_obj: dict[str, dict] = json.load(
open(mcp_json_file, "r", encoding="utf-8")
)["mcpServers"]

Expand Down Expand Up @@ -383,7 +388,7 @@ def get_func_desc_anthropic_style(self) -> list:
tools.append(tool)
return tools

def get_func_desc_google_genai_style(self) -> Dict:
def get_func_desc_google_genai_style(self) -> dict:
declarations = {}
tools = []
for f in self.func_list:
Expand Down
2 changes: 1 addition & 1 deletion astrbot/core/star/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

`AstrBot Star` 就是插件。

在 AstrBot v4.0 版本后,AstrBot 内部将插件命名为 `star`。插件的 handler 称作 `star_handler`。
在 AstrBot v4.0 版本后,AstrBot 内部将插件命名为 `star`。插件的 handler 称作 `star_handler`。
Loading