-
Notifications
You must be signed in to change notification settings - Fork 107
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CM-9113] Introduce api object (#109)
* Add initial implementation of API and LLMTraceAPI * Add copyright comments * Fix lint errors * Expose new api to comet_llm namespace * Refactor if-else statements * Add API.query * Add type hint * Update log_metadata * Refactor code * Refactor some code * Add reading api key from environment * Fix lint errors
- Loading branch information
1 parent
d5a6b57
commit 5da0e02
Showing
5 changed files
with
278 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# -*- coding: utf-8 -*- | ||
# ******************************************************* | ||
# ____ _ _ | ||
# / ___|___ _ __ ___ ___| |_ _ __ ___ | | | ||
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | | ||
# | |__| (_) | | | | | | __/ |_ _| | | | | | | | ||
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| | ||
# | ||
# Sign up for free at https://www.comet.com | ||
# Copyright (C) 2015-2023 Comet ML INC | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this package. | ||
# ******************************************************* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# -*- coding: utf-8 -*- | ||
# ******************************************************* | ||
# ____ _ _ | ||
# / ___|___ _ __ ___ ___| |_ _ __ ___ | | | ||
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | | ||
# | |__| (_) | | | | | | __/ |_ _| | | | | | | | ||
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| | ||
# | ||
# Sign up for free at https://www.comet.com | ||
# Copyright (C) 2015-2023 Comet ML INC | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this package. | ||
# ******************************************************* | ||
|
||
from typing import List, Optional | ||
|
||
import comet_ml | ||
|
||
from .. import experiment_info, logging_messages, query_dsl | ||
from . import llm_trace_api | ||
|
||
# TODO: make the decision about dependencies from comet-ml. Address testing. | ||
|
||
|
||
class API: | ||
def __init__(self, api_key: Optional[str] = None) -> None: | ||
experiment_info_ = experiment_info.get( | ||
api_key, | ||
api_key_not_found_message=logging_messages.API_KEY_NOT_FOUND_MESSAGE | ||
% "API", | ||
) | ||
self._api = comet_ml.API(api_key=experiment_info_.api_key, cache=False) | ||
|
||
def get_llm_trace_by_key(self, trace_key: str) -> llm_trace_api.LLMTraceAPI: | ||
""" | ||
Get an API Trace object by key. | ||
Args: | ||
trace_key: str, key of the prompt or chain | ||
Returns: An LLMTraceAPI object that can be used to get or update trace data | ||
""" | ||
matching_trace = self._api.get_experiment_by_key(trace_key) | ||
|
||
if matching_trace is None: | ||
raise ValueError( | ||
f"Failed to find any matching traces with the key {trace_key}" | ||
) | ||
|
||
return llm_trace_api.LLMTraceAPI.__api__from_api_experiment__(matching_trace) | ||
|
||
def get_llm_trace_by_name( | ||
self, workspace: str, project_name: str, trace_name: str | ||
) -> llm_trace_api.LLMTraceAPI: | ||
""" | ||
Get an API Trace object by name. | ||
Args: | ||
workspace: str, name of the workspace | ||
project_name: str, name of the project | ||
trace_name: str, name of the prompt or chain | ||
Returns: An LLMTraceAPI object that can be used to get or update trace data | ||
""" | ||
matching_trace = self._api.query( | ||
workspace, project_name, query_dsl.Other("Name") == trace_name | ||
) | ||
|
||
if len(matching_trace) == 0: | ||
raise ValueError( | ||
f"Failed to find any matching traces with the name {trace_name} in the project {project_name}" | ||
) | ||
elif len(matching_trace) > 1: | ||
raise ValueError( | ||
f"Found multiple traces with the name {trace_name} in the project {project_name}" | ||
) | ||
|
||
return llm_trace_api.LLMTraceAPI.__api__from_api_experiment__(matching_trace[0]) | ||
|
||
def query( | ||
self, workspace: str, project_name: str, query: str | ||
) -> List[llm_trace_api.LLMTraceAPI]: | ||
""" | ||
Fetch LLM Trace based on a query. Currently it is only possible to use | ||
trace metadata or details fields to filter the traces. | ||
Args: | ||
workspace: str, name of the workspace | ||
project_name: str, name of the project | ||
query: str, name of the prompt or chain | ||
Returns: A list of LLMTrace objects | ||
Notes: | ||
The `query` object takes the form of (QUERY_VARIABLE OPERATOR VALUE) with: | ||
* QUERY_VARIABLE is either TraceMetadata, Duration, Timestamp. | ||
* OPERATOR is any standard mathematical operators `<=`, `>=`, `!=`, `<`, `>`. | ||
It is also possible to add multiple query conditions using `&`. | ||
If you are querying nested parameters, you should flatted the parameter name using the | ||
`.` operator. | ||
To query the duration, you can use Duration(). | ||
Example: | ||
```python | ||
# Find all traces where the metadata field `token` is greater than 50 | ||
api.query("workspace", "project", TraceMetadata("token") > 50) | ||
# Find all traces where the duration field is between 1 second and 2 seconds | ||
api.query("workspace", "project", (Duration() > 1) & (Duration() <= 2)) | ||
# Find all traces based on the timestamp | ||
api.query("workspace", "project", Timestamp() > datetime(2023, 9, 10)) | ||
``` | ||
""" | ||
matching_api_objects = self._api.query(workspace, project_name, query) | ||
|
||
return [ | ||
llm_trace_api.LLMTraceAPI.__api__from_api_experiment__(api_object) | ||
for api_object in matching_api_objects | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# -*- coding: utf-8 -*- | ||
# ******************************************************* | ||
# ____ _ _ | ||
# / ___|___ _ __ ___ ___| |_ _ __ ___ | | | ||
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | | ||
# | |__| (_) | | | | | | __/ |_ _| | | | | | | | ||
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| | ||
# | ||
# Sign up for free at https://www.comet.com | ||
# Copyright (C) 2015-2023 Comet ML INC | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this package. | ||
# ******************************************************* | ||
|
||
import io | ||
import json | ||
from typing import Dict, Optional | ||
|
||
import comet_ml | ||
|
||
from .. import convert | ||
from ..chains import deepmerge | ||
from ..types import JSONEncodable | ||
|
||
|
||
class LLMTraceAPI: | ||
_api_experiment: comet_ml.APIExperiment | ||
|
||
def __init__(self) -> None: | ||
raise NotImplementedError( | ||
"Please use API.get_llm_trace_by_key or API.get_llm_trace_by_name methods to get the instance" | ||
) | ||
|
||
@classmethod | ||
def __api__from_api_experiment__( | ||
cls, api_experiment: comet_ml.APIExperiment | ||
) -> "LLMTraceAPI": | ||
instance = object.__new__(cls) | ||
instance._api_experiment = api_experiment | ||
|
||
return instance | ||
|
||
def get_name(self) -> Optional[str]: | ||
""" | ||
Get the name of the trace | ||
""" | ||
return self._api_experiment.get_name() # type: ignore | ||
|
||
def get_key(self) -> str: | ||
""" | ||
Get the unique identifier for this trace | ||
""" | ||
return self._api_experiment.key # type: ignore | ||
|
||
def log_user_feedback(self, score: float) -> None: | ||
""" | ||
Log user feedback | ||
Args: | ||
score: float, the feedback score. Can be either 0, 0.0, 1 or 1.0 | ||
""" | ||
ALLOWED_SCORES = [0.0, 1.0] | ||
if score not in ALLOWED_SCORES: | ||
raise ValueError( | ||
f"Score it not valid, should be {ALLOWED_SCORES} but got {score}" | ||
) | ||
|
||
self._api_experiment.log_metric("user_feedback", score) | ||
|
||
def _get_trace_data(self) -> Dict[str, JSONEncodable]: | ||
try: | ||
asset_id = next( | ||
asset | ||
for asset in self._api_experiment.get_asset_list() | ||
if asset["fileName"] == "comet_llm_data.json" | ||
)["assetId"] | ||
except Exception as exception: | ||
raise ValueError( | ||
"Failed update metadata for this trace, metadata is not available" | ||
) from exception | ||
|
||
trace_data = json.loads(self._api_experiment.get_asset(asset_id)) | ||
|
||
return trace_data # type: ignore | ||
|
||
def get_metadata(self) -> Dict[str, JSONEncodable]: | ||
""" | ||
Get trace metadata | ||
""" | ||
trace_data = self._get_trace_data() | ||
|
||
return trace_data["metadata"] # type: ignore | ||
|
||
def log_metadata(self, metadata: Dict[str, JSONEncodable]) -> None: | ||
""" | ||
Update the metadata field for a trace, can be used to set or update metadata fields | ||
Args: | ||
metadata_dict: dict, dict in the form of {"metadata_name": value, ...}. Nested metadata is supported | ||
""" | ||
|
||
trace_data = self._get_trace_data() | ||
updated_trace_metadata = deepmerge.deepmerge( | ||
trace_data.get("metadata", {}), metadata | ||
) | ||
trace_data["metadata"] = updated_trace_metadata | ||
|
||
stream = io.StringIO() | ||
json.dump(trace_data, stream) | ||
stream.seek(0) | ||
self._api_experiment.log_asset( | ||
stream, overwrite=True, name="comet_llm_data.json" | ||
) | ||
|
||
self._api_experiment.log_parameters( | ||
convert.chain_metadata_to_flat_parameters(metadata) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# -*- coding: utf-8 -*- | ||
# ******************************************************* | ||
# ____ _ _ | ||
# / ___|___ _ __ ___ ___| |_ _ __ ___ | | | ||
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | | ||
# | |__| (_) | | | | | | __/ |_ _| | | | | | | | ||
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| | ||
# | ||
# Sign up for free at https://www.comet.com | ||
# Copyright (C) 2015-2023 Comet ML INC | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this package. | ||
# ******************************************************* | ||
|
||
from comet_ml import api | ||
|
||
Duration = lambda: api.Metric("duration") # noqa: E731 | ||
Timestamp = lambda: api.Metadata("start_server_timestamp") # noqa: E731 | ||
TraceMetadata = api.Parameter | ||
TraceDetail = api.Metadata | ||
Other = api.Other |