diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py index 44947255787..52ab8ff8a63 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py @@ -147,6 +147,7 @@ .append_default_model_info(model_info_list[3]) .append_default_model_info(model_info_list[4]) .append_default_model_info(model_info_list[0]) + .append_default_model_info(model_info_list[2]) .append_model_info_list(model_info_ttv_list) .append_default_model_info(model_info_ttv_list[0]) .append_model_info_list(module_info_itv_list) diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py index 28127c055f3..a8d7224defb 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py @@ -92,4 +92,5 @@ def encryption_dict(self, model: Dict[str, Any]) -> Dict[str, Any]: def get_model_params_setting_form(self, model_name): return BaiLianEmbeddingModelParams() + api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1') dashscope_api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/image.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/image.py index 24657f6f447..d74f8af7b7e 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/image.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/image.py @@ -84,6 +84,7 @@ def is_valid( def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]: return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1') api_key = forms.PasswordInputField('API Key', required=True) def get_model_params_setting_form(self, model_name): diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/itv.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/itv.py index 7f38f49c416..413d8eb1d5d 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/itv.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/itv.py @@ -5,6 +5,7 @@ from django.utils.translation import gettext_lazy as _, gettext from common.exception.app_exception import AppApiException +from common import forms from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel from common.forms.switch_field import SwitchField from models_provider.base_model_provider import BaseModelCredential, ValidCode @@ -41,6 +42,7 @@ class ImageToVideoModelCredential(BaseForm, BaseModelCredential): Provides validation and encryption for the model credentials. """ + api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1') api_key = PasswordInputField('API Key', required=True) def is_valid( diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py index 9511e2ef0de..104f98f901a 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py @@ -39,7 +39,7 @@ class BaiLianLLMModelParams(BaseForm): class BaiLianLLMModelCredential(BaseForm, BaseModelCredential): - api_base = forms.TextInputField(_('API URL'), required=True) + api_base = forms.TextInputField(_('API URL'), required=True, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1') api_key = forms.PasswordInputField(_('API Key'), required=True) def is_valid( diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py index a43b4007b6d..b3264e882ef 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py @@ -6,6 +6,7 @@ from langchain_core.documents import Document from common.exception.app_exception import AppApiException +from common import forms from common.forms import BaseForm, PasswordInputField from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker @@ -17,6 +18,7 @@ class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential): Provides validation and encryption for the model credentials. """ + api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1') dashscope_api_key = PasswordInputField('API Key', required=True) def is_valid( diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/asr_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/asr_stt.py index 6ca485006fc..5909dcf43bf 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/asr_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/asr_stt.py @@ -10,7 +10,7 @@ class AliyunBaiLianAsrSTTModelCredential(BaseForm, BaseModelCredential): - api_url = forms.TextInputField(_('API URL'), required=True) + api_base = forms.TextInputField(_('API URL'), required=True, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1') api_key = forms.PasswordInputField(_('API Key'), required=True) def is_valid(self, diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/default_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/default_stt.py index 85a1802e2ad..fc0f4d6fb8b 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/default_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/default_stt.py @@ -27,7 +27,7 @@ class AliyunBaiLianDefaultSTTModelCredential(BaseForm, BaseModelCredential): {'label': _('Real-time speech recognition - Fun-ASR/Paraformer'), 'value': 'other'} ]) - api_url = forms.TextInputField(_('API URL'), required=True, relation_show_field_dict={'type': ['qwen', 'omni']}) + api_base = forms.TextInputField(_('API URL'), required=True, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1', relation_show_field_dict={'type': ['qwen', 'omni']}) api_key = forms.PasswordInputField(_('API Key'), required=True) def is_valid(self, diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/omni_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/omni_stt.py index 34c737017bc..0588c05b9ad 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/omni_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/omni_stt.py @@ -17,7 +17,7 @@ class AliyunBaiLianOmiSTTModelParams(BaseForm): class AliyunBaiLianOmiSTTModelCredential(BaseForm, BaseModelCredential): - api_url = forms.TextInputField(_('API URL'), required=True) + api_base = forms.TextInputField(_('API URL'), required=True, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1') api_key = forms.PasswordInputField(_('API Key'), required=True) def is_valid(self, diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/stt.py index dd2f56c239a..b12e14bed32 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/stt.py @@ -24,6 +24,7 @@ class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential): Provides validation and encryption for the model credentials. """ + api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1') api_key = PasswordInputField("API Key", required=True) def is_valid( diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tti.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tti.py index 7825501e4b3..42d7869cfac 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tti.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tti.py @@ -5,6 +5,7 @@ from django.utils.translation import gettext_lazy as _, gettext from common.exception.app_exception import AppApiException +from common import forms from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from common.utils.logger import maxkb_logger @@ -68,6 +69,7 @@ class QwenTextToImageModelCredential(BaseForm, BaseModelCredential): Provides validation and encryption for the model credentials. """ + api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1') api_key = PasswordInputField('API Key', required=True) def is_valid( diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py index 089d259006f..1f88aeb4c85 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py @@ -5,6 +5,7 @@ from django.utils.translation import gettext_lazy as _, gettext from common.exception.app_exception import AppApiException +from common import forms from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from common.utils.logger import maxkb_logger @@ -53,12 +54,55 @@ class AliyunBaiLianTTSModelGeneralParams(BaseForm): ) +class AliyunBaiLianTTSQwenFlashParams(BaseForm): + """ + Parameters class for the qwen TTS models. + """ + voice = SingleSelect( + TooltipLabel(_('Voice'), _('Voice options for qwen TTS models')), + required=True, + default_value='Cherry', + text_field='value', + value_field='value', + option_list=[ + {'text': '芊悦', 'value': 'Cherry'}, + {'text': '苏瑶', 'value': 'Serena'}, + {'text': '晨煦', 'value': 'Ethan'}, + {'text': '千雪', 'value': 'Chelsie'}, + {'text': '月白', 'value': 'Moon'}, + {'text': '四月', 'value': 'Maia'}, + {'text': '凯', 'value': 'Kai'}, + ] + ) + + language_type = SingleSelect( + TooltipLabel(_('Language Type'), _('Language type for the speech synthesis')), + required=True, + default_value='Chinese', + text_field='value', + value_field='value', + option_list=[ + {'text': _('Chinese'), 'value': 'Chinese'}, + {'text': _('English'), 'value': 'English'}, + {'text': _('German'), 'value': 'German'}, + {'text': _('Italian'), 'value': 'Italian'}, + {'text': _('Portuguese'), 'value': 'Portuguese'}, + {'text': _('Spanish'), 'value': 'Spanish'}, + {'text': _('Japanese'), 'value': 'Japanese'}, + {'text': _('Korean'), 'value': 'Korean'}, + {'text': _('French'), 'value': 'French'}, + {'text': _('Russian'), 'value': 'Russian'}, + ] + ) + + class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential): """ Credential class for the Aliyun BaiLian TTS (Text-to-Speech) model. Provides validation and encryption for the model credentials. """ + api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com') api_key = PasswordInputField("API Key", required=True) def is_valid( @@ -135,4 +179,9 @@ def get_model_params_setting_form(self, model_name: str): :param model_name: Name of the model. :return: Parameter setting form. """ - return AliyunBaiLianTTSModelGeneralParams() + # 根据模型名称返回不同的参数设置表单 + if model_name in ['cosyvoice-v1']: + return AliyunBaiLianTTSModelGeneralParams() + else: + # 其他模型(包括qwen-tts系列)使用新的参数表单 + return AliyunBaiLianTTSQwenFlashParams() diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/ttv.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/ttv.py index b78f86ab910..1dd2c2011f1 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/ttv.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/ttv.py @@ -5,6 +5,7 @@ from django.utils.translation import gettext_lazy as _, gettext from common.exception.app_exception import AppApiException +from common import forms from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel from common.forms.switch_field import SwitchField from models_provider.base_model_provider import BaseModelCredential, ValidCode @@ -43,6 +44,7 @@ class TextToVideoModelCredential(BaseForm, BaseModelCredential): Provides validation and encryption for the model credentials. """ + api_base = forms.TextInputField(_('API URL'), required=False, default_value='https://dashscope.aliyuncs.com/compatible-mode/v1') api_key = PasswordInputField('API Key', required=True) def is_valid( diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py index c02f5dc52be..bac5789e367 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py @@ -17,8 +17,9 @@ class AliyunBaiLianEmbedding(MaxKBBaseModel): model_name: str optional_params: dict - def __init__(self, api_key, model_name: str, optional_params: dict): - self.client = OpenAI(api_key=api_key, base_url='https://dashscope.aliyuncs.com/compatible-mode/v1').embeddings + def __init__(self, api_key, api_base, model_name: str, optional_params: dict): + api_base = api_base or 'https://dashscope.aliyuncs.com/compatible-mode/v1' + self.client = OpenAI(api_key=api_key, base_url=api_base).embeddings self.model_name = model_name self.optional_params = optional_params @@ -30,6 +31,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return AliyunBaiLianEmbedding( api_key=model_credential.get('dashscope_api_key'), + api_base=model_credential.get('api_base'), model_name=model_name, optional_params=optional_params ) diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py index ac4f2085549..43f346845fe 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py @@ -24,10 +24,11 @@ def is_cache_model(): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + api_base = model_credential.get('api_base') or 'https://dashscope.aliyuncs.com/compatible-mode/v1' chat_tong_yi = QwenVLChatModel( model_name=model_name, openai_api_key=model_credential.get('api_key'), - openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1', + openai_api_base=api_base, # stream_options={"include_usage": True}, streaming=True, stream_usage=True, @@ -41,7 +42,15 @@ def check_auth(self, api_key): def get_upload_policy(self, api_key, model_name): """获取文件上传凭证""" - url = "https://dashscope.aliyuncs.com/api/v1/uploads" + # 如果有自定义api_base,提取host部分,否则使用默认URL + if hasattr(self, 'openai_api_base') and self.openai_api_base: + # 从api_base中提取host,替换默认URL + from urllib.parse import urlparse + parsed_url = urlparse(self.openai_api_base) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + url = f"{base_url}/api/v1/uploads" + else: + url = "https://dashscope.aliyuncs.com/api/v1/uploads" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" @@ -109,7 +118,11 @@ def stream( stop: Optional[list[str]] = None, **kwargs: Any, ) -> Iterator[BaseMessageChunk]: - url = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions" + # 如果有自定义api_base,使用它,否则使用默认URL + if hasattr(self, 'openai_api_base') and self.openai_api_base: + url = f"{self.openai_api_base}/chat/completions" + else: + url = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions" headers = { "Authorization": f"Bearer {self.openai_api_key.get_secret_value()}", diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/reranker.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/reranker.py index 9b5b121a05f..c20c55c18cd 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/reranker.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/reranker.py @@ -20,6 +20,7 @@ class AliyunBaiLianReranker(MaxKBBaseModel, BaseDocumentCompressor): model: Optional[str] api_key: Optional[str] + api_base: Optional[str] top_n: Optional[int] = 3 # 取前 N 个最相关的结果 @@ -31,6 +32,7 @@ def is_cache_model(): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): return AliyunBaiLianReranker(model=model_name, api_key=model_credential.get('dashscope_api_key'), + api_base=model_credential.get('api_base'), top_n=model_kwargs.get('top_n', 3)) def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ @@ -39,6 +41,9 @@ def compress_documents(self, documents: Sequence[Document], query: str, callback return [] texts = [doc.page_content for doc in documents] + # 如果提供了api_base,则配置dashscope使用自定义endpoint + if self.api_base: + dashscope.base_http_url = self.api_base resp = dashscope.TextReRank.call( model=self.model, query=query, diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/asr_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/asr_stt.py index f902d0a7d96..7a6f0460ea7 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/asr_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/asr_stt.py @@ -32,7 +32,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** return AliyunBaiLianAsrSpeechToText( model=model_name, api_key=model_credential.get('api_key'), - api_url=model_credential.get('api_url'), + api_url=model_credential.get('api_base'), params=model_kwargs, **model_kwargs ) diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/default_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/default_stt.py index 6607345bb24..d39a7f3d275 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/default_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/default_stt.py @@ -34,7 +34,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** return AliyunBaiLianAsrSpeechToText( model=model_name, api_key=model_credential.get('api_key'), - api_url=model_credential.get('api_url'), + api_url=model_credential.get('api_base'), params=model_kwargs, **model_kwargs ) @@ -42,7 +42,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** return AliyunBaiLianOmiSpeechToText( model=model_name, api_key=model_credential.get('api_key'), - api_url=model_credential.get('api_url'), + api_url=model_credential.get('api_base'), params=model_kwargs, **model_kwargs ) @@ -50,6 +50,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** return AliyunBaiLianSpeechToText( model=model_name, api_key=model_credential.get('api_key'), + api_base=model_credential.get('api_base'), params=model_kwargs, **model_kwargs, ) diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/omni_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/omni_stt.py index db098c991e3..74e558216db 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/omni_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/omni_stt.py @@ -32,7 +32,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** return AliyunBaiLianOmiSpeechToText( model=model_name, api_key=model_credential.get('api_key'), - api_url=model_credential.get('api_url') , + api_url=model_credential.get('api_base') , params= model_kwargs, **model_kwargs ) diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/stt.py index f48c0adf291..1cc7ff4e3c4 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/stt.py @@ -12,12 +12,14 @@ class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText): api_key: str + api_base: str model: str params: dict def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') self.model = kwargs.get('model') self.params = kwargs.get('params') @@ -36,6 +38,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** return AliyunBaiLianSpeechToText( model=model_name, api_key=model_credential.get('api_key'), + api_base=model_credential.get('api_base'), params=model_kwargs, **optional_params, ) @@ -47,6 +50,9 @@ def check_auth(self): def speech_to_text(self, audio_file): dashscope.api_key = self.api_key + # 如果提供了api_base,则配置dashscope使用自定义endpoint + if self.api_base: + dashscope.base_http_url = self.api_base recognition_params = { 'model': self.model, 'format': 'mp3', diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/tti.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/tti.py index 2ca3696af9b..0056925f9f8 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/tti.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/tti.py @@ -15,12 +15,14 @@ class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage): api_key: str + api_base: str model_name: str params: dict def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') self.model_name = kwargs.get('model_name') self.params = kwargs.get('params') @@ -37,6 +39,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** chat_tong_yi = QwenTextToImageModel( model_name=model_name, api_key=model_credential.get('api_key'), + api_base=model_credential.get('api_base'), **optional_params, ) return chat_tong_yi @@ -47,9 +50,11 @@ def check_auth(self): def generate_image(self, prompt: str, negative_prompt: str = None): if self.model_name.startswith("wan"): + # 如果提供了api_base,则使用自定义base_url,否则使用默认URL + base_url = self.api_base or 'https://dashscope.aliyuncs.com/compatible-mode/v1' rsp = ImageSynthesis.call(api_key=self.api_key, model=self.model_name, - base_url='https://dashscope.aliyuncs.com/compatible-mode/v1', + base_url=base_url, prompt=prompt, negative_prompt=negative_prompt, **self.params) @@ -73,12 +78,14 @@ def generate_image(self, prompt: str, negative_prompt: str = None): ] } ] + # 如果提供了api_base,则使用自定义base_url,否则使用默认URL + base_url = self.api_base or 'https://dashscope.aliyuncs.com/v1' rsp = MultiModalConversation.call( api_key=self.api_key, model=self.model_name, messages=messages, result_format='message', - base_url='https://dashscope.aliyuncs.com/v1', + base_url=base_url, stream=False, negative_prompt=negative_prompt, **self.params diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/tts.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/tts.py index bea3d584c20..53a9f6cbe72 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/tts.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/tts.py @@ -1,6 +1,7 @@ from typing import Dict import dashscope +from dashscope.api_entities.dashscope_response import DashScopeAPIResponse from django.utils.translation import gettext as _ @@ -11,12 +12,14 @@ class AliyunBaiLianTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): api_key: str + api_base: str model: str params: dict def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') self.model = kwargs.get('model') self.params = kwargs.get('params') @@ -34,6 +37,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** return AliyunBaiLianTextToSpeech( model=model_name, api_key=model_credential.get('api_key'), + api_base=model_credential.get('api_base'), **optional_params, ) @@ -42,14 +46,74 @@ def check_auth(self): def text_to_speech(self, text): dashscope.api_key = self.api_key + # 如果提供了api_base,则配置dashscope使用自定义endpoint + if self.api_base: + dashscope.base_http_url = self.api_base text = _remove_empty_lines(text) + + # 为sambert模型使用特定的API if 'sambert' in self.model: from dashscope.audio.tts import SpeechSynthesizer audio = SpeechSynthesizer.call(model=self.model, text=text, **self.params).get_audio_data() - else: + # 为cosyvoice-v1模型使用tts_v2 API + elif self.model in ['cosyvoice-v1']: from dashscope.audio.tts_v2 import SpeechSynthesizer synthesizer = SpeechSynthesizer(model=self.model, **self.params) audio = synthesizer.call(text) + # 其他模型(包括qwen-tts系列)使用multimodal-generation API + else: + import requests + import json + + headers = { + 'Authorization': f'Bearer {self.api_key}', + 'Content-Type': 'application/json' + } + + # 设置默认参数 + voice = self.params.get('voice', 'Cherry') + language_type = self.params.get('language_type', 'Chinese') + + data = { + 'model': self.model, + 'input': { + 'text': text, + 'voice': voice, + 'language_type': language_type + } + } + + # 添加其他可能的参数 + for key, value in self.params.items(): + if key not in ['voice', 'language_type']: + data['input'][key] = value + + url = f"{self.api_base}/api/v1/services/aigc/multimodal-generation/generation" if self.api_base else "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation" + + response = requests.post(url, headers=headers, data=json.dumps(data)) + + if response.status_code != 200: + raise Exception(f'Failed to generate audio: {response.text}') + + response_data = response.json() + + # 提取音频数据,根据实际返回格式 + if 'output' in response_data and 'audio' in response_data['output']: + audio_data = response_data['output']['audio'] + audio_url = audio_data.get('url') + + if audio_url: + # 下载音频文件 + audio_response = requests.get(audio_url) + if audio_response.status_code == 200: + return audio_response.content + else: + raise Exception(f'Failed to download audio: {audio_response.text}') + else: + raise Exception(f'No audio URL in response: {response_data}') + else: + raise Exception(f'Unexpected response format: {response_data}') + if audio is None: raise Exception('Failed to generate audio') if type(audio) == str: diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/ttv.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/ttv.py index 234bdc21713..f40152a36a9 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/ttv.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/ttv.py @@ -14,6 +14,7 @@ class GenerationVideoModel(MaxKBBaseModel, BaseGenerationVideo): api_key: str + api_base: str model_name: str params: dict max_retries: int = 3 @@ -22,6 +23,7 @@ class GenerationVideoModel(MaxKBBaseModel, BaseGenerationVideo): def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') self.model_name = kwargs.get('model_name') self.params = kwargs.get('params', {}) self.max_retries = kwargs.get('max_retries', 3) @@ -40,6 +42,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** return GenerationVideoModel( model_name=model_name, api_key=model_credential.get('api_key'), + api_base=model_credential.get('api_base'), **optional_params, ) @@ -83,6 +86,9 @@ def generate_video(self, prompt, negative_prompt=None, first_frame_url=None, las params.update(self.params) # --- 异步提交任务 --- + # 如果提供了api_base,则配置dashscope使用自定义endpoint + if self.api_base: + params['base_url'] = self.api_base rsp = self._safe_call(VideoSynthesis.async_call, **params) if rsp.status_code != HTTPStatus.OK: maxkb_logger.info(f'提交任务失败,status_code: {rsp.status_code}, code: {rsp.code}, message: {rsp.message}')