Skip to content

Commit 2c39c98

Browse files
committed
feat: Support gemini stt model
1 parent 6169870 commit 2c39c98

File tree

3 files changed

+116
-1
lines changed

3 files changed

+116
-1
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# coding=utf-8
2+
from typing import Dict
3+
4+
from common import forms
5+
from common.exception.app_exception import AppApiException
6+
from common.forms import BaseForm
7+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
8+
9+
10+
class GeminiSTTModelCredential(BaseForm, BaseModelCredential):
11+
api_key = forms.PasswordInputField('API Key', required=True)
12+
13+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
14+
raise_exception=False):
15+
model_type_list = provider.get_model_type_list()
16+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
17+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
18+
19+
for key in ['api_key']:
20+
if key not in model_credential:
21+
if raise_exception:
22+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
23+
else:
24+
return False
25+
try:
26+
model = provider.get_model(model_type, model_name, model_credential)
27+
model.check_auth()
28+
except Exception as e:
29+
if isinstance(e, AppApiException):
30+
raise e
31+
if raise_exception:
32+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
33+
else:
34+
return False
35+
return True
36+
37+
def encryption_dict(self, model: Dict[str, object]):
38+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
39+
40+
def get_model_params_setting_form(self, model_name):
41+
pass

apps/setting/models_provider/impl/gemini_model_provider/gemini_model_provider.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
ModelInfoManage
1414
from setting.models_provider.impl.gemini_model_provider.credential.image import GeminiImageModelCredential
1515
from setting.models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential
16+
from setting.models_provider.impl.gemini_model_provider.credential.stt import GeminiSTTModelCredential
1617
from setting.models_provider.impl.gemini_model_provider.model.image import GeminiImage
1718
from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel
19+
from setting.models_provider.impl.gemini_model_provider.model.stt import GeminiSpeechToText
1820
from smartdoc.conf import PROJECT_DIR
1921

2022
gemini_llm_model_credential = GeminiLLMModelCredential()
2123
gemini_image_model_credential = GeminiImageModelCredential()
24+
gemini_stt_model_credential = GeminiSTTModelCredential()
2225

2326
model_info_list = [
2427
ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
@@ -42,14 +45,25 @@
4245
GeminiImage),
4346
]
4447

45-
48+
model_stt_info_list = [
49+
ModelInfo('gemini-1.5-flash', '最新的Gemini 1.5 Flash模型,随Google更新而更新',
50+
ModelTypeConst.STT,
51+
gemini_stt_model_credential,
52+
GeminiSpeechToText),
53+
ModelInfo('gemini-1.5-pro', '最新的Gemini 1.5 Flash模型,随Google更新而更新',
54+
ModelTypeConst.STT,
55+
gemini_stt_model_credential,
56+
GeminiSpeechToText),
57+
]
4658

4759
model_info_manage = (
4860
ModelInfoManage.builder()
4961
.append_model_info_list(model_info_list)
5062
.append_model_info_list(model_image_info_list)
63+
.append_model_info_list(model_stt_info_list)
5164
.append_default_model_info(model_info_list[0])
5265
.append_default_model_info(model_image_info_list[0])
66+
.append_default_model_info(model_stt_info_list[0])
5367
.build()
5468
)
5569

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import asyncio
2+
import io
3+
from typing import Dict
4+
5+
from langchain_core.messages import HumanMessage
6+
from langchain_google_genai import ChatGoogleGenerativeAI
7+
from openai import OpenAI
8+
9+
from common.config.tokenizer_manage_config import TokenizerManage
10+
from setting.models_provider.base_model_provider import MaxKBBaseModel
11+
from setting.models_provider.impl.base_stt import BaseSpeechToText
12+
import google.generativeai as genai
13+
14+
15+
def custom_get_token_ids(text: str):
16+
tokenizer = TokenizerManage.get_tokenizer()
17+
return tokenizer.encode(text)
18+
19+
20+
class GeminiSpeechToText(MaxKBBaseModel, BaseSpeechToText):
21+
api_key: str
22+
model: str
23+
24+
def __init__(self, **kwargs):
25+
super().__init__(**kwargs)
26+
self.api_key = kwargs.get('api_key')
27+
28+
@staticmethod
29+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
30+
optional_params = {}
31+
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
32+
optional_params['max_tokens'] = model_kwargs['max_tokens']
33+
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
34+
optional_params['temperature'] = model_kwargs['temperature']
35+
return GeminiSpeechToText(
36+
model=model_name,
37+
api_key=model_credential.get('api_key'),
38+
**optional_params,
39+
)
40+
41+
def check_auth(self):
42+
client = ChatGoogleGenerativeAI(
43+
model=self.model,
44+
google_api_key=self.api_key
45+
)
46+
response_list = client.invoke('你好')
47+
# print(response_list)
48+
49+
def speech_to_text(self, audio_file):
50+
client = ChatGoogleGenerativeAI(
51+
model=self.model,
52+
google_api_key=self.api_key
53+
)
54+
audio_data = audio_file.read()
55+
msg = HumanMessage(content=[
56+
{'type': 'text', 'text': '把音频转成文字'},
57+
{"type": "media", 'mime_type': 'audio/mp3', "data": audio_data}
58+
])
59+
res = client.invoke([msg])
60+
return res.content

0 commit comments

Comments
 (0)