Skip to content

Commit

Permalink
[CM-10066]: openai fix (#130)
Browse files Browse the repository at this point in the history
* CM-10066: make the messages serializble;

* CM-10066: fix lint issues;

* CM-10066: update code style;

* CM-10066: fix lint issues;
  • Loading branch information
aadereiko authored Apr 23, 2024
1 parent 385fd98 commit 26992cf
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 3 deletions.
41 changes: 39 additions & 2 deletions src/comet_llm/autologgers/openai/chat_completion_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@
# *******************************************************

import inspect
import json
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, Union

import comet_llm.logging
from comet_llm.logging_messages import MESSAGE_IS_NOT_JSON_SERIALIZABLE
from comet_llm.types import JSONEncodable

from . import metadata

if TYPE_CHECKING:
from openai import Stream
from openai.openai_object import OpenAIObject
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message import ChatCompletionMessage

Inputs = Dict[str, Any]
Outputs = Dict[str, Any]
Expand All @@ -47,7 +50,9 @@ def parse_create_arguments(kwargs: Dict[str, Any]) -> Tuple[Inputs, Metadata]:
kwargs_copy = kwargs.copy()
inputs = {}

inputs["messages"] = kwargs_copy.pop("messages")
inputs["messages"] = _parse_create_list_argument_messages(
kwargs_copy.pop("messages")
)
if "function_call" in kwargs_copy:
inputs["function_call"] = kwargs_copy.pop("function_call")

Expand Down Expand Up @@ -102,3 +107,35 @@ def _v1_x_x__parse_create_result(
metadata["output_model"] = metadata.pop("model")

return outputs, metadata


def _parse_create_list_argument_messages(
messages: List[Union[Dict, "ChatCompletionMessage"]],
) -> JSONEncodable:
if _is_jsonable(messages):
return messages

result = []

for message in messages:
if _is_jsonable(message):
result.append(message)
continue

if hasattr(message, "to_dict"):
message_to_append_ = message.to_dict()
else:
LOGGER.debug(MESSAGE_IS_NOT_JSON_SERIALIZABLE)
message_to_append_ = message

result.append(message_to_append_)

return result


def _is_jsonable(x: Any) -> bool:
try:
json.dumps(x)
return True
except (TypeError, OverflowError):
return False
2 changes: 2 additions & 0 deletions src/comet_llm/logging_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@
PARSE_API_KEY_TOO_MANY_PARTS = "Too many parts (%d) found in the Comet API key: %r"

BASE_URL_MISMATCH_CONFIG_API_KEY = "Comet URL conflict detected between config (%r) and API Key (%r). SDK will use config URL. Resolve by either removing config URL or set it to the same value."

MESSAGE_IS_NOT_JSON_SERIALIZABLE = "Message is not JSON serializable"
41 changes: 40 additions & 1 deletion tests/unit/autologgers/openai/test_chat_completion_parsers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Iterable

import box
import pytest
from testix import *
Expand Down Expand Up @@ -108,4 +110,41 @@ def test_parse_create_result__input_is_Stream__input_parsed_with_hardcoded_value
]
)
def test_create_arguments_supported__happyflow(inputs, result):
assert chat_completion_parsers.create_arguments_supported(inputs) is result
assert chat_completion_parsers.create_arguments_supported(inputs) is result


class FakeClassWithoutDict:
def __init__(self, a, b):
self.a = a
self.b = b

def __eq__(self, other):
return self.a == other.a and self.b == other.b


class FakeClassWithDict:
def __init__(self, a, b):
self.a = a
self.b = b

def __eq__(self, other):
return self.a == other.a and self.b == other.b

def to_dict(self):
return {"a": self.a, "b": self.b}



@pytest.mark.parametrize(
"messages,result",
[
([None], [None]),
("123", "123"),
([1, 2, 3], [1, 2, 3]),
([{"key": "value"}, {"key2": "value2"}], [{"key": "value"}, {"key2": "value2"}]),
([FakeClassWithoutDict(a="1", b="2"), FakeClassWithoutDict(a="3", b="4")], [FakeClassWithoutDict(a="1", b="2"), FakeClassWithoutDict(a="3", b="4")]),
([FakeClassWithDict(a="1", b="2")], [{"a": "1", "b": "2"}])
]
)
def test_parse_create_list_argument_messages(messages: Iterable, result: Iterable):
assert chat_completion_parsers._parse_create_list_argument_messages(messages=messages) == result

0 comments on commit 26992cf

Please sign in to comment.