From 07df66a7485b669689c63b66be45de51c28cf69f Mon Sep 17 00:00:00 2001 From: kqlio67 <> Date: Sun, 1 Dec 2024 12:01:42 +0200 Subject: [PATCH] refactor(g4f/Provider/Airforce.py): Enhance text generation with retry and timeout --- g4f/Provider/Airforce.py | 137 +++++++++++++++++++++------------------ 1 file changed, 73 insertions(+), 64 deletions(-) diff --git a/g4f/Provider/Airforce.py b/g4f/Provider/Airforce.py index f65cd953c35..ebe746423ee 100644 --- a/g4f/Provider/Airforce.py +++ b/g4f/Provider/Airforce.py @@ -3,6 +3,8 @@ import random import json import re +import aiohttp +import asyncio import requests from requests.packages.urllib3.exceptions import InsecureRequestWarning @@ -35,75 +37,57 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin): supports_system_message = True supports_message_history = True - @classmethod - def fetch_completions_models(cls): - response = requests.get('https://api.airforce/models', verify=False) - response.raise_for_status() - data = response.json() - return [model['id'] for model in data['data']] - - @classmethod - def fetch_imagine_models(cls): - response = requests.get('https://api.airforce/imagine/models', verify=False) - response.raise_for_status() - return response.json() - default_model = "gpt-4o-mini" default_image_model = "flux" additional_models_imagine = ["stable-diffusion-xl-base", "stable-diffusion-xl-lightning", "flux-1.1-pro"] - @classmethod - def get_models(cls): - if not cls.models: - cls.image_models = [*cls.fetch_imagine_models(), *cls.additional_models_imagine] - cls.models = [ - *cls.fetch_completions_models(), - *cls.image_models - ] - return cls.models - model_aliases = { ### completions ### - # openchat "openchat-3.5": "openchat-3.5-0106", - - # deepseek-ai "deepseek-coder": "deepseek-coder-6.7b-instruct", - - # NousResearch "hermes-2-dpo": "Nous-Hermes-2-Mixtral-8x7B-DPO", "hermes-2-pro": "hermes-2-pro-mistral-7b", - - # teknium "openhermes-2.5": "openhermes-2.5-mistral-7b", - - # liquid "lfm-40b": "lfm-40b-moe", - - # DiscoResearch "german-7b": "discolm-german-7b-v1", - - # meta-llama "llama-2-7b": "llama-2-7b-chat-int8", "llama-2-7b": "llama-2-7b-chat-fp16", "llama-3.1-70b": "llama-3.1-70b-chat", "llama-3.1-8b": "llama-3.1-8b-chat", "llama-3.1-70b": "llama-3.1-70b-turbo", "llama-3.1-8b": "llama-3.1-8b-turbo", - - # inferless "neural-7b": "neural-chat-7b-v3-1", - - # HuggingFaceH4 "zephyr-7b": "zephyr-7b-beta", - ### imagine ### "sdxl": "stable-diffusion-xl-base", "sdxl": "stable-diffusion-xl-lightning", "flux-pro": "flux-1.1-pro", } + @classmethod + def fetch_completions_models(cls): + response = requests.get('https://api.airforce/models', verify=False) + response.raise_for_status() + data = response.json() + return [model['id'] for model in data['data']] + + @classmethod + def fetch_imagine_models(cls): + response = requests.get('https://api.airforce/imagine/models', verify=False) + response.raise_for_status() + return response.json() + + @classmethod + def get_models(cls): + if not cls.models: + cls.image_models = [*cls.fetch_imagine_models(), *cls.additional_models_imagine] + cls.models = [ + *cls.fetch_completions_models(), + *cls.image_models + ] + return cls.models + @classmethod def create_async_generator( cls, @@ -112,7 +96,7 @@ def create_async_generator( proxy: str = None, prompt: str = None, seed: int = None, - size: str = "1:1", # "1:1", "16:9", "9:16", "21:9", "9:21", "1:2", "2:1" + size: str = "1:1", # "1:1", "16:9", "9:16", "21:9", "9:21", "1:2", "2:1" stream: bool = False, **kwargs ) -> AsyncResult: @@ -168,9 +152,11 @@ async def _generate_text( messages: Messages, proxy: str = None, stream: bool = False, + timeout: int = 120, max_tokens: int = 4096, temperature: float = 1, top_p: float = 1, + max_retries: int = 3, **kwargs ) -> AsyncResult: headers = { @@ -187,7 +173,15 @@ async def _generate_text( message_chunks = split_message(full_message, max_length=1000) - async with StreamSession(headers=headers, proxy=proxy) as session: + async with StreamSession( + headers=headers, + proxy=proxy, + timeout=aiohttp.ClientTimeout( + total=timeout, + connect=30, + sock_connect=30 + ) + ) as session: full_response = "" for chunk in message_chunks: data = { @@ -199,26 +193,41 @@ async def _generate_text( "stream": stream } - async with session.post(cls.api_endpoint_completions, json=data) as response: - await raise_for_status(response) - content_type = response.headers.get('Content-Type', '').lower() - - if 'application/json' in content_type: - json_data = await response.json() - if json_data.get("model") == "error": - raise RuntimeError(json_data['choices'][0]['message'].get('content', '')) - if stream: - async for line in response.iter_lines(): - if line: - line = line.decode('utf-8').strip() - if line.startswith("data: ") and line != "data: [DONE]": - json_data = json.loads(line[6:]) - content = json_data['choices'][0]['delta'].get('content', '') - if content: - yield cls._filter_content(content) - else: - content = json_data['choices'][0]['message']['content'] - full_response += cls._filter_content(content) + for attempt in range(max_retries): + try: + async with session.post( + cls.api_endpoint_completions, + json=data, + timeout=timeout + ) as response: + await raise_for_status(response) + content_type = response.headers.get('Content-Type', '').lower() + + if 'application/json' in content_type: + json_data = await response.json() + if json_data.get("model") == "error": + raise RuntimeError(json_data['choices'][0]['message'].get('content', '')) + + if stream: + async for line in response.iter_lines(): + if line: + line = line.decode('utf-8').strip() + if line.startswith("data: ") and line != "data: [DONE]": + json_data = json.loads(line[6:]) + content = json_data['choices'][0]['delta'].get('content', '') + if content: + yield cls._filter_content(content) + else: + content = json_data['choices'][0]['message']['content'] + full_response += cls._filter_content(content) + + break + + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + if attempt == max_retries - 1: + raise RuntimeError(f"Request failed after {max_retries} attempts: {str(e)}") + + await asyncio.sleep(2 ** attempt) yield full_response @@ -237,7 +246,7 @@ def _filter_content(cls, part_response: str) -> str: ) part_response = re.sub( - r"\[ERROR\] '\w{8}-\w{4}-\w{4}-\w{4}-\w{12}'", # any-uncensored + r"\[ERROR\] '\w{8}-\w{4}-\w{4}-\w{4}-\w{12}'", # any-uncensored '', part_response )