Skip to content

Commit

Permalink
download pred probs as pd dataframe (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
clu0 committed Jul 19, 2023
1 parent a00820b commit 19de423
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
11 changes: 8 additions & 3 deletions cleanlab_studio/internal/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,21 @@ def download_cleanlab_columns(
return cleanset_pd


def download_numpy(api_key: str, cleanset_id: str, name: str) -> npt.NDArray[np.float_]:
def download_array(
api_key: str, cleanset_id: str, name: str
) -> Union[npt.NDArray[np.float_], pd.DataFrame]:
res = requests.get(
cli_base_url + f"/cleansets/{cleanset_id}/{name}",
headers=_construct_headers(api_key),
)
handle_api_error(res)
res_json: JSONDict = res.json()
if res_json["success"]:
np_data: npt.NDArray[np.float_] = np.array(res_json[name])
return np_data
if res_json["array_type"] == "numpy":
np_data: npt.NDArray[np.float_] = np.array(res_json[name])
return np_data
pd_data: pd.DataFrame = pd.read_json(res_json[name], orient="records")
return pd_data
raise APIError(f"{name} for cleanset {cleanset_id} not found")


Expand Down
17 changes: 13 additions & 4 deletions cleanlab_studio/studio/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,20 @@ def __init__(self, outer): # type: ignore
def download_pred_probs(
self,
cleanset_id: str,
) -> npt.NDArray[np.float_]:
return api.download_numpy(self._outer._api_key, cleanset_id, "pred_probs")
) -> Union[npt.NDArray[np.float_], pd.DataFrame]:
"""
Downloads predicted probabilities for a cleanset
Old pred_probs were saved as numpy arrays, which is still compatible
Newer pred_probs are saved as pd.DataFrames
"""
return api.download_array(self._outer._api_key, cleanset_id, "pred_probs")

def download_embeddings(
self,
cleanset_id: str,
) -> npt.NDArray[np.float_]:
return api.download_numpy(self._outer._api_key, cleanset_id, "embeddings")
) -> Union[npt.NDArray[np.float_], pd.DataFrame]:
"""
Downloads embeddings for a cleanset
The downloaded array will always be a numpy array, the above is just for typing purposes
"""
return api.download_array(self._outer._api_key, cleanset_id, "embeddings")

0 comments on commit 19de423

Please sign in to comment.