diff --git a/README.md b/README.md
index cfaaadc05dd..4f61d79150d 100644
--- a/README.md
+++ b/README.md
@@ -105,15 +105,18 @@ docker run \
hlohaus789/g4f:latest
```
-Or run this command to start the gui without a browser and in the debug mode:
+Start the GUI without a browser requirement and in debug mode.
+There's no need to update the Docker image every time.
+Simply remove the g4f package from the image and install the Python package:
```bash
-docker pull hlohaus789/g4f:latest-slim
docker run \
-p 8080:8080 \
-v ${PWD}/har_and_cookies:/app/har_and_cookies \
-v ${PWD}/generated_images:/app/generated_images \
hlohaus789/g4f:latest-slim \
- python -m g4f.cli gui -debug
+ rm -r -f /app/g4f/ \
+ && pip install -U g4f[slim] \
+ && python -m g4f.cli gui -d
```
3. **Access the Client:**
diff --git a/etc/examples/api.py b/etc/examples/api.py
index f8f5d5eca40..2485baded6c 100644
--- a/etc/examples/api.py
+++ b/etc/examples/api.py
@@ -1,13 +1,17 @@
import requests
import json
+import uuid
+
url = "http://localhost:1337/v1/chat/completions"
+conversation_id = str(uuid.uuid4())
body = {
"model": "",
- "provider": "",
+ "provider": "Copilot",
"stream": True,
"messages": [
- {"role": "user", "content": "What can you do? Who are you?"}
- ]
+ {"role": "user", "content": "Hello, i am Heiner. How are you?"}
+ ],
+ "conversation_id": conversation_id
}
response = requests.post(url, json=body, stream=True)
response.raise_for_status()
@@ -21,4 +25,27 @@
print(json_data.get("choices", [{"delta": {}}])[0]["delta"].get("content", ""), end="")
except json.JSONDecodeError:
pass
-print()
\ No newline at end of file
+print()
+print()
+print()
+body = {
+ "model": "",
+ "provider": "Copilot",
+ "stream": True,
+ "messages": [
+ {"role": "user", "content": "Tell me somethings about my name"}
+ ],
+ "conversation_id": conversation_id
+}
+response = requests.post(url, json=body, stream=True)
+response.raise_for_status()
+for line in response.iter_lines():
+ if line.startswith(b"data: "):
+ try:
+ json_data = json.loads(line[6:])
+ if json_data.get("error"):
+ print(json_data)
+ break
+ print(json_data.get("choices", [{"delta": {}}])[0]["delta"].get("content", ""), end="")
+ except json.JSONDecodeError:
+ pass
\ No newline at end of file
diff --git a/g4f/Provider/Airforce.py b/g4f/Provider/Airforce.py
index 54bb543b170..f5bcfefad2e 100644
--- a/g4f/Provider/Airforce.py
+++ b/g4f/Provider/Airforce.py
@@ -20,7 +20,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
working = True
supports_system_message = True
supports_message_history = True
-
+
@classmethod
def fetch_completions_models(cls):
response = requests.get('https://api.airforce/models', verify=False)
@@ -34,19 +34,20 @@ def fetch_imagine_models(cls):
response.raise_for_status()
return response.json()
- completions_models = fetch_completions_models.__func__(None)
- imagine_models = fetch_imagine_models.__func__(None)
-
default_model = "gpt-4o-mini"
default_image_model = "flux"
additional_models_imagine = ["stable-diffusion-xl-base", "stable-diffusion-xl-lightning", "Flux-1.1-Pro"]
- text_models = completions_models
- image_models = [*imagine_models, *additional_models_imagine]
- models = [
- *text_models,
- *image_models,
- ]
-
+
+ @classmethod
+ def get_models(cls):
+ if not cls.models:
+ cls.image_models = [*cls.fetch_imagine_models(), *cls.additional_models_imagine]
+ cls.models = [
+ *cls.fetch_completions_models(),
+ *cls.image_models
+ ]
+ return cls.models
+
model_aliases = {
### completions ###
# openchat
@@ -100,7 +101,6 @@ def create_async_generator(
**kwargs
) -> AsyncResult:
model = cls.get_model(model)
-
if model in cls.image_models:
return cls._generate_image(model, messages, proxy, seed, size)
else:
diff --git a/g4f/Provider/Blackbox.py b/g4f/Provider/Blackbox.py
index b259b4aada7..419055374e1 100644
--- a/g4f/Provider/Blackbox.py
+++ b/g4f/Provider/Blackbox.py
@@ -20,15 +20,15 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
supports_system_message = True
supports_message_history = True
_last_validated_value = None
-
+
default_model = 'blackboxai'
default_vision_model = default_model
default_image_model = 'Image Generation'
image_models = ['Image Generation', 'repomap']
vision_models = [default_model, 'gpt-4o', 'gemini-pro', 'gemini-1.5-flash', 'llama-3.1-8b', 'llama-3.1-70b', 'llama-3.1-405b']
-
+
userSelectedModel = ['gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro']
-
+
agentMode = {
'Image Generation': {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
}
@@ -77,22 +77,21 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
}
additional_prefixes = {
- 'gpt-4o': '@gpt-4o',
- 'gemini-pro': '@gemini-pro',
- 'claude-sonnet-3.5': '@claude-sonnet'
- }
+ 'gpt-4o': '@gpt-4o',
+ 'gemini-pro': '@gemini-pro',
+ 'claude-sonnet-3.5': '@claude-sonnet'
+ }
model_prefixes = {
- **{mode: f"@{value['id']}" for mode, value in trendingAgentMode.items()
- if mode not in ["gemini-1.5-flash", "llama-3.1-8b", "llama-3.1-70b", "llama-3.1-405b", "repomap"]},
- **additional_prefixes
- }
+ **{
+ mode: f"@{value['id']}" for mode, value in trendingAgentMode.items()
+ if mode not in ["gemini-1.5-flash", "llama-3.1-8b", "llama-3.1-70b", "llama-3.1-405b", "repomap"]
+ },
+ **additional_prefixes
+ }
-
models = list(dict.fromkeys([default_model, *userSelectedModel, *list(agentMode.keys()), *list(trendingAgentMode.keys())]))
-
-
model_aliases = {
"gemini-flash": "gemini-1.5-flash",
"claude-3.5-sonnet": "claude-sonnet-3.5",
@@ -131,12 +130,11 @@ async def fetch_validated(cls):
return cls._last_validated_value
-
@staticmethod
def generate_id(length=7):
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))
-
+
@classmethod
def add_prefix_to_messages(cls, messages: Messages, model: str) -> Messages:
prefix = cls.model_prefixes.get(model, "")
@@ -157,6 +155,7 @@ async def create_async_generator(
cls,
model: str,
messages: Messages,
+ prompt: str = None,
proxy: str = None,
web_search: bool = False,
image: ImageType = None,
@@ -191,7 +190,7 @@ async def create_async_generator(
'sec-fetch-site': 'same-origin',
'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36'
}
-
+
data = {
"messages": messages,
"id": message_id,
@@ -221,26 +220,25 @@ async def create_async_generator(
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
response.raise_for_status()
response_text = await response.text()
-
+
if model in cls.image_models:
image_matches = re.findall(r'!\[.*?\]\((https?://[^\)]+)\)', response_text)
if image_matches:
image_url = image_matches[0]
- image_response = ImageResponse(images=[image_url], alt="Generated Image")
- yield image_response
+ yield ImageResponse(image_url, prompt)
return
response_text = re.sub(r'Generated by BLACKBOX.AI, try unlimited chat https://www.blackbox.ai', '', response_text, flags=re.DOTALL)
-
+
json_match = re.search(r'\$~~~\$(.*?)\$~~~\$', response_text, re.DOTALL)
if json_match:
search_results = json.loads(json_match.group(1))
answer = response_text.split('$~~~$')[-1].strip()
-
+
formatted_response = f"{answer}\n\n**Source:**"
for i, result in enumerate(search_results, 1):
formatted_response += f"\n{i}. {result['title']}: {result['link']}"
-
+
yield formatted_response
else:
yield response_text.strip()
diff --git a/g4f/Provider/Copilot.py b/g4f/Provider/Copilot.py
index e8eea0a5171..2f37b1ebf8f 100644
--- a/g4f/Provider/Copilot.py
+++ b/g4f/Provider/Copilot.py
@@ -57,6 +57,7 @@ def create_completion(
image: ImageType = None,
conversation: Conversation = None,
return_conversation: bool = False,
+ web_search: bool = True,
**kwargs
) -> CreateResult:
if not has_curl_cffi:
@@ -72,10 +73,9 @@ def create_completion(
else:
access_token = conversation.access_token
debug.log(f"Copilot: Access token: {access_token[:7]}...{access_token[-5:]}")
- debug.log(f"Copilot: Cookies: {';'.join([*cookies])}")
websocket_url = f"{websocket_url}&accessToken={quote(access_token)}"
- headers = {"authorization": f"Bearer {access_token}", "cookie": format_cookies(cookies)}
-
+ headers = {"authorization": f"Bearer {access_token}"}
+
with Session(
timeout=timeout,
proxy=proxy,
@@ -124,12 +124,14 @@ def create_completion(
is_started = False
msg = None
image_prompt: str = None
+ last_msg = None
while True:
try:
msg = wss.recv()[0]
msg = json.loads(msg)
except:
break
+ last_msg = msg
if msg.get("event") == "appendText":
is_started = True
yield msg.get("text")
@@ -139,8 +141,12 @@ def create_completion(
yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")})
elif msg.get("event") == "done":
break
+ elif msg.get("event") == "error":
+ raise RuntimeError(f"Error: {msg}")
+ elif msg.get("event") not in ["received", "startMessage", "citation", "partCompleted"]:
+ debug.log(f"Copilot Message: {msg}")
if not is_started:
- raise RuntimeError(f"Last message: {msg}")
+ raise RuntimeError(f"Invalid response: {last_msg}")
@classmethod
async def get_access_token_and_cookies(cls, proxy: str = None):
diff --git a/g4f/Provider/PollinationsAI.py b/g4f/Provider/PollinationsAI.py
index 57597bf17b0..a30f896d302 100644
--- a/g4f/Provider/PollinationsAI.py
+++ b/g4f/Provider/PollinationsAI.py
@@ -3,7 +3,6 @@
from urllib.parse import quote
import random
import requests
-from sys import maxsize
from aiohttp import ClientSession
from ..typing import AsyncResult, Messages
@@ -40,6 +39,7 @@ async def create_async_generator(
cls,
model: str,
messages: Messages,
+ prompt: str = None,
api_base: str = "https://text.pollinations.ai/openai",
api_key: str = None,
proxy: str = None,
@@ -49,9 +49,10 @@ async def create_async_generator(
if model:
model = cls.get_model(model)
if model in cls.image_models:
- prompt = messages[-1]["content"]
+ if prompt is None:
+ prompt = messages[-1]["content"]
if seed is None:
- seed = random.randint(0, maxsize)
+ seed = random.randint(0, 100000)
image = f"https://image.pollinations.ai/prompt/{quote(prompt)}?width=1024&height=1024&seed={int(seed)}&nofeed=true&nologo=true&model={quote(model)}"
yield ImageResponse(image, prompt)
return
diff --git a/g4f/Provider/ReplicateHome.py b/g4f/Provider/ReplicateHome.py
index a7fc9b54543..00de09e0dff 100644
--- a/g4f/Provider/ReplicateHome.py
+++ b/g4f/Provider/ReplicateHome.py
@@ -6,6 +6,8 @@
from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
+from ..requests.aiohttp import get_connector
+from ..requests.raise_for_status import raise_for_status
from .helper import format_prompt
from ..image import ImageResponse
@@ -32,10 +34,8 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
'yorickvp/llava-13b',
]
-
-
models = text_models + image_models
-
+
model_aliases = {
# image_models
"sd-3": "stability-ai/stable-diffusion-3",
@@ -56,23 +56,14 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
# text_models
"google-deepmind/gemma-2b-it": "dff94eaf770e1fc211e425a50b51baa8e4cac6c39ef074681f9e39d778773626",
"yorickvp/llava-13b": "80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
-
}
- @classmethod
- def get_model(cls, model: str) -> str:
- if model in cls.models:
- return model
- elif model in cls.model_aliases:
- return cls.model_aliases[model]
- else:
- return cls.default_model
-
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
+ prompt: str = None,
proxy: str = None,
**kwargs
) -> AsyncResult:
@@ -96,29 +87,30 @@ async def create_async_generator(
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36"
}
- async with ClientSession(headers=headers) as session:
- if model in cls.image_models:
- prompt = messages[-1]['content'] if messages else ""
- else:
- prompt = format_prompt(messages)
-
+ async with ClientSession(headers=headers, connector=get_connector(proxy=proxy)) as session:
+ if prompt is None:
+ if model in cls.image_models:
+ prompt = messages[-1]['content']
+ else:
+ prompt = format_prompt(messages)
+
data = {
"model": model,
"version": cls.model_versions[model],
"input": {"prompt": prompt},
}
-
- async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
- response.raise_for_status()
+
+ async with session.post(cls.api_endpoint, json=data) as response:
+ await raise_for_status(response)
result = await response.json()
prediction_id = result['id']
-
+
poll_url = f"https://homepage.replicate.com/api/poll?id={prediction_id}"
max_attempts = 30
delay = 5
for _ in range(max_attempts):
- async with session.get(poll_url, proxy=proxy) as response:
- response.raise_for_status()
+ async with session.get(poll_url) as response:
+ await raise_for_status(response)
try:
result = await response.json()
except ContentTypeError:
@@ -131,7 +123,7 @@ async def create_async_generator(
if result['status'] == 'succeeded':
if model in cls.image_models:
image_url = result['output'][0]
- yield ImageResponse(image_url, "Generated image")
+ yield ImageResponse(image_url, prompt)
return
else:
for chunk in result['output']:
@@ -140,6 +132,6 @@ async def create_async_generator(
elif result['status'] == 'failed':
raise Exception(f"Prediction failed: {result.get('error')}")
await asyncio.sleep(delay)
-
+
if result['status'] != 'succeeded':
raise Exception("Prediction timed out")
diff --git a/g4f/Provider/RubiksAI.py b/g4f/Provider/RubiksAI.py
index c06e6c3dfb2..816ea60c533 100644
--- a/g4f/Provider/RubiksAI.py
+++ b/g4f/Provider/RubiksAI.py
@@ -9,7 +9,7 @@
from aiohttp import ClientSession
from ..typing import AsyncResult, Messages
-from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
+from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, Sources
from ..requests.raise_for_status import raise_for_status
class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
@@ -23,7 +23,6 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
default_model = 'gpt-4o-mini'
models = [default_model, 'gpt-4o', 'o1-mini', 'claude-3.5-sonnet', 'grok-beta', 'gemini-1.5-pro', 'nova-pro']
-
model_aliases = {
"llama-3.1-70b": "llama-3.1-70b-versatile",
}
@@ -118,7 +117,7 @@ async def create_async_generator(
if 'url' in json_data and 'title' in json_data:
if web_search:
- sources.append({'title': json_data['title'], 'url': json_data['url']})
+ sources.append(json_data)
elif 'choices' in json_data:
for choice in json_data['choices']:
@@ -128,5 +127,4 @@ async def create_async_generator(
yield content
if web_search and sources:
- sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
- yield f"\n\n**Source:**\n{sources_text}"
\ No newline at end of file
+ yield Sources(sources)
\ No newline at end of file
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index 667f6964b88..c0d8edf076b 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -1,4 +1,4 @@
from ..providers.base_provider import *
-from ..providers.types import FinishReason, Streaming
-from ..providers.conversation import BaseConversation
+from ..providers.types import Streaming
+from ..providers.response import BaseConversation, Sources, FinishReason
from .helper import get_cookies, format_prompt
\ No newline at end of file
diff --git a/g4f/Provider/bing/conversation.py b/g4f/Provider/bing/conversation.py
index b5c237f9a53..43bcbb4d483 100644
--- a/g4f/Provider/bing/conversation.py
+++ b/g4f/Provider/bing/conversation.py
@@ -2,7 +2,7 @@
from ...requests import StreamSession, raise_for_status
from ...errors import RateLimitError
-from ...providers.conversation import BaseConversation
+from ...providers.response import BaseConversation
class Conversation(BaseConversation):
"""
diff --git a/g4f/Provider/needs_auth/BingCreateImages.py b/g4f/Provider/needs_auth/BingCreateImages.py
index 80984d40239..b95a78c3b7a 100644
--- a/g4f/Provider/needs_auth/BingCreateImages.py
+++ b/g4f/Provider/needs_auth/BingCreateImages.py
@@ -28,13 +28,14 @@ async def create_async_generator(
cls,
model: str,
messages: Messages,
+ prompt: str = None,
api_key: str = None,
cookies: Cookies = None,
proxy: str = None,
**kwargs
) -> AsyncResult:
session = BingCreateImages(cookies, proxy, api_key)
- yield await session.generate(messages[-1]["content"])
+ yield await session.generate(messages[-1]["content"] if prompt is None else prompt)
async def generate(self, prompt: str) -> ImageResponse:
"""
diff --git a/g4f/Provider/needs_auth/DeepInfraImage.py b/g4f/Provider/needs_auth/DeepInfraImage.py
index 24df04e3da8..4479056117e 100644
--- a/g4f/Provider/needs_auth/DeepInfraImage.py
+++ b/g4f/Provider/needs_auth/DeepInfraImage.py
@@ -29,9 +29,10 @@ async def create_async_generator(
cls,
model: str,
messages: Messages,
+ prompt: str = None,
**kwargs
) -> AsyncResult:
- yield await cls.create_async(messages[-1]["content"], model, **kwargs)
+ yield await cls.create_async(messages[-1]["content"] if prompt is None else prompt, model, **kwargs)
@classmethod
async def create_async(
diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py
index 781aa410c19..89f6f802fc4 100644
--- a/g4f/Provider/needs_auth/Gemini.py
+++ b/g4f/Provider/needs_auth/Gemini.py
@@ -4,8 +4,10 @@
import json
import random
import re
+import base64
from aiohttp import ClientSession, BaseConnector
+
try:
import nodriver
has_nodriver = True
@@ -14,12 +16,13 @@
from ... import debug
from ...typing import Messages, Cookies, ImageType, AsyncResult, AsyncIterator
-from ..base_provider import AsyncGeneratorProvider, BaseConversation
+from ..base_provider import AsyncGeneratorProvider, BaseConversation, SynthesizeData
from ..helper import format_prompt, get_cookies
from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import get_connector
from ...errors import MissingAuthError
from ...image import ImageResponse, to_bytes
+from ... import debug
REQUEST_HEADERS = {
"authority": "gemini.google.com",
@@ -54,6 +57,7 @@ class Gemini(AsyncGeneratorProvider):
image_models = ["gemini"]
default_vision_model = "gemini"
models = ["gemini", "gemini-1.5-flash", "gemini-1.5-pro"]
+ synthesize_content_type = "audio/vnd.wav"
_cookies: Cookies = None
_snlm0e: str = None
_sid: str = None
@@ -106,6 +110,7 @@ async def create_async_generator(
prompt = format_prompt(messages) if conversation is None else messages[-1]["content"]
cls._cookies = cookies or cls._cookies or get_cookies(".google.com", False, True)
base_connector = get_connector(connector, proxy)
+
async with ClientSession(
headers=REQUEST_HEADERS,
connector=base_connector
@@ -122,6 +127,7 @@ async def create_async_generator(
if not cls._snlm0e:
raise RuntimeError("Invalid cookies. SNlM0e not found")
+ yield SynthesizeData(cls.__name__, {"text": messages[-1]["content"]})
image_url = await cls.upload_image(base_connector, to_bytes(image), image_name) if image else None
async with ClientSession(
@@ -198,6 +204,39 @@ async def create_async_generator(
except TypeError:
pass
+ @classmethod
+ async def synthesize(cls, params: dict, proxy: str = None) -> AsyncIterator[bytes]:
+ if "text" not in params:
+ raise ValueError("Missing parameter text")
+ async with ClientSession(
+ cookies=cls._cookies,
+ headers=REQUEST_HEADERS,
+ connector=get_connector(proxy=proxy),
+ ) as session:
+ if not cls._snlm0e:
+ await cls.fetch_snlm0e(session, cls._cookies) if cls._cookies else None
+ inner_data = json.dumps([None, params["text"], "de-DE", None, 2])
+ async with session.post(
+ "https://gemini.google.com/_/BardChatUi/data/batchexecute",
+ data={
+ "f.req": json.dumps([[["XqA3Ic", inner_data, None, "generic"]]]),
+ "at": cls._snlm0e,
+ },
+ params={
+ "rpcids": "XqA3Ic",
+ "source-path": "/app/2704fb4aafcca926",
+ "bl": "boq_assistant-bard-web-server_20241119.00_p1",
+ "f.sid": "" if cls._sid is None else cls._sid,
+ "hl": "de",
+ "_reqid": random.randint(1111, 9999),
+ "rt": "c"
+ },
+ ) as response:
+ await raise_for_status(response)
+ iter_base64_response = iter_filter_base64(response.content.iter_chunked(1024))
+ async for chunk in iter_base64_decode(iter_base64_response):
+ yield chunk
+
def build_request(
prompt: str,
language: str,
@@ -280,3 +319,27 @@ def __init__(self,
self.conversation_id = conversation_id
self.response_id = response_id
self.choice_id = choice_id
+async def iter_filter_base64(response_iter: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
+ search_for = b'[["wrb.fr","XqA3Ic","[\\"'
+ end_with = b'\\'
+ is_started = False
+ async for chunk in response_iter:
+ if is_started:
+ if end_with in chunk:
+ yield chunk.split(end_with, 1).pop(0)
+ break
+ else:
+ yield chunk
+ elif search_for in chunk:
+ is_started = True
+ yield chunk.split(search_for, 1).pop()
+ else:
+ raise RuntimeError(f"Response: {chunk}")
+
+async def iter_base64_decode(response_iter: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
+ buffer = b""
+ async for chunk in response_iter:
+ chunk = buffer + chunk
+ rest = len(chunk) % 4
+ buffer = chunk[-rest:]
+ yield base64.b64decode(chunk[:-rest])
\ No newline at end of file
diff --git a/g4f/Provider/needs_auth/GithubCopilot.py b/g4f/Provider/needs_auth/GithubCopilot.py
new file mode 100644
index 00000000000..3eb66b5ec32
--- /dev/null
+++ b/g4f/Provider/needs_auth/GithubCopilot.py
@@ -0,0 +1,93 @@
+from __future__ import annotations
+
+import json
+
+from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, BaseConversation
+from ...typing import AsyncResult, Messages, Cookies
+from ...requests.raise_for_status import raise_for_status
+from ...requests import StreamSession
+from ...providers.helper import format_prompt
+from ...cookies import get_cookies
+
+class Conversation(BaseConversation):
+ conversation_id: str
+
+ def __init__(self, conversation_id: str):
+ self.conversation_id = conversation_id
+
+class GithubCopilot(AsyncGeneratorProvider, ProviderModelMixin):
+ url = "https://copilot.microsoft.com"
+ working = True
+ needs_auth = True
+ supports_stream = True
+ default_model = "gpt-4o"
+ models = [default_model, "o1-mini", "o1-preview", "claude-3.5-sonnet"]
+
+ @classmethod
+ async def create_async_generator(
+ cls,
+ model: str,
+ messages: Messages,
+ stream: bool = False,
+ api_key: str = None,
+ proxy: str = None,
+ cookies: Cookies = None,
+ conversation_id: str = None,
+ conversation: Conversation = None,
+ return_conversation: bool = False,
+ **kwargs
+ ) -> AsyncResult:
+ if not model:
+ model = cls.default_model
+ if cookies is None:
+ cookies = get_cookies(".github.com")
+ async with StreamSession(
+ proxy=proxy,
+ impersonate="chrome",
+ cookies=cookies,
+ headers={
+ "GitHub-Verified-Fetch": "true",
+ }
+ ) as session:
+ headers = {}
+ if api_key is None:
+ async with session.post("https://github.com/github-copilot/chat/token") as response:
+ await raise_for_status(response, "Get token")
+ api_key = (await response.json()).get("token")
+ headers = {
+ "Authorization": f"GitHub-Bearer {api_key}",
+ }
+ if conversation is not None:
+ conversation_id = conversation.conversation_id
+ if conversation_id is None:
+ print(headers)
+ async with session.post("https://api.individual.githubcopilot.com/github/chat/threads", headers=headers) as response:
+ await raise_for_status(response)
+ conversation_id = (await response.json()).get("thread_id")
+ if return_conversation:
+ yield Conversation(conversation_id)
+ content = messages[-1]["content"]
+ else:
+ content = format_prompt(messages)
+ json_data = {
+ "content": content,
+ "intent": "conversation",
+ "references":[],
+ "context": [],
+ "currentURL": f"https://github.com/copilot/c/{conversation_id}",
+ "streaming": True,
+ "confirmations": [],
+ "customInstructions": [],
+ "model": model,
+ "mode": "immersive"
+ }
+ async with session.post(
+ f"https://api.individual.githubcopilot.com/github/chat/threads/{conversation_id}/messages",
+ json=json_data,
+ headers=headers
+ ) as response:
+ async for line in response.iter_lines():
+ if line.startswith(b"data: "):
+ data = json.loads(line[6:])
+ if data.get("type") == "content":
+ yield data.get("body")
\ No newline at end of file
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index 97515ec4063..37bdf0742c8 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -7,7 +7,6 @@
import base64
import time
import requests
-from aiohttp import ClientWebSocketResponse
from copy import copy
try:
@@ -16,19 +15,15 @@
has_nodriver = True
except ImportError:
has_nodriver = False
-try:
- from platformdirs import user_config_dir
- has_platformdirs = True
-except ImportError:
- has_platformdirs = False
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...typing import AsyncResult, Messages, Cookies, ImageType, AsyncIterator
from ...requests.raise_for_status import raise_for_status
-from ...requests.aiohttp import StreamSession
+from ...requests import StreamSession
+from ...requests import get_nodriver
from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format
-from ...errors import MissingAuthError, ResponseError
-from ...providers.conversation import BaseConversation
+from ...errors import MissingAuthError
+from ...providers.response import BaseConversation, FinishReason, SynthesizeData
from ..helper import format_cookies
from ..openai.har_file import get_request_config, NoValidHarFileError
from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url
@@ -63,9 +58,10 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
supports_system_message = True
default_model = "auto"
default_vision_model = "gpt-4o"
- fallback_models = ["auto", "gpt-4", "gpt-4o", "gpt-4o-mini", "gpt-4o-canmore", "o1-preview", "o1-mini"]
+ fallback_models = [default_model, "gpt-4", "gpt-4o", "gpt-4o-mini", "gpt-4o-canmore", "o1-preview", "o1-mini"]
vision_models = fallback_models
image_models = fallback_models
+ synthesize_content_type = "audio/mpeg"
_api_key: str = None
_headers: dict = None
@@ -84,51 +80,6 @@ def get_models(cls):
cls.models = cls.fallback_models
return cls.models
- @classmethod
- async def create(
- cls,
- prompt: str = None,
- model: str = "",
- messages: Messages = [],
- action: str = "next",
- **kwargs
- ) -> Response:
- """
- Create a new conversation or continue an existing one
-
- Args:
- prompt: The user input to start or continue the conversation
- model: The name of the model to use for generating responses
- messages: The list of previous messages in the conversation
- history_disabled: A flag indicating if the history and training should be disabled
- action: The type of action to perform, either "next", "continue", or "variant"
- conversation_id: The ID of the existing conversation, if any
- parent_id: The ID of the parent message, if any
- image: The image to include in the user input, if any
- **kwargs: Additional keyword arguments to pass to the generator
-
- Returns:
- A Response object that contains the generator, action, messages, and options
- """
- # Add the user input to the messages list
- if prompt is not None:
- messages.append({
- "role": "user",
- "content": prompt
- })
- generator = cls.create_async_generator(
- model,
- messages,
- return_conversation=True,
- **kwargs
- )
- return Response(
- generator,
- action,
- messages,
- kwargs
- )
-
@classmethod
async def upload_image(
cls,
@@ -160,7 +111,7 @@ async def upload_image(
# Post the image data to the service and get the image data
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response:
cls._update_request_args(session)
- await raise_for_status(response)
+ await raise_for_status(response, "Create file failed")
image_data = {
**data,
**await response.json(),
@@ -178,7 +129,7 @@ async def upload_image(
"x-ms-blob-type": "BlockBlob"
}
) as response:
- await raise_for_status(response)
+ await raise_for_status(response, "Send file failed")
# Post the file ID to the service and get the download URL
async with session.post(
f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded",
@@ -186,38 +137,12 @@ async def upload_image(
headers=headers
) as response:
cls._update_request_args(session)
- await raise_for_status(response)
+ await raise_for_status(response, "Get download url failed")
image_data["download_url"] = (await response.json())["download_url"]
return ImageRequest(image_data)
@classmethod
- async def get_default_model(cls, session: StreamSession, headers: dict):
- """
- Get the default model name from the service
-
- Args:
- session: The StreamSession object to use for requests
- headers: The headers to include in the requests
-
- Returns:
- The default model name as a string
- """
- if not cls.default_model:
- url = f"{cls.url}/backend-anon/models" if cls._api_key is None else f"{cls.url}/backend-api/models"
- async with session.get(url, headers=headers) as response:
- cls._update_request_args(session)
- if response.status == 401:
- raise MissingAuthError('Add a .har file for OpenaiChat' if cls._api_key is None else "Invalid api key")
- await raise_for_status(response)
- data = await response.json()
- if "categories" in data:
- cls.default_model = data["categories"][-1]["default_model"]
- return cls.default_model
- raise ResponseError(data)
- return cls.default_model
-
- @classmethod
- def create_messages(cls, messages: Messages, image_request: ImageRequest = None):
+ def create_messages(cls, messages: Messages, image_request: ImageRequest = None, system_hints: list = None):
"""
Create a list of messages for the user input
@@ -235,7 +160,7 @@ def create_messages(cls, messages: Messages, image_request: ImageRequest = None)
"id": str(uuid.uuid4()),
"create_time": int(time.time()),
"id": str(uuid.uuid4()),
- "metadata": {"serialization_metadata": {"custom_symbol_offsets": []}}
+ "metadata": {"serialization_metadata": {"custom_symbol_offsets": []}, "system_hints": system_hints},
} for message in messages]
# Check if there is an image response
@@ -264,7 +189,7 @@ def create_messages(cls, messages: Messages, image_request: ImageRequest = None)
return messages
@classmethod
- async def get_generated_image(cls, session: StreamSession, headers: dict, element: dict) -> ImageResponse:
+ async def get_generated_image(cls, session: StreamSession, headers: dict, element: dict, prompt: str = None) -> ImageResponse:
"""
Retrieves the image response based on the message content.
@@ -286,6 +211,8 @@ async def get_generated_image(cls, session: StreamSession, headers: dict, elemen
try:
prompt = element["metadata"]["dalle"]["prompt"]
file_id = element["asset_pointer"].split("file-service://", 1)[1]
+ except TypeError:
+ return
except Exception as e:
raise RuntimeError(f"No Image: {e.__class__.__name__}: {e}")
try:
@@ -297,30 +224,6 @@ async def get_generated_image(cls, session: StreamSession, headers: dict, elemen
except Exception as e:
raise RuntimeError(f"Error in downloading image: {e}")
- @classmethod
- async def delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str):
- """
- Deletes a conversation by setting its visibility to False.
-
- This method sends an HTTP PATCH request to update the visibility of a conversation.
- It's used to effectively delete a conversation from being accessed or displayed in the future.
-
- Args:
- session (StreamSession): The StreamSession object used for making HTTP requests.
- headers (dict): HTTP headers to be used for the request.
- conversation_id (str): The unique identifier of the conversation to be deleted.
-
- Raises:
- HTTPError: If the HTTP request fails or returns an unsuccessful status code.
- """
- async with session.patch(
- f"{cls.url}/backend-api/conversation/{conversation_id}",
- json={"is_visible": False},
- headers=headers
- ) as response:
- cls._update_request_args(session)
- ...
-
@classmethod
async def create_async_generator(
cls,
@@ -328,7 +231,6 @@ async def create_async_generator(
messages: Messages,
proxy: str = None,
timeout: int = 180,
- api_key: str = None,
cookies: Cookies = None,
auto_continue: bool = False,
history_disabled: bool = False,
@@ -340,6 +242,7 @@ async def create_async_generator(
image_name: str = None,
return_conversation: bool = False,
max_retries: int = 3,
+ web_search: bool = False,
**kwargs
) -> AsyncResult:
"""
@@ -367,19 +270,13 @@ async def create_async_generator(
Raises:
RuntimeError: If an error occurs during processing.
"""
+ await cls.login(proxy)
+
async with StreamSession(
proxy=proxy,
impersonate="chrome",
timeout=timeout
) as session:
- if cls._expires is not None and cls._expires < time.time():
- cls._headers = cls._api_key = None
- try:
- await get_request_config(proxy)
- cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
- cls._set_api_key(RequestConfig.access_token)
- except NoValidHarFileError as e:
- await cls.nodriver_auth(proxy)
try:
image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None
except Exception as e:
@@ -419,12 +316,13 @@ async def create_async_generator(
if "proofofwork" in chat_requirements:
proofofwork = generate_proof_token(
**chat_requirements["proofofwork"],
- user_agent=cls._headers["user-agent"],
+ user_agent=cls._headers.get("user-agent"),
proof_token=RequestConfig.proof_token
)
[debug.log(text) for text in (
f"Arkose: {'False' if not need_arkose else RequestConfig.arkose_token[:12]+'...'}",
f"Proofofwork: {'False' if proofofwork is None else proofofwork[:12]+'...'}",
+ f"AccessToken: {'False' if cls._api_key is None else cls._api_key[:12]+'...'}",
)]
data = {
"action": action,
@@ -436,23 +334,25 @@ async def create_async_generator(
"conversation_mode": {"kind":"primary_assistant"},
"websocket_request_id": str(uuid.uuid4()),
"supported_encodings": ["v1"],
- "supports_buffering": True
+ "supports_buffering": True,
+ "system_hints": ["search"] if web_search else None
}
if conversation.conversation_id is not None:
data["conversation_id"] = conversation.conversation_id
debug.log(f"OpenaiChat: Use conversation: {conversation.conversation_id}")
if action != "continue":
messages = messages if conversation_id is None else [messages[-1]]
- data["messages"] = cls.create_messages(messages, image_request)
+ data["messages"] = cls.create_messages(messages, image_request, ["search"] if web_search else None)
headers = {
+ **cls._headers,
"accept": "text/event-stream",
- "Openai-Sentinel-Chat-Requirements-Token": chat_token,
- **cls._headers
+ "content-type": "application/json",
+ "openai-sentinel-chat-requirements-token": chat_token,
}
if RequestConfig.arkose_token:
- headers["Openai-Sentinel-Arkose-Token"] = RequestConfig.arkose_token
+ headers["openai-sentinel-arkose-token"] = RequestConfig.arkose_token
if proofofwork is not None:
- headers["Openai-Sentinel-Proof-Token"] = proofofwork
+ headers["openai-sentinel-proof-token"] = proofofwork
if need_turnstile and RequestConfig.turnstile_token is not None:
headers['openai-sentinel-turnstile-token'] = RequestConfig.turnstile_token
async with session.post(
@@ -469,31 +369,24 @@ async def create_async_generator(
await asyncio.sleep(5)
continue
await raise_for_status(response)
- async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, conversation):
- if return_conversation:
- history_disabled = False
- return_conversation = False
- yield conversation
- yield chunk
+ if return_conversation:
+ yield conversation
+ async for line in response.iter_lines():
+ async for chunk in cls.iter_messages_line(session, line, conversation):
+ yield chunk
+ if not history_disabled:
+ yield SynthesizeData(cls.__name__, {
+ "conversation_id": conversation.conversation_id,
+ "message_id": conversation.message_id,
+ "voice": "maple",
+ })
if auto_continue and conversation.finish_reason == "max_tokens":
conversation.finish_reason = None
action = "continue"
await asyncio.sleep(5)
else:
break
- if history_disabled and auto_continue:
- await cls.delete_conversation(session, cls._headers, conversation.conversation_id)
-
- @classmethod
- async def iter_messages_chunk(
- cls,
- messages: AsyncIterator,
- session: StreamSession,
- fields: Conversation,
- ) -> AsyncIterator:
- async for message in messages:
- async for chunk in cls.iter_messages_line(session, message, fields):
- yield chunk
+ yield FinishReason(conversation.finish_reason)
@classmethod
async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: Conversation) -> AsyncIterator:
@@ -530,9 +423,9 @@ async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: C
generated_images = []
for element in c.get("parts"):
if isinstance(element, dict) and element.get("content_type") == "image_asset_pointer":
- generated_images.append(
- cls.get_generated_image(session, cls._headers, element)
- )
+ image = cls.get_generated_image(session, cls._headers, element)
+ if image is not None:
+ generated_images.append(image)
for image_response in await asyncio.gather(*generated_images):
yield image_response
if m.get("author", {}).get("role") == "assistant":
@@ -541,19 +434,39 @@ async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: C
if "error" in line and line.get("error"):
raise RuntimeError(line.get("error"))
+ @classmethod
+ async def synthesize(cls, params: dict) -> AsyncIterator[bytes]:
+ await cls.login()
+ async with StreamSession(
+ impersonate="chrome",
+ timeout=900
+ ) as session:
+ async with session.get(
+ f"{cls.url}/backend-api/synthesize",
+ params=params,
+ headers=cls._headers
+ ) as response:
+ await raise_for_status(response)
+ async for chunk in response.iter_content():
+ yield chunk
+
+ @classmethod
+ async def login(cls, proxy: str = None):
+ if cls._expires is not None and cls._expires < time.time():
+ cls._headers = cls._api_key = None
+ try:
+ await get_request_config(proxy)
+ cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
+ cls._set_api_key(RequestConfig.access_token)
+ except NoValidHarFileError:
+ if has_nodriver:
+ await cls.nodriver_auth(proxy)
+ else:
+ raise
+
@classmethod
async def nodriver_auth(cls, proxy: str = None):
- if not has_nodriver:
- return
- if has_platformdirs:
- user_data_dir = user_config_dir("g4f-nodriver")
- else:
- user_data_dir = None
- debug.log(f"Open nodriver with user_dir: {user_data_dir}")
- browser = await nodriver.start(
- user_data_dir=user_data_dir,
- browser_args=None if proxy is None else [f"--proxy-server={proxy}"],
- )
+ browser = await get_nodriver(proxy=proxy)
page = browser.main_tab
def on_request(event: nodriver.cdp.network.RequestWillBeSent):
if event.request.url == start_url or event.request.url.startswith(conversation_url):
@@ -592,14 +505,14 @@ def on_request(event: nodriver.cdp.network.RequestWillBeSent):
pass
for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])):
RequestConfig.cookies[c.name] = c.value
- RequestConfig.user_agent = await page.evaluate("window.navigator.userAgent")
+ user_agent = await page.evaluate("window.navigator.userAgent")
await page.select("#prompt-textarea", 240)
while True:
if RequestConfig.proof_token:
break
await asyncio.sleep(1)
await page.close()
- cls._create_request_args(RequestConfig.cookies, RequestConfig.headers, user_agent=RequestConfig.user_agent)
+ cls._create_request_args(RequestConfig.cookies, RequestConfig.headers, user_agent=user_agent)
cls._set_api_key(RequestConfig.access_token)
@staticmethod
@@ -642,90 +555,4 @@ def __init__(self, conversation_id: str = None, message_id: str = None, finish_r
self.conversation_id = conversation_id
self.message_id = message_id
self.finish_reason = finish_reason
- self.is_recipient = False
-
-class Response():
- """
- Class to encapsulate a response from the chat service.
- """
- def __init__(
- self,
- generator: AsyncResult,
- action: str,
- messages: Messages,
- options: dict
- ):
- self._generator = generator
- self.action = action
- self.is_end = False
- self._message = None
- self._messages = messages
- self._options = options
- self._fields = None
-
- async def generator(self) -> AsyncIterator:
- if self._generator is not None:
- self._generator = None
- chunks = []
- async for chunk in self._generator:
- if isinstance(chunk, Conversation):
- self._fields = chunk
- else:
- yield chunk
- chunks.append(str(chunk))
- self._message = "".join(chunks)
- if self._fields is None:
- raise RuntimeError("Missing response fields")
- self.is_end = self._fields.finish_reason == "stop"
-
- def __aiter__(self):
- return self.generator()
-
- async def get_message(self) -> str:
- await self.generator()
- return self._message
-
- async def get_fields(self) -> dict:
- await self.generator()
- return {
- "conversation_id": self._fields.conversation_id,
- "parent_id": self._fields.message_id
- }
-
- async def create_next(self, prompt: str, **kwargs) -> Response:
- return await OpenaiChat.create(
- **self._options,
- prompt=prompt,
- messages=await self.get_messages(),
- action="next",
- **await self.get_fields(),
- **kwargs
- )
-
- async def do_continue(self, **kwargs) -> Response:
- fields = await self.get_fields()
- if self.is_end:
- raise RuntimeError("Can't continue message. Message already finished.")
- return await OpenaiChat.create(
- **self._options,
- messages=await self.get_messages(),
- action="continue",
- **fields,
- **kwargs
- )
-
- async def create_variant(self, **kwargs) -> Response:
- if self.action != "next":
- raise RuntimeError("Can't create variant from continue or variant request.")
- return await OpenaiChat.create(
- **self._options,
- messages=self._messages,
- action="variant",
- **await self.get_fields(),
- **kwargs
- )
-
- async def get_messages(self) -> list:
- messages = self._messages
- messages.append({"role": "assistant", "content": await self.message()})
- return messages
+ self.is_recipient = False
\ No newline at end of file
diff --git a/g4f/Provider/needs_auth/__init__.py b/g4f/Provider/needs_auth/__init__.py
index 1c7fe7c598b..f339170616f 100644
--- a/g4f/Provider/needs_auth/__init__.py
+++ b/g4f/Provider/needs_auth/__init__.py
@@ -7,6 +7,7 @@
from .DeepInfraImage import DeepInfraImage
from .Gemini import Gemini
from .GeminiPro import GeminiPro
+from .GithubCopilot import GithubCopilot
from .Groq import Groq
from .HuggingFace import HuggingFace
from .HuggingFace2 import HuggingFace2
diff --git a/g4f/Provider/openai/har_file.py b/g4f/Provider/openai/har_file.py
index 4569e1b7e98..e863b6acf77 100644
--- a/g4f/Provider/openai/har_file.py
+++ b/g4f/Provider/openai/har_file.py
@@ -25,7 +25,6 @@ class NoValidHarFileError(Exception):
pass
class RequestConfig:
- user_agent: str = None
cookies: dict = None
headers: dict = None
access_request_id: str = None
@@ -63,28 +62,30 @@ def readHAR():
continue
for v in harFile['log']['entries']:
v_headers = get_headers(v)
- try:
- if "openai-sentinel-proof-token" in v_headers:
- RequestConfig.proof_token = json.loads(base64.b64decode(
- v_headers["openai-sentinel-proof-token"].split("gAAAAAB", 1)[-1].encode()
- ).decode())
- if "openai-sentinel-turnstile-token" in v_headers:
- RequestConfig.turnstile_token = v_headers["openai-sentinel-turnstile-token"]
- except Exception as e:
- debug.log(f"Read proof token: {e}")
if arkose_url == v['request']['url']:
RequestConfig.arkose_request = parseHAREntry(v)
- elif v['request']['url'] == start_url or v['request']['url'].startswith(conversation_url):
+ elif v['request']['url'].startswith(start_url):
try:
match = re.search(r'"accessToken":"(.*?)"', v["response"]["content"]["text"])
if match:
RequestConfig.access_token = match.group(1)
except KeyError:
- continue
- RequestConfig.cookies = {c['name']: c['value'] for c in v['request']['cookies'] if c['name'] != "oai-did"}
- RequestConfig.headers = v_headers
- if RequestConfig.access_token is None:
- raise NoValidHarFileError("No accessToken found in .har files")
+ pass
+ try:
+ if "openai-sentinel-proof-token" in v_headers:
+ RequestConfig.headers = v_headers
+ RequestConfig.proof_token = json.loads(base64.b64decode(
+ v_headers["openai-sentinel-proof-token"].split("gAAAAAB", 1)[-1].encode()
+ ).decode())
+ if "openai-sentinel-turnstile-token" in v_headers:
+ RequestConfig.turnstile_token = v_headers["openai-sentinel-turnstile-token"]
+ if "authorization" in v_headers:
+ RequestConfig.access_token = v_headers["authorization"].split(" ")[1]
+ RequestConfig.cookies = {c['name']: c['value'] for c in v['request']['cookies']}
+ except Exception as e:
+ debug.log(f"Error on read headers: {e}")
+ if RequestConfig.proof_token is None:
+ raise NoValidHarFileError("No proof_token found in .har files")
def get_headers(entry) -> dict:
return {h['name'].lower(): h['value'] for h in entry['request']['headers'] if h['name'].lower() not in ['content-length', 'cookie'] and not h['name'].startswith(':')}
@@ -149,7 +150,7 @@ def getN() -> str:
return base64.b64encode(timestamp.encode()).decode()
async def get_request_config(proxy: str) -> RequestConfig:
- if RequestConfig.access_token is None:
+ if RequestConfig.proof_token is None:
readHAR()
if RequestConfig.arkose_request is not None:
RequestConfig.arkose_token = await sendRequest(genArkReq(RequestConfig.arkose_request), proxy)
diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py
index 02ba5260d67..f67a2aea033 100644
--- a/g4f/api/__init__.py
+++ b/g4f/api/__init__.py
@@ -4,8 +4,11 @@
import json
import uvicorn
import secrets
+import os
+import shutil
-from fastapi import FastAPI, Response, Request
+import os.path
+from fastapi import FastAPI, Response, Request, UploadFile
from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
from fastapi.exceptions import RequestValidationError
from fastapi.security import APIKeyHeader
@@ -13,15 +16,20 @@
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from fastapi.encoders import jsonable_encoder
from fastapi.middleware.cors import CORSMiddleware
+from starlette.responses import FileResponse
from pydantic import BaseModel
-from typing import Union, Optional
+from typing import Union, Optional, List
import g4f
import g4f.debug
-from g4f.client import AsyncClient, ChatCompletion
+from g4f.client import AsyncClient, ChatCompletion, convert_to_provider
+from g4f.providers.response import BaseConversation
from g4f.client.helper import filter_none
+from g4f.image import is_accepted_format, images_dir
from g4f.typing import Messages
-from g4f.cookies import read_cookie_files
+from g4f.errors import ProviderNotFoundError
+from g4f.cookies import read_cookie_files, get_cookies_dir
+from g4f.Provider import ProviderType, ProviderUtils, __providers__
logger = logging.getLogger(__name__)
@@ -63,6 +71,7 @@ class ChatCompletionsConfig(BaseModel):
api_key: Optional[str] = None
web_search: Optional[bool] = None
proxy: Optional[str] = None
+ conversation_id: str = None
class ImageGenerationConfig(BaseModel):
prompt: str
@@ -72,6 +81,18 @@ class ImageGenerationConfig(BaseModel):
api_key: Optional[str] = None
proxy: Optional[str] = None
+class ProviderResponseModel(BaseModel):
+ id: str
+ object: str = "provider"
+ created: int
+ owned_by: Optional[str]
+
+class ModelResponseModel(BaseModel):
+ id: str
+ object: str = "model"
+ created: int
+ owned_by: Optional[str]
+
class AppConfig:
ignored_providers: Optional[list[str]] = None
g4f_api_key: Optional[str] = None
@@ -98,11 +119,12 @@ def __init__(self, app: FastAPI, g4f_api_key=None) -> None:
self.client = AsyncClient()
self.g4f_api_key = g4f_api_key
self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key")
+ self.conversations: dict[str, dict[str, BaseConversation]] = {}
def register_authorization(self):
@self.app.middleware("http")
async def authorization(request: Request, call_next):
- if self.g4f_api_key and request.url.path in ["/v1/chat/completions", "/v1/completions", "/v1/images/generate"]:
+ if self.g4f_api_key and request.url.path not in ("/", "/v1"):
try:
user_g4f_api_key = await self.get_g4f_api_key(request)
except HTTPException as e:
@@ -116,9 +138,7 @@ async def authorization(request: Request, call_next):
status_code=HTTP_403_FORBIDDEN,
content=jsonable_encoder({"detail": "Invalid G4F API key"}),
)
-
- response = await call_next(request)
- return response
+ return await call_next(request)
def register_validation_exception_handler(self):
@self.app.exception_handler(RequestValidationError)
@@ -146,25 +166,26 @@ async def read_root_v1():
return HTMLResponse('g4f API: Go to '
'models, '
'chat/completions, or '
- 'images/generate.')
+ 'images/generate
'
+ 'Open Swagger UI at: '
+ '/docs')
@self.app.get("/v1/models")
- async def models():
+ async def models() -> list[ModelResponseModel]:
model_list = dict(
(model, g4f.models.ModelUtils.convert[model])
for model in g4f.Model.__all__()
)
- model_list = [{
+ return [{
'id': model_id,
'object': 'model',
'created': 0,
'owned_by': model.base_provider
} for model_id, model in model_list.items()]
- return JSONResponse(model_list)
@self.app.get("/v1/models/{model_name}")
async def model_info(model_name: str):
- try:
+ if model_name in g4f.models.ModelUtils.convert:
model_info = g4f.models.ModelUtils.convert[model_name]
return JSONResponse({
'id': model_name,
@@ -172,19 +193,27 @@ async def model_info(model_name: str):
'created': 0,
'owned_by': model_info.base_provider
})
- except:
- return JSONResponse({"error": "The model does not exist."})
+ return JSONResponse({"error": "The model does not exist."}, 404)
@self.app.post("/v1/chat/completions")
async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None):
try:
config.provider = provider if config.provider is None else config.provider
+ if config.provider is None:
+ config.provider = AppConfig.provider
if config.api_key is None and request is not None:
auth_header = request.headers.get("Authorization")
if auth_header is not None:
- auth_header = auth_header.split(None, 1)[-1]
- if auth_header and auth_header != "Bearer":
- config.api_key = auth_header
+ api_key = auth_header.split(None, 1)[-1]
+ if api_key and api_key != "Bearer":
+ config.api_key = api_key
+
+ conversation = return_conversation = None
+ if config.conversation_id is not None and config.provider is not None:
+ return_conversation = True
+ if config.conversation_id in self.conversations:
+ if config.provider in self.conversations[config.conversation_id]:
+ conversation = self.conversations[config.conversation_id][config.provider]
# Create the completion response
response = self.client.chat.completions.create(
@@ -194,6 +223,11 @@ async def chat_completions(config: ChatCompletionsConfig, request: Request = Non
"provider": AppConfig.provider,
"proxy": AppConfig.proxy,
**config.dict(exclude_none=True),
+ **{
+ "conversation_id": None,
+ "return_conversation": return_conversation,
+ "conversation": conversation
+ }
},
ignored=AppConfig.ignored_providers
),
@@ -206,7 +240,13 @@ async def chat_completions(config: ChatCompletionsConfig, request: Request = Non
async def streaming():
try:
async for chunk in response:
- yield f"data: {json.dumps(chunk.to_json())}\n\n"
+ if isinstance(chunk, BaseConversation):
+ if config.conversation_id is not None and config.provider is not None:
+ if config.conversation_id not in self.conversations:
+ self.conversations[config.conversation_id] = {}
+ self.conversations[config.conversation_id][config.provider] = chunk
+ else:
+ yield f"data: {json.dumps(chunk.to_json())}\n\n"
except GeneratorExit:
pass
except Exception as e:
@@ -222,7 +262,13 @@ async def streaming():
@self.app.post("/v1/images/generate")
@self.app.post("/v1/images/generations")
- async def generate_image(config: ImageGenerationConfig):
+ async def generate_image(config: ImageGenerationConfig, request: Request):
+ if config.api_key is None:
+ auth_header = request.headers.get("Authorization")
+ if auth_header is not None:
+ api_key = auth_header.split(None, 1)[-1]
+ if api_key and api_key != "Bearer":
+ config.api_key = api_key
try:
response = await self.client.images.generate(
prompt=config.prompt,
@@ -234,14 +280,87 @@ async def generate_image(config: ImageGenerationConfig):
proxy = config.proxy
)
)
+ for image in response.data:
+ if hasattr(image, "url") and image.url.startswith("/"):
+ image.url = f"{request.base_url}{image.url.lstrip('/')}"
return JSONResponse(response.to_json())
except Exception as e:
logger.exception(e)
return Response(content=format_exception(e, config, True), status_code=500, media_type="application/json")
- @self.app.post("/v1/completions")
- async def completions():
- return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json")
+ @self.app.get("/v1/providers")
+ async def providers() -> list[ProviderResponseModel]:
+ return [{
+ 'id': provider.__name__,
+ 'object': 'provider',
+ 'created': 0,
+ 'url': provider.url,
+ 'label': getattr(provider, "label", None),
+ } for provider in __providers__ if provider.working]
+
+ @self.app.get("/v1/providers/{provider}")
+ async def providers_info(provider: str) -> ProviderResponseModel:
+ if provider not in ProviderUtils.convert:
+ return JSONResponse({"error": "The provider does not exist."}, 404)
+ provider: ProviderType = ProviderUtils.convert[provider]
+ def safe_get_models(provider: ProviderType) -> list[str]:
+ try:
+ return provider.get_models() if hasattr(provider, "get_models") else []
+ except:
+ return []
+ return {
+ 'id': provider.__name__,
+ 'object': 'provider',
+ 'created': 0,
+ 'url': provider.url,
+ 'label': getattr(provider, "label", None),
+ 'models': safe_get_models(provider),
+ 'image_models': getattr(provider, "image_models", []) or [],
+ 'vision_models': [model for model in [getattr(provider, "default_vision_model", None)] if model],
+ 'params': [*provider.get_parameters()] if hasattr(provider, "get_parameters") else []
+ }
+
+ @self.app.post("/v1/upload_cookies")
+ def upload_cookies(files: List[UploadFile]):
+ response_data = []
+ for file in files:
+ try:
+ if file and file.filename.endswith(".json") or file.filename.endswith(".har"):
+ filename = os.path.basename(file.filename)
+ with open(os.path.join(get_cookies_dir(), filename), 'wb') as f:
+ shutil.copyfileobj(file.file, f)
+ response_data.append({"filename": filename})
+ finally:
+ file.file.close()
+ return response_data
+
+ @self.app.get("/v1/synthesize/{provider}")
+ async def synthesize(request: Request, provider: str):
+ try:
+ provider_handler = convert_to_provider(provider)
+ except ProviderNotFoundError:
+ return Response("Provider not found", 404)
+ if not hasattr(provider_handler, "synthesize"):
+ return Response("Provider doesn't support synthesize", 500)
+ if len(request.query_params) == 0:
+ return Response("Missing query params", 500)
+ response_data = provider_handler.synthesize({**request.query_params})
+ content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream")
+ return StreamingResponse(response_data, media_type=content_type)
+
+ @self.app.get("/images/{filename}")
+ async def get_image(filename) -> FileResponse:
+ target = os.path.join(images_dir, filename)
+
+ if not os.path.isfile(target):
+ return Response(status_code=404)
+
+ with open(target, "rb") as f:
+ content_type = is_accepted_format(f.read(12))
+
+ return FileResponse(target, media_type=content_type)
+
+
def format_exception(e: Exception, config: Union[ChatCompletionsConfig, ImageGenerationConfig], image: bool = False) -> str:
last_provider = {} if not image else g4f.get_last_provider(True)
diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py
index 549a244b412..f6a0f5e8eb9 100644
--- a/g4f/client/__init__.py
+++ b/g4f/client/__init__.py
@@ -6,23 +6,25 @@
import string
import asyncio
import base64
-import aiohttp
-import logging
-from typing import Union, AsyncIterator, Iterator, Coroutine
+from typing import Union, AsyncIterator, Iterator, Coroutine, Optional
from ..providers.base_provider import AsyncGeneratorProvider
-from ..image import ImageResponse, to_image, to_data_uri, is_accepted_format, EXTENSIONS_MAP
-from ..typing import Messages, Image
-from ..providers.types import ProviderType, FinishReason, BaseConversation
-from ..errors import NoImageResponseError
+from ..image import ImageResponse, copy_images, images_dir
+from ..typing import Messages, Image, ImageType
+from ..providers.types import ProviderType
+from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData
+from ..errors import NoImageResponseError, ModelNotFoundError
from ..providers.retry_provider import IterListProvider
+from ..providers.asyncio import get_running_loop, to_sync_generator, async_generator_to_list
from ..Provider.needs_auth.BingCreateImages import BingCreateImages
-from ..requests.aiohttp import get_connector
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .image_models import ImageModels
from .types import IterResponse, ImageProvider, Client as BaseClient
from .service import get_model_and_provider, get_last_provider, convert_to_provider
-from .helper import find_stop, filter_json, filter_none, safe_aclose, to_sync_iter, to_async_iterator
+from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator
+
+ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
+AsyncChatCompletionResponseType = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
try:
anext # Python 3.8+
@@ -35,20 +37,19 @@ async def anext(aiter):
# Synchronous iter_response function
def iter_response(
- response: Union[Iterator[str], AsyncIterator[str]],
+ response: Union[Iterator[Union[str, ResponseType]]],
stream: bool,
- response_format: dict = None,
- max_tokens: int = None,
- stop: list = None
-) -> Iterator[Union[ChatCompletion, ChatCompletionChunk]]:
+ response_format: Optional[dict] = None,
+ max_tokens: Optional[int] = None,
+ stop: Optional[list[str]] = None
+) -> ChatCompletionResponseType:
content = ""
finish_reason = None
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
idx = 0
if hasattr(response, '__aiter__'):
- # It's an async iterator, wrap it into a sync iterator
- response = to_sync_iter(response)
+ response = to_sync_generator(response)
for chunk in response:
if isinstance(chunk, FinishReason):
@@ -57,6 +58,8 @@ def iter_response(
elif isinstance(chunk, BaseConversation):
yield chunk
continue
+ elif isinstance(chunk, SynthesizeData):
+ continue
chunk = str(chunk)
content += chunk
@@ -88,22 +91,23 @@ def iter_response(
yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
# Synchronous iter_append_model_and_provider function
-def iter_append_model_and_provider(response: Iterator[ChatCompletionChunk]) -> Iterator[ChatCompletionChunk]:
+def iter_append_model_and_provider(response: ChatCompletionResponseType) -> ChatCompletionResponseType:
last_provider = None
for chunk in response:
- last_provider = get_last_provider(True) if last_provider is None else last_provider
- chunk.model = last_provider.get("model")
- chunk.provider = last_provider.get("name")
- yield chunk
+ if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
+ last_provider = get_last_provider(True) if last_provider is None else last_provider
+ chunk.model = last_provider.get("model")
+ chunk.provider = last_provider.get("name")
+ yield chunk
async def async_iter_response(
- response: AsyncIterator[str],
+ response: AsyncIterator[Union[str, ResponseType]],
stream: bool,
- response_format: dict = None,
- max_tokens: int = None,
- stop: list = None
-) -> AsyncIterator[Union[ChatCompletion, ChatCompletionChunk]]:
+ response_format: Optional[dict] = None,
+ max_tokens: Optional[int] = None,
+ stop: Optional[list[str]] = None
+) -> AsyncChatCompletionResponseType:
content = ""
finish_reason = None
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
@@ -117,6 +121,8 @@ async def async_iter_response(
elif isinstance(chunk, BaseConversation):
yield chunk
continue
+ elif isinstance(chunk, SynthesizeData):
+ continue
chunk = str(chunk)
content += chunk
@@ -149,13 +155,16 @@ async def async_iter_response(
if hasattr(response, 'aclose'):
await safe_aclose(response)
-async def async_iter_append_model_and_provider(response: AsyncIterator[ChatCompletionChunk]) -> AsyncIterator:
+async def async_iter_append_model_and_provider(
+ response: AsyncChatCompletionResponseType
+ ) -> AsyncChatCompletionResponseType:
last_provider = None
try:
async for chunk in response:
- last_provider = get_last_provider(True) if last_provider is None else last_provider
- chunk.model = last_provider.get("model")
- chunk.provider = last_provider.get("name")
+ if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
+ last_provider = get_last_provider(True) if last_provider is None else last_provider
+ chunk.model = last_provider.get("model")
+ chunk.provider = last_provider.get("name")
yield chunk
finally:
if hasattr(response, 'aclose'):
@@ -164,8 +173,8 @@ async def async_iter_append_model_and_provider(response: AsyncIterator[ChatCompl
class Client(BaseClient):
def __init__(
self,
- provider: ProviderType = None,
- image_provider: ImageProvider = None,
+ provider: Optional[ProviderType] = None,
+ image_provider: Optional[ImageProvider] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
@@ -173,7 +182,7 @@ def __init__(
self.images: Images = Images(self, image_provider)
class Completions:
- def __init__(self, client: Client, provider: ProviderType = None):
+ def __init__(self, client: Client, provider: Optional[ProviderType] = None):
self.client: Client = client
self.provider: ProviderType = provider
@@ -181,16 +190,16 @@ def create(
self,
messages: Messages,
model: str,
- provider: ProviderType = None,
- stream: bool = False,
- proxy: str = None,
- response_format: dict = None,
- max_tokens: int = None,
- stop: Union[list[str], str] = None,
- api_key: str = None,
- ignored: list[str] = None,
- ignore_working: bool = False,
- ignore_stream: bool = False,
+ provider: Optional[ProviderType] = None,
+ stream: Optional[bool] = False,
+ proxy: Optional[str] = None,
+ response_format: Optional[dict] = None,
+ max_tokens: Optional[int] = None,
+ stop: Optional[Union[list[str], str]] = None,
+ api_key: Optional[str] = None,
+ ignored: Optional[list[str]] = None,
+ ignore_working: Optional[bool] = False,
+ ignore_stream: Optional[bool] = False,
**kwargs
) -> IterResponse:
model, provider = get_model_and_provider(
@@ -220,10 +229,10 @@ def create(
response = asyncio.run(response)
if stream and hasattr(response, '__aiter__'):
# It's an async generator, wrap it into a sync iterator
- response = to_sync_iter(response)
+ response = to_sync_generator(response)
elif hasattr(response, '__aiter__'):
# If response is an async generator, collect it into a list
- response = list(to_sync_iter(response))
+ response = asyncio.run(async_generator_to_list(response))
response = iter_response(response, stream, response_format, max_tokens, stop)
response = iter_append_model_and_provider(response)
if stream:
@@ -234,22 +243,38 @@ def create(
class Chat:
completions: Completions
- def __init__(self, client: Client, provider: ProviderType = None):
+ def __init__(self, client: Client, provider: Optional[ProviderType] = None):
self.completions = Completions(client, provider)
class Images:
- def __init__(self, client: Client, provider: ProviderType = None):
+ def __init__(self, client: Client, provider: Optional[ProviderType] = None):
self.client: Client = client
- self.provider: ProviderType = provider
+ self.provider: Optional[ProviderType] = provider
self.models: ImageModels = ImageModels(client)
- def generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse:
+ def generate(
+ self,
+ prompt: str,
+ model: str = None,
+ provider: Optional[ProviderType] = None,
+ response_format: str = "url",
+ proxy: Optional[str] = None,
+ **kwargs
+ ) -> ImagesResponse:
"""
Synchronous generate method that runs the async_generate method in an event loop.
"""
- return asyncio.run(self.async_generate(prompt, model, provider, response_format=response_format, proxy=proxy, **kwargs))
+ return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs))
- async def async_generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse:
+ async def async_generate(
+ self,
+ prompt: str,
+ model: Optional[str] = None,
+ provider: Optional[ProviderType] = None,
+ response_format: Optional[str] = "url",
+ proxy: Optional[str] = None,
+ **kwargs
+ ) -> ImagesResponse:
if provider is None:
provider_handler = self.models.get(model, provider or self.provider or BingCreateImages)
elif isinstance(provider, str):
@@ -257,97 +282,73 @@ async def async_generate(self, prompt: str, model: str = None, provider: Provide
else:
provider_handler = provider
if provider_handler is None:
- raise ValueError(f"Unknown model: {model}")
- if proxy is None:
- proxy = self.client.proxy
-
+ raise ModelNotFoundError(f"Unknown model: {model}")
if isinstance(provider_handler, IterListProvider):
if provider_handler.providers:
provider_handler = provider_handler.providers[0]
else:
- raise ValueError(f"IterListProvider for model {model} has no providers")
+ raise ModelNotFoundError(f"IterListProvider for model {model} has no providers")
+ if proxy is None:
+ proxy = self.client.proxy
response = None
- if hasattr(provider_handler, "create_async_generator"):
- messages = [{"role": "user", "content": prompt}]
- async for item in provider_handler.create_async_generator(model, messages, **kwargs):
+ if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
+ messages = [{"role": "user", "content": f"Generate a image: {prompt}"}]
+ async for item in provider_handler.create_async_generator(model, messages, prompt=prompt, **kwargs):
if isinstance(item, ImageResponse):
response = item
break
- elif hasattr(provider, 'create'):
+ elif hasattr(provider_handler, 'create'):
if asyncio.iscoroutinefunction(provider_handler.create):
response = await provider_handler.create(prompt)
else:
response = provider_handler.create(prompt)
if isinstance(response, str):
response = ImageResponse([response], prompt)
+ elif hasattr(provider_handler, "create_completion"):
+ get_running_loop(check_nested=True)
+ messages = [{"role": "user", "content": f"Generate a image: {prompt}"}]
+ for item in provider_handler.create_completion(model, messages, prompt=prompt, **kwargs):
+ if isinstance(item, ImageResponse):
+ response = item
+ break
else:
raise ValueError(f"Provider {provider} does not support image generation")
if isinstance(response, ImageResponse):
- return await self._process_image_response(response, response_format, proxy, model=model, provider=provider)
-
+ return await self._process_image_response(
+ response,
+ response_format,
+ proxy,
+ model,
+ getattr(provider_handler, "__name__", None)
+ )
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
- async def _process_image_response(self, response: ImageResponse, response_format: str, proxy: str = None, model: str = None, provider: str = None) -> ImagesResponse:
- async def process_image_item(session: aiohttp.ClientSession, image_data: str):
- image_data_bytes = None
- if image_data.startswith("http://") or image_data.startswith("https://"):
- if response_format == "url":
- return Image(url=image_data, revised_prompt=response.alt)
- elif response_format == "b64_json":
- # Fetch the image data and convert it to base64
- image_data_bytes = await self._fetch_image(session, image_data)
- b64_json = base64.b64encode(image_data_bytes).decode("utf-8")
- return Image(b64_json=b64_json, url=image_data, revised_prompt=response.alt)
- else:
- # Assume image_data is base64 data or binary
- if response_format == "url":
- if image_data.startswith("data:image"):
- # Remove the data URL scheme and get the base64 data
- base64_data = image_data.split(",", 1)[-1]
- else:
- base64_data = image_data
- # Decode the base64 data
- image_data_bytes = base64.b64decode(base64_data)
- if image_data_bytes:
- file_name = self._save_image(image_data_bytes)
- return Image(url=file_name, revised_prompt=response.alt)
- else:
- raise ValueError("Unable to process image data")
-
- last_provider = get_last_provider(True)
- async with aiohttp.ClientSession(cookies=response.get("cookies"), connector=get_connector(proxy=proxy)) as session:
- return ImagesResponse(
- await asyncio.gather(*[process_image_item(session, image_data) for image_data in response.get_list()]),
- model=last_provider.get("model") if model is None else model,
- provider=last_provider.get("name") if provider is None else provider
- )
-
- async def _fetch_image(self, session: aiohttp.ClientSession, url: str) -> bytes:
- # Asynchronously fetch image data from the URL
- async with session.get(url) as resp:
- if resp.status == 200:
- return await resp.read()
- else:
- raise RuntimeError(f"Failed to fetch image from {url}, status code {resp.status}")
-
- def _save_image(self, image_data_bytes: bytes) -> str:
- os.makedirs('generated_images', exist_ok=True)
- image = to_image(image_data_bytes)
- file_name = f"generated_images/image_{int(time.time())}_{random.randint(0, 10000)}.{EXTENSIONS_MAP[is_accepted_format(image_data_bytes)]}"
- image.save(file_name)
- return file_name
-
- def create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse:
+ def create_variation(
+ self,
+ image: Union[str, bytes],
+ model: str = None,
+ provider: Optional[ProviderType] = None,
+ response_format: str = "url",
+ **kwargs
+ ) -> ImagesResponse:
return asyncio.run(self.async_create_variation(
image, model, provider, response_format, **kwargs
))
- async def async_create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse:
+ async def async_create_variation(
+ self,
+ image: ImageType,
+ model: Optional[str] = None,
+ provider: Optional[ProviderType] = None,
+ response_format: str = "url",
+ proxy: Optional[str] = None,
+ **kwargs
+ ) -> ImagesResponse:
if provider is None:
provider = self.models.get(model, provider or self.provider or BingCreateImages)
if provider is None:
- raise ValueError(f"Unknown model: {model}")
+ raise ModelNotFoundError(f"Unknown model: {model}")
if isinstance(provider, str):
provider = convert_to_provider(provider)
if proxy is None:
@@ -355,38 +356,61 @@ async def async_create_variation(self, image: Union[str, bytes], model: str = No
if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
messages = [{"role": "user", "content": "create a variation of this image"}]
- image_data = to_data_uri(image)
generator = None
try:
- generator = provider.create_async_generator(model, messages, image=image_data, response_format=response_format, proxy=proxy, **kwargs)
- async for response in generator:
- if isinstance(response, ImageResponse):
- return self._process_image_response(response)
- except RuntimeError as e:
- if "async generator ignored GeneratorExit" in str(e):
- logging.warning("Generator ignored GeneratorExit in create_variation, handling gracefully")
- else:
- raise
+ generator = provider.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs)
+ async for chunk in generator:
+ if isinstance(chunk, ImageResponse):
+ response = chunk
+ break
finally:
if generator and hasattr(generator, 'aclose'):
await safe_aclose(generator)
- logging.info("AsyncGeneratorProvider processing completed in create_variation")
elif hasattr(provider, 'create_variation'):
if asyncio.iscoroutinefunction(provider.create_variation):
response = await provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
else:
response = provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
- if isinstance(response, str):
- response = ImageResponse([response])
- return self._process_image_response(response)
else:
- raise ValueError(f"Provider {provider} does not support image variation")
+ raise NoImageResponseError(f"Provider {provider} does not support image variation")
+
+ if isinstance(response, str):
+ response = ImageResponse([response])
+ if isinstance(response, ImageResponse):
+ return self._process_image_response(response, response_format, proxy, model, getattr(provider, "__name__", None))
+ raise NoImageResponseError(f"Unexpected response type: {type(response)}")
+
+ async def _process_image_response(
+ self,
+ response: ImageResponse,
+ response_format: str,
+ proxy: str = None,
+ model: Optional[str] = None,
+ provider: Optional[str] = None
+ ) -> list[Image]:
+ if response_format in ("url", "b64_json"):
+ images = await copy_images(response.get_list(), response.options.get("cookies"), proxy)
+ async def process_image_item(image_file: str) -> Image:
+ if response_format == "b64_json":
+ with open(os.path.join(images_dir, os.path.basename(image_file)), "rb") as file:
+ image_data = base64.b64encode(file.read()).decode()
+ return Image(url=image_file, b64_json=image_data, revised_prompt=response.alt)
+ return Image(url=image_file, revised_prompt=response.alt)
+ images = await asyncio.gather(*[process_image_item(image) for image in images])
+ else:
+ images = [Image(url=image, revised_prompt=response.alt) for image in response.get_list()]
+ last_provider = get_last_provider(True)
+ return ImagesResponse(
+ images,
+ model=last_provider.get("model") if model is None else model,
+ provider=last_provider.get("name") if provider is None else provider
+ )
class AsyncClient(BaseClient):
def __init__(
self,
- provider: ProviderType = None,
- image_provider: ImageProvider = None,
+ provider: Optional[ProviderType] = None,
+ image_provider: Optional[ImageProvider] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
@@ -396,11 +420,11 @@ def __init__(
class AsyncChat:
completions: AsyncCompletions
- def __init__(self, client: AsyncClient, provider: ProviderType = None):
+ def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
self.completions = AsyncCompletions(client, provider)
class AsyncCompletions:
- def __init__(self, client: AsyncClient, provider: ProviderType = None):
+ def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
self.client: AsyncClient = client
self.provider: ProviderType = provider
@@ -408,18 +432,18 @@ def create(
self,
messages: Messages,
model: str,
- provider: ProviderType = None,
- stream: bool = False,
- proxy: str = None,
- response_format: dict = None,
- max_tokens: int = None,
- stop: Union[list[str], str] = None,
- api_key: str = None,
- ignored: list[str] = None,
- ignore_working: bool = False,
- ignore_stream: bool = False,
+ provider: Optional[ProviderType] = None,
+ stream: Optional[bool] = False,
+ proxy: Optional[str] = None,
+ response_format: Optional[dict] = None,
+ max_tokens: Optional[int] = None,
+ stop: Optional[Union[list[str], str]] = None,
+ api_key: Optional[str] = None,
+ ignored: Optional[list[str]] = None,
+ ignore_working: Optional[bool] = False,
+ ignore_stream: Optional[bool] = False,
**kwargs
- ) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk]]:
+ ) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk, BaseConversation]]:
model, provider = get_model_and_provider(
model,
self.provider if provider is None else provider,
@@ -450,15 +474,29 @@ def create(
return response if stream else anext(response)
class AsyncImages(Images):
- def __init__(self, client: AsyncClient, provider: ImageProvider = None):
+ def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
self.client: AsyncClient = client
- self.provider: ImageProvider = provider
+ self.provider: Optional[ProviderType] = provider
self.models: ImageModels = ImageModels(client)
- async def generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse:
+ async def generate(
+ self,
+ prompt: str,
+ model: Optional[str] = None,
+ provider: Optional[ProviderType] = None,
+ response_format: str = "url",
+ **kwargs
+ ) -> ImagesResponse:
return await self.async_generate(prompt, model, provider, response_format, **kwargs)
- async def create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse:
+ async def create_variation(
+ self,
+ image: ImageType,
+ model: str = None,
+ provider: ProviderType = None,
+ response_format: str = "url",
+ **kwargs
+ ) -> ImagesResponse:
return await self.async_create_variation(
image, model, provider, response_format, **kwargs
- )
+ )
\ No newline at end of file
diff --git a/g4f/client/helper.py b/g4f/client/helper.py
index 71bfd38ae90..909cc1320cb 100644
--- a/g4f/client/helper.py
+++ b/g4f/client/helper.py
@@ -1,12 +1,9 @@
from __future__ import annotations
import re
-import queue
-import threading
import logging
-import asyncio
-from typing import AsyncIterator, Iterator, AsyncGenerator
+from typing import AsyncIterator, Iterator, AsyncGenerator, Optional
def filter_json(text: str) -> str:
"""
@@ -23,7 +20,7 @@ def filter_json(text: str) -> str:
return match.group("code")
return text
-def find_stop(stop, content: str, chunk: str = None):
+def find_stop(stop: Optional[list[str]], content: str, chunk: str = None):
first = -1
word = None
if stop is not None:
@@ -53,33 +50,6 @@ async def safe_aclose(generator: AsyncGenerator) -> None:
except Exception as e:
logging.warning(f"Error while closing generator: {e}")
-# Helper function to convert an async generator to a synchronous iterator
-def to_sync_iter(async_gen: AsyncIterator) -> Iterator:
- q = queue.Queue()
- loop = asyncio.new_event_loop()
- done = object()
-
- def _run():
- asyncio.set_event_loop(loop)
-
- async def iterate():
- try:
- async for item in async_gen:
- q.put(item)
- finally:
- q.put(done)
-
- loop.run_until_complete(iterate())
- loop.close()
-
- threading.Thread(target=_run).start()
-
- while True:
- item = q.get()
- if item is done:
- break
- yield item
-
# Helper function to convert a synchronous iterator to an async iterator
async def to_async_iterator(iterator: Iterator) -> AsyncIterator:
for item in iterator:
diff --git a/g4f/client/image_models.py b/g4f/client/image_models.py
index edaa4592242..0b97a56b959 100644
--- a/g4f/client/image_models.py
+++ b/g4f/client/image_models.py
@@ -1,7 +1,5 @@
from __future__ import annotations
-from .types import Client, ImageProvider
-
from ..models import ModelUtils
class ImageModels():
diff --git a/g4f/client/types.py b/g4f/client/types.py
index 4f252ba991b..5010e098c96 100644
--- a/g4f/client/types.py
+++ b/g4f/client/types.py
@@ -3,7 +3,7 @@
import os
from .stubs import ChatCompletion, ChatCompletionChunk
-from ..providers.types import BaseProvider, ProviderType, FinishReason
+from ..providers.types import BaseProvider
from typing import Union, Iterator, AsyncIterator
ImageProvider = Union[BaseProvider, object]
diff --git a/g4f/gui/client/index.html b/g4f/gui/client/index.html
index 48214093926..8cbcd578338 100644
--- a/g4f/gui/client/index.html
+++ b/g4f/gui/client/index.html
@@ -111,6 +111,11 @@