Skip to content

Commit

Permalink
Merge pull request #203 from yjg30737/Dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
yjg30737 authored Dec 22, 2024
2 parents 6ed8f59 + 26e5e01 commit 476c7d1
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 85 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,5 @@ dmypy.json
pyqt_openai/pyqt_openai.ini
pyqt_openai/*.db
pyqt_openai/test/
pyqt_openai/config.yaml
pyqt_openai/config.yaml
test
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ dependencies = [
"docx2txt",
"openpyxl",

"g4f==0.3.3.4",
"g4f",
"nodriver",
"curl_cffi",
"litellm",

Expand Down
1 change: 1 addition & 0 deletions pyqt_openai/g4f_image_widget/g4fImageRightSideBar.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def getArgument(self):
**obj,
"model": self.__modelCmbBox.currentText(),
"provider": self.__providerCmbBox.currentText(),
"response_format": "url",
"prompt": self._promptTextEdit.toPlainText(),
"negative_prompt": self._negativeTextEdit.toPlainText(),
}
31 changes: 9 additions & 22 deletions pyqt_openai/g4f_image_widget/g4fImageThread.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,9 @@ def stop(self):
self.__stop = True

def run(self):
try:
provider = G4F_PROVIDER_DEFAULT
if self.__input_args["provider"] != G4F_PROVIDER_DEFAULT:
provider = self.__input_args["provider"]
self.__input_args["provider"] = convert_to_provider(
self.__input_args["provider"]
)
# try:
if self.__input_args["provider"] == G4F_PROVIDER_DEFAULT:
del self.__input_args["provider"]

for _ in range(self.__number_of_images):
if self.__stop:
Expand All @@ -45,26 +41,17 @@ def run(self):
self.__input_args["prompt"] = generate_random_prompt(
self.__randomizing_prompt_source_arr
)
images = G4F_CLIENT.images
if provider != G4F_PROVIDER_DEFAULT:
images.provider = self.__input_args["provider"]
else:
del self.__input_args["provider"]
provider = images.models.get(self.__input_args['model'], images.provider)
if isinstance(provider, IterListProvider):
if provider.providers:
provider = provider.providers[0]
provider = provider.__name__

response = images.generate(**self.__input_args)
response = G4F_CLIENT.images.generate(
**self.__input_args
)
arg = {
**self.__input_args,
"provider": provider,
"provider": response.provider,
"data": download_image_as_base64(response.data[0].url),
}

result = ImagePromptContainer(**arg)
self.replyGenerated.emit(result)
self.allReplyGenerated.emit()
except Exception as e:
self.errorGenerated.emit(str(e))
# except Exception as e:
# self.errorGenerated.emit(str(e))
2 changes: 1 addition & 1 deletion pyqt_openai/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __createPromptEntry(self):
self.__conn.rollback()
self.__c.execute("PRAGMA foreign_keys=ON") # Ensure foreign keys are re-enabled
else:
print(f"Table {PROMPT_ENTRY_TABLE_NAME} already updated.")
pass
else:
self.__c.execute(
f"""CREATE TABLE {PROMPT_ENTRY_TABLE_NAME} (
Expand Down
132 changes: 73 additions & 59 deletions pyqt_openai/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,19 @@
import sys
import tempfile
import time
import filetype
import traceback
import wave
import zipfile
from datetime import datetime
from io import BytesIO
from pathlib import Path
from inspect import signature

import PIL.Image
import filetype
import numpy as np
import psutil
from g4f import ProviderType
from g4f.providers.base_provider import ProviderModelMixin
from litellm import completion

from pyqt_openai.widgets.scrollableErrorDialog import ScrollableErrorDialog

if sys.platform == "win32":
Expand All @@ -45,7 +43,7 @@
from g4f.Provider import ProviderUtils, __providers__, __map__
from g4f.errors import ProviderNotFoundError
from g4f.models import ModelUtils
from g4f.providers.retry_provider import IterProvider
from g4f.providers.retry_provider import IterListProvider
from jinja2 import Template

import pyqt_openai.util
Expand All @@ -64,8 +62,7 @@
O1_MODELS,
STT_MODEL,
DEFAULT_DATETIME_FORMAT,
DEFAULT_TOKEN_CHUNK_SIZE, DEFAULT_API_CONFIGS, INDENT_SIZE, FAMOUS_LLM_LIST,
)
DEFAULT_TOKEN_CHUNK_SIZE, DEFAULT_API_CONFIGS, INDENT_SIZE, )
from pyqt_openai.config_loader import CONFIG_MANAGER
from pyqt_openai.globals import (
DB,
Expand Down Expand Up @@ -526,7 +523,7 @@ def convert_to_provider(provider: str):
]
if not provider_list:
raise ProviderNotFoundError(f"Providers not found: {provider}")
provider = IterProvider(provider_list)
provider = IterListProvider(provider_list)
elif provider in ProviderUtils.convert:
provider = ProviderUtils.convert[provider]
elif provider:
Expand Down Expand Up @@ -636,32 +633,62 @@ def get_g4f_image_models() -> list:
Get all the models that support image generation
Some of the image providers are not included in this list
"""
image_models = []
image_models = [
### Stability AI ###
'sdxl',
'sd-3',

### Playground ###
'playground-v2.5',

### Flux AI ###
'flux',
'flux-pro',
'flux-dev',
'flux-realism',
'flux-anime',
'flux-3d',
'flux-disney',
'flux-pixel',
'flux-4o',

### OpenAI ###
'dall-e-3',

### Recraft ###
'recraft-v3',

### Other ###
'any-dark'
]
index = []
for provider in __providers__:
if hasattr(provider, "image_models"):
if hasattr(provider, "get_models"):
provider.get_models()
parent = provider
if hasattr(provider, "parent"):
parent = __map__[provider.parent]
if parent.__name__ not in index:
if provider.image_models:
for model in provider.image_models:
image_models.append(
{
"provider": parent.__name__,
"url": parent.url,
"label": parent.label if hasattr(parent, "label") else None,
"image_model": model,
}
)
index.append(parent.__name__)

models = [model["image_model"] for model in image_models]
# Filter out the models in FAMOUS_LLM_LIST
models = [model for model in models if model not in FAMOUS_LLM_LIST]
return models
# for provider in __providers__:
# try:
# if hasattr(provider, "image_models"):
# if hasattr(provider, "get_models"):
# provider.get_models()
# parent = provider
# if hasattr(provider, "parent"):
# parent = __map__[provider.parent]
# if parent.__name__ not in index:
# if provider.image_models:
# for model in provider.image_models:
# image_models.append(
# {
# "provider": parent.__name__,
# "url": parent.url,
# "label": parent.label if hasattr(parent, "label") else None,
# "image_model": model,
# }
# )
# index.append(parent.__name__)
# except Exception as e:
# continue
#
# models = [model["image_model"] for model in image_models]
# # Filter out the models in FAMOUS_LLM_LIST
# models = [model for model in models if model not in FAMOUS_LLM_LIST]
return image_models


def get_g4f_image_providers(including_auto=False) -> list:
Expand Down Expand Up @@ -700,37 +727,24 @@ def get_g4f_image_models_from_provider(provider) -> list:
if provider == G4F_PROVIDER_DEFAULT:
return get_g4f_image_models()

def get_provider_models(provider: str) -> list[dict]:
"""
From g4f/gui/server/api.py
"""
def get_provider_models(provider: str, api_key: str = None):
if provider in __map__:
provider: ProviderType = __map__[provider]
if issubclass(provider, ProviderModelMixin):
if api_key is not None and "api_key" in signature(provider.get_models).parameters:
models = provider.get_models(api_key=api_key)
else:
models = provider.get_models()
return [
{"model": model, "default": model == provider.default_model}
for model in provider.get_models()
]
elif provider.supports_gpt_35_turbo or provider.supports_gpt_4:
return [
*(
[{"model": "gpt-4", "default": not provider.supports_gpt_4}]
if provider.supports_gpt_4
else []
),
*(
[
{
"model": "gpt-3.5-turbo",
"default": not provider.supports_gpt_4,
}
]
if provider.supports_gpt_35_turbo
else []
),
{
"model": model,
"default": model == provider.default_model,
"vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
"image": False if provider.image_models is None else model in provider.image_models,
}
for model in models
]
else:
return []
return []

return [model["model"] for model in get_provider_models(provider)]

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ llama-index
docx2txt
openpyxl

g4f==0.3.3.4
g4f
nodriver

curl_cffi
litellm
Expand Down

0 comments on commit 476c7d1

Please sign in to comment.