Skip to content

Commit

Permalink
Fix history support for OpenaiChat
Browse files Browse the repository at this point in the history
  • Loading branch information
hlohaus committed Mar 28, 2024
1 parent d0da590 commit 03fd5ac
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
11 changes: 5 additions & 6 deletions g4f/Provider/needs_auth/OpenaiChat.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,19 +389,17 @@ async def create_async_generator(
print(f"{e.__class__.__name__}: {e}")

model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha")
fields = Conversation() if conversation is None else copy(conversation)
fields = Conversation(conversation_id, parent_id) if conversation is None else copy(conversation)
fields.finish_reason = None
while fields.finish_reason is None:
conversation_id = fields.conversation_id if hasattr(fields, "conversation_id") else conversation_id
parent_id = fields.message_id if hasattr(fields, "message_id") else parent_id
websocket_request_id = str(uuid.uuid4())
data = {
"action": action,
"conversation_mode": {"kind": "primary_assistant"},
"force_paragen": False,
"force_rate_limit": False,
"conversation_id": conversation_id,
"parent_message_id": parent_id,
"conversation_id": fields.conversation_id,
"parent_message_id": fields.message_id,
"model": model,
"history_and_training_disabled": history_disabled and not auto_continue and not return_conversation,
"websocket_request_id": websocket_request_id
Expand All @@ -425,14 +423,15 @@ async def create_async_generator(
await raise_for_status(response)
async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, fields):
if return_conversation:
history_disabled = False
return_conversation = False
yield fields
yield chunk
if not auto_continue:
break
action = "continue"
await asyncio.sleep(5)
if history_disabled and auto_continue and not return_conversation:
if history_disabled and auto_continue:
await cls.delete_conversation(session, cls._headers, fields.conversation_id)

@staticmethod
Expand Down
15 changes: 9 additions & 6 deletions g4f/gui/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from g4f.Provider.bing.create_images import patch_provider
from g4f.providers.conversation import BaseConversation

conversations: dict[str, BaseConversation] = {}
conversations: dict[dict[str, BaseConversation]] = {}

class Api():

Expand Down Expand Up @@ -106,7 +106,8 @@ def get_conversation(self, options: dict, **kwargs) -> Iterator:
kwargs["image"] = open(self.image, "rb")
for message in self._create_response_stream(
self._prepare_conversation_kwargs(options, kwargs),
options.get("conversation_id")
options.get("conversation_id"),
options.get('provider')
):
if not window.evaluate_js(f"if (!this.abort) this.add_message_chunk({json.dumps(message)}); !this.abort && !this.error;"):
break
Expand Down Expand Up @@ -193,8 +194,8 @@ def _prepare_conversation_kwargs(self, json_data: dict, kwargs: dict):
messages[-1]["content"] = get_search_message(messages[-1]["content"])

conversation_id = json_data.get("conversation_id")
if conversation_id and conversation_id in conversations:
kwargs["conversation"] = conversations[conversation_id]
if conversation_id and provider in conversations and conversation_id in conversations[provider]:
kwargs["conversation"] = conversations[provider][conversation_id]

model = json_data.get('model')
model = model if model else models.default
Expand All @@ -211,7 +212,7 @@ def _prepare_conversation_kwargs(self, json_data: dict, kwargs: dict):
**kwargs
}

def _create_response_stream(self, kwargs, conversation_id: str) -> Iterator:
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str) -> Iterator:
"""
Creates and returns a streaming response for the conversation.
Expand All @@ -231,7 +232,9 @@ def _create_response_stream(self, kwargs, conversation_id: str) -> Iterator:
first = False
yield self._format_json("provider", get_last_provider(True))
if isinstance(chunk, BaseConversation):
conversations[conversation_id] = chunk
if provider not in conversations:
conversations[provider] = {}
conversations[provider][conversation_id] = chunk
yield self._format_json("conversation", conversation_id)
elif isinstance(chunk, Exception):
logging.exception(chunk)
Expand Down
2 changes: 1 addition & 1 deletion g4f/gui/server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def handle_conversation(self):
kwargs = self._prepare_conversation_kwargs(json_data, kwargs)

return self.app.response_class(
self._create_response_stream(kwargs, json_data.get("conversation_id")),
self._create_response_stream(kwargs, json_data.get("conversation_id"), json_data.get("provider")),
mimetype='text/event-stream'
)

Expand Down

0 comments on commit 03fd5ac

Please sign in to comment.