Skip to content

Commit

Permalink
add data enrichment feature (#297)
Browse files Browse the repository at this point in the history
  • Loading branch information
novking authored Aug 23, 2024
1 parent 321885d commit b4cb32c
Show file tree
Hide file tree
Showing 6 changed files with 985 additions and 6 deletions.
8 changes: 8 additions & 0 deletions cleanlab_studio/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,13 @@ def __init__(self, filepath: Union[str, pathlib.Path] = "") -> None:
super().__init__(f"File could not be found at {filepath}. Please check the file path.")


class InvalidCsvFilename(HandledError):
pass


class EnrichmentProjectError(InternalError):
pass


class InvalidInputError(HandledError):
pass
256 changes: 253 additions & 3 deletions cleanlab_studio/internal/api/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import io
import os
Expand Down Expand Up @@ -40,14 +42,18 @@
pyspark_exists = False

from cleanlab_studio.errors import NotInstalledError
from cleanlab_studio.internal.api.api_helper import check_uuid_well_formed
from cleanlab_studio.internal.types import JSONDict, SchemaOverride
from cleanlab_studio.internal.api.api_helper import (
check_uuid_well_formed,
check_valid_csv_filename,
)
from cleanlab_studio.internal.types import JSONDict, SchemaOverride, TLMQualityPreset
from cleanlab_studio.version import __version__

base_url = os.environ.get("CLEANLAB_API_BASE_URL", "https://api.cleanlab.ai/api")
cli_base_url = f"{base_url}/cli/v0"
upload_base_url = f"{base_url}/upload/v1"
dataset_base_url = f"{base_url}/datasets"
enrichment_base_url = f"{base_url}/enrichment/v0"
project_base_url = f"{base_url}/projects"
cleanset_base_url = f"{base_url}/cleansets"
model_base_url = f"{base_url}/v1/deployment"
Expand Down Expand Up @@ -77,7 +83,7 @@ def handle_api_error_from_json(res_json: JSONDict, status_code: Optional[int] =
else:
raise APIError(res_json["description"])

if res_json.get("error", None) is not None:
if isinstance(res_json, dict) and res_json.get("error", None) is not None:
error = res_json["error"]
if (
status_code == 422
Expand All @@ -87,6 +93,9 @@ def handle_api_error_from_json(res_json: JSONDict, status_code: Optional[int] =
raise InvalidProjectConfiguration(error["description"])
raise APIError(res_json["error"])

if status_code != 200:
raise APIError(f"API call failed with status code {status_code}")


def handle_rate_limit_error_from_resp(resp: aiohttp.ClientResponse) -> None:
"""Catches 429 (rate limit) errors."""
Expand Down Expand Up @@ -652,6 +661,247 @@ def get_deployed_model_info(api_key: str, model_id: str) -> Dict[str, str]:
return cast(Dict[str, str], res.json())


def create_enrichment_project(
api_key: str,
dataset_id: str,
name: str,
) -> JSONDict:
"""Create a new enrichment project."""
check_uuid_well_formed(dataset_id, "dataset ID")
request_json = dict(
dataset_id=dataset_id,
name=name,
)
res = requests.post(
f"{enrichment_base_url}/projects",
headers=_construct_headers(api_key),
json=request_json,
)
handle_api_error(res)
return cast(JSONDict, res.json())


def delete_enrichment_project(api_key: str, project_id: str) -> None:
"""Delete an enrichment project."""
check_uuid_well_formed(project_id, "Enrichment Project ID")
res = requests.delete(
f"{enrichment_base_url}/projects/{project_id}", headers=_construct_headers(api_key)
)
handle_api_error(res)


def get_enrichment_project(api_key: str, project_id: str) -> JSONDict:
"""Get an existing enrichment project."""
check_uuid_well_formed(project_id, "Enrichment Project ID")
res = requests.get(
f"{enrichment_base_url}/projects/{project_id}",
headers=_construct_headers(api_key),
)
handle_api_error(res)
return cast(JSONDict, res.json())


def list_all_enrichment_projects(api_key: str) -> List[JSONDict]:
"""Get a list of all enrichment projects."""
all_projects: List[JSONDict] = []
page: Optional[int] = 1

while page is not None:
res = requests.get(
f"{enrichment_base_url}/projects?page={page}",
headers=_construct_headers(api_key),
)
handle_api_error(res)
res_json: JSONDict = res.json()

if "projects" in res_json:
all_projects.extend(res_json["projects"])

page = res_json.get("next_page", None)

return all_projects


def enrichment_preview(
api_key: str,
new_column_name: str,
project_id: str,
prompt: str,
constrain_outputs: Optional[List[str]] = None,
extraction_pattern: Optional[str] = None,
indices: Optional[List[int]] = None,
optimize_prompt: Optional[bool] = True,
quality_preset: Optional[TLMQualityPreset] = "medium",
replacements: Optional[List[Dict[str, str]]] = [],
tlm_options: Optional[Dict[str, Any]] = {},
) -> JSONDict:
"""Call Enrichment Preview API and get response."""
check_uuid_well_formed(project_id, "project_id")
request_json = dict(
new_column_name=new_column_name,
project_id=project_id,
prompt=prompt,
constrain_outputs=constrain_outputs,
extraction_pattern=extraction_pattern,
indices=indices,
optimize_prompt=optimize_prompt,
replacements=replacements,
tlm_options=tlm_options,
tlm_quality_preset=quality_preset,
)

res = requests.post(
f"{enrichment_base_url}/preview",
headers=_construct_headers(api_key),
json=request_json,
)
handle_api_error(res)
return cast(JSONDict, res.json())


def enrichment_run(
api_key: str,
new_column_name: str,
project_id: str,
prompt: str,
constrain_outputs: Optional[List[str]] = None,
extraction_pattern: Optional[str] = None,
optimize_prompt: Optional[bool] = True,
quality_preset: Optional[TLMQualityPreset] = "medium",
replacements: Optional[List[Dict[str, str]]] = [],
tlm_options: Optional[Dict[str, Any]] = {},
) -> JSONDict:
"""Call Enrichment Enrich_all API and get response."""
check_uuid_well_formed(project_id, "project_id")
request_json = dict(
new_column_name=new_column_name,
project_id=project_id,
prompt=prompt,
constrain_outputs=constrain_outputs,
extraction_pattern=extraction_pattern,
optimize_prompt=optimize_prompt,
replacements=replacements,
tlm_options=tlm_options,
tlm_quality_preset=quality_preset,
)

res = requests.post(
f"{enrichment_base_url}/enrich_all",
headers=_construct_headers(api_key),
json=request_json,
)
handle_api_error(res)
return cast(JSONDict, res.json())


def get_enrichment_job_status(api_key: str, job_id: str) -> JSONDict:
"""Get status of enrichment job."""
check_uuid_well_formed(job_id, "job_id")

res = requests.get(
f"{enrichment_base_url}/status/{job_id}",
headers=_construct_headers(api_key),
)
handle_api_error(res)
return cast(JSONDict, res.json())


def get_enrichment_job_result(
api_key: str, job_id: str, page: int, include_original_dataset: Optional[bool] = False
) -> List[JSONDict]:
"""Get result of enrichment job.
Args:
api_key (str): studio API key for auth
job_id (str): job id
page (int): page number
include_original_dataset (bool): whether to return only results or merged results and original dataset directly from the backend
"""
check_uuid_well_formed(job_id, "job_id")

res = requests.get(
f"{enrichment_base_url}/enrich_all/{job_id}",
headers=_construct_headers(api_key),
params=dict(page=page, include_original_dataset=include_original_dataset),
)
handle_api_error(res)
return cast(List[JSONDict], res.json())


def list_enrichment_jobs(api_key: str, project_id: str) -> List[JSONDict]:
"""List all enrichment jobs for a project."""
check_uuid_well_formed(project_id, "project_id")
res = requests.get(
f"{enrichment_base_url}/projects/{project_id}/jobs",
headers=_construct_headers(api_key),
)
handle_api_error(res)
return cast(List[JSONDict], res.json())


def get_enrichment_job(api_key: str, job_id: str) -> JSONDict:
"""Get enrichment job."""
check_uuid_well_formed(job_id, "job_id")

res = requests.get(
f"{enrichment_base_url}/jobs/{job_id}",
headers=_construct_headers(api_key),
)
handle_api_error(res)
return cast(JSONDict, res.json())


def export_results(api_key: str, job_id: str, filename: Optional[str] = None) -> str:
"""
Exports the results of a job to a CSV file.
Args:
api_key (str): The API key used for authentication.
job_id (str): The unique identifier of the job whose results are to be exported.
filename (str): The name of the CSV file to save the results to. If None, a default filename is generated.
Returns:
str: A message indicating the CSV file has been saved, including the filename.
"""
check_uuid_well_formed(job_id, "job_id")
if filename is None:
filename = f"enrichment_results_{job_id}.csv"
check_valid_csv_filename(filename)
res = requests.get(
f"{enrichment_base_url}/export/{job_id}",
headers=_construct_headers(api_key),
)
if res.status_code == 200:
with open(filename, "wb") as file:
file.write(res.content)
else:
handle_api_error(res)
return filename


def pause_enrichment_job(api_key: str, job_id: str) -> None:
"""Pause enrichment job."""
check_uuid_well_formed(job_id, "job_id")

res = requests.post(
f"{enrichment_base_url}/enrich_all/{job_id}/pause",
headers=_construct_headers(api_key),
)
handle_api_error(res)


def resume_enrichment_job(api_key: str, job_id: str) -> JSONDict:
"""Resume enrichment job."""
check_uuid_well_formed(job_id, "job_id")

res = requests.post(
f"{enrichment_base_url}/enrich_all/{job_id}/resume",
headers=_construct_headers(api_key),
)
handle_api_error(res)
return cast(JSONDict, res.json())


def tlm_retry(func: Callable[..., Any]) -> Callable[..., Any]:
"""Implements TLM retry decorator, with special handling for rate limit retries."""

Expand Down
7 changes: 6 additions & 1 deletion cleanlab_studio/internal/api/api_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid

from cleanlab_studio.errors import InvalidUUIDError
from cleanlab_studio.errors import InvalidCsvFilename, InvalidUUIDError


def check_uuid_well_formed(uuid_string: str, id_name: str) -> None:
Expand All @@ -10,3 +10,8 @@ def check_uuid_well_formed(uuid_string: str, id_name: str) -> None:
raise InvalidUUIDError(
f"{uuid_string} is not a well-formed {id_name}, please double check and try again."
)


def check_valid_csv_filename(filename: str) -> None:
if not filename.lower().endswith(".csv"):
raise InvalidCsvFilename(f"{filename} is not a valid csv filename.")
Loading

0 comments on commit b4cb32c

Please sign in to comment.