From de14f7af876f2e8e58ec899c8a5d75f102d1bcc7 Mon Sep 17 00:00:00 2001 From: yjg30737 Date: Sat, 23 Nov 2024 13:08:10 +0900 Subject: [PATCH 1/6] Fix bugs related to sentence type prompt generation feature, refactoring --- .../chat_widget/prompt_gen_widget/formPage.py | 138 ++++++++++++------ .../prompt_gen_widget/sentencePage.py | 19 +-- 2 files changed, 101 insertions(+), 56 deletions(-) diff --git a/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py b/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py index 60d4b65..39d4dab 100644 --- a/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py +++ b/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py @@ -52,6 +52,7 @@ class FormGroupList(QWidget): added = Signal(int) deleted = Signal(int) currentRowChanged = Signal(int) + itemChanged = Signal(int) def __init__(self, parent=None): super().__init__(parent) @@ -92,30 +93,30 @@ def __initUi(self): groups = DB.selectPromptGroup(prompt_type="form") - self.__list = QListWidget() + self.list = QListWidget() for group in groups: self.__addGroupItem(group.id, group.name) - self.__list.currentRowChanged.connect(self.currentRowChanged) - self.__list.itemChanged.connect(self.__itemChanged) + self.list.currentRowChanged.connect(self.__currentRowChanged) + self.list.itemChanged.connect(self.__itemChanged) lay = QVBoxLayout() lay.addWidget(topWidget) - lay.addWidget(self.__list) + lay.addWidget(self.list) lay.setContentsMargins(0, 0, 5, 0) self.setLayout(lay) - self.__list.setCurrentRow(0) + self.list.setCurrentRow(0) def __addGroupItem(self, id, name): item = QListWidgetItem() item.setFlags(item.flags() | Qt.ItemFlag.ItemIsEditable) item.setData(Qt.ItemDataRole.UserRole, id) item.setText(name) - self.__list.addItem(item) - self.__list.setCurrentItem(item) + self.list.addItem(item) + self.list.setCurrentItem(item) self.added.emit(id) def __add(self): @@ -127,8 +128,8 @@ def __add(self): self.__addGroupItem(id, name) def __delete(self): - i = self.__list.currentRow() - item = self.__list.takeItem(i) + i = self.list.currentRow() + item = self.list.takeItem(i) id = item.data(Qt.ItemDataRole.UserRole) DB.deletePromptGroup(id) self.deleted.emit(i) @@ -175,6 +176,13 @@ def __export(self): def __itemChanged(self, item): id = item.data(Qt.ItemDataRole.UserRole) DB.updatePromptGroup(id, item.text()) + self.itemChanged.emit(id) + + def __currentRowChanged(self, r_idx): + item = self.list.item(r_idx) + if item: + id = item.data(Qt.ItemDataRole.UserRole) + self.currentRowChanged.emit(id) class PromptTable(QWidget): @@ -184,16 +192,14 @@ class PromptTable(QWidget): updated = Signal(str) - def __init__(self, id, parent=None): + def __init__(self, parent=None): super().__init__(parent) - self.__initVal(id) + self.__initVal() self.__initUi() - def __initVal(self, id): - self.__group_id = id - - self.__title = DB.selectCertainPromptGroup(self.__group_id).name - self.__entries = DB.selectPromptEntry(self.__group_id) + def __initVal(self): + self.__title = "" + self.__entries = [] def __initUi(self): self.__addBtn = Button() @@ -205,8 +211,10 @@ def __initUi(self): self.__addBtn.clicked.connect(self.__add) self.__delBtn.clicked.connect(self.__delete) + self.__titleLbl = QLabel() + lay = QHBoxLayout() - lay.addWidget(QLabel(self.__title)) + lay.addWidget(self.__titleLbl) lay.addSpacerItem(QSpacerItem(10, 10, QSizePolicy.Policy.MinimumExpanding)) lay.addWidget(self.__addBtn) lay.addWidget(self.__delBtn) @@ -219,15 +227,16 @@ def __initUi(self): self.__table = QTableWidget() self.__table.setColumnCount(2) self.__table.setRowCount(len(self.__entries)) - self.__table.horizontalHeader().setSectionResizeMode( - 1, QHeaderView.ResizeMode.Stretch - ) self.__table.setSelectionBehavior( QAbstractItemView.SelectionBehavior.SelectRows ) self.__table.setHorizontalHeaderLabels( [LangClass.TRANSLATIONS["Name"], LangClass.TRANSLATIONS["Value"]] ) + self.__table.horizontalHeader().setSectionResizeMode( + 1, QHeaderView.ResizeMode.Stretch + ) + self.__table.itemChanged.connect(self.__saveChangedPrompt) for i in range(len(self.__entries)): act = self.__entries[i].act @@ -253,6 +262,33 @@ def __initUi(self): self.setLayout(lay) + def showEntries(self, id): + self.__group_id = id + + prompt_group = DB.selectCertainPromptGroup(id=self.__group_id) + self.__title = prompt_group.name + self.__entries = DB.selectPromptEntry(self.__group_id) + + self.__titleLbl.setText(self.__title) + + self.__table.setRowCount(len(self.__entries)) + for i in range(len(self.__entries)): + act = self.__entries[i].act + prompt = self.__entries[i].prompt + + item1 = QTableWidgetItem(act) + item1.setData(Qt.ItemDataRole.UserRole, self.__entries[i].id) + item1.setTextAlignment(Qt.AlignmentFlag.AlignCenter) + + item2 = QTableWidgetItem(prompt) + item2.setTextAlignment(Qt.AlignmentFlag.AlignCenter) + + self.__table.setItem(i, 0, item1) + self.__table.setItem(i, 1, item2) + + self.__addBtn.setEnabled(True) + self.__delBtn.setEnabled(True) + def getPromptText(self): prompt_text = "" for i in range(self.__table.rowCount()): @@ -266,11 +302,23 @@ def __generatePrompt(self): prompt_text = self.getPromptText() self.updated.emit(prompt_text) + def setNothingRightNow(self): + self.__title = "" + self.__titleLbl.setText(self.__title) + self.__table.clearContents() + self.__addBtn.setEnabled(False) + self.__delBtn.setEnabled(False) + + def getId(self): + return self.__group_id + def __saveChangedPrompt(self, item: QTableWidgetItem): act = self.__table.item(item.row(), 0) id = act.data(Qt.ItemDataRole.UserRole) act = act.text() - prompt = self.__table.item(item.row(), 1).text() + + prompt = self.__table.item(item.row(), 1) + prompt = prompt.text() if prompt else "" DB.updatePromptEntry(id, act, prompt) def __add(self): @@ -316,21 +364,21 @@ def __initVal(self): self.__groups = DB.selectPromptGroup(prompt_type="form") def __initUi(self): - leftWidget = FormGroupList() - leftWidget.added.connect(self.__added) - leftWidget.deleted.connect(self.__deleted) - leftWidget.currentRowChanged.connect(self.__showEntries) + self.__leftWidget = FormGroupList() + self.__leftWidget.added.connect(self.add) + self.__leftWidget.deleted.connect(self.delete) - self.__rightWidget = QStackedWidget() + self.__leftWidget.currentRowChanged.connect(self.__showEntries) - for group in self.__groups: - promptTable = PromptTable(id=group.id) - promptTable.updated.connect(self.updated) - self.__rightWidget.addWidget(promptTable) + self.__table = PromptTable() + if len(self.__groups) > 0: + self.__leftWidget.list.setCurrentRow(0) + self.__table.showEntries(self.__groups[0].id) + self.__table.updated.connect(self.updated) mainWidget = QSplitter() - mainWidget.addWidget(leftWidget) - mainWidget.addWidget(self.__rightWidget) + mainWidget.addWidget(self.__leftWidget) + mainWidget.addWidget(self.__table) mainWidget.setChildrenCollapsible(False) mainWidget.setSizes([300, 700]) @@ -339,18 +387,14 @@ def __initUi(self): self.setLayout(lay) - def __added(self, id): - promptTable = PromptTable(id) - promptTable.updated.connect(self.updated) - self.__rightWidget.addWidget(promptTable) - self.__rightWidget.setCurrentWidget(promptTable) - - def __deleted(self, n): - w = self.__rightWidget.widget(n) - self.__rightWidget.removeWidget(w) - - def __showEntries(self, n): - self.__rightWidget.setCurrentIndex(n) - w = self.__rightWidget.currentWidget() - if w and isinstance(w, PromptTable): - self.updated.emit(w.getPromptText()) + def add(self, id): + self.__table.showEntries(id) + + def delete(self, id): + if self.__table.getId() == id: + self.__table.setNothingRightNow() + elif len(DB.selectPromptGroup(prompt_type="form")) == 0: + self.__table.setNothingRightNow() + + def __showEntries(self, id): + self.__table.showEntries(id) diff --git a/pyqt_openai/chat_widget/prompt_gen_widget/sentencePage.py b/pyqt_openai/chat_widget/prompt_gen_widget/sentencePage.py index f9df64e..c52ed6c 100644 --- a/pyqt_openai/chat_widget/prompt_gen_widget/sentencePage.py +++ b/pyqt_openai/chat_widget/prompt_gen_widget/sentencePage.py @@ -356,17 +356,23 @@ class SentencePage(QWidget): def __init__(self, parent=None): super().__init__(parent) + self.__initVal() self.__initUi() + def __initVal(self): + self.__groups = DB.selectPromptGroup(prompt_type="sentence") + def __initUi(self): leftWidget = SentenceGroupList() leftWidget.added.connect(self.add) leftWidget.deleted.connect(self.delete) leftWidget.currentRowChanged.connect(self.__showEntries) - leftWidget.itemChanged.connect(self.__itemChanged) self.__table = PromptTable() + if len(self.__groups) > 0: + leftWidget.list.setCurrentRow(0) + self.__table.showEntries(self.__groups[0].id) self.__table.updated.connect(self.updated) mainWidget = QSplitter() @@ -380,14 +386,6 @@ def __initUi(self): self.setLayout(lay) - leftWidget.list.setCurrentRow(0) - - def __itemChanged(self, id): - self.__table.showEntries(id) - - def __showEntries(self, id): - self.__table.showEntries(id) - def add(self, id): self.__table.showEntries(id) @@ -396,3 +394,6 @@ def delete(self, id): self.__table.setNothingRightNow() elif len(DB.selectPromptGroup(prompt_type="sentence")) == 0: self.__table.setNothingRightNow() + + def __showEntries(self, id): + self.__table.showEntries(id) From 8089f0f7111b326706083dd1133bedb5832704a5 Mon Sep 17 00:00:00 2001 From: yjg30737 Date: Sat, 23 Nov 2024 14:09:30 +0900 Subject: [PATCH 2/6] Get image provider if it is set to auto, set g4f to 0.3.3.4 temporarily --- pyproject.toml | 1 + .../chat_widget/prompt_gen_widget/formPage.py | 16 ++++++-------- .../g4f_image_widget/g4fImageThread.py | 9 ++++++++ pyqt_openai/util/common.py | 21 ++++++++++--------- requirements.txt | 3 ++- 5 files changed, 29 insertions(+), 21 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fe38556..1d24c3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "openpyxl", "g4f", + "curl_cffi", "litellm", diff --git a/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py b/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py index 39d4dab..0935d2c 100644 --- a/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py +++ b/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py @@ -1,4 +1,5 @@ -import json, os +import json +import os from PySide6.QtCore import Signal, Qt from PySide6.QtWidgets import ( @@ -7,7 +8,6 @@ QMessageBox, QSizePolicy, QSpacerItem, - QStackedWidget, QLabel, QAbstractItemView, QTableWidgetItem, @@ -30,20 +30,20 @@ QFILEDIALOG_DEFAULT_DIRECTORY, INDENT_SIZE, ) -from pyqt_openai.chat_widget.prompt_gen_widget.promptGroupDirectInputDialog import ( - PromptGroupDirectInputDialog, -) from pyqt_openai.chat_widget.prompt_gen_widget.promptEntryDirectInputDialog import ( PromptEntryDirectInputDialog, ) +from pyqt_openai.chat_widget.prompt_gen_widget.promptGroupDirectInputDialog import ( + PromptGroupDirectInputDialog, +) from pyqt_openai.chat_widget.prompt_gen_widget.promptGroupExportDialog import ( PromptGroupExportDialog, ) from pyqt_openai.chat_widget.prompt_gen_widget.promptGroupImportDialog import ( PromptGroupImportDialog, ) -from pyqt_openai.lang.translations import LangClass from pyqt_openai.globals import DB +from pyqt_openai.lang.translations import LangClass from pyqt_openai.util.common import open_directory, get_prompt_data from pyqt_openai.widgets.button import Button @@ -186,10 +186,6 @@ def __currentRowChanged(self, r_idx): class PromptTable(QWidget): - """ - benchmarked https://gptforwork.com/tools/prompt-generator - """ - updated = Signal(str) def __init__(self, parent=None): diff --git a/pyqt_openai/g4f_image_widget/g4fImageThread.py b/pyqt_openai/g4f_image_widget/g4fImageThread.py index ecb8a98..778f055 100644 --- a/pyqt_openai/g4f_image_widget/g4fImageThread.py +++ b/pyqt_openai/g4f_image_widget/g4fImageThread.py @@ -1,4 +1,7 @@ +from abc import ABCMeta + from PySide6.QtCore import QThread, Signal +from g4f.providers.retry_provider import IterListProvider from pyqt_openai import G4F_PROVIDER_DEFAULT from pyqt_openai.globals import G4F_CLIENT @@ -47,6 +50,12 @@ def run(self): images.provider = self.__input_args["provider"] else: del self.__input_args["provider"] + provider = images.models.get(self.__input_args['model'], images.provider) + if isinstance(provider, IterListProvider): + if provider.providers: + provider = provider.providers[0] + provider = provider.__name__ + response = images.generate(**self.__input_args) arg = { **self.__input_args, diff --git a/pyqt_openai/util/common.py b/pyqt_openai/util/common.py index bd701c4..f76b31d 100644 --- a/pyqt_openai/util/common.py +++ b/pyqt_openai/util/common.py @@ -698,16 +698,17 @@ def get_g4f_image_models() -> list: if hasattr(provider, "parent"): parent = __map__[provider.parent] if parent.__name__ not in index: - for model in provider.image_models: - image_models.append( - { - "provider": parent.__name__, - "url": parent.url, - "label": parent.label if hasattr(parent, "label") else None, - "image_model": model, - } - ) - index.append(parent.__name__) + if provider.image_models: + for model in provider.image_models: + image_models.append( + { + "provider": parent.__name__, + "url": parent.url, + "label": parent.label if hasattr(parent, "label") else None, + "image_model": model, + } + ) + index.append(parent.__name__) models = [model["image_model"] for model in image_models] return models diff --git a/requirements.txt b/requirements.txt index 0303e9e..f3f0186 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,8 @@ llama-index docx2txt openpyxl -g4f +g4f==0.3.3.4 + curl_cffi litellm From af582cf3b6bcf943e64d59e0b3943959b611dc4c Mon Sep 17 00:00:00 2001 From: yjg30737 Date: Sat, 23 Nov 2024 17:21:49 +0900 Subject: [PATCH 3/6] Refactoring --- pyproject.toml | 5 +- pyqt_openai/__init__.py | 2 +- .../chat_widget/prompt_gen_widget/formPage.py | 90 +++++++++++-------- .../prompt_gen_widget/sentencePage.py | 25 ++---- version_info.txt | 8 +- 5 files changed, 70 insertions(+), 60 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1d24c3f..4411091 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pyqt-openai" -version = "1.8.1" +version = "1.8.2" description = "Python multipurpose chatbot that user can use GPT, other AI models altogether (Release Name: VividNode)" authors = [{ name = "Jung Gyu Yoon", email = "yjg30737@gmail.com" }] license = { text = "MIT" } @@ -27,8 +27,7 @@ dependencies = [ "docx2txt", "openpyxl", - "g4f", - + "g4f==0.3.3.4", "curl_cffi", "litellm", diff --git a/pyqt_openai/__init__.py b/pyqt_openai/__init__.py index 4dc3070..8522461 100644 --- a/pyqt_openai/__init__.py +++ b/pyqt_openai/__init__.py @@ -23,7 +23,7 @@ # For the sake of following the PEP8 standard, we will declare module-level dunder names. # PEP8 standard about dunder names: https://peps.python.org/pep-0008/#module-level-dunder-names -__version__ = "1.8.1" +__version__ = "1.8.2" __author__ = "Jung Gyu Yoon" # Constants diff --git a/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py b/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py index 0935d2c..7c8f11c 100644 --- a/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py +++ b/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py @@ -62,20 +62,20 @@ def __initUi(self): self.__addBtn = Button() self.__delBtn = Button() + self.__importBtn = Button() + self.__importBtn.setToolTip(LangClass.TRANSLATIONS["Import"]) + + self.__exportBtn = Button() + self.__exportBtn.setToolTip(LangClass.TRANSLATIONS["Export"]) + self.__addBtn.setStyleAndIcon(ICON_ADD) self.__delBtn.setStyleAndIcon(ICON_DELETE) + self.__importBtn.setStyleAndIcon(ICON_IMPORT) + self.__exportBtn.setStyleAndIcon(ICON_EXPORT) self.__addBtn.clicked.connect(self.__add) self.__delBtn.clicked.connect(self.__delete) - - self.__importBtn = Button() - self.__importBtn.setStyleAndIcon(ICON_IMPORT) - self.__importBtn.setToolTip(LangClass.TRANSLATIONS["Import"]) self.__importBtn.clicked.connect(self.__import) - - self.__exportBtn = Button() - self.__exportBtn.setStyleAndIcon(ICON_EXPORT) - self.__exportBtn.setToolTip(LangClass.TRANSLATIONS["Export"]) self.__exportBtn.clicked.connect(self.__export) lay = QHBoxLayout() @@ -92,11 +92,15 @@ def __initUi(self): topWidget.setLayout(lay) groups = DB.selectPromptGroup(prompt_type="form") + if len(groups) <= 0: + self.__delBtn.setEnabled(False) self.list = QListWidget() for group in groups: - self.__addGroupItem(group.id, group.name) + id = group.id + name = group.name + self.__addGroupItem(id, name) self.list.currentRowChanged.connect(self.__currentRowChanged) self.list.itemChanged.connect(self.__itemChanged) @@ -108,8 +112,6 @@ def __initUi(self): self.setLayout(lay) - self.list.setCurrentRow(0) - def __addGroupItem(self, id, name): item = QListWidgetItem() item.setFlags(item.flags() | Qt.ItemFlag.ItemIsEditable) @@ -119,6 +121,8 @@ def __addGroupItem(self, id, name): self.list.setCurrentItem(item) self.added.emit(id) + self.__delBtn.setEnabled(True) + def __add(self): dialog = PromptGroupDirectInputDialog(self) reply = dialog.exec() @@ -132,10 +136,14 @@ def __delete(self): item = self.list.takeItem(i) id = item.data(Qt.ItemDataRole.UserRole) DB.deletePromptGroup(id) - self.deleted.emit(i) + self.deleted.emit(id) + + groups = DB.selectPromptGroup(prompt_type="form") + if len(groups) <= 0: + self.__delBtn.setEnabled(False) def __import(self): - dialog = PromptGroupImportDialog(parent=self) + dialog = PromptGroupImportDialog(parent=self, prompt_type="form") reply = dialog.exec() if reply == QDialog.DialogCode.Accepted: # Get the data @@ -232,24 +240,25 @@ def __initUi(self): self.__table.horizontalHeader().setSectionResizeMode( 1, QHeaderView.ResizeMode.Stretch ) + self.__table.currentItemChanged.connect(self.__rowChanged) self.__table.itemChanged.connect(self.__saveChangedPrompt) - for i in range(len(self.__entries)): - act = self.__entries[i].act - prompt = self.__entries[i].prompt - - item1 = QTableWidgetItem(act) - item1.setData(Qt.ItemDataRole.UserRole, self.__entries[i].id) - item1.setTextAlignment(Qt.AlignmentFlag.AlignCenter) - - item2 = QTableWidgetItem(prompt) - item2.setTextAlignment(Qt.AlignmentFlag.AlignCenter) - - self.__table.setItem(i, 0, item1) - self.__table.setItem(i, 1, item2) - - self.__table.itemChanged.connect(self.__generatePrompt) - self.__table.itemChanged.connect(self.__saveChangedPrompt) + # for i in range(len(self.__entries)): + # act = self.__entries[i].act + # prompt = self.__entries[i].prompt + # + # item1 = QTableWidgetItem(act) + # item1.setData(Qt.ItemDataRole.UserRole, self.__entries[i].id) + # item1.setTextAlignment(Qt.AlignmentFlag.AlignCenter) + # + # item2 = QTableWidgetItem(prompt) + # item2.setTextAlignment(Qt.AlignmentFlag.AlignCenter) + # + # self.__table.setItem(i, 0, item1) + # self.__table.setItem(i, 1, item2) + # + # self.__table.itemChanged.connect(self.__generatePrompt) + # self.__table.itemChanged.connect(self.__saveChangedPrompt) lay = QVBoxLayout() lay.addWidget(topWidget) @@ -308,6 +317,17 @@ def setNothingRightNow(self): def getId(self): return self.__group_id + def __rowChanged(self, new_item: QTableWidgetItem, old_item: QTableWidgetItem): + prompt = "" + # To avoid AttributeError + if new_item: + prompt = ( + self.__table.item(new_item.row(), 1).text() + if new_item.column() == 0 + else new_item.text() + ) + self.updated.emit(prompt) + def __saveChangedPrompt(self, item: QTableWidgetItem): act = self.__table.item(item.row(), 0) id = act.data(Qt.ItemDataRole.UserRole) @@ -360,20 +380,20 @@ def __initVal(self): self.__groups = DB.selectPromptGroup(prompt_type="form") def __initUi(self): - self.__leftWidget = FormGroupList() - self.__leftWidget.added.connect(self.add) - self.__leftWidget.deleted.connect(self.delete) + leftWidget = FormGroupList() + leftWidget.added.connect(self.add) + leftWidget.deleted.connect(self.delete) - self.__leftWidget.currentRowChanged.connect(self.__showEntries) + leftWidget.currentRowChanged.connect(self.__showEntries) self.__table = PromptTable() if len(self.__groups) > 0: - self.__leftWidget.list.setCurrentRow(0) + leftWidget.list.setCurrentRow(0) self.__table.showEntries(self.__groups[0].id) self.__table.updated.connect(self.updated) mainWidget = QSplitter() - mainWidget.addWidget(self.__leftWidget) + mainWidget.addWidget(leftWidget) mainWidget.addWidget(self.__table) mainWidget.setChildrenCollapsible(False) mainWidget.setSizes([300, 700]) diff --git a/pyqt_openai/chat_widget/prompt_gen_widget/sentencePage.py b/pyqt_openai/chat_widget/prompt_gen_widget/sentencePage.py index c52ed6c..c703ba2 100644 --- a/pyqt_openai/chat_widget/prompt_gen_widget/sentencePage.py +++ b/pyqt_openai/chat_widget/prompt_gen_widget/sentencePage.py @@ -60,17 +60,17 @@ def __initUi(self): self.__addBtn = Button() self.__delBtn = Button() - self.__addBtn.setStyleAndIcon(ICON_ADD) - self.__delBtn.setStyleAndIcon(ICON_DELETE) - self.__importBtn = Button() - self.__importBtn.setStyleAndIcon(ICON_IMPORT) self.__importBtn.setToolTip(LangClass.TRANSLATIONS["Import"]) self.__exportBtn = Button() - self.__exportBtn.setStyleAndIcon(ICON_EXPORT) self.__exportBtn.setToolTip(LangClass.TRANSLATIONS["Export"]) + self.__addBtn.setStyleAndIcon(ICON_ADD) + self.__delBtn.setStyleAndIcon(ICON_DELETE) + self.__importBtn.setStyleAndIcon(ICON_IMPORT) + self.__exportBtn.setStyleAndIcon(ICON_EXPORT) + self.__addBtn.clicked.connect(self.__add) self.__delBtn.clicked.connect(self.__delete) self.__importBtn.clicked.connect(self.__import) @@ -79,13 +79,7 @@ def __initUi(self): lay = QHBoxLayout() # Should've added "Sentence Group" to the translation, but it's not in the # translation file for incomplete JSON response issue - lay.addWidget( - QLabel( - LangClass.TRANSLATIONS["Sentence"] - + " " - + LangClass.TRANSLATIONS["Group"] - ) - ) + lay.addWidget(QLabel(LangClass.TRANSLATIONS["Sentence Group"])) lay.addSpacerItem(QSpacerItem(10, 10, QSizePolicy.Policy.MinimumExpanding)) lay.addWidget(self.__addBtn) lay.addWidget(self.__delBtn) @@ -97,12 +91,12 @@ def __initUi(self): topWidget = QWidget() topWidget.setLayout(lay) - self.list = QListWidget() - groups = DB.selectPromptGroup(prompt_type="sentence") if len(groups) <= 0: self.__delBtn.setEnabled(False) + self.list = QListWidget() + for group in groups: id = group.id name = group.name @@ -347,9 +341,6 @@ def __delete(self): self.__table.removeRow(i) DB.deletePromptEntry(self.__group_id, id) - def clearContents(self): - self.__table.clearContents() - class SentencePage(QWidget): updated = Signal(str) diff --git a/version_info.txt b/version_info.txt index b5d1c3f..cc91dd0 100644 --- a/version_info.txt +++ b/version_info.txt @@ -5,8 +5,8 @@ # VSVersionInfo( ffi=FixedFileInfo( - filevers=(1, 8, 1), - prodvers=(1, 8, 1), + filevers=(1, 8, 2), + prodvers=(1, 8, 2), mask=0x3f, flags=0x0, OS=0x4, @@ -19,10 +19,10 @@ VSVersionInfo( [ StringTable( u'040904B0', - [StringStruct(u'FileVersion', u'1.8.1'), + [StringStruct(u'FileVersion', u'1.8.2'), StringStruct(u'ProductName', u'VividNode'), StringStruct(u'LegalCopyright', u'Copyright © 2024 Jung Gyu Yoon'), - StringStruct(u'ProductVersion', u'1.8.1')]) + StringStruct(u'ProductVersion', u'1.8.2')]) ]), VarFileInfo([VarStruct(u'Translation', [1033, 1200])]) ] From 166e24968e2e92397102379429ca566a26e15e08 Mon Sep 17 00:00:00 2001 From: yjg30737 Date: Sat, 23 Nov 2024 21:26:41 +0900 Subject: [PATCH 4/6] Refactoring --- .../chat_widget/prompt_gen_widget/formPage.py | 416 ------------------ .../promptGeneratorWidget.py | 10 +- .../prompt_gen_widget/promptGroupList.py | 213 +++++++++ .../prompt_gen_widget/promptPage.py | 59 +++ .../prompt_gen_widget/promptTable.py | 171 +++++++ .../prompt_gen_widget/sentencePage.py | 390 ---------------- 6 files changed, 447 insertions(+), 812 deletions(-) delete mode 100644 pyqt_openai/chat_widget/prompt_gen_widget/formPage.py create mode 100644 pyqt_openai/chat_widget/prompt_gen_widget/promptGroupList.py create mode 100644 pyqt_openai/chat_widget/prompt_gen_widget/promptPage.py create mode 100644 pyqt_openai/chat_widget/prompt_gen_widget/promptTable.py delete mode 100644 pyqt_openai/chat_widget/prompt_gen_widget/sentencePage.py diff --git a/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py b/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py deleted file mode 100644 index 7c8f11c..0000000 --- a/pyqt_openai/chat_widget/prompt_gen_widget/formPage.py +++ /dev/null @@ -1,416 +0,0 @@ -import json -import os - -from PySide6.QtCore import Signal, Qt -from PySide6.QtWidgets import ( - QFileDialog, - QTableWidget, - QMessageBox, - QSizePolicy, - QSpacerItem, - QLabel, - QAbstractItemView, - QTableWidgetItem, - QHeaderView, - QHBoxLayout, - QVBoxLayout, - QWidget, - QDialog, - QListWidget, - QListWidgetItem, - QSplitter, -) - -from pyqt_openai import ( - JSON_FILE_EXT_LIST_STR, - ICON_ADD, - ICON_DELETE, - ICON_IMPORT, - ICON_EXPORT, - QFILEDIALOG_DEFAULT_DIRECTORY, - INDENT_SIZE, -) -from pyqt_openai.chat_widget.prompt_gen_widget.promptEntryDirectInputDialog import ( - PromptEntryDirectInputDialog, -) -from pyqt_openai.chat_widget.prompt_gen_widget.promptGroupDirectInputDialog import ( - PromptGroupDirectInputDialog, -) -from pyqt_openai.chat_widget.prompt_gen_widget.promptGroupExportDialog import ( - PromptGroupExportDialog, -) -from pyqt_openai.chat_widget.prompt_gen_widget.promptGroupImportDialog import ( - PromptGroupImportDialog, -) -from pyqt_openai.globals import DB -from pyqt_openai.lang.translations import LangClass -from pyqt_openai.util.common import open_directory, get_prompt_data -from pyqt_openai.widgets.button import Button - - -class FormGroupList(QWidget): - added = Signal(int) - deleted = Signal(int) - currentRowChanged = Signal(int) - itemChanged = Signal(int) - - def __init__(self, parent=None): - super().__init__(parent) - self.__initUi() - - def __initUi(self): - self.__addBtn = Button() - self.__delBtn = Button() - - self.__importBtn = Button() - self.__importBtn.setToolTip(LangClass.TRANSLATIONS["Import"]) - - self.__exportBtn = Button() - self.__exportBtn.setToolTip(LangClass.TRANSLATIONS["Export"]) - - self.__addBtn.setStyleAndIcon(ICON_ADD) - self.__delBtn.setStyleAndIcon(ICON_DELETE) - self.__importBtn.setStyleAndIcon(ICON_IMPORT) - self.__exportBtn.setStyleAndIcon(ICON_EXPORT) - - self.__addBtn.clicked.connect(self.__add) - self.__delBtn.clicked.connect(self.__delete) - self.__importBtn.clicked.connect(self.__import) - self.__exportBtn.clicked.connect(self.__export) - - lay = QHBoxLayout() - lay.addWidget(QLabel(LangClass.TRANSLATIONS["Form Group"])) - lay.addSpacerItem(QSpacerItem(10, 10, QSizePolicy.Policy.MinimumExpanding)) - lay.addWidget(self.__addBtn) - lay.addWidget(self.__delBtn) - lay.addWidget(self.__importBtn) - lay.addWidget(self.__exportBtn) - lay.setAlignment(Qt.AlignmentFlag.AlignRight) - lay.setContentsMargins(0, 0, 0, 0) - - topWidget = QWidget() - topWidget.setLayout(lay) - - groups = DB.selectPromptGroup(prompt_type="form") - if len(groups) <= 0: - self.__delBtn.setEnabled(False) - - self.list = QListWidget() - - for group in groups: - id = group.id - name = group.name - self.__addGroupItem(id, name) - - self.list.currentRowChanged.connect(self.__currentRowChanged) - self.list.itemChanged.connect(self.__itemChanged) - - lay = QVBoxLayout() - lay.addWidget(topWidget) - lay.addWidget(self.list) - lay.setContentsMargins(0, 0, 5, 0) - - self.setLayout(lay) - - def __addGroupItem(self, id, name): - item = QListWidgetItem() - item.setFlags(item.flags() | Qt.ItemFlag.ItemIsEditable) - item.setData(Qt.ItemDataRole.UserRole, id) - item.setText(name) - self.list.addItem(item) - self.list.setCurrentItem(item) - self.added.emit(id) - - self.__delBtn.setEnabled(True) - - def __add(self): - dialog = PromptGroupDirectInputDialog(self) - reply = dialog.exec() - if reply == QDialog.DialogCode.Accepted: - name = dialog.getPromptGroupName() - id = DB.insertPromptGroup(name, prompt_type="form") - self.__addGroupItem(id, name) - - def __delete(self): - i = self.list.currentRow() - item = self.list.takeItem(i) - id = item.data(Qt.ItemDataRole.UserRole) - DB.deletePromptGroup(id) - self.deleted.emit(id) - - groups = DB.selectPromptGroup(prompt_type="form") - if len(groups) <= 0: - self.__delBtn.setEnabled(False) - - def __import(self): - dialog = PromptGroupImportDialog(parent=self, prompt_type="form") - reply = dialog.exec() - if reply == QDialog.DialogCode.Accepted: - # Get the data - result = dialog.getSelected() - # Save the data - for group in result: - id = DB.insertPromptGroup(group["name"], prompt_type="form") - for entry in group["data"]: - DB.insertPromptEntry(id, entry["act"], entry["prompt"]) - name = group["name"] - self.__addGroupItem(id, name) - - def __export(self): - try: - # Get the file - file_data = QFileDialog.getSaveFileName( - self, - LangClass.TRANSLATIONS["Save"], - QFILEDIALOG_DEFAULT_DIRECTORY, - JSON_FILE_EXT_LIST_STR, - ) - if file_data[0]: - filename = file_data[0] - # Get the data - data = get_prompt_data(prompt_type="form") - dialog = PromptGroupExportDialog(data=data, parent=self) - reply = dialog.exec() - if reply == QDialog.DialogCode.Accepted: - data = dialog.getSelected() - # Save the data - with open(filename, "w") as f: - json.dump(data, f, indent=INDENT_SIZE) - open_directory(os.path.dirname(filename)) - except Exception as e: - QMessageBox.critical(self, LangClass.TRANSLATIONS["Error"], str(e)) - print(e) - - def __itemChanged(self, item): - id = item.data(Qt.ItemDataRole.UserRole) - DB.updatePromptGroup(id, item.text()) - self.itemChanged.emit(id) - - def __currentRowChanged(self, r_idx): - item = self.list.item(r_idx) - if item: - id = item.data(Qt.ItemDataRole.UserRole) - self.currentRowChanged.emit(id) - - -class PromptTable(QWidget): - updated = Signal(str) - - def __init__(self, parent=None): - super().__init__(parent) - self.__initVal() - self.__initUi() - - def __initVal(self): - self.__title = "" - self.__entries = [] - - def __initUi(self): - self.__addBtn = Button() - self.__delBtn = Button() - - self.__addBtn.setStyleAndIcon(ICON_ADD) - self.__delBtn.setStyleAndIcon(ICON_DELETE) - - self.__addBtn.clicked.connect(self.__add) - self.__delBtn.clicked.connect(self.__delete) - - self.__titleLbl = QLabel() - - lay = QHBoxLayout() - lay.addWidget(self.__titleLbl) - lay.addSpacerItem(QSpacerItem(10, 10, QSizePolicy.Policy.MinimumExpanding)) - lay.addWidget(self.__addBtn) - lay.addWidget(self.__delBtn) - lay.setAlignment(Qt.AlignmentFlag.AlignRight) - lay.setContentsMargins(0, 0, 0, 0) - - topWidget = QWidget() - topWidget.setLayout(lay) - - self.__table = QTableWidget() - self.__table.setColumnCount(2) - self.__table.setRowCount(len(self.__entries)) - self.__table.setSelectionBehavior( - QAbstractItemView.SelectionBehavior.SelectRows - ) - self.__table.setHorizontalHeaderLabels( - [LangClass.TRANSLATIONS["Name"], LangClass.TRANSLATIONS["Value"]] - ) - self.__table.horizontalHeader().setSectionResizeMode( - 1, QHeaderView.ResizeMode.Stretch - ) - self.__table.currentItemChanged.connect(self.__rowChanged) - self.__table.itemChanged.connect(self.__saveChangedPrompt) - - # for i in range(len(self.__entries)): - # act = self.__entries[i].act - # prompt = self.__entries[i].prompt - # - # item1 = QTableWidgetItem(act) - # item1.setData(Qt.ItemDataRole.UserRole, self.__entries[i].id) - # item1.setTextAlignment(Qt.AlignmentFlag.AlignCenter) - # - # item2 = QTableWidgetItem(prompt) - # item2.setTextAlignment(Qt.AlignmentFlag.AlignCenter) - # - # self.__table.setItem(i, 0, item1) - # self.__table.setItem(i, 1, item2) - # - # self.__table.itemChanged.connect(self.__generatePrompt) - # self.__table.itemChanged.connect(self.__saveChangedPrompt) - - lay = QVBoxLayout() - lay.addWidget(topWidget) - lay.addWidget(self.__table) - lay.setContentsMargins(5, 0, 0, 0) - - self.setLayout(lay) - - def showEntries(self, id): - self.__group_id = id - - prompt_group = DB.selectCertainPromptGroup(id=self.__group_id) - self.__title = prompt_group.name - self.__entries = DB.selectPromptEntry(self.__group_id) - - self.__titleLbl.setText(self.__title) - - self.__table.setRowCount(len(self.__entries)) - for i in range(len(self.__entries)): - act = self.__entries[i].act - prompt = self.__entries[i].prompt - - item1 = QTableWidgetItem(act) - item1.setData(Qt.ItemDataRole.UserRole, self.__entries[i].id) - item1.setTextAlignment(Qt.AlignmentFlag.AlignCenter) - - item2 = QTableWidgetItem(prompt) - item2.setTextAlignment(Qt.AlignmentFlag.AlignCenter) - - self.__table.setItem(i, 0, item1) - self.__table.setItem(i, 1, item2) - - self.__addBtn.setEnabled(True) - self.__delBtn.setEnabled(True) - - def getPromptText(self): - prompt_text = "" - for i in range(self.__table.rowCount()): - name = self.__table.item(i, 0).text() if self.__table.item(i, 0) else "" - value = self.__table.item(i, 1).text() if self.__table.item(i, 1) else "" - if value.strip(): - prompt_text += f"{name}: {value}\n" - return prompt_text - - def __generatePrompt(self): - prompt_text = self.getPromptText() - self.updated.emit(prompt_text) - - def setNothingRightNow(self): - self.__title = "" - self.__titleLbl.setText(self.__title) - self.__table.clearContents() - self.__addBtn.setEnabled(False) - self.__delBtn.setEnabled(False) - - def getId(self): - return self.__group_id - - def __rowChanged(self, new_item: QTableWidgetItem, old_item: QTableWidgetItem): - prompt = "" - # To avoid AttributeError - if new_item: - prompt = ( - self.__table.item(new_item.row(), 1).text() - if new_item.column() == 0 - else new_item.text() - ) - self.updated.emit(prompt) - - def __saveChangedPrompt(self, item: QTableWidgetItem): - act = self.__table.item(item.row(), 0) - id = act.data(Qt.ItemDataRole.UserRole) - act = act.text() - - prompt = self.__table.item(item.row(), 1) - prompt = prompt.text() if prompt else "" - DB.updatePromptEntry(id, act, prompt) - - def __add(self): - dialog = PromptEntryDirectInputDialog(self.__group_id, self) - reply = dialog.exec() - if reply == QDialog.DialogCode.Accepted: - self.__table.itemChanged.disconnect(self.__saveChangedPrompt) - - act = dialog.getAct() - self.__table.setRowCount(self.__table.rowCount() + 1) - - item1 = QTableWidgetItem(act) - item1.setTextAlignment(Qt.AlignmentFlag.AlignCenter) - self.__table.setItem(self.__table.rowCount() - 1, 0, item1) - - item2 = QTableWidgetItem("") - item2.setTextAlignment(Qt.AlignmentFlag.AlignCenter) - self.__table.setItem(self.__table.rowCount() - 1, 1, item2) - - id = DB.insertPromptEntry(self.__group_id, act, "") - item1.setData(Qt.ItemDataRole.UserRole, id) - - self.__table.itemChanged.connect(self.__saveChangedPrompt) - - def __delete(self): - for i in sorted( - set([i.row() for i in self.__table.selectedIndexes()]), reverse=True - ): - id = self.__table.item(i, 0).data(Qt.ItemDataRole.UserRole) - self.__table.removeRow(i) - DB.deletePromptEntry(self.__group_id, id) - - -class FormPage(QWidget): - updated = Signal(str) - - def __init__(self, parent=None): - super().__init__(parent) - self.__initVal() - self.__initUi() - - def __initVal(self): - self.__groups = DB.selectPromptGroup(prompt_type="form") - - def __initUi(self): - leftWidget = FormGroupList() - leftWidget.added.connect(self.add) - leftWidget.deleted.connect(self.delete) - - leftWidget.currentRowChanged.connect(self.__showEntries) - - self.__table = PromptTable() - if len(self.__groups) > 0: - leftWidget.list.setCurrentRow(0) - self.__table.showEntries(self.__groups[0].id) - self.__table.updated.connect(self.updated) - - mainWidget = QSplitter() - mainWidget.addWidget(leftWidget) - mainWidget.addWidget(self.__table) - mainWidget.setChildrenCollapsible(False) - mainWidget.setSizes([300, 700]) - - lay = QVBoxLayout() - lay.addWidget(mainWidget) - - self.setLayout(lay) - - def add(self, id): - self.__table.showEntries(id) - - def delete(self, id): - if self.__table.getId() == id: - self.__table.setNothingRightNow() - elif len(DB.selectPromptGroup(prompt_type="form")) == 0: - self.__table.setNothingRightNow() - - def __showEntries(self, id): - self.__table.showEntries(id) diff --git a/pyqt_openai/chat_widget/prompt_gen_widget/promptGeneratorWidget.py b/pyqt_openai/chat_widget/prompt_gen_widget/promptGeneratorWidget.py index cb39fc5..3ac7037 100644 --- a/pyqt_openai/chat_widget/prompt_gen_widget/promptGeneratorWidget.py +++ b/pyqt_openai/chat_widget/prompt_gen_widget/promptGeneratorWidget.py @@ -1,5 +1,5 @@ import pyperclip - +from PySide6.QtCore import Qt from PySide6.QtWidgets import ( QTextBrowser, QSplitter, @@ -10,10 +10,8 @@ QTabWidget, QScrollArea, ) -from PySide6.QtCore import Qt -from pyqt_openai.chat_widget.prompt_gen_widget.formPage import FormPage -from pyqt_openai.chat_widget.prompt_gen_widget.sentencePage import SentencePage +from pyqt_openai.chat_widget.prompt_gen_widget.promptPage import PromptPage from pyqt_openai.lang.translations import LangClass @@ -25,10 +23,10 @@ def __init__(self, parent=None): def __initUi(self): promptLbl = QLabel(LangClass.TRANSLATIONS["Prompt"]) - formPage = FormPage() + formPage = PromptPage(prompt_type='form') formPage.updated.connect(self.__textChanged) - sentencePage = SentencePage() + sentencePage = PromptPage(prompt_type='sentence') sentencePage.updated.connect(self.__textChanged) self.__prompt = QTextBrowser() diff --git a/pyqt_openai/chat_widget/prompt_gen_widget/promptGroupList.py b/pyqt_openai/chat_widget/prompt_gen_widget/promptGroupList.py new file mode 100644 index 0000000..e118fb3 --- /dev/null +++ b/pyqt_openai/chat_widget/prompt_gen_widget/promptGroupList.py @@ -0,0 +1,213 @@ +import json +import os + +from PySide6.QtCore import Signal, Qt +from PySide6.QtWidgets import ( + QFileDialog, + QMessageBox, + QSizePolicy, + QSpacerItem, + QLabel, + QHBoxLayout, + QVBoxLayout, + QWidget, + QDialog, + QListWidget, + QListWidgetItem, +) + +from pyqt_openai import ( + JSON_FILE_EXT_LIST_STR, + ICON_ADD, + ICON_DELETE, + ICON_IMPORT, + ICON_EXPORT, + QFILEDIALOG_DEFAULT_DIRECTORY, + INDENT_SIZE, +) +from pyqt_openai.chat_widget.prompt_gen_widget.promptGroupDirectInputDialog import ( + PromptGroupDirectInputDialog, +) +from pyqt_openai.chat_widget.prompt_gen_widget.promptGroupExportDialog import ( + PromptGroupExportDialog, +) +from pyqt_openai.chat_widget.prompt_gen_widget.promptGroupImportDialog import ( + PromptGroupImportDialog, +) +from pyqt_openai.globals import DB +from pyqt_openai.lang.translations import LangClass +from pyqt_openai.util.common import open_directory, get_prompt_data, export_prompt +from pyqt_openai.widgets.button import Button + + +class PromptGroupList(QWidget): + added = Signal(int) + deleted = Signal(int) + currentRowChanged = Signal(int) + itemChanged = Signal(int) + + def __init__(self, prompt_type='form', parent=None): + super().__init__(parent) + self.__initVal(prompt_type) + self.__initUi() + + def __initVal(self, prompt_type): + self.prompt_type = prompt_type + + def __initUi(self): + self.__addBtn = Button() + self.__delBtn = Button() + + self.__importBtn = Button() + self.__importBtn.setToolTip(LangClass.TRANSLATIONS["Import"]) + + self.__exportBtn = Button() + self.__exportBtn.setToolTip(LangClass.TRANSLATIONS["Export"]) + + self.__addBtn.setStyleAndIcon(ICON_ADD) + self.__delBtn.setStyleAndIcon(ICON_DELETE) + self.__importBtn.setStyleAndIcon(ICON_IMPORT) + self.__exportBtn.setStyleAndIcon(ICON_EXPORT) + + self.__addBtn.clicked.connect(self.__add) + self.__delBtn.clicked.connect(self.__delete) + self.__importBtn.clicked.connect(self.__import) + self.__exportBtn.clicked.connect(self.__export) + + lay = QHBoxLayout() + lay.addWidget(QLabel(LangClass.TRANSLATIONS[f"{self.prompt_type.capitalize()} Group"])) + lay.addSpacerItem(QSpacerItem(10, 10, QSizePolicy.Policy.MinimumExpanding)) + lay.addWidget(self.__addBtn) + lay.addWidget(self.__delBtn) + lay.addWidget(self.__importBtn) + lay.addWidget(self.__exportBtn) + lay.setAlignment(Qt.AlignmentFlag.AlignRight) + lay.setContentsMargins(0, 0, 0, 0) + + topWidget = QWidget() + topWidget.setLayout(lay) + + groups = DB.selectPromptGroup(prompt_type=self.prompt_type) + if len(groups) <= 0: + self.__delBtn.setEnabled(False) + + self.list = QListWidget() + + for group in groups: + id = group.id + name = group.name + self.__addGroupItem(id, name) + + self.list.currentRowChanged.connect(self.__currentRowChanged) + self.list.itemChanged.connect(self.__itemChanged) + + lay = QVBoxLayout() + lay.addWidget(topWidget) + lay.addWidget(self.list) + lay.setContentsMargins(0, 0, 5, 0) + + self.setLayout(lay) + + def __addGroupItem(self, id, name): + item = QListWidgetItem() + item.setFlags(item.flags() | Qt.ItemFlag.ItemIsEditable) + item.setData(Qt.ItemDataRole.UserRole, id) + item.setText(name) + self.list.addItem(item) + self.list.setCurrentItem(item) + self.added.emit(id) + + self.__delBtn.setEnabled(True) + + def __add(self): + dialog = PromptGroupDirectInputDialog(self) + reply = dialog.exec() + if reply == QDialog.DialogCode.Accepted: + name = dialog.getPromptGroupName() + id = DB.insertPromptGroup(name, prompt_type=self.prompt_type) + self.__addGroupItem(id, name) + + def __delete(self): + i = self.list.currentRow() + item = self.list.takeItem(i) + id = item.data(Qt.ItemDataRole.UserRole) + DB.deletePromptGroup(id) + self.deleted.emit(id) + + groups = DB.selectPromptGroup(prompt_type=self.prompt_type) + if len(groups) <= 0: + self.__delBtn.setEnabled(False) + + def __import(self): + dialog = PromptGroupImportDialog(parent=self, prompt_type=self.prompt_type) + reply = dialog.exec() + if reply == QDialog.DialogCode.Accepted: + # Get the data + result = dialog.getSelected() + # Save the data + for group in result: + id = DB.insertPromptGroup(group["name"], prompt_type=self.prompt_type) + for entry in group["data"]: + DB.insertPromptEntry(id, entry["act"], entry["prompt"]) + name = group["name"] + self.__addGroupItem(id, name) + + def __export(self): + try: + if self.prompt_type == 'form': + # Get the file + file_data = QFileDialog.getSaveFileName( + self, + LangClass.TRANSLATIONS["Save"], + QFILEDIALOG_DEFAULT_DIRECTORY, + JSON_FILE_EXT_LIST_STR, + ) + if file_data[0]: + filename = file_data[0] + # Get the data + data = get_prompt_data(prompt_type=self.prompt_type) + dialog = PromptGroupExportDialog(data=data, parent=self) + reply = dialog.exec() + if reply == QDialog.DialogCode.Accepted: + data = dialog.getSelected() + # Save the data + with open(filename, "w") as f: + json.dump(data, f, indent=INDENT_SIZE) + elif self.prompt_type == 'sentence': + # Get the file + file_data = QFileDialog.getSaveFileName( + self, + LangClass.TRANSLATIONS["Save"], + QFILEDIALOG_DEFAULT_DIRECTORY, + f"CSV files Compressed File (*.zip);;{JSON_FILE_EXT_LIST_STR}", + ) + if file_data[0]: + filename = file_data[0] + # Get the data + data = get_prompt_data(self.prompt_type) + # Get extension + ext = os.path.splitext(filename)[1] + # If it is a compressed file, it is a compressed csv, so change the extension to csv + if ext == ".zip": + ext = ".csv" + dialog = PromptGroupExportDialog(data=data, ext=ext, parent=self) + reply = dialog.exec() + if reply == QDialog.DialogCode.Accepted: + data = dialog.getSelected() + export_prompt(data, filename, ext) + open_directory(os.path.dirname(filename)) + open_directory(os.path.dirname(filename)) + except Exception as e: + QMessageBox.critical(self, LangClass.TRANSLATIONS["Error"], str(e)) + print(e) + + def __itemChanged(self, item): + id = item.data(Qt.ItemDataRole.UserRole) + DB.updatePromptGroup(id, item.text()) + self.itemChanged.emit(id) + + def __currentRowChanged(self, r_idx): + item = self.list.item(r_idx) + if item: + id = item.data(Qt.ItemDataRole.UserRole) + self.currentRowChanged.emit(id) \ No newline at end of file diff --git a/pyqt_openai/chat_widget/prompt_gen_widget/promptPage.py b/pyqt_openai/chat_widget/prompt_gen_widget/promptPage.py new file mode 100644 index 0000000..cb29194 --- /dev/null +++ b/pyqt_openai/chat_widget/prompt_gen_widget/promptPage.py @@ -0,0 +1,59 @@ +from PySide6.QtCore import Signal +from PySide6.QtWidgets import ( + QVBoxLayout, + QWidget, + QSplitter, +) + +from pyqt_openai.chat_widget.prompt_gen_widget.promptGroupList import PromptGroupList +from pyqt_openai.chat_widget.prompt_gen_widget.promptTable import PromptTable +from pyqt_openai.globals import DB + + +class PromptPage(QWidget): + updated = Signal(str) + + def __init__(self, prompt_type='form', parent=None): + super().__init__(parent) + self.__initVal(prompt_type) + self.__initUi() + + def __initVal(self, prompt_type): + self.prompt_type = prompt_type + self.__groups = DB.selectPromptGroup(prompt_type=self.prompt_type) + + def __initUi(self): + leftWidget = PromptGroupList(prompt_type=self.prompt_type) + leftWidget.added.connect(self.add) + leftWidget.deleted.connect(self.delete) + + leftWidget.currentRowChanged.connect(self.__showEntries) + + self.__table = PromptTable() + if len(self.__groups) > 0: + leftWidget.list.setCurrentRow(0) + self.__table.showEntries(self.__groups[0].id) + self.__table.updated.connect(self.updated) + + mainWidget = QSplitter() + mainWidget.addWidget(leftWidget) + mainWidget.addWidget(self.__table) + mainWidget.setChildrenCollapsible(False) + mainWidget.setSizes([300, 700]) + + lay = QVBoxLayout() + lay.addWidget(mainWidget) + + self.setLayout(lay) + + def add(self, id): + self.__table.showEntries(id) + + def delete(self, id): + if self.__table.getId() == id: + self.__table.setNothingRightNow() + elif len(DB.selectPromptGroup(prompt_type=self.prompt_type)) == 0: + self.__table.setNothingRightNow() + + def __showEntries(self, id): + self.__table.showEntries(id) diff --git a/pyqt_openai/chat_widget/prompt_gen_widget/promptTable.py b/pyqt_openai/chat_widget/prompt_gen_widget/promptTable.py new file mode 100644 index 0000000..97043ba --- /dev/null +++ b/pyqt_openai/chat_widget/prompt_gen_widget/promptTable.py @@ -0,0 +1,171 @@ +from PySide6.QtCore import Signal, Qt +from PySide6.QtWidgets import ( + QTableWidget, + QSizePolicy, + QSpacerItem, + QLabel, + QAbstractItemView, + QTableWidgetItem, + QHeaderView, + QHBoxLayout, + QVBoxLayout, + QWidget, + QDialog, +) + +from pyqt_openai import ( + ICON_ADD, + ICON_DELETE, +) +from pyqt_openai.chat_widget.prompt_gen_widget.promptEntryDirectInputDialog import ( + PromptEntryDirectInputDialog, +) +from pyqt_openai.globals import DB +from pyqt_openai.lang.translations import LangClass +from pyqt_openai.widgets.button import Button + + +class PromptTable(QWidget): + updated = Signal(str) + + def __init__(self, parent=None): + super().__init__(parent) + self.__initVal() + self.__initUi() + + def __initVal(self): + self.__title = "" + self.__entries = [] + + def __initUi(self): + self.__addBtn = Button() + self.__delBtn = Button() + + self.__addBtn.setStyleAndIcon(ICON_ADD) + self.__delBtn.setStyleAndIcon(ICON_DELETE) + + self.__addBtn.clicked.connect(self.__add) + self.__delBtn.clicked.connect(self.__delete) + + self.__titleLbl = QLabel() + + lay = QHBoxLayout() + lay.addWidget(self.__titleLbl) + lay.addSpacerItem(QSpacerItem(10, 10, QSizePolicy.Policy.MinimumExpanding)) + lay.addWidget(self.__addBtn) + lay.addWidget(self.__delBtn) + lay.setAlignment(Qt.AlignmentFlag.AlignRight) + lay.setContentsMargins(0, 0, 0, 0) + + topWidget = QWidget() + topWidget.setLayout(lay) + + self.__table = QTableWidget() + self.__table.setColumnCount(2) + self.__table.setSelectionBehavior( + QAbstractItemView.SelectionBehavior.SelectRows + ) + self.__table.setHorizontalHeaderLabels( + [LangClass.TRANSLATIONS["Name"], LangClass.TRANSLATIONS["Value"]] + ) + self.__table.horizontalHeader().setSectionResizeMode( + 1, QHeaderView.ResizeMode.Stretch + ) + self.__table.currentItemChanged.connect(self.__rowChanged) + self.__table.itemChanged.connect(self.__saveChangedPrompt) + + lay = QVBoxLayout() + lay.addWidget(topWidget) + lay.addWidget(self.__table) + lay.setContentsMargins(5, 0, 0, 0) + + self.setLayout(lay) + + def showEntries(self, id): + self.__group_id = id + + prompt_group = DB.selectCertainPromptGroup(id=self.__group_id) + self.__title = prompt_group.name + self.__entries = DB.selectPromptEntry(self.__group_id) + + self.__titleLbl.setText(self.__title) + + self.__table.setRowCount(len(self.__entries)) + for i in range(len(self.__entries)): + act = self.__entries[i].act + prompt = self.__entries[i].prompt + + item1 = QTableWidgetItem(act) + item1.setData(Qt.ItemDataRole.UserRole, self.__entries[i].id) + item1.setTextAlignment(Qt.AlignmentFlag.AlignCenter) + + item2 = QTableWidgetItem(prompt) + item2.setTextAlignment(Qt.AlignmentFlag.AlignCenter) + + self.__table.setItem(i, 0, item1) + self.__table.setItem(i, 1, item2) + + self.__addBtn.setEnabled(True) + self.__delBtn.setEnabled(True) + + def setNothingRightNow(self): + self.__title = "" + self.__titleLbl.setText(self.__title) + self.__table.clearContents() + self.__addBtn.setEnabled(False) + self.__delBtn.setEnabled(False) + + def getId(self): + return self.__group_id + + def __rowChanged(self, new_item: QTableWidgetItem, old_item: QTableWidgetItem): + prompt = "" + # To avoid AttributeError + if new_item: + prompt = ( + self.__table.item(new_item.row(), 1).text() + if new_item.column() == 0 + else new_item.text() + ) + self.updated.emit(prompt) + + def __saveChangedPrompt(self, item: QTableWidgetItem): + act = self.__table.item(item.row(), 0) + id = act.data(Qt.ItemDataRole.UserRole) + act = act.text() + + prompt = self.__table.item(item.row(), 1) + prompt = prompt.text() if prompt else "" + DB.updatePromptEntry(id, act, prompt) + + def __add(self): + dialog = PromptEntryDirectInputDialog(self.__group_id, self) + reply = dialog.exec() + if reply == QDialog.DialogCode.Accepted: + self.__table.itemChanged.disconnect(self.__saveChangedPrompt) + + act = dialog.getAct() + self.__table.setRowCount(self.__table.rowCount() + 1) + + item1 = QTableWidgetItem(act) + item1.setTextAlignment(Qt.AlignmentFlag.AlignCenter) + self.__table.setItem(self.__table.rowCount() - 1, 0, item1) + + prompt = dialog.getPrompt() + + item2 = QTableWidgetItem(prompt) + item2.setTextAlignment(Qt.AlignmentFlag.AlignCenter) + self.__table.setItem(self.__table.rowCount() - 1, 1, item2) + + id = DB.insertPromptEntry(self.__group_id, act, prompt) + item1.setData(Qt.ItemDataRole.UserRole, id) + + self.__table.itemChanged.connect(self.__saveChangedPrompt) + + def __delete(self): + for i in sorted( + set([i.row() for i in self.__table.selectedIndexes()]), reverse=True + ): + id = self.__table.item(i, 0).data(Qt.ItemDataRole.UserRole) + self.__table.removeRow(i) + DB.deletePromptEntry(self.__group_id, id) \ No newline at end of file diff --git a/pyqt_openai/chat_widget/prompt_gen_widget/sentencePage.py b/pyqt_openai/chat_widget/prompt_gen_widget/sentencePage.py deleted file mode 100644 index c703ba2..0000000 --- a/pyqt_openai/chat_widget/prompt_gen_widget/sentencePage.py +++ /dev/null @@ -1,390 +0,0 @@ -import os - -from PySide6.QtCore import Signal, Qt -from PySide6.QtWidgets import ( - QWidget, - QDialog, - QTableWidget, - QVBoxLayout, - QHBoxLayout, - QHeaderView, - QTableWidgetItem, - QAbstractItemView, - QFileDialog, - QLabel, - QSpacerItem, - QListWidget, - QListWidgetItem, - QSizePolicy, - QSplitter, - QMessageBox, -) - -from pyqt_openai import ( - ICON_ADD, - ICON_DELETE, - ICON_IMPORT, - ICON_EXPORT, - QFILEDIALOG_DEFAULT_DIRECTORY, - JSON_FILE_EXT_LIST_STR, -) -from pyqt_openai.chat_widget.prompt_gen_widget.promptEntryDirectInputDialog import ( - PromptEntryDirectInputDialog, -) -from pyqt_openai.chat_widget.prompt_gen_widget.promptGroupDirectInputDialog import ( - PromptGroupDirectInputDialog, -) -from pyqt_openai.chat_widget.prompt_gen_widget.promptGroupExportDialog import ( - PromptGroupExportDialog, -) -from pyqt_openai.chat_widget.prompt_gen_widget.promptGroupImportDialog import ( - PromptGroupImportDialog, -) -from pyqt_openai.globals import DB -from pyqt_openai.lang.translations import LangClass -from pyqt_openai.util.common import open_directory, get_prompt_data, export_prompt -from pyqt_openai.widgets.button import Button - - -class SentenceGroupList(QWidget): - added = Signal(int) - deleted = Signal(int) - currentRowChanged = Signal(int) - itemChanged = Signal(int) - - def __init__(self, parent=None): - super().__init__(parent) - self.__initUi() - - def __initUi(self): - self.__addBtn = Button() - self.__delBtn = Button() - - self.__importBtn = Button() - self.__importBtn.setToolTip(LangClass.TRANSLATIONS["Import"]) - - self.__exportBtn = Button() - self.__exportBtn.setToolTip(LangClass.TRANSLATIONS["Export"]) - - self.__addBtn.setStyleAndIcon(ICON_ADD) - self.__delBtn.setStyleAndIcon(ICON_DELETE) - self.__importBtn.setStyleAndIcon(ICON_IMPORT) - self.__exportBtn.setStyleAndIcon(ICON_EXPORT) - - self.__addBtn.clicked.connect(self.__add) - self.__delBtn.clicked.connect(self.__delete) - self.__importBtn.clicked.connect(self.__import) - self.__exportBtn.clicked.connect(self.__export) - - lay = QHBoxLayout() - # Should've added "Sentence Group" to the translation, but it's not in the - # translation file for incomplete JSON response issue - lay.addWidget(QLabel(LangClass.TRANSLATIONS["Sentence Group"])) - lay.addSpacerItem(QSpacerItem(10, 10, QSizePolicy.Policy.MinimumExpanding)) - lay.addWidget(self.__addBtn) - lay.addWidget(self.__delBtn) - lay.addWidget(self.__importBtn) - lay.addWidget(self.__exportBtn) - lay.setAlignment(Qt.AlignmentFlag.AlignRight) - lay.setContentsMargins(0, 0, 0, 0) - - topWidget = QWidget() - topWidget.setLayout(lay) - - groups = DB.selectPromptGroup(prompt_type="sentence") - if len(groups) <= 0: - self.__delBtn.setEnabled(False) - - self.list = QListWidget() - - for group in groups: - id = group.id - name = group.name - self.__addGroupItem(id, name) - - self.list.currentRowChanged.connect(self.__currentRowChanged) - self.list.itemChanged.connect(self.__itemChanged) - - lay = QVBoxLayout() - lay.addWidget(topWidget) - lay.addWidget(self.list) - lay.setContentsMargins(0, 0, 5, 0) - - self.setLayout(lay) - - def __addGroupItem(self, id, name): - item = QListWidgetItem() - item.setFlags(item.flags() | Qt.ItemFlag.ItemIsEditable) - item.setData(Qt.ItemDataRole.UserRole, id) - item.setText(name) - self.list.addItem(item) - self.list.setCurrentItem(item) - self.added.emit(id) - - self.__delBtn.setEnabled(True) - - def __add(self): - dialog = PromptGroupDirectInputDialog(self) - reply = dialog.exec() - if reply == QDialog.DialogCode.Accepted: - name = dialog.getPromptGroupName() - id = DB.insertPromptGroup(name, prompt_type="sentence") - self.__addGroupItem(id, name) - - def __delete(self): - i = self.list.currentRow() - item = self.list.takeItem(i) - id = item.data(Qt.ItemDataRole.UserRole) - DB.deletePromptGroup(id) - self.deleted.emit(id) - - groups = DB.selectPromptGroup(prompt_type="sentence") - if len(groups) <= 0: - self.__delBtn.setEnabled(False) - - def __import(self): - dialog = PromptGroupImportDialog(parent=self, prompt_type="sentence") - reply = dialog.exec() - if reply == QDialog.DialogCode.Accepted: - # Get the data - result = dialog.getSelected() - # Save the data - for group in result: - id = DB.insertPromptGroup(group["name"], prompt_type="sentence") - for entry in group["data"]: - DB.insertPromptEntry(id, entry["act"], entry["prompt"]) - name = group["name"] - self.__addGroupItem(id, name) - - def __export(self): - try: - # Get the file - file_data = QFileDialog.getSaveFileName( - self, - LangClass.TRANSLATIONS["Save"], - QFILEDIALOG_DEFAULT_DIRECTORY, - f"CSV files Compressed File (*.zip);;{JSON_FILE_EXT_LIST_STR}", - ) - if file_data[0]: - filename = file_data[0] - # Get the data - data = get_prompt_data("sentence") - # Get extension - ext = os.path.splitext(filename)[1] - # If it is a compressed file, it is a compressed csv, so change the extension to csv - if ext == ".zip": - ext = ".csv" - dialog = PromptGroupExportDialog(data=data, ext=ext, parent=self) - reply = dialog.exec() - if reply == QDialog.DialogCode.Accepted: - data = dialog.getSelected() - export_prompt(data, filename, ext) - open_directory(os.path.dirname(filename)) - except Exception as e: - QMessageBox.critical(self, LangClass.TRANSLATIONS["Error"], str(e)) - print(e) - - def __itemChanged(self, item): - id = item.data(Qt.ItemDataRole.UserRole) - DB.updatePromptGroup(id, item.text()) - self.itemChanged.emit(id) - - def __currentRowChanged(self, r_idx): - item = self.list.item(r_idx) - if item: - id = item.data(Qt.ItemDataRole.UserRole) - self.currentRowChanged.emit(id) - - -class PromptTable(QWidget): - updated = Signal(str) - - def __init__(self, parent=None): - super().__init__(parent) - self.__initVal() - self.__initUi() - - def __initVal(self): - self.__title = "" - self.__entries = [] - - def __initUi(self): - self.__addBtn = Button() - self.__delBtn = Button() - - self.__addBtn.setStyleAndIcon(ICON_ADD) - self.__delBtn.setStyleAndIcon(ICON_DELETE) - - self.__addBtn.clicked.connect(self.__add) - self.__delBtn.clicked.connect(self.__delete) - - self.__titleLbl = QLabel() - - lay = QHBoxLayout() - lay.addWidget(self.__titleLbl) - lay.addSpacerItem(QSpacerItem(10, 10, QSizePolicy.Policy.MinimumExpanding)) - lay.addWidget(self.__addBtn) - lay.addWidget(self.__delBtn) - lay.setAlignment(Qt.AlignmentFlag.AlignRight) - lay.setContentsMargins(0, 0, 0, 0) - - topWidget = QWidget() - topWidget.setLayout(lay) - - self.__table = QTableWidget() - self.__table.setColumnCount(2) - self.__table.setSelectionBehavior( - QAbstractItemView.SelectionBehavior.SelectRows - ) - self.__table.setHorizontalHeaderLabels( - [LangClass.TRANSLATIONS["Name"], LangClass.TRANSLATIONS["Value"]] - ) - self.__table.horizontalHeader().setSectionResizeMode( - 1, QHeaderView.ResizeMode.Stretch - ) - self.__table.currentItemChanged.connect(self.__rowChanged) - self.__table.itemChanged.connect(self.__saveChangedPrompt) - - lay = QVBoxLayout() - lay.addWidget(topWidget) - lay.addWidget(self.__table) - lay.setContentsMargins(5, 0, 0, 0) - - self.setLayout(lay) - - def showEntries(self, id): - self.__group_id = id - - prompt_group = DB.selectCertainPromptGroup(id=self.__group_id) - self.__title = prompt_group.name - self.__entries = DB.selectPromptEntry(self.__group_id) - - self.__titleLbl.setText(self.__title) - - self.__table.setRowCount(len(self.__entries)) - for i in range(len(self.__entries)): - act = self.__entries[i].act - prompt = self.__entries[i].prompt - - item1 = QTableWidgetItem(act) - item1.setData(Qt.ItemDataRole.UserRole, self.__entries[i].id) - item1.setTextAlignment(Qt.AlignmentFlag.AlignCenter) - - item2 = QTableWidgetItem(prompt) - item2.setTextAlignment(Qt.AlignmentFlag.AlignCenter) - - self.__table.setItem(i, 0, item1) - self.__table.setItem(i, 1, item2) - - self.__addBtn.setEnabled(True) - self.__delBtn.setEnabled(True) - - def setNothingRightNow(self): - self.__title = "" - self.__titleLbl.setText(self.__title) - self.__table.clearContents() - self.__addBtn.setEnabled(False) - self.__delBtn.setEnabled(False) - - def getId(self): - return self.__group_id - - def __rowChanged(self, new_item: QTableWidgetItem, old_item: QTableWidgetItem): - prompt = "" - # To avoid AttributeError - if new_item: - prompt = ( - self.__table.item(new_item.row(), 1).text() - if new_item.column() == 0 - else new_item.text() - ) - self.updated.emit(prompt) - - def __saveChangedPrompt(self, item: QTableWidgetItem): - act = self.__table.item(item.row(), 0) - id = act.data(Qt.ItemDataRole.UserRole) - act = act.text() - - prompt = self.__table.item(item.row(), 1) - prompt = prompt.text() if prompt else "" - DB.updatePromptEntry(id, act, prompt) - - def __add(self): - dialog = PromptEntryDirectInputDialog(self.__group_id, self) - reply = dialog.exec() - if reply == QDialog.DialogCode.Accepted: - self.__table.itemChanged.disconnect(self.__saveChangedPrompt) - - act = dialog.getAct() - self.__table.setRowCount(self.__table.rowCount() + 1) - - item1 = QTableWidgetItem(act) - item1.setTextAlignment(Qt.AlignmentFlag.AlignCenter) - self.__table.setItem(self.__table.rowCount() - 1, 0, item1) - - prompt = dialog.getPrompt() - - item2 = QTableWidgetItem(prompt) - item2.setTextAlignment(Qt.AlignmentFlag.AlignCenter) - self.__table.setItem(self.__table.rowCount() - 1, 1, item2) - - id = DB.insertPromptEntry(self.__group_id, act, prompt) - item1.setData(Qt.ItemDataRole.UserRole, id) - - self.__table.itemChanged.connect(self.__saveChangedPrompt) - - def __delete(self): - for i in sorted( - set([i.row() for i in self.__table.selectedIndexes()]), reverse=True - ): - id = self.__table.item(i, 0).data(Qt.ItemDataRole.UserRole) - self.__table.removeRow(i) - DB.deletePromptEntry(self.__group_id, id) - - -class SentencePage(QWidget): - updated = Signal(str) - - def __init__(self, parent=None): - super().__init__(parent) - self.__initVal() - self.__initUi() - - def __initVal(self): - self.__groups = DB.selectPromptGroup(prompt_type="sentence") - - def __initUi(self): - leftWidget = SentenceGroupList() - leftWidget.added.connect(self.add) - leftWidget.deleted.connect(self.delete) - - leftWidget.currentRowChanged.connect(self.__showEntries) - - self.__table = PromptTable() - if len(self.__groups) > 0: - leftWidget.list.setCurrentRow(0) - self.__table.showEntries(self.__groups[0].id) - self.__table.updated.connect(self.updated) - - mainWidget = QSplitter() - mainWidget.addWidget(leftWidget) - mainWidget.addWidget(self.__table) - mainWidget.setChildrenCollapsible(False) - mainWidget.setSizes([300, 700]) - - lay = QVBoxLayout() - lay.addWidget(mainWidget) - - self.setLayout(lay) - - def add(self, id): - self.__table.showEntries(id) - - def delete(self, id): - if self.__table.getId() == id: - self.__table.setNothingRightNow() - elif len(DB.selectPromptGroup(prompt_type="sentence")) == 0: - self.__table.setNothingRightNow() - - def __showEntries(self, id): - self.__table.showEntries(id) From d2d46e1ba65cd543c1e60b5c40db8efcc5f3e52f Mon Sep 17 00:00:00 2001 From: yjg30737 Date: Sun, 24 Nov 2024 09:29:46 +0900 Subject: [PATCH 5/6] Filter image models only as far as possible --- pyqt_openai/__init__.py | 3 +++ pyqt_openai/util/common.py | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pyqt_openai/__init__.py b/pyqt_openai/__init__.py index 8522461..0539d75 100644 --- a/pyqt_openai/__init__.py +++ b/pyqt_openai/__init__.py @@ -411,6 +411,9 @@ def move_bin(filename, dst_dir): # This has to be managed separately since some of the arguments are different with usual models O1_MODELS = ["o1-preview", "o1-mini"] +# For filtering out famous LLMs for image models +FAMOUS_LLM_LIST = ["gpt", "claude", "gemini", "llama", "meta", "qwen", "falcon"] + # Overall API configuration data DEFAULT_API_CONFIGS = [ # OpenAI diff --git a/pyqt_openai/util/common.py b/pyqt_openai/util/common.py index f76b31d..2d73b1a 100644 --- a/pyqt_openai/util/common.py +++ b/pyqt_openai/util/common.py @@ -63,7 +63,7 @@ O1_MODELS, STT_MODEL, DEFAULT_DATETIME_FORMAT, - DEFAULT_TOKEN_CHUNK_SIZE, DEFAULT_API_CONFIGS, INDENT_SIZE, + DEFAULT_TOKEN_CHUNK_SIZE, DEFAULT_API_CONFIGS, INDENT_SIZE, FAMOUS_LLM_LIST, ) from pyqt_openai.config_loader import CONFIG_MANAGER from pyqt_openai.globals import ( @@ -711,6 +711,8 @@ def get_g4f_image_models() -> list: index.append(parent.__name__) models = [model["image_model"] for model in image_models] + # Filter out the models in FAMOUS_LLM_LIST + models = [model for model in models if model not in FAMOUS_LLM_LIST] return models From 5b87c8b098114297832daa8a8f2a2a8fc2e2828f Mon Sep 17 00:00:00 2001 From: yjg30737 Date: Sun, 24 Nov 2024 10:36:50 +0900 Subject: [PATCH 6/6] Use filetype to get the mime type of image bytes, fix vision related issue, remove some unused functions --- pyproject.toml | 1 + pyqt_openai/util/common.py | 74 ++++++-------------------------------- requirements.txt | 1 + 3 files changed, 13 insertions(+), 63 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4411091..9658bdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "pyaudio", "pillow", "psutil", + "filetype", "openai", "anthropic", diff --git a/pyqt_openai/util/common.py b/pyqt_openai/util/common.py index 2d73b1a..fd87d16 100644 --- a/pyqt_openai/util/common.py +++ b/pyqt_openai/util/common.py @@ -17,6 +17,7 @@ import sys import tempfile import time +import filetype import traceback import wave import zipfile @@ -581,63 +582,6 @@ def get_chat_model(is_g4f=False): all_models.extend(obj.get("model_list", [])) return all_models -def get_gemini_argument(model, system, messages, cur_text, stream, images): - try: - args = { - "system": system, - "model": model, - "messages": messages, - "stream": stream, - } - if len(images) > 0: - args["images"] = [PIL.Image.open(BytesIO(image)) for image in images] - args["messages"].append({"role": "user", "content": cur_text}) - return args - except Exception as e: - print(e) - raise e - - -def get_claude_argument(model, system, messages, cur_text, stream, images): - try: - args = { - "model": model, - "system": system, - "messages": messages, - "max_tokens": DEFAULT_TOKEN_CHUNK_SIZE, - "stream": stream, - } - # TODO REFACTORING (FOR COMMON FUNCTION FOR VISION) - # Vision - if len(images) > 0: - multiple_images_content = [] - for image in images: - multiple_images_content.append( - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": get_image_url_from_local(image), - }, - } - ) - - multiple_images_content = multiple_images_content[:] + [ - {"type": "text", "text": cur_text} - ] - - args["messages"].append( - {"role": "user", "content": multiple_images_content} - ) - else: - args["messages"].append({"role": "user", "content": cur_text}) - return args - except Exception as e: - print(e) - raise e - - def set_api_key(env_var_name, api_key): api_key = api_key.strip() if api_key else "" if env_var_name == "OPENAI_API_KEY": @@ -655,7 +599,14 @@ def set_api_key(env_var_name, api_key): # Set environment variables dynamically os.environ[env_var_name] = api_key -def get_image_url_from_local(image, is_openai=False): +def get_mime_type_from_bytes(byte_data): + kind = filetype.guess(byte_data) + if kind is None: + raise ValueError("Could not determine MIME type from bytes") + print(kind.mime) + return kind.mime + +def get_image_url_from_local(image): """ Image is bytes, this function converts it to base64 and returns the image url """ @@ -665,10 +616,7 @@ def encode_image(image): return base64.b64encode(image).decode("utf-8") base64_image = encode_image(image) - if is_openai: - return f"data:image/jpeg;base64,{base64_image}" - else: - return base64_image + return f"data:{get_mime_type_from_bytes(image)};base64,{base64_image}" def get_message_obj(role, content): @@ -839,7 +787,7 @@ def get_api_argument( { "type": "image_url", "image_url": { - "url": get_image_url_from_local(image, is_openai=True), + "url": get_image_url_from_local(image), }, } ) diff --git a/requirements.txt b/requirements.txt index f3f0186..03a0c3f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ requests pyaudio pillow psutil +filetype openai anthropic