Skip to content

Commit

Permalink
fix: openai init with chat models (#42)
Browse files Browse the repository at this point in the history
* fix: openai init with chat models

* fix tests

---------

Co-authored-by: Douglas Reid <[email protected]>
  • Loading branch information
douglas-reid and Douglas Reid authored Apr 28, 2023
1 parent 23c9a35 commit d5861ec
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/steamship_langchain/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Provides Steamship-compatible LLMs for use in langchain (🦜️🔗) chains and agents."""
from .openai import OpenAI
from .openai import OpenAI, OpenAIChat

__all__ = ["OpenAI"]
__all__ = ["OpenAI", "OpenAIChat"]
18 changes: 15 additions & 3 deletions src/steamship_langchain/llms/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hashlib
import json
import logging
import warnings
from collections import defaultdict
from typing import Any, Dict, Generator, List, Mapping, Optional

Expand Down Expand Up @@ -45,6 +46,18 @@ class OpenAI(BaseOpenAI):
client: Steamship # We can use validate_environment to add the client here
batch_task_timeout_seconds: int = 10 * 60 # 10 minute limit on generation tasks

def __new__(cls, **data: Any):
"""Initialize the OpenAI object."""
model_name = data.get("model_name", "")
if model_name.startswith("gpt-3.5-turbo") or model_name.startswith("gpt-4"):
warnings.warn(
"You are trying to use a chat model. This way of initializing it is "
"no longer supported. Instead, please use: "
"`from steamship_langchain.llms import OpenAIChat`"
)
return OpenAIChat(**data)
return super().__new__(cls)

@property
def _identifying_params(self) -> Mapping[str, Any]:
return {
Expand Down Expand Up @@ -186,7 +199,6 @@ def _batch(
logging.error(f"could not generate from OpenAI LLM: {e}")
# TODO(douglas-reid): determine appropriate action here.
# for now, if an error is encountered, just swallow.
pass

return generations, token_usage

Expand All @@ -202,8 +214,8 @@ class Config:
def __init__(
self, client: Steamship, model_name: str = "gpt-4", moderate_output: bool = True, **kwargs
):
super().__init__(client=client, model_name=model_name, **kwargs)
plugin_config = {"model": self.model_name, "moderate_output": moderate_output}
super().__init__(client=client, **kwargs)
plugin_config = {"model": model_name, "moderate_output": moderate_output}
if self.openai_api_key:
plugin_config["openai_api_key"] = self.openai_api_key

Expand Down
2 changes: 1 addition & 1 deletion tests/file_loaders/test_youtube.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@


@pytest.mark.usefixtures("client")
@pytest.mark.skip() # YT loader implementation is failing at the moment due to YT changes.
def test_youtube_loader(client: Steamship):

loader_under_test = YouTubeFileLoader(client=client)
files = loader_under_test.load(TEST_URL)

Expand Down
24 changes: 22 additions & 2 deletions tests/llms/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ def test_openai_chat_llm(client: Steamship) -> None:
"""Test Chat version of the LLM"""
llm = OpenAIChat(client=client)
llm_result = llm.generate(
prompts=["Please say the Pledge of Allegiance"], stop=["flag", "Flag"]
prompts=["Please print the words of the Pledge of Allegiance"], stop=["flag", "Flag"]
)
assert len(llm_result.generations) == 1
generation = llm_result.generations[0]
assert len(generation) == 1
text_response = generation[0].text
assert text_response.strip() == "I pledge allegiance to the"
assert text_response.strip(' "') == "I pledge allegiance to the"


@pytest.mark.usefixtures("client")
Expand All @@ -155,3 +155,23 @@ def test_openai_chat_llm_with_prefixed_messages(client: Steamship) -> None:
assert len(generation) == 1
text_response = generation[0].text
assert text_response.strip() == "What is the meaning of life?"


@pytest.mark.usefixtures("client")
def test_openai_llm_with_chat_model_init(client: Steamship) -> None:
"""Test Chat version of the LLM, with old init style"""
messages = [
{
"role": "system",
"content": "You are EchoGPT. For every prompt you receive, you reply with the exact same text.",
},
{"role": "user", "content": "This is a test."},
{"role": "assistant", "content": "This is a test."},
]
llm = OpenAI(client=client, prefix_messages=messages, model_name="gpt-4")
llm_result = llm.generate(prompts=["What is the meaning of life?"])
assert len(llm_result.generations) == 1
generation = llm_result.generations[0]
assert len(generation) == 1
text_response = generation[0].text
assert text_response.strip() == "What is the meaning of life?"

0 comments on commit d5861ec

Please sign in to comment.