Skip to content

Commit

Permalink
[CM-9113] Introduce api object (#109)
Browse files Browse the repository at this point in the history
* Add initial implementation of API and LLMTraceAPI

* Add copyright comments

* Fix lint errors

* Expose new api to comet_llm namespace

* Refactor if-else statements

* Add API.query

* Add type hint

* Update log_metadata

* Refactor code

* Refactor some code

* Add reading api key from environment

* Fix lint errors
  • Loading branch information
alexkuzmik authored Dec 15, 2023
1 parent d5a6b57 commit 5da0e02
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/comet_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# *******************************************************

from . import app, autologgers, config, logging
from .api_objects.api import API
from .config import init, is_ready

if config.comet_disabled():
Expand All @@ -33,6 +34,7 @@
"is_ready",
"log_user_feedback",
"flush",
"API",
]

logging.setup()
Expand Down
13 changes: 13 additions & 0 deletions src/comet_llm/api_objects/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this package.
# *******************************************************
125 changes: 125 additions & 0 deletions src/comet_llm/api_objects/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this package.
# *******************************************************

from typing import List, Optional

import comet_ml

from .. import experiment_info, logging_messages, query_dsl
from . import llm_trace_api

# TODO: make the decision about dependencies from comet-ml. Address testing.


class API:
def __init__(self, api_key: Optional[str] = None) -> None:
experiment_info_ = experiment_info.get(
api_key,
api_key_not_found_message=logging_messages.API_KEY_NOT_FOUND_MESSAGE
% "API",
)
self._api = comet_ml.API(api_key=experiment_info_.api_key, cache=False)

def get_llm_trace_by_key(self, trace_key: str) -> llm_trace_api.LLMTraceAPI:
"""
Get an API Trace object by key.
Args:
trace_key: str, key of the prompt or chain
Returns: An LLMTraceAPI object that can be used to get or update trace data
"""
matching_trace = self._api.get_experiment_by_key(trace_key)

if matching_trace is None:
raise ValueError(
f"Failed to find any matching traces with the key {trace_key}"
)

return llm_trace_api.LLMTraceAPI.__api__from_api_experiment__(matching_trace)

def get_llm_trace_by_name(
self, workspace: str, project_name: str, trace_name: str
) -> llm_trace_api.LLMTraceAPI:
"""
Get an API Trace object by name.
Args:
workspace: str, name of the workspace
project_name: str, name of the project
trace_name: str, name of the prompt or chain
Returns: An LLMTraceAPI object that can be used to get or update trace data
"""
matching_trace = self._api.query(
workspace, project_name, query_dsl.Other("Name") == trace_name
)

if len(matching_trace) == 0:
raise ValueError(
f"Failed to find any matching traces with the name {trace_name} in the project {project_name}"
)
elif len(matching_trace) > 1:
raise ValueError(
f"Found multiple traces with the name {trace_name} in the project {project_name}"
)

return llm_trace_api.LLMTraceAPI.__api__from_api_experiment__(matching_trace[0])

def query(
self, workspace: str, project_name: str, query: str
) -> List[llm_trace_api.LLMTraceAPI]:
"""
Fetch LLM Trace based on a query. Currently it is only possible to use
trace metadata or details fields to filter the traces.
Args:
workspace: str, name of the workspace
project_name: str, name of the project
query: str, name of the prompt or chain
Returns: A list of LLMTrace objects
Notes:
The `query` object takes the form of (QUERY_VARIABLE OPERATOR VALUE) with:
* QUERY_VARIABLE is either TraceMetadata, Duration, Timestamp.
* OPERATOR is any standard mathematical operators `<=`, `>=`, `!=`, `<`, `>`.
It is also possible to add multiple query conditions using `&`.
If you are querying nested parameters, you should flatted the parameter name using the
`.` operator.
To query the duration, you can use Duration().
Example:
```python
# Find all traces where the metadata field `token` is greater than 50
api.query("workspace", "project", TraceMetadata("token") > 50)
# Find all traces where the duration field is between 1 second and 2 seconds
api.query("workspace", "project", (Duration() > 1) & (Duration() <= 2))
# Find all traces based on the timestamp
api.query("workspace", "project", Timestamp() > datetime(2023, 9, 10))
```
"""
matching_api_objects = self._api.query(workspace, project_name, query)

return [
llm_trace_api.LLMTraceAPI.__api__from_api_experiment__(api_object)
for api_object in matching_api_objects
]
117 changes: 117 additions & 0 deletions src/comet_llm/api_objects/llm_trace_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this package.
# *******************************************************

import io
import json
from typing import Dict, Optional

import comet_ml

from .. import convert
from ..chains import deepmerge
from ..types import JSONEncodable


class LLMTraceAPI:
_api_experiment: comet_ml.APIExperiment

def __init__(self) -> None:
raise NotImplementedError(
"Please use API.get_llm_trace_by_key or API.get_llm_trace_by_name methods to get the instance"
)

@classmethod
def __api__from_api_experiment__(
cls, api_experiment: comet_ml.APIExperiment
) -> "LLMTraceAPI":
instance = object.__new__(cls)
instance._api_experiment = api_experiment

return instance

def get_name(self) -> Optional[str]:
"""
Get the name of the trace
"""
return self._api_experiment.get_name() # type: ignore

def get_key(self) -> str:
"""
Get the unique identifier for this trace
"""
return self._api_experiment.key # type: ignore

def log_user_feedback(self, score: float) -> None:
"""
Log user feedback
Args:
score: float, the feedback score. Can be either 0, 0.0, 1 or 1.0
"""
ALLOWED_SCORES = [0.0, 1.0]
if score not in ALLOWED_SCORES:
raise ValueError(
f"Score it not valid, should be {ALLOWED_SCORES} but got {score}"
)

self._api_experiment.log_metric("user_feedback", score)

def _get_trace_data(self) -> Dict[str, JSONEncodable]:
try:
asset_id = next(
asset
for asset in self._api_experiment.get_asset_list()
if asset["fileName"] == "comet_llm_data.json"
)["assetId"]
except Exception as exception:
raise ValueError(
"Failed update metadata for this trace, metadata is not available"
) from exception

trace_data = json.loads(self._api_experiment.get_asset(asset_id))

return trace_data # type: ignore

def get_metadata(self) -> Dict[str, JSONEncodable]:
"""
Get trace metadata
"""
trace_data = self._get_trace_data()

return trace_data["metadata"] # type: ignore

def log_metadata(self, metadata: Dict[str, JSONEncodable]) -> None:
"""
Update the metadata field for a trace, can be used to set or update metadata fields
Args:
metadata_dict: dict, dict in the form of {"metadata_name": value, ...}. Nested metadata is supported
"""

trace_data = self._get_trace_data()
updated_trace_metadata = deepmerge.deepmerge(
trace_data.get("metadata", {}), metadata
)
trace_data["metadata"] = updated_trace_metadata

stream = io.StringIO()
json.dump(trace_data, stream)
stream.seek(0)
self._api_experiment.log_asset(
stream, overwrite=True, name="comet_llm_data.json"
)

self._api_experiment.log_parameters(
convert.chain_metadata_to_flat_parameters(metadata)
)
21 changes: 21 additions & 0 deletions src/comet_llm/query_dsl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this package.
# *******************************************************

from comet_ml import api

Duration = lambda: api.Metric("duration") # noqa: E731
Timestamp = lambda: api.Metadata("start_server_timestamp") # noqa: E731
TraceMetadata = api.Parameter
TraceDetail = api.Metadata
Other = api.Other

0 comments on commit 5da0e02

Please sign in to comment.