Skip to content

Commit

Permalink
Merge branch 'xtekky:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
kqlio67 authored Nov 26, 2024
2 parents 3811099 + f0308ab commit f150db2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 37 deletions.
22 changes: 11 additions & 11 deletions g4f/Provider/Airforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import requests
from requests.packages.urllib3.exceptions import InsecureRequestWarning
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
from urllib.parse import quote

from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
Expand Down Expand Up @@ -36,7 +37,7 @@ def fetch_imagine_models(cls):

default_model = "gpt-4o-mini"
default_image_model = "flux"
additional_models_imagine = ["stable-diffusion-xl-base", "stable-diffusion-xl-lightning", "Flux-1.1-Pro"]
additional_models_imagine = ["stable-diffusion-xl-base", "stable-diffusion-xl-lightning", "flux-1.1-pro"]

@classmethod
def get_models(cls):
Expand Down Expand Up @@ -86,7 +87,7 @@ def get_models(cls):
### imagine ###
"sdxl": "stable-diffusion-xl-base",
"sdxl": "stable-diffusion-xl-lightning",
"flux-pro": "Flux-1.1-Pro",
"flux-pro": "flux-1.1-pro",
}

@classmethod
Expand All @@ -95,22 +96,26 @@ def create_async_generator(
model: str,
messages: Messages,
proxy: str = None,
prompt: str = None,
seed: int = None,
size: str = "1:1", # "1:1", "16:9", "9:16", "21:9", "9:21", "1:2", "2:1"
stream: bool = False,
**kwargs
) -> AsyncResult:
model = cls.get_model(model)

if model in cls.image_models:
return cls._generate_image(model, messages, proxy, seed, size)
if prompt is None:
prompt = messages[-1]['content']
return cls._generate_image(model, prompt, proxy, seed, size)
else:
return cls._generate_text(model, messages, proxy, stream, **kwargs)

@classmethod
async def _generate_image(
cls,
model: str,
messages: Messages,
prompt: str,
proxy: str = None,
seed: int = None,
size: str = "1:1",
Expand All @@ -125,7 +130,6 @@ async def _generate_image(
}
if seed is None:
seed = random.randint(0, 100000)
prompt = messages[-1]['content']

async with StreamSession(headers=headers, proxy=proxy) as session:
params = {
Expand All @@ -140,12 +144,8 @@ async def _generate_image(

if 'application/json' in content_type:
raise RuntimeError(await response.json().get("error", {}).get("message"))
elif 'image' in content_type:
image_data = b""
async for chunk in response.iter_content():
if chunk:
image_data += chunk
image_url = f"{cls.api_endpoint_imagine}?model={model}&prompt={prompt}&size={size}&seed={seed}"
elif content_type.startswith("image/"):
image_url = f"{cls.api_endpoint_imagine}?model={model}&prompt={quote(prompt)}&size={size}&seed={seed}"
yield ImageResponse(images=image_url, alt=prompt)

@classmethod
Expand Down
6 changes: 5 additions & 1 deletion g4f/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class Annotated:

logger = logging.getLogger(__name__)

DEFAULT_PORT = 1337

def create_app(g4f_api_key: str = None):
app = FastAPI()

Expand Down Expand Up @@ -493,7 +495,7 @@ def format_exception(e: Union[Exception, str], config: Union[ChatCompletionsConf

def run_api(
host: str = '0.0.0.0',
port: int = 1337,
port: int = None,
bind: str = None,
debug: bool = False,
workers: int = None,
Expand All @@ -505,6 +507,8 @@ def run_api(
use_colors = debug
if bind is not None:
host, port = bind.split(":")
if port is None:
port = DEFAULT_PORT
uvicorn.run(
f"g4f.api:create_app{'_debug' if debug else ''}",
host=host,
Expand Down
55 changes: 30 additions & 25 deletions g4f/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..errors import NoImageResponseError, ModelNotFoundError
from ..providers.retry_provider import IterListProvider
from ..providers.asyncio import get_running_loop, to_sync_generator, async_generator_to_list
from ..Provider.needs_auth.BingCreateImages import BingCreateImages
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .image_models import ImageModels
from .types import IterResponse, ImageProvider, Client as BaseClient
Expand Down Expand Up @@ -264,28 +264,34 @@ def generate(
"""
return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs))

async def async_generate(
self,
prompt: str,
model: Optional[str] = None,
provider: Optional[ProviderType] = None,
response_format: Optional[str] = "url",
proxy: Optional[str] = None,
**kwargs
) -> ImagesResponse:
async def get_provider_handler(self, model: Optional[str], provider: Optional[ImageProvider], default: ImageProvider) -> ImageProvider:
if provider is None:
provider_handler = self.models.get(model, provider or self.provider or BingCreateImages)
provider_handler = self.provider
if provider_handler is None:
provider_handler = self.models.get(model, default)
elif isinstance(provider, str):
provider_handler = convert_to_provider(provider)
else:
provider_handler = provider
if provider_handler is None:
raise ModelNotFoundError(f"Unknown model: {model}")
return default
if isinstance(provider_handler, IterListProvider):
if provider_handler.providers:
provider_handler = provider_handler.providers[0]
else:
raise ModelNotFoundError(f"IterListProvider for model {model} has no providers")
return provider_handler

async def async_generate(
self,
prompt: str,
model: Optional[str] = None,
provider: Optional[ProviderType] = None,
response_format: Optional[str] = "url",
proxy: Optional[str] = None,
**kwargs
) -> ImagesResponse:
provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
if proxy is None:
proxy = self.client.proxy

Expand All @@ -311,7 +317,7 @@ async def async_generate(
response = item
break
else:
raise ValueError(f"Provider {provider} does not support image generation")
raise ValueError(f"Provider {getattr(provider_handler, '__name__')} does not support image generation")
if isinstance(response, ImageResponse):
return await self._process_image_response(
response,
Expand All @@ -320,6 +326,8 @@ async def async_generate(
model,
getattr(provider_handler, "__name__", None)
)
if response is None:
raise NoImageResponseError(f"No image response from {getattr(provider_handler, '__name__')}")
raise NoImageResponseError(f"Unexpected response type: {type(response)}")

def create_variation(
Expand All @@ -343,38 +351,35 @@ async def async_create_variation(
proxy: Optional[str] = None,
**kwargs
) -> ImagesResponse:
if provider is None:
provider = self.models.get(model, provider or self.provider or BingCreateImages)
if provider is None:
raise ModelNotFoundError(f"Unknown model: {model}")
if isinstance(provider, str):
provider = convert_to_provider(provider)
provider_handler = await self.get_provider_handler(model, provider, OpenaiAccount)
if proxy is None:
proxy = self.client.proxy

if hasattr(provider, "create_async_generator"):
if hasattr(provider_handler, "create_async_generator"):
messages = [{"role": "user", "content": "create a variation of this image"}]
generator = None
try:
generator = provider.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs)
generator = provider_handler.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs)
async for chunk in generator:
if isinstance(chunk, ImageResponse):
response = chunk
break
finally:
await safe_aclose(generator)
elif hasattr(provider, 'create_variation'):
if asyncio.iscoroutinefunction(provider.create_variation):
response = await provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
elif hasattr(provider_handler, 'create_variation'):
if asyncio.iscoroutinefunction(provider.provider_handler):
response = await provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
else:
response = provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
response = provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
else:
raise NoImageResponseError(f"Provider {provider} does not support image variation")

if isinstance(response, str):
response = ImageResponse([response])
if isinstance(response, ImageResponse):
return self._process_image_response(response, response_format, proxy, model, getattr(provider, "__name__", None))
if response is None:
raise NoImageResponseError(f"No image response from {getattr(provider, '__name__')}")
raise NoImageResponseError(f"Unexpected response type: {type(response)}")

async def _process_image_response(
Expand Down

0 comments on commit f150db2

Please sign in to comment.