diff --git a/g4f/Provider/Blackbox.py b/g4f/Provider/Blackbox.py index 8d820344ca4..75abb1836df 100644 --- a/g4f/Provider/Blackbox.py +++ b/g4f/Provider/Blackbox.py @@ -20,17 +20,14 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin): supports_system_message = True supports_message_history = True _last_validated_value = None - + default_model = 'blackboxai' - - image_models = ['Image Generation', 'repomap'] - - userSelectedModel = ['gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro'] - + default_image_model = 'generate_image' + image_models = [default_image_model, 'repomap'] + text_models = [default_model, 'gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro'] agentMode = { 'Image Generation': {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"}, } - trendingAgentMode = { "gemini-1.5-flash": {'mode': True, 'id': 'Gemini'}, "llama-3.1-8b": {'mode': True, 'id': "llama-3.1-8b"}, @@ -72,12 +69,8 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin): 'Youtube Agent': {'mode': True, 'id': "Youtube Agent"}, 'builder Agent': {'mode': True, 'id': "builder Agent"}, } - model_prefixes = {mode: f"@{value['id']}" for mode, value in trendingAgentMode.items() if mode not in ["gemini-1.5-flash", "llama-3.1-8b", "llama-3.1-70b", "llama-3.1-405b", "repomap"]} - - - models = [default_model, *userSelectedModel, *list(agentMode.keys()), *list(trendingAgentMode.keys())] - + models = [*text_models, default_image_model, *list(trendingAgentMode.keys())] model_aliases = { "gemini-flash": "gemini-1.5-flash", "claude-3.5-sonnet": "claude-sonnet-3.5", @@ -118,12 +111,11 @@ async def fetch_validated(cls): return cls._last_validated_value - @staticmethod def generate_id(length=7): characters = string.ascii_letters + string.digits return ''.join(random.choice(characters) for _ in range(length)) - + @classmethod def add_prefix_to_messages(cls, messages: Messages, model: str) -> Messages: prefix = cls.model_prefixes.get(model, "") @@ -139,15 +131,6 @@ def add_prefix_to_messages(cls, messages: Messages, model: str) -> Messages: return new_messages - @classmethod - def get_model(cls, model: str) -> str: - if model in cls.models: - return model - elif model in cls.model_aliases: - return cls.model_aliases[model] - else: - return cls.default_model - @classmethod async def create_async_generator( cls, @@ -209,7 +192,7 @@ async def create_async_generator( "clickedForceWebSearch": False, "visitFromDelta": False, "mobileClient": False, - "userSelectedModel": model if model in cls.userSelectedModel else None, + "userSelectedModel": model if model in cls.text_models else None, "webSearchMode": web_search, "validated": validated_value, }