Skip to content

Commit

Permalink
refactor(g4f/Provider/Airforce.py): Enhance text generation with retr…
Browse files Browse the repository at this point in the history
…y and timeout
  • Loading branch information
kqlio67 committed Dec 1, 2024
1 parent 5645b1a commit 07df66a
Showing 1 changed file with 73 additions and 64 deletions.
137 changes: 73 additions & 64 deletions g4f/Provider/Airforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import random
import json
import re
import aiohttp
import asyncio

import requests
from requests.packages.urllib3.exceptions import InsecureRequestWarning
Expand Down Expand Up @@ -35,75 +37,57 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
supports_system_message = True
supports_message_history = True

@classmethod
def fetch_completions_models(cls):
response = requests.get('https://api.airforce/models', verify=False)
response.raise_for_status()
data = response.json()
return [model['id'] for model in data['data']]

@classmethod
def fetch_imagine_models(cls):
response = requests.get('https://api.airforce/imagine/models', verify=False)
response.raise_for_status()
return response.json()

default_model = "gpt-4o-mini"
default_image_model = "flux"
additional_models_imagine = ["stable-diffusion-xl-base", "stable-diffusion-xl-lightning", "flux-1.1-pro"]

@classmethod
def get_models(cls):
if not cls.models:
cls.image_models = [*cls.fetch_imagine_models(), *cls.additional_models_imagine]
cls.models = [
*cls.fetch_completions_models(),
*cls.image_models
]
return cls.models

model_aliases = {
### completions ###
# openchat
"openchat-3.5": "openchat-3.5-0106",

# deepseek-ai
"deepseek-coder": "deepseek-coder-6.7b-instruct",

# NousResearch
"hermes-2-dpo": "Nous-Hermes-2-Mixtral-8x7B-DPO",
"hermes-2-pro": "hermes-2-pro-mistral-7b",

# teknium
"openhermes-2.5": "openhermes-2.5-mistral-7b",

# liquid
"lfm-40b": "lfm-40b-moe",

# DiscoResearch
"german-7b": "discolm-german-7b-v1",

# meta-llama
"llama-2-7b": "llama-2-7b-chat-int8",
"llama-2-7b": "llama-2-7b-chat-fp16",
"llama-3.1-70b": "llama-3.1-70b-chat",
"llama-3.1-8b": "llama-3.1-8b-chat",
"llama-3.1-70b": "llama-3.1-70b-turbo",
"llama-3.1-8b": "llama-3.1-8b-turbo",

# inferless
"neural-7b": "neural-chat-7b-v3-1",

# HuggingFaceH4
"zephyr-7b": "zephyr-7b-beta",


### imagine ###
"sdxl": "stable-diffusion-xl-base",
"sdxl": "stable-diffusion-xl-lightning",
"flux-pro": "flux-1.1-pro",
}

@classmethod
def fetch_completions_models(cls):
response = requests.get('https://api.airforce/models', verify=False)
response.raise_for_status()
data = response.json()
return [model['id'] for model in data['data']]

@classmethod
def fetch_imagine_models(cls):
response = requests.get('https://api.airforce/imagine/models', verify=False)
response.raise_for_status()
return response.json()

@classmethod
def get_models(cls):
if not cls.models:
cls.image_models = [*cls.fetch_imagine_models(), *cls.additional_models_imagine]
cls.models = [
*cls.fetch_completions_models(),
*cls.image_models
]
return cls.models

@classmethod
def create_async_generator(
cls,
Expand All @@ -112,7 +96,7 @@ def create_async_generator(
proxy: str = None,
prompt: str = None,
seed: int = None,
size: str = "1:1", # "1:1", "16:9", "9:16", "21:9", "9:21", "1:2", "2:1"
size: str = "1:1", # "1:1", "16:9", "9:16", "21:9", "9:21", "1:2", "2:1"
stream: bool = False,
**kwargs
) -> AsyncResult:
Expand Down Expand Up @@ -168,9 +152,11 @@ async def _generate_text(
messages: Messages,
proxy: str = None,
stream: bool = False,
timeout: int = 120,
max_tokens: int = 4096,
temperature: float = 1,
top_p: float = 1,
max_retries: int = 3,
**kwargs
) -> AsyncResult:
headers = {
Expand All @@ -187,7 +173,15 @@ async def _generate_text(

message_chunks = split_message(full_message, max_length=1000)

async with StreamSession(headers=headers, proxy=proxy) as session:
async with StreamSession(
headers=headers,
proxy=proxy,
timeout=aiohttp.ClientTimeout(
total=timeout,
connect=30,
sock_connect=30
)
) as session:
full_response = ""
for chunk in message_chunks:
data = {
Expand All @@ -199,26 +193,41 @@ async def _generate_text(
"stream": stream
}

async with session.post(cls.api_endpoint_completions, json=data) as response:
await raise_for_status(response)
content_type = response.headers.get('Content-Type', '').lower()

if 'application/json' in content_type:
json_data = await response.json()
if json_data.get("model") == "error":
raise RuntimeError(json_data['choices'][0]['message'].get('content', ''))
if stream:
async for line in response.iter_lines():
if line:
line = line.decode('utf-8').strip()
if line.startswith("data: ") and line != "data: [DONE]":
json_data = json.loads(line[6:])
content = json_data['choices'][0]['delta'].get('content', '')
if content:
yield cls._filter_content(content)
else:
content = json_data['choices'][0]['message']['content']
full_response += cls._filter_content(content)
for attempt in range(max_retries):
try:
async with session.post(
cls.api_endpoint_completions,
json=data,
timeout=timeout
) as response:
await raise_for_status(response)
content_type = response.headers.get('Content-Type', '').lower()

if 'application/json' in content_type:
json_data = await response.json()
if json_data.get("model") == "error":
raise RuntimeError(json_data['choices'][0]['message'].get('content', ''))

if stream:
async for line in response.iter_lines():
if line:
line = line.decode('utf-8').strip()
if line.startswith("data: ") and line != "data: [DONE]":
json_data = json.loads(line[6:])
content = json_data['choices'][0]['delta'].get('content', '')
if content:
yield cls._filter_content(content)
else:
content = json_data['choices'][0]['message']['content']
full_response += cls._filter_content(content)

break

except (aiohttp.ClientError, asyncio.TimeoutError) as e:
if attempt == max_retries - 1:
raise RuntimeError(f"Request failed after {max_retries} attempts: {str(e)}")

await asyncio.sleep(2 ** attempt)

yield full_response

Expand All @@ -237,7 +246,7 @@ def _filter_content(cls, part_response: str) -> str:
)

part_response = re.sub(
r"\[ERROR\] '\w{8}-\w{4}-\w{4}-\w{4}-\w{12}'", # any-uncensored
r"\[ERROR\] '\w{8}-\w{4}-\w{4}-\w{4}-\w{12}'", # any-uncensored
'',
part_response
)
Expand Down

0 comments on commit 07df66a

Please sign in to comment.