Skip to content

Commit

Permalink
feat: Claude 3 multimmodal improvement, ChuanhuChat now converts unsu…
Browse files Browse the repository at this point in the history
…pported image formats such as .jpg into jpeg
  • Loading branch information
GaiZhenbiao committed Mar 5, 2024
1 parent 94991b8 commit 5097de6
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
9 changes: 9 additions & 0 deletions modules/models/Claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def _get_claude_style_history(self):
elif message["role"] == "image":
image_buffer.append(message["content"])
image_count += 1
# history with base64 data replaced with "#base64#"
# history_for_display = history.copy()
# for message in history_for_display:
# if message["role"] == "user":
# if type(message["content"]) == list:
# for content in message["content"]:
# if content["type"] == "image":
# content["source"]["data"] = "#base64#"
# logging.info(f"History for Claude: {history_for_display}")
return history

def get_answer_stream_iter(self):
Expand Down
24 changes: 17 additions & 7 deletions modules/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from itertools import islice
from threading import Condition, Thread
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import PIL
from io import BytesIO

import aiohttp
import colorama
Expand Down Expand Up @@ -403,7 +405,7 @@ def handle_file_upload(self, files, chatbot, language):
other_files = []
if files:
for f in files:
if f.name.endswith((".jpg", ".png", ".jpeg", ".gif", ".webp")):
if f.name.endswith(IMAGE_FORMATS):
image_files.append(f)
else:
other_files.append(f)
Expand Down Expand Up @@ -1123,14 +1125,22 @@ def clear_cuda_cache(self):
torch.cuda.empty_cache()

def get_base64_image(self, image_path):
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
if image_path.endswith(DIRECTLY_SUPPORTED_IMAGE_FORMATS):
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
else:
# convert to jpeg
image = PIL.Image.open(image_path)
image = image.convert("RGB")
buffer = BytesIO()
image.save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")

def get_image_type(self, image_path):
extension = os.path.splitext(image_path)[1][1:]
if extension == "jpg":
extension = "jpeg"
return extension
if image_path.lower().endswith(DIRECTLY_SUPPORTED_IMAGE_FORMATS):
return os.path.splitext(image_path)[1][1:].lower()
else:
return "jpeg"


class Base_Chat_Langchain_Client(BaseLLMModel):
Expand Down
3 changes: 3 additions & 0 deletions modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@
i18n("模型自动总结(消耗tokens)"),
]

DIRECTLY_SUPPORTED_IMAGE_FORMATS = (".png", ".jpeg", ".gif", ".webp") # image types that can be directly uploaded, other formats will be converted to jpeg
IMAGE_FORMATS = DIRECTLY_SUPPORTED_IMAGE_FORMATS + (".jpg", ".bmp", "heic", "heif") # all supported image formats


WEBSEARCH_PTOMPT_TEMPLATE = """\
Web search results:
Expand Down

0 comments on commit 5097de6

Please sign in to comment.