From 79c407b9397c9e10807dcda4d9df166609284b64 Mon Sep 17 00:00:00 2001 From: H Lohaus Date: Fri, 29 Nov 2024 13:56:11 +0100 Subject: [PATCH] IterListProvider support for generating images (#2441) * IterListProvider support for generating images * Add missing get_har_files import in Copilot * Fix typo in dall-e-3 model name * Add image client unittests * Add MicrosoftDesigner provider * Import MicrosoftDesigner and add it to the model list --- etc/unittest/__main__.py | 1 + etc/unittest/image_client.py | 44 +++++ etc/unittest/mocks.py | 21 +++ g4f/Provider/AmigoChat.py | 10 +- g4f/Provider/Copilot.py | 79 ++++----- g4f/Provider/needs_auth/MicrosoftDesigner.py | 167 +++++++++++++++++++ g4f/Provider/needs_auth/OpenaiChat.py | 4 +- g4f/Provider/needs_auth/__init__.py | 50 +++--- g4f/Provider/openai/har_file.py | 4 +- g4f/Provider/you/har_file.py | 5 +- g4f/client/__init__.py | 120 +++++++------ g4f/errors.py | 3 + g4f/models.py | 11 +- g4f/providers/asyncio.py | 4 +- g4f/providers/retry_provider.py | 2 +- g4f/typing.py | 3 +- 16 files changed, 392 insertions(+), 136 deletions(-) create mode 100644 etc/unittest/image_client.py create mode 100644 g4f/Provider/needs_auth/MicrosoftDesigner.py diff --git a/etc/unittest/__main__.py b/etc/unittest/__main__.py index e49dec30538..3719c374dbb 100644 --- a/etc/unittest/__main__.py +++ b/etc/unittest/__main__.py @@ -5,6 +5,7 @@ from .main import * from .model import * from .client import * +from .image_client import * from .include import * from .retry_provider import * diff --git a/etc/unittest/image_client.py b/etc/unittest/image_client.py new file mode 100644 index 00000000000..b52ba8b004d --- /dev/null +++ b/etc/unittest/image_client.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import asyncio +import unittest + +from g4f.client import AsyncClient, ImagesResponse +from g4f.providers.retry_provider import IterListProvider +from .mocks import ( + YieldImageResponseProviderMock, + MissingAuthProviderMock, + AsyncRaiseExceptionProviderMock, + YieldNoneProviderMock +) + +DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}] + +class TestIterListProvider(unittest.IsolatedAsyncioTestCase): + + async def test_skip_provider(self): + client = AsyncClient(image_provider=IterListProvider([MissingAuthProviderMock, YieldImageResponseProviderMock], False)) + response = await client.images.generate("Hello", "", response_format="orginal") + self.assertIsInstance(response, ImagesResponse) + self.assertEqual("Hello", response.data[0].url) + + async def test_only_one_result(self): + client = AsyncClient(image_provider=IterListProvider([YieldImageResponseProviderMock, YieldImageResponseProviderMock], False)) + response = await client.images.generate("Hello", "", response_format="orginal") + self.assertIsInstance(response, ImagesResponse) + self.assertEqual("Hello", response.data[0].url) + + async def test_skip_none(self): + client = AsyncClient(image_provider=IterListProvider([YieldNoneProviderMock, YieldImageResponseProviderMock], False)) + response = await client.images.generate("Hello", "", response_format="orginal") + self.assertIsInstance(response, ImagesResponse) + self.assertEqual("Hello", response.data[0].url) + + def test_raise_exception(self): + async def run_exception(): + client = AsyncClient(image_provider=IterListProvider([YieldNoneProviderMock, AsyncRaiseExceptionProviderMock], False)) + await client.images.generate("Hello", "") + self.assertRaises(RuntimeError, asyncio.run, run_exception()) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/etc/unittest/mocks.py b/etc/unittest/mocks.py index c2058e34e6b..c43d98ccaf9 100644 --- a/etc/unittest/mocks.py +++ b/etc/unittest/mocks.py @@ -1,4 +1,6 @@ from g4f.providers.base_provider import AbstractProvider, AsyncProvider, AsyncGeneratorProvider +from g4f.image import ImageResponse +from g4f.errors import MissingAuthError class ProviderMock(AbstractProvider): working = True @@ -41,6 +43,25 @@ async def create_async_generator( for message in messages: yield message["content"] +class YieldImageResponseProviderMock(AsyncGeneratorProvider): + working = True + + @classmethod + async def create_async_generator( + cls, model, messages, stream, prompt: str, **kwargs + ): + yield ImageResponse(prompt, "") + +class MissingAuthProviderMock(AbstractProvider): + working = True + + @classmethod + def create_completion( + cls, model, messages, stream, **kwargs + ): + raise MissingAuthError(cls.__name__) + yield cls.__name__ + class RaiseExceptionProviderMock(AbstractProvider): working = True diff --git a/g4f/Provider/AmigoChat.py b/g4f/Provider/AmigoChat.py index 0acb58543a8..f79dc15b8aa 100644 --- a/g4f/Provider/AmigoChat.py +++ b/g4f/Provider/AmigoChat.py @@ -65,9 +65,9 @@ 'flux-pro/v1.1-ultra': {'persona_id': "flux-pro-v1.1-ultra"}, # Amigo, your balance is not enough to make the request, wait until 12 UTC or upgrade your plan 'flux-pro/v1.1-ultra-raw': {'persona_id': "flux-pro-v1.1-ultra-raw"}, # Amigo, your balance is not enough to make the request, wait until 12 UTC or upgrade your plan 'flux/dev': {'persona_id': "flux-dev"}, - - 'dalle-e-3': {'persona_id': "dalle-three"}, - + + 'dall-e-3': {'persona_id': "dalle-three"}, + 'recraft-v3': {'persona_id': "recraft"} } } @@ -129,8 +129,8 @@ class AmigoChat(AsyncGeneratorProvider, ProviderModelMixin): ### image ### "flux-realism": "flux-realism", "flux-dev": "flux/dev", - - "dalle-3": "dalle-e-3", + + "dalle-3": "dall-e-3", } @classmethod diff --git a/g4f/Provider/Copilot.py b/g4f/Provider/Copilot.py index a64d52aa293..ee9daf33b49 100644 --- a/g4f/Provider/Copilot.py +++ b/g4f/Provider/Copilot.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import json import asyncio from http.cookiejar import CookieJar @@ -20,10 +19,10 @@ from .base_provider import AbstractProvider, ProviderModelMixin, BaseConversation from .helper import format_prompt from ..typing import CreateResult, Messages, ImageType -from ..errors import MissingRequirementsError +from ..errors import MissingRequirementsError, NoValidHarFileError from ..requests.raise_for_status import raise_for_status from ..providers.asyncio import get_running_loop -from .openai.har_file import NoValidHarFileError, get_headers +from .openai.har_file import get_headers, get_har_files from ..requests import get_nodriver from ..image import ImageResponse, to_bytes, is_accepted_format from .. import debug @@ -76,12 +75,12 @@ def create_completion( if cls.needs_auth or image is not None: if conversation is None or conversation.access_token is None: try: - access_token, cookies = readHAR() + access_token, cookies = readHAR(cls.url) except NoValidHarFileError as h: debug.log(f"Copilot: {h}") try: get_running_loop(check_nested=True) - access_token, cookies = asyncio.run(cls.get_access_token_and_cookies(proxy)) + access_token, cookies = asyncio.run(get_access_token_and_cookies(cls.url, proxy)) except MissingRequirementsError: raise h else: @@ -162,35 +161,34 @@ def create_completion( if not is_started: raise RuntimeError(f"Invalid response: {last_msg}") - @classmethod - async def get_access_token_and_cookies(cls, proxy: str = None): - browser = await get_nodriver(proxy=proxy) - page = await browser.get(cls.url) - access_token = None - while access_token is None: - access_token = await page.evaluate(""" - (() => { - for (var i = 0; i < localStorage.length; i++) { - try { - item = JSON.parse(localStorage.getItem(localStorage.key(i))); - if (item.credentialType == "AccessToken" - && item.expiresOn > Math.floor(Date.now() / 1000) - && item.target.includes("ChatAI")) { - return item.secret; - } - } catch(e) {} - } - })() - """) - if access_token is None: - await asyncio.sleep(1) - cookies = {} - for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])): - cookies[c.name] = c.value - await page.close() - return access_token, cookies - -def readHAR(): +async def get_access_token_and_cookies(url: str, proxy: str = None, target: str = "ChatAI",): + browser = await get_nodriver(proxy=proxy) + page = await browser.get(url) + access_token = None + while access_token is None: + access_token = await page.evaluate(""" + (() => { + for (var i = 0; i < localStorage.length; i++) { + try { + item = JSON.parse(localStorage.getItem(localStorage.key(i))); + if (item.credentialType == "AccessToken" + && item.expiresOn > Math.floor(Date.now() / 1000) + && item.target.includes("target")) { + return item.secret; + } + } catch(e) {} + } + })() + """.replace('"target"', json.dumps(target))) + if access_token is None: + await asyncio.sleep(1) + cookies = {} + for c in await page.send(nodriver.cdp.network.get_cookies([url])): + cookies[c.name] = c.value + await page.close() + return access_token, cookies + +def readHAR(url: str): api_key = None cookies = None for path in get_har_files(): @@ -201,16 +199,13 @@ def readHAR(): # Error: not a HAR file! continue for v in harFile['log']['entries']: - v_headers = get_headers(v) - if v['request']['url'].startswith(Copilot.url): - try: - if "authorization" in v_headers: - api_key = v_headers["authorization"].split(maxsplit=1).pop() - except Exception as e: - debug.log(f"Error on read headers: {e}") + if v['request']['url'].startswith(url): + v_headers = get_headers(v) + if "authorization" in v_headers: + api_key = v_headers["authorization"].split(maxsplit=1).pop() if v['request']['cookies']: cookies = {c['name']: c['value'] for c in v['request']['cookies']} if api_key is None: raise NoValidHarFileError("No access token found in .har files") - return api_key, cookies + return api_key, cookies \ No newline at end of file diff --git a/g4f/Provider/needs_auth/MicrosoftDesigner.py b/g4f/Provider/needs_auth/MicrosoftDesigner.py new file mode 100644 index 00000000000..715f11acc6f --- /dev/null +++ b/g4f/Provider/needs_auth/MicrosoftDesigner.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import uuid +import aiohttp +import random +import asyncio +import json + +from ...image import ImageResponse +from ...errors import MissingRequirementsError, NoValidHarFileError +from ...typing import AsyncResult, Messages +from ...requests.raise_for_status import raise_for_status +from ...requests.aiohttp import get_connector +from ...requests import get_nodriver +from ..Copilot import get_headers, get_har_files +from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin +from ..helper import get_random_hex +from ... import debug + +class MicrosoftDesigner(AsyncGeneratorProvider, ProviderModelMixin): + label = "Microsoft Designer" + url = "https://designer.microsoft.com" + working = True + needs_auth = True + default_image_model = "dall-e-3" + image_models = [default_image_model, "1024x1024", "1024x1792", "1792x1024"] + models = image_models + + @classmethod + async def create_async_generator( + cls, + model: str, + messages: Messages, + prompt: str = None, + proxy: str = None, + **kwargs + ) -> AsyncResult: + image_size = "1024x1024" + if model != cls.default_image_model and model in cls.image_models: + image_size = model + yield await cls.generate(messages[-1]["content"] if prompt is None else prompt, image_size, proxy) + + @classmethod + async def generate(cls, prompt: str, image_size: str, proxy: str = None) -> ImageResponse: + try: + access_token, user_agent = readHAR("https://designerapp.officeapps.live.com") + except NoValidHarFileError as h: + debug.log(f"{cls.__name__}: {h}") + try: + access_token, user_agent = await get_access_token_and_user_agent(cls.url, proxy) + except MissingRequirementsError: + raise h + images = await create_images(prompt, access_token, user_agent, image_size, proxy) + return ImageResponse(images, prompt) + +async def create_images(prompt: str, access_token: str, user_agent: str, image_size: str, proxy: str = None, seed: int = None): + url = 'https://designerapp.officeapps.live.com/designerapp/DallE.ashx?action=GetDallEImagesCogSci' + if seed is None: + seed = random.randint(0, 10000) + + headers = { + "User-Agent": user_agent, + "Accept": "application/json, text/plain, */*", + "Accept-Language": "en-US", + 'Authorization': f'Bearer {access_token}', + "AudienceGroup": "Production", + "Caller": "DesignerApp", + "ClientId": "b5c2664a-7e9b-4a7a-8c9a-cd2c52dcf621", + "SessionId": str(uuid.uuid4()), + "UserId": get_random_hex(16), + "ContainerId": "1e2843a7-2a98-4a6c-93f2-42002de5c478", + "FileToken": "9f1a4cb7-37e7-4c90-b44d-cb61cfda4bb8", + "x-upload-to-storage-das": "1", + "traceparent": "", + "X-DC-Hint": "FranceCentral", + "Platform": "Web", + "HostApp": "DesignerApp", + "ReleaseChannel": "", + "IsSignedInUser": "true", + "Locale": "de-DE", + "UserType": "MSA", + "x-req-start": "2615401", + "ClientBuild": "1.0.20241120.9", + "ClientName": "DesignerApp", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "cross-site", + "Pragma": "no-cache", + "Cache-Control": "no-cache", + "Referer": "https://designer.microsoft.com/" + } + + form_data = aiohttp.FormData() + form_data.add_field('dalle-caption', prompt) + form_data.add_field('dalle-scenario-name', 'TextToImage') + form_data.add_field('dalle-batch-size', '4') + form_data.add_field('dalle-image-response-format', 'UrlWithBase64Thumbnail') + form_data.add_field('dalle-seed', seed) + form_data.add_field('ClientFlights', 'EnableBICForDALLEFlight') + form_data.add_field('dalle-hear-back-in-ms', 1000) + form_data.add_field('dalle-include-b64-thumbnails', 'true') + form_data.add_field('dalle-aspect-ratio-scaling-factor-b64-thumbnails', 0.3) + form_data.add_field('dalle-image-size', image_size) + + async with aiohttp.ClientSession(connector=get_connector(proxy=proxy)) as session: + async with session.post(url, headers=headers, data=form_data) as response: + await raise_for_status(response) + response_data = await response.json() + form_data.add_field('dalle-boost-count', response_data.get('dalle-boost-count', 0)) + polling_meta_data = response_data.get('polling_response', {}).get('polling_meta_data', {}) + form_data.add_field('dalle-poll-url', polling_meta_data.get('poll_url', '')) + + while True: + await asyncio.sleep(polling_meta_data.get('poll_interval', 1000) / 1000) + async with session.post(url, headers=headers, data=form_data) as response: + await raise_for_status(response) + response_data = await response.json() + images = [image["ImageUrl"] for image in response_data.get('image_urls_thumbnail', [])] + if images: + return images + +def readHAR(url: str) -> tuple[str, str]: + api_key = None + user_agent = None + for path in get_har_files(): + with open(path, 'rb') as file: + try: + harFile = json.loads(file.read()) + except json.JSONDecodeError: + # Error: not a HAR file! + continue + for v in harFile['log']['entries']: + if v['request']['url'].startswith(url): + v_headers = get_headers(v) + if "authorization" in v_headers: + api_key = v_headers["authorization"].split(maxsplit=1).pop() + if "user-agent" in v_headers: + user_agent = v_headers["user-agent"] + if api_key is None: + raise NoValidHarFileError("No access token found in .har files") + + return api_key, user_agent + +async def get_access_token_and_user_agent(url: str, proxy: str = None): + browser = await get_nodriver(proxy=proxy) + page = await browser.get(url) + user_agent = await page.evaluate("navigator.userAgent") + access_token = None + while access_token is None: + access_token = await page.evaluate(""" + (() => { + for (var i = 0; i < localStorage.length; i++) { + try { + item = JSON.parse(localStorage.getItem(localStorage.key(i))); + if (item.credentialType == "AccessToken" + && item.expiresOn > Math.floor(Date.now() / 1000) + && item.target.includes("designerappservice")) { + return item.secret; + } + } catch(e) {} + } + })() + """) + if access_token is None: + await asyncio.sleep(1) + await page.close() + return access_token, user_agent \ No newline at end of file diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 929e86647e1..9378a8c74ff 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -22,10 +22,10 @@ from ...requests import StreamSession from ...requests import get_nodriver from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format -from ...errors import MissingAuthError +from ...errors import MissingAuthError, NoValidHarFileError from ...providers.response import BaseConversation, FinishReason, SynthesizeData from ..helper import format_cookies -from ..openai.har_file import get_request_config, NoValidHarFileError +from ..openai.har_file import get_request_config from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url from ..openai.proofofwork import generate_proof_token from ..openai.new import get_requirements_token diff --git a/g4f/Provider/needs_auth/__init__.py b/g4f/Provider/needs_auth/__init__.py index 7e327471bdb..dcfcdd8ab04 100644 --- a/g4f/Provider/needs_auth/__init__.py +++ b/g4f/Provider/needs_auth/__init__.py @@ -1,25 +1,27 @@ -from .gigachat import * +from .gigachat import * -from .BingCreateImages import BingCreateImages -from .Cerebras import Cerebras -from .CopilotAccount import CopilotAccount -from .DeepInfra import DeepInfra -from .DeepInfraImage import DeepInfraImage -from .Gemini import Gemini -from .GeminiPro import GeminiPro -from .GithubCopilot import GithubCopilot -from .Groq import Groq -from .HuggingFace import HuggingFace -from .HuggingFace2 import HuggingFace2 -from .MetaAI import MetaAI -from .MetaAIAccount import MetaAIAccount -from .OpenaiAPI import OpenaiAPI -from .OpenaiChat import OpenaiChat -from .PerplexityApi import PerplexityApi -from .Poe import Poe -from .PollinationsAI import PollinationsAI -from .Raycast import Raycast -from .Replicate import Replicate -from .Theb import Theb -from .ThebApi import ThebApi -from .WhiteRabbitNeo import WhiteRabbitNeo +from .BingCreateImages import BingCreateImages +from .Cerebras import Cerebras +from .CopilotAccount import CopilotAccount +from .DeepInfra import DeepInfra +from .DeepInfraImage import DeepInfraImage +from .Gemini import Gemini +from .GeminiPro import GeminiPro +from .GithubCopilot import GithubCopilot +from .Groq import Groq +from .HuggingFace import HuggingFace +from .HuggingFace2 import HuggingFace2 +from .MetaAI import MetaAI +from .MetaAIAccount import MetaAIAccount +from .MicrosoftDesigner import MicrosoftDesigner +from .OpenaiAccount import OpenaiAccount +from .OpenaiAPI import OpenaiAPI +from .OpenaiChat import OpenaiChat +from .PerplexityApi import PerplexityApi +from .Poe import Poe +from .PollinationsAI import PollinationsAI +from .Raycast import Raycast +from .Replicate import Replicate +from .Theb import Theb +from .ThebApi import ThebApi +from .WhiteRabbitNeo import WhiteRabbitNeo \ No newline at end of file diff --git a/g4f/Provider/openai/har_file.py b/g4f/Provider/openai/har_file.py index 819952cda14..9af81332fd1 100644 --- a/g4f/Provider/openai/har_file.py +++ b/g4f/Provider/openai/har_file.py @@ -13,6 +13,7 @@ from .crypt import decrypt, encrypt from ...requests import StreamSession from ...cookies import get_cookies_dir +from ...errors import NoValidHarFileError from ... import debug arkose_url = "https://tcr9i.chat.openai.com/fc/gt2/public_key/35536E1E-65B4-4D96-9D97-6ADB7EFF8147" @@ -21,9 +22,6 @@ start_url = "https://chatgpt.com/" conversation_url = "https://chatgpt.com/c/" -class NoValidHarFileError(Exception): - pass - class RequestConfig: cookies: dict = None headers: dict = None diff --git a/g4f/Provider/you/har_file.py b/g4f/Provider/you/har_file.py index 40bf388267d..5ed0abd6b1f 100644 --- a/g4f/Provider/you/har_file.py +++ b/g4f/Provider/you/har_file.py @@ -8,14 +8,11 @@ from ...requests import StreamSession, raise_for_status from ...cookies import get_cookies_dir -from ...errors import MissingRequirementsError +from ...errors import MissingRequirementsError, NoValidHarFileError from ... import debug logger = logging.getLogger(__name__) -class NoValidHarFileError(Exception): - ... - class arkReq: def __init__(self, arkURL, arkHeaders, arkBody, arkCookies, userAgent): self.arkURL = arkURL diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index ea47ec73da6..55068317636 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -12,15 +12,17 @@ from ..typing import Messages, ImageType from ..providers.types import ProviderType from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData -from ..errors import NoImageResponseError, ModelNotFoundError +from ..errors import NoImageResponseError, MissingAuthError, NoValidHarFileError from ..providers.retry_provider import IterListProvider -from ..providers.asyncio import get_running_loop, to_sync_generator, async_generator_to_list +from ..providers.asyncio import to_sync_generator, async_generator_to_list from ..Provider.needs_auth import BingCreateImages, OpenaiAccount +from ..image import to_bytes from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse from .image_models import ImageModels from .types import IterResponse, ImageProvider, Client as BaseClient from .service import get_model_and_provider, get_last_provider, convert_to_provider from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator +from .. import debug ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]] AsyncChatCompletionResponseType = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]] @@ -274,11 +276,6 @@ async def get_provider_handler(self, model: Optional[str], provider: Optional[Im provider_handler = provider if provider_handler is None: return default - if isinstance(provider_handler, IterListProvider): - if provider_handler.providers: - provider_handler = provider_handler.providers[0] - else: - raise ModelNotFoundError(f"IterListProvider for model {model} has no providers") return provider_handler async def async_generate( @@ -291,33 +288,23 @@ async def async_generate( **kwargs ) -> ImagesResponse: provider_handler = await self.get_provider_handler(model, provider, BingCreateImages) - provider_name = provider.__name__ if hasattr(provider, "__name__") else type(provider).__name__ + provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__ if proxy is None: proxy = self.client.proxy response = None - if hasattr(provider_handler, "create_async_generator"): - messages = [{"role": "user", "content": f"Generate a image: {prompt}"}] - async for item in provider_handler.create_async_generator(model, messages, prompt=prompt, **kwargs): - if isinstance(item, ImageResponse): - response = item - break - elif hasattr(provider_handler, 'create'): - if asyncio.iscoroutinefunction(provider_handler.create): - response = await provider_handler.create(prompt) - else: - response = provider_handler.create(prompt) - if isinstance(response, str): - response = ImageResponse([response], prompt) - elif hasattr(provider_handler, "create_completion"): - get_running_loop(check_nested=True) - messages = [{"role": "user", "content": f"Generate a image: {prompt}"}] - for item in provider_handler.create_completion(model, messages, prompt=prompt, **kwargs): - if isinstance(item, ImageResponse): - response = item - break + if isinstance(provider_handler, IterListProvider): + for provider in provider_handler.providers: + try: + response = await self._generate_image_response(provider, provider.__name__, model, prompt, **kwargs) + if response is not None: + provider_name = provider.__name__ + break + except (MissingAuthError, NoValidHarFileError) as e: + debug.log(f"Image provider {provider.__name__}: {e}") else: - raise ValueError(f"Provider {provider_name} does not support image generation") + response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs) + if isinstance(response, ImageResponse): return await self._process_image_response( response, @@ -330,6 +317,46 @@ async def async_generate( raise NoImageResponseError(f"No image response from {provider_name}") raise NoImageResponseError(f"Unexpected response type: {type(response)}") + async def _generate_image_response( + self, + provider_handler, + provider_name, + model: str, + prompt: str, + prompt_prefix: str = "Generate a image: ", + image: ImageType = None, + **kwargs + ) -> ImageResponse: + messages = [{"role": "user", "content": f"{prompt_prefix}{prompt}"}] + response = None + if hasattr(provider_handler, "create_async_generator"): + async for item in provider_handler.create_async_generator( + model, + messages, + stream=True, + prompt=prompt, + image=image, + **kwargs + ): + if isinstance(item, ImageResponse): + response = item + break + elif hasattr(provider_handler, "create_completion"): + for item in provider_handler.create_completion( + model, + messages, + True, + prompt=prompt, + image=image, + **kwargs + ): + if isinstance(item, ImageResponse): + response = item + break + else: + raise ValueError(f"Provider {provider_name} does not support image generation") + return response + def create_variation( self, image: ImageType, @@ -352,33 +379,28 @@ async def async_create_variation( **kwargs ) -> ImagesResponse: provider_handler = await self.get_provider_handler(model, provider, OpenaiAccount) - provider_name = provider.__name__ if hasattr(provider, "__name__") else type(provider).__name__ + provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__ if proxy is None: proxy = self.client.proxy + prompt = "create a variation of this image" - if hasattr(provider_handler, "create_async_generator"): - messages = [{"role": "user", "content": "create a variation of this image"}] - generator = None - try: - generator = provider_handler.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs) - async for chunk in generator: - if isinstance(chunk, ImageResponse): - response = chunk + response = None + if isinstance(provider_handler, IterListProvider): + # File pointer can be read only once, so we need to convert it to bytes + image = to_bytes(image) + for provider in provider_handler.providers: + try: + response = await self._generate_image_response(provider, provider.__name__, model, prompt, image=image, **kwargs) + if response is not None: + provider_name = provider.__name__ break - finally: - await safe_aclose(generator) - elif hasattr(provider_handler, 'create_variation'): - if asyncio.iscoroutinefunction(provider.provider_handler): - response = await provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs) - else: - response = provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs) + except (MissingAuthError, NoValidHarFileError) as e: + debug.log(f"Image provider {provider.__name__}: {e}") else: - raise NoImageResponseError(f"Provider {provider_name} does not support image variation") + response = await self._generate_image_response(provider_handler, provider_name, model, prompt, image=image, **kwargs) - if isinstance(response, str): - response = ImageResponse([response]) if isinstance(response, ImageResponse): - return self._process_image_response(response, response_format, proxy, model, provider_name) + return await self._process_image_response(response, response_format, proxy, model, provider_name) if response is None: raise NoImageResponseError(f"No image response from {provider_name}") raise NoImageResponseError(f"Unexpected response type: {type(response)}") diff --git a/g4f/errors.py b/g4f/errors.py index 3d553ba6170..4ae7a7ce6dd 100644 --- a/g4f/errors.py +++ b/g4f/errors.py @@ -44,4 +44,7 @@ class ResponseError(Exception): ... class ResponseStatusError(Exception): + ... + +class NoValidHarFileError(Exception): ... \ No newline at end of file diff --git a/g4f/models.py b/g4f/models.py index 4092858d0f3..a376472bed1 100644 --- a/g4f/models.py +++ b/g4f/models.py @@ -7,10 +7,12 @@ AIChatFree, AmigoChat, Blackbox, + BingCreateImages, ChatGpt, ChatGptEs, Cloudflare, Copilot, + CopilotAccount, DarkAI, DDG, DeepInfraChat, @@ -25,7 +27,9 @@ MagickPen, Mhystical, MetaAI, + MicrosoftDesigner, OpenaiChat, + OpenaiAccount, PerplexityLabs, Pi, Pizzagpt, @@ -629,9 +633,9 @@ def __all__() -> list[str]: ### OpenAI ### dalle_3 = Model( - name = 'dalle-3', + name = 'dall-e-3', base_provider = 'OpenAI', - best_provider = AmigoChat + best_provider = IterListProvider([CopilotAccount, OpenaiAccount, MicrosoftDesigner, BingCreateImages]) ) ### Recraft ### @@ -828,7 +832,8 @@ class ModelUtils: ### OpenAI ### 'dalle-3': dalle_3, - + 'dall-e-3': dalle_3, + ### Recraft ### 'recraft-v3': recraft_v3, diff --git a/g4f/providers/asyncio.py b/g4f/providers/asyncio.py index cf0ce1a0faf..2b83bfeb6a9 100644 --- a/g4f/providers/asyncio.py +++ b/g4f/providers/asyncio.py @@ -2,7 +2,7 @@ import asyncio from asyncio import AbstractEventLoop, runners -from typing import Union, Callable, AsyncGenerator, Generator +from typing import Optional, Callable, AsyncGenerator, Generator from ..errors import NestAsyncioError @@ -17,7 +17,7 @@ except ImportError: has_uvloop = False -def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]: +def get_running_loop(check_nested: bool) -> Optional[AbstractEventLoop]: try: loop = asyncio.get_running_loop() # Do not patch uvloop loop because its incompatible. diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py index 92386955bc6..5bd5790db32 100644 --- a/g4f/providers/retry_provider.py +++ b/g4f/providers/retry_provider.py @@ -3,7 +3,7 @@ import asyncio import random -from ..typing import Type, List, CreateResult, Messages, Iterator, AsyncResult +from ..typing import Type, List, CreateResult, Messages, AsyncResult from .types import BaseProvider, BaseRetryProvider, ProviderType from .. import debug from ..errors import RetryProviderError, RetryNoProviderError diff --git a/g4f/typing.py b/g4f/typing.py index 1bc71141439..48a4e90d589 100644 --- a/g4f/typing.py +++ b/g4f/typing.py @@ -4,7 +4,8 @@ try: from PIL.Image import Image except ImportError: - from typing import Type as Image + class Image: + pass if sys.version_info >= (3, 8): from typing import TypedDict