Skip to content

Commit

Permalink
Update Blackbox.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hlohaus authored Nov 18, 2024
1 parent 4526dd4 commit 56beb19
Showing 1 changed file with 7 additions and 24 deletions.
31 changes: 7 additions & 24 deletions g4f/Provider/Blackbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,14 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
supports_system_message = True
supports_message_history = True
_last_validated_value = None

default_model = 'blackboxai'

image_models = ['Image Generation', 'repomap']

userSelectedModel = ['gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro']

default_image_model = 'generate_image'
image_models = [default_image_model, 'repomap']
text_models = [default_model, 'gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro']
agentMode = {
'Image Generation': {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
}

trendingAgentMode = {
"gemini-1.5-flash": {'mode': True, 'id': 'Gemini'},
"llama-3.1-8b": {'mode': True, 'id': "llama-3.1-8b"},
Expand Down Expand Up @@ -72,12 +69,8 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
'Youtube Agent': {'mode': True, 'id': "Youtube Agent"},
'builder Agent': {'mode': True, 'id': "builder Agent"},
}

model_prefixes = {mode: f"@{value['id']}" for mode, value in trendingAgentMode.items() if mode not in ["gemini-1.5-flash", "llama-3.1-8b", "llama-3.1-70b", "llama-3.1-405b", "repomap"]}


models = [default_model, *userSelectedModel, *list(agentMode.keys()), *list(trendingAgentMode.keys())]

models = [*text_models, default_image_model, *list(trendingAgentMode.keys())]
model_aliases = {
"gemini-flash": "gemini-1.5-flash",
"claude-3.5-sonnet": "claude-sonnet-3.5",
Expand Down Expand Up @@ -118,12 +111,11 @@ async def fetch_validated(cls):

return cls._last_validated_value


@staticmethod
def generate_id(length=7):
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))

@classmethod
def add_prefix_to_messages(cls, messages: Messages, model: str) -> Messages:
prefix = cls.model_prefixes.get(model, "")
Expand All @@ -139,15 +131,6 @@ def add_prefix_to_messages(cls, messages: Messages, model: str) -> Messages:

return new_messages

@classmethod
def get_model(cls, model: str) -> str:
if model in cls.models:
return model
elif model in cls.model_aliases:
return cls.model_aliases[model]
else:
return cls.default_model

@classmethod
async def create_async_generator(
cls,
Expand Down Expand Up @@ -209,7 +192,7 @@ async def create_async_generator(
"clickedForceWebSearch": False,
"visitFromDelta": False,
"mobileClient": False,
"userSelectedModel": model if model in cls.userSelectedModel else None,
"userSelectedModel": model if model in cls.text_models else None,
"webSearchMode": web_search,
"validated": validated_value,
}
Expand Down

0 comments on commit 56beb19

Please sign in to comment.