Skip to content

Commit

Permalink
Add conversation mode (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
AAClause authored Nov 30, 2023
1 parent 850c4aa commit f0dacc0
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 62 deletions.
1 change: 1 addition & 0 deletions addon/globalPlugins/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"TTSModel": f"option({', '.join(TTS_MODELS)}, default={TTS_DEFAULT_MODEL})",
"TTSVoice": f"option({', '.join(TTS_VOICES)}, default={TTS_DEFAULT_VOICE})",
"blockEscapeKey": "boolean(default=False)",
"conversationMode": "boolean(default=True)",
"saveSystem": "boolean(default=False)",
"advancedMode": "boolean(default=False)",
"renewClient": "boolean(default=False)",
Expand Down
50 changes: 5 additions & 45 deletions addon/globalPlugins/openai/imagehelper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import base64
import os
import sys
import re
from logHandler import log
from .consts import ADDON_DIR

Expand All @@ -20,60 +19,21 @@ def encode_image(image_path):

def describeFromImageFileList(
client,
pathList,
prompt=None,
max_tokens=700
messages: list,
max_tokens: int = 700,
):
"""
Describe a list of images from a list of file paths.
@param client: OpenAI client
@param pathList: list of file paths
@param prompt: prompt to use
@param messages: list of messages
@param max_tokens: max tokens to use
@return: description
"""
if not prompt or not prompt.strip():
if not messages:
return None
content = [
{
"type": "text",
"text": prompt,
}
]
for path in pathList:
url_re = re.compile(r"^https?://")
if url_re.match(path):
content.append(
{
"type": "image_url",
"image_url": {
"url": path,
},
}
)
elif os.path.isfile(path):
base64_image = encode_image(path)
format = path.split(".")[-1]
mime_type = f"image/{format}"
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_image}"
},
}
)
else:
raise ValueError("Invalid path: {}".format(path))
response = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": content,
}
],
messages=messages,
max_tokens=max_tokens
)
return response.choices[0]

137 changes: 120 additions & 17 deletions addon/globalPlugins/openai/maindialog.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import re
import sys
import threading
import winsound
Expand All @@ -15,7 +16,7 @@
import ui
from logHandler import log
from .consts import ADDON_DIR, DATA_DIR
from .imagehelper import describeFromImageFileList
from .imagehelper import describeFromImageFileList, encode_image
additionalLibsPath = os.path.join(ADDON_DIR, "lib")
sys.path.insert(0, additionalLibsPath)
import openai
Expand Down Expand Up @@ -95,19 +96,25 @@ def run(self):

block.temperature = temperature
block.topP = topP
conversationMode = conf["conversationMode"]
conf["conversationMode"] = wnd.conversationCheckBox.IsChecked()
block.pathList = wnd.pathList

if not 0 <= temperature <= model.maxTemperature * 100:
wx.PostEvent(self._notifyWindow, ResultEvent(_("Invalid temperature")))
return
if not TOP_P_MIN <= topP <= TOP_P_MAX:
wx.PostEvent(self._notifyWindow, ResultEvent(_("Invalid top P")))
return
messages = []
if system:
messages.append({"role": "system", "content": system})
wnd.getMessages(messages)
if prompt:
messages.append({"role": "user", "content": prompt})
params = {
"model": model.name,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": prompt}
],
"messages": messages,
"temperature": temperature,
"max_tokens": maxTokens,
"top_p": topP,
Expand Down Expand Up @@ -175,15 +182,41 @@ class ImageDescriptionThread(threading.Thread):
def __init__(self, notifyWindow):
threading.Thread.__init__(self)
self._notifyWindow = notifyWindow
self._pathList = notifyWindow.pathList

def run(self):
wnd = self._notifyWindow
prompt = wnd.promptText.GetValue()
max_tokens = wnd.maxTokens.GetValue()
client = wnd.client
prompt = wnd.promptText.GetValue().strip()
messages = []
wnd.getMessages(messages)
if wnd.pathList:
content = [
{"type": "text", "text": prompt}
]
content.extend(wnd.getImages())
messages.append({
"role": "user",
"content": content
})
else:
messages.append({"role": "user", "content": prompt})
nbImages = 0
for message in messages:
if message["role"] == "user":
for content in message["content"]:
if not isinstance(content, dict):
continue
if content["type"] == "image_url":
nbImages += 1
if nbImages:
wnd.message(_("%d images to analyze...") % nbImages)
max_tokens = wnd.maxTokens.GetValue()
try:
description = describeFromImageFileList(client, self._pathList, prompt=prompt, max_tokens=max_tokens)
description = describeFromImageFileList(
client,
messages=messages,
max_tokens=max_tokens
)
except BaseException as err:
wx.PostEvent(self._notifyWindow, ResultEvent(repr(err)))
return
Expand Down Expand Up @@ -430,6 +463,7 @@ class HistoryBlock():
displayHeader = True
focused = False
responseTerminated = False
pathList = None


class OpenAIDlg(wx.Dialog):
Expand Down Expand Up @@ -461,6 +495,11 @@ def __init__(
)
super().__init__(parent, title=title)

self.conversationCheckBox = wx.CheckBox(
parent=self,
label=_("Conversati&on mode")
)
self.conversationCheckBox.SetValue(conf["conversationMode"])
systemLabel = wx.StaticText(
parent=self,
label=_("S&ystem:")
Expand Down Expand Up @@ -553,6 +592,7 @@ def __init__(

self.onModelChange(None)
sizer1 = wx.BoxSizer(wx.VERTICAL)
sizer1.Add(self.conversationCheckBox, 0, wx.ALL, 5)
sizer1.Add(systemLabel, 0, wx.ALL, 5)
sizer1.Add(self.systemText, 0, wx.ALL, 5)
sizer1.Add(historyLabel, 0, wx.ALL, 5)
Expand Down Expand Up @@ -648,7 +688,7 @@ def __init__(

def loadData(self):
if not os.path.exists(DATA_JSON_FP):
return
return {}
try:
with open(DATA_JSON_FP, 'r') as f :
return json.loads(f.read())
Expand Down Expand Up @@ -707,7 +747,11 @@ def onOk(self, evt):
wx.OK|wx.ICON_ERROR
)
return
if model.name == MODEL_VISION and not self.pathList:
if (
model.name == MODEL_VISION
and not self.conversationCheckBox.IsChecked()
and not self.pathList
):
gui.messageBox(
_("No image provided. Please use the Image Description button and select one or more images. Otherwise, please select another model."),
_("Open AI"),
Expand Down Expand Up @@ -754,14 +798,18 @@ def OnResult(self, event):
winsound.PlaySound(None, winsound.SND_ASYNC)
if not event.data:
return

if isinstance(event.data, openai.types.chat.chat_completion.Choice):
historyBlock = HistoryBlock()
historyBlock.system = self.systemText.GetValue().strip()
historyBlock.prompt = self.promptText.GetValue().strip()
"""
# TODO: Create a special history block for attached images and additional information
if self.pathList:
historyBlock.prompt += "\n\n" + _("Attached images:")
for path in self.pathList:
historyBlock.prompt += f"\n + <image: \"{path}\">"
self.pathList = None
historyBlock.prompt += f"\n- \"{path}\""
"""
historyBlock.model = self.getCurrentModel().name
if self.conf["advancedMode"]:
historyBlock.temperature = self.temperature.GetValue() / 100
Expand All @@ -774,6 +822,8 @@ def OnResult(self, event):
historyBlock.response = event.data
historyBlock.responseText = event.data.message.content
historyBlock.responseTerminated = True
historyBlock.pathList = self.pathList
self.pathList = None
if self.lastBlock is None:
self.firstBlock = self.lastBlock = historyBlock
else:
Expand All @@ -784,6 +834,7 @@ def OnResult(self, event):
self.promptText.Clear()
self.promptText.SetFocus()
return

if isinstance(event.data, openai.types.audio.transcription.Transcription):
self.promptText.AppendText(event.data.text)
self.promptText.SetFocus()
Expand All @@ -793,6 +844,7 @@ def OnResult(self, event):
True
)
return

if isinstance(event.data, openai._base_client.HttpxBinaryResponseContent):
if os.path.exists(TTS_FILE_NAME):
os.startfile(TTS_FILE_NAME)
Expand Down Expand Up @@ -871,6 +923,57 @@ def addShortcuts(self):
accelTable = wx.AcceleratorTable(accelEntries)
self.SetAcceleratorTable(accelTable)

def getImages(
self,
pathList: list = None
) -> list:
if not pathList:
pathList = self.pathList
images = []
for path in pathList:
url_re = re.compile(r"^https?://")
if url_re.match(path):
images.append({"type": "image_url", "image_url": {"url": path}})
elif os.path.isfile(path):
base64_image = encode_image(path)
format = path.split(".")[-1]
mime_type = f"image/{format}"
images.append({
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_image}"
}
})
else:
raise ValueError(f"Invalid path: {path}")
break
return images

def getMessages(
self,
messages: list
) -> list:
model = self.getCurrentModel()
if not self.conversationCheckBox.IsChecked():
return messages
block = self.firstBlock
while block is not None:
if block.prompt:
if block.pathList:
content = [
{"type": "text", "text": block.prompt}
]
content.extend(self.getImages(block.pathList))
messages.append({
"role": "user",
"content": content
})
else:
messages.append({"role": "user", "content": block.prompt})
if block.responseText:
messages.append({"role": "system", "content": block.responseText})
block = block.next

def onPreviousPrompt(self, event):
value = self.previousPrompt
if value:
Expand Down Expand Up @@ -967,11 +1070,11 @@ def onDeleteBlock(self, evt):
block = segment.owner

if block.segmentBreakLine is not None:
block.segmentBreakLine.delete ()
block.segmentPromptLabel.delete ()
block.segmentBreakLine.delete()
block.segmentPromptLabel.delete()
block.segmentPrompt.delete()
block.segmentResponseLabel.delete ()
block.segmentResponse.delete ()
block.segmentResponseLabel.delete()
block.segmentResponse.delete()

if block.previous is not None:
block.previous.next = block.next
Expand Down

0 comments on commit f0dacc0

Please sign in to comment.