Skip to content

Commit 2a9bc87

Browse files
committed
feat: add basic enhancer
1 parent 77d85cf commit 2a9bc87

File tree

6 files changed

+69
-22
lines changed

6 files changed

+69
-22
lines changed

.env.template

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,13 @@ DEBUG=false
5151
LOG_LEVEL=INFO
5252

5353
# PDF解析
54-
MINERU_MODEL_SOURCE=local
54+
MINERU_MODEL_SOURCE=local
55+
56+
# 信息增强
57+
LLM_MODEL_NAME=gpt-4o
58+
LLM_BASE_URL=http://192.168.120.2:4000
59+
LLM_API_KEY=ae
60+
61+
VLLM_MODEL_NAME=qwen2.5-vl-7b-instruct
62+
VLLM_API_KEY=sk-
63+
VLLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1

config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,13 @@ class Settings:
3333
MAX_FILES_PER_REQUEST: int = int(os.getenv("MAX_FILES_PER_REQUEST", "20"))
3434
TASK_TIMEOUT: int = int(os.getenv("TASK_TIMEOUT", "3600")) # 1小时
3535

36+
# 模型配置
37+
LLM_MODEL_NAME: str = os.getenv("LLM_MODEL_NAME", "gpt-4o")
38+
LLM_BASE_URL: str = os.getenv("LLM_BASE_URL", "http://192.168.120.2:4000")
39+
LLM_API_KEY: str = os.getenv("LLM_API_KEY", "sk-")
40+
41+
VLLM_MODEL_NAME: str = os.getenv("VLLM_MODEL_NAME", "qwen2.5-vl-7b-instruct")
42+
VLLM_API_KEY: str = os.getenv("VLLM_API_KEY", "sk-")
43+
VLLM_BASE_URL: str = os.getenv("VLLM_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
44+
3645
settings = Settings()

enhancers/base_models.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,44 @@
11
from abc import ABC, abstractmethod
2+
from typing import Any
3+
4+
from openai import AsyncOpenAI
5+
from pydantic import BaseModel
6+
from tenacity import retry, stop_after_attempt, wait_exponential
27

38
from parsers.base_models import ChunkData
49

10+
MAX_RETRIES = 3
11+
WAIT_TIME = 4
12+
WAIT_MAX_TIME = 15
13+
MULTIPLIER = 1
14+
15+
class JsonResponseFormat(BaseModel):
16+
"""JSON 响应格式"""
17+
description:str
518

619
class InformationEnhancer(ABC):
720
"""信息增强器基类"""
21+
def __init__(self, model_name: str, base_url: str, api_key: str):
22+
self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
23+
self.model_name = model_name
24+
self.system_prompt = "You are a helpful assistant."
25+
826
@abstractmethod
927
async def enhance(self, information: ChunkData) -> ChunkData:
1028
"""增强信息"""
1129
pass
1230

13-
class TableInformationEnhancer(InformationEnhancer):
14-
"""表格信息增强器"""
15-
16-
async def enhance(self, information: ChunkData) -> ChunkData:
17-
"""增强信息"""
18-
return information
19-
20-
class FormulasInformationEnhancer(InformationEnhancer):
21-
"""公式信息增强器"""
22-
23-
async def enhance(self, information: ChunkData) -> ChunkData:
24-
"""增强信息"""
25-
return information
26-
27-
class ImageInformationEnhancer(InformationEnhancer):
28-
"""图片信息增强器"""
29-
30-
async def enhance(self, information: ChunkData) -> ChunkData:
31-
"""增强信息"""
32-
return information
31+
@retry(stop=stop_after_attempt(MAX_RETRIES), wait=wait_exponential(multiplier=MULTIPLIER, min=WAIT_TIME, max=WAIT_MAX_TIME))
32+
async def get_structured_response(self, user_prompt: list[dict[str, Any]], response_format: JsonResponseFormat) -> str|None:
33+
"""获取结构化响应"""
34+
response = await self.client.chat.completions.parse(
35+
model=self.model_name,
36+
messages=[
37+
{"role": "system", "content": self.system_prompt},
38+
{"role": "user", "content": user_prompt} # type: ignore
39+
],
40+
response_format=response_format # type: ignore
41+
)
42+
if response.choices[0].message.refusal:
43+
raise ValueError(f"模型拒绝了请求: {response.choices[0].message.refusal}")
44+
return response.choices[0].message.parsed

enhancers/enhancer_registry.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
from collections.abc import Callable
99

10+
from config import settings
1011
from enhancers.base_models import InformationEnhancer
1112
from parsers.base_models import ChunkType
1213

@@ -67,7 +68,11 @@ def get_enhancer(modality: ChunkType) -> InformationEnhancer | None:
6768

6869
enhancer_class = ENHANCER_REGISTRY[modality_type]
6970
try:
70-
return enhancer_class()
71+
match modality_type:
72+
case ChunkType.IMAGE.value.lower():
73+
return enhancer_class(settings.VLLM_MODEL_NAME, settings.VLLM_BASE_URL, settings.VLLM_API_KEY)
74+
case _:
75+
return enhancer_class(settings.LLM_MODEL_NAME, settings.LLM_BASE_URL, settings.LLM_API_KEY)
7176
except Exception as e:
7277
logger.error(f"创建信息增强器实例失败: {enhancer_class.__name__}, 错误: {e}")
7378
return None

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies = [
1818
"docling>=2.45.0",
1919
"mineru[core]>=2.1.11",
2020
"beautifulsoup4>=4.13.4",
21+
"tenacity>=9.1.2",
2122
]
2223

2324
[dependency-groups]

uv.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)