Skip to content

Commit

Permalink
Update (g4f/models.py g4f/Provider/Cloudflare.py)
Browse files Browse the repository at this point in the history
  • Loading branch information
kqlio67 committed Nov 11, 2024
1 parent c74a694 commit 8e8410c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 63 deletions.
47 changes: 14 additions & 33 deletions g4f/Provider/Cloudflare.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from aiohttp import ClientSession
import asyncio
import json
import uuid
Expand All @@ -10,7 +11,6 @@
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_prompt


class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
label = "Cloudflare AI"
url = "https://playground.ai.cloudflare.com"
Expand All @@ -22,8 +22,6 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):

default_model = '@cf/meta/llama-3.1-8b-instruct-awq'
models = [
'@hf/google/gemma-7b-it',

'@cf/meta/llama-2-7b-chat-fp16',
'@cf/meta/llama-2-7b-chat-int8',

Expand All @@ -38,21 +36,12 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):

'@hf/mistral/mistral-7b-instruct-v0.2',

'@cf/microsoft/phi-2',

'@cf/qwen/qwen1.5-0.5b-chat',
'@cf/qwen/qwen1.5-1.8b-chat',
'@cf/qwen/qwen1.5-14b-chat-awq',
'@cf/qwen/qwen1.5-7b-chat-awq',

'@cf/defog/sqlcoder-7b-2',
]

model_aliases = {
#"falcon-7b": "@cf/tiiuae/falcon-7b-instruct",

"gemma-7b": "@hf/google/gemma-7b-it",

"llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16",
"llama-2-7b": "@cf/meta/llama-2-7b-chat-int8",

Expand All @@ -65,11 +54,6 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):

"llama-3.2-1b": "@cf/meta/llama-3.2-1b-instruct",

"phi-2": "@cf/microsoft/phi-2",

"qwen-1.5-0-5b": "@cf/qwen/qwen1.5-0.5b-chat",
"qwen-1.5-1-8b": "@cf/qwen/qwen1.5-1.8b-chat",
"qwen-1.5-14b": "@cf/qwen/qwen1.5-14b-chat-awq",
"qwen-1.5-7b": "@cf/qwen/qwen1.5-7b-chat-awq",

#"sqlcoder-7b": "@cf/defog/sqlcoder-7b-2",
Expand All @@ -90,6 +74,7 @@ async def create_async_generator(
model: str,
messages: Messages,
proxy: str = None,
max_tokens: int = 2048,
**kwargs
) -> AsyncResult:
model = cls.get_model(model)
Expand Down Expand Up @@ -117,52 +102,48 @@ async def create_async_generator(

scraper = cloudscraper.create_scraper()


prompt = messages[-1]['content']

data = {
"messages": [
{"role": "user", "content": prompt}
{"role": "user", "content": format_prompt(messages)}
],
"lora": None,
"model": model,
"max_tokens": 2048,
"max_tokens": max_tokens,
"stream": True
}

max_retries = 5
max_retries = 3
full_response = ""

for attempt in range(max_retries):
try:
response = scraper.post(
cls.api_endpoint,
headers=headers,
cookies=cookies,
json=data,
stream=True
stream=True,
proxies={'http': proxy, 'https': proxy} if proxy else None
)

if response.status_code == 403:
await asyncio.sleep(2 ** attempt)
continue

response.raise_for_status()

skip_tokens = ["</s>", "<s>", "</s>", "[DONE]", "<|endoftext|>", "<|end|>"]
filtered_response = ""

for line in response.iter_lines():
if line.startswith(b'data: '):
if line == b'data: [DONE]':
if full_response:
yield full_response
break
try:
content = json.loads(line[6:].decode('utf-8'))
response_text = content['response']
if not any(token in response_text for token in skip_tokens):
filtered_response += response_text
if 'response' in content and content['response'] != '</s>':
yield content['response']
except Exception:
continue

yield filtered_response.strip()
break
except Exception as e:
if attempt == max_retries - 1:
Expand Down
31 changes: 1 addition & 30 deletions g4f/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def __all__() -> list[str]:
phi_2 = Model(
name = "phi-2",
base_provider = "Microsoft",
best_provider = IterListProvider([Cloudflare, Airforce])
best_provider = IterListProvider([Airforce])
)

phi_3_5_mini = Model(
Expand Down Expand Up @@ -286,12 +286,6 @@ def __all__() -> list[str]:
best_provider = IterListProvider([ReplicateHome])
)

gemma_7b = Model(
name = 'gemma-7b',
base_provider = 'Google',
best_provider = Cloudflare
)


### Anthropic ###
claude_2_1 = Model(
Expand Down Expand Up @@ -358,30 +352,12 @@ def __all__() -> list[str]:

### Qwen ###
# qwen 1_5
qwen_1_5_5b = Model(
name = 'qwen-1.5-5b',
base_provider = 'Qwen',
best_provider = Cloudflare
)

qwen_1_5_7b = Model(
name = 'qwen-1.5-7b',
base_provider = 'Qwen',
best_provider = Cloudflare
)

qwen_1_5_8b = Model(
name = 'qwen-1.5-8b',
base_provider = 'Qwen',
best_provider = Cloudflare
)

qwen_1_5_14b = Model(
name = 'qwen-1.5-14b',
base_provider = 'Qwen',
best_provider = IterListProvider([Cloudflare])
)

# qwen 2
qwen_2_72b = Model(
name = 'qwen-2-72b',
Expand Down Expand Up @@ -690,7 +666,6 @@ class ModelUtils:


### Microsoft ###
'phi-2': phi_2,
'phi-3.5-mini': phi_3_5_mini,


Expand All @@ -702,7 +677,6 @@ class ModelUtils:

# gemma
'gemma-2b': gemma_2b,
'gemma-7b': gemma_7b,


### Anthropic ###
Expand Down Expand Up @@ -737,10 +711,7 @@ class ModelUtils:

### Qwen ###
# qwen 1.5
'qwen-1.5-5b': qwen_1_5_5b,
'qwen-1.5-7b': qwen_1_5_7b,
'qwen-1.5-8b': qwen_1_5_8b,
'qwen-1.5-14b': qwen_1_5_14b,

# qwen 2
'qwen-2-72b': qwen_2_72b,
Expand Down

0 comments on commit 8e8410c

Please sign in to comment.