From 4ad576326eb7b61ea66cf349cd5f15444483e27b Mon Sep 17 00:00:00 2001 From: Ulyana Date: Tue, 30 Jul 2024 12:50:58 -0700 Subject: [PATCH] Add explainability to TLM (#267) --- cleanlab_studio/internal/api/api.py | 26 +++++++++--------- cleanlab_studio/internal/constants.py | 4 +-- cleanlab_studio/internal/tlm/validation.py | 8 +++--- cleanlab_studio/internal/types.py | 3 +-- .../studio/trustworthy_language_model.py | 27 ++++++++++--------- 5 files changed, 33 insertions(+), 35 deletions(-) diff --git a/cleanlab_studio/internal/api/api.py b/cleanlab_studio/internal/api/api.py index a6607527..9e749205 100644 --- a/cleanlab_studio/internal/api/api.py +++ b/cleanlab_studio/internal/api/api.py @@ -2,7 +2,16 @@ import io import os import time -from typing import Callable, cast, List, Optional, Tuple, Dict, Union, Any +from io import StringIO +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast + +import aiohttp +import aiohttp.client_exceptions +import numpy as np +import numpy.typing as npt +import pandas as pd +import requests +from tqdm import tqdm from cleanlab_studio.errors import ( APIError, @@ -15,15 +24,6 @@ ) from cleanlab_studio.internal.tlm.concurrency import TlmRateHandler -import aiohttp -import aiohttp.client_exceptions -import requests -from tqdm import tqdm -import pandas as pd -import numpy as np -import numpy.typing as npt -from io import StringIO - try: import snowflake import snowflake.snowpark as snowpark @@ -39,12 +39,10 @@ except ImportError: pyspark_exists = False +from cleanlab_studio.errors import NotInstalledError +from cleanlab_studio.internal.api.api_helper import check_uuid_well_formed from cleanlab_studio.internal.types import JSONDict, SchemaOverride from cleanlab_studio.version import __version__ -from cleanlab_studio.errors import NotInstalledError -from cleanlab_studio.internal.api.api_helper import ( - check_uuid_well_formed, -) base_url = os.environ.get("CLEANLAB_API_BASE_URL", "https://api.cleanlab.ai/api") cli_base_url = f"{base_url}/cli/v0" diff --git a/cleanlab_studio/internal/constants.py b/cleanlab_studio/internal/constants.py index 29230d5b..f4d1edac 100644 --- a/cleanlab_studio/internal/constants.py +++ b/cleanlab_studio/internal/constants.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Set +from typing import List, Set, Tuple # TLM constants # prepend constants with _ so that they don't show up in help.cleanlab.ai docs @@ -16,5 +16,5 @@ TLM_MAX_TOKEN_RANGE: Tuple[int, int] = (64, 512) # (min, max) TLM_NUM_CANDIDATE_RESPONSES_RANGE: Tuple[int, int] = (1, 20) # (min, max) TLM_NUM_CONSISTENCY_SAMPLES_RANGE: Tuple[int, int] = (0, 20) # (min, max) -TLM_VALID_LOG_OPTIONS: Set[str] = {"perplexity"} +TLM_VALID_LOG_OPTIONS: Set[str] = {"perplexity", "explanation"} TLM_VALID_GET_TRUSTWORTHINESS_SCORE_KWARGS: Set[str] = {"perplexity"} diff --git a/cleanlab_studio/internal/tlm/validation.py b/cleanlab_studio/internal/tlm/validation.py index dd374aad..909fd284 100644 --- a/cleanlab_studio/internal/tlm/validation.py +++ b/cleanlab_studio/internal/tlm/validation.py @@ -1,16 +1,16 @@ import os -from typing import Union, Sequence, List, Dict, Tuple, Any +from typing import Any, Dict, List, Sequence, Union + from cleanlab_studio.errors import ValidationError from cleanlab_studio.internal.constants import ( _VALID_TLM_MODELS, TLM_MAX_TOKEN_RANGE, TLM_NUM_CANDIDATE_RESPONSES_RANGE, TLM_NUM_CONSISTENCY_SAMPLES_RANGE, - TLM_VALID_LOG_OPTIONS, TLM_VALID_GET_TRUSTWORTHINESS_SCORE_KWARGS, + TLM_VALID_LOG_OPTIONS, ) - SKIP_VALIDATE_TLM_OPTIONS: bool = ( os.environ.get("CLEANLAB_STUDIO_SKIP_VALIDATE_TLM_OPTIONS", "false").lower() == "true" ) @@ -216,7 +216,6 @@ def process_response_and_kwargs( ) if val is not None and not 0 <= val <= 1: raise ValidationError("Perplexity values must be between 0 and 1") - elif isinstance(response, Sequence): if not isinstance(val, Sequence): raise ValidationError( @@ -235,7 +234,6 @@ def process_response_and_kwargs( if v is not None and not 0 <= v <= 1: raise ValidationError("Perplexity values must be between 0 and 1") - else: raise ValidationError( f"Invalid type {type(val)}, perplexity must be either a sequence or a float" diff --git a/cleanlab_studio/internal/types.py b/cleanlab_studio/internal/types.py index 63bb44f1..c6a2a9ee 100644 --- a/cleanlab_studio/internal/types.py +++ b/cleanlab_studio/internal/types.py @@ -1,5 +1,4 @@ -from typing import Any, Dict, Optional, TypedDict, Literal - +from typing import Any, Dict, Literal, Optional, TypedDict JSONDict = Dict[str, Any] diff --git a/cleanlab_studio/studio/trustworthy_language_model.py b/cleanlab_studio/studio/trustworthy_language_model.py index b380eb38..c7c27f70 100644 --- a/cleanlab_studio/studio/trustworthy_language_model.py +++ b/cleanlab_studio/studio/trustworthy_language_model.py @@ -10,29 +10,31 @@ import asyncio import sys -from typing import Coroutine, List, Optional, Union, cast, Sequence, Any, Dict -from tqdm.asyncio import tqdm_asyncio -import numpy as np +from typing import Any, Coroutine, Dict, List, Optional, Sequence, Union, cast import aiohttp -from typing_extensions import NotRequired, TypedDict # for Python <3.11 with (Not)Required +from tqdm.asyncio import tqdm_asyncio +from typing_extensions import ( # for Python <3.11 with (Not)Required + NotRequired, + TypedDict, +) +from cleanlab_studio.errors import ValidationError from cleanlab_studio.internal.api import api +from cleanlab_studio.internal.constants import ( + _TLM_MAX_RETRIES, + _VALID_TLM_QUALITY_PRESETS, +) from cleanlab_studio.internal.tlm.concurrency import TlmRateHandler from cleanlab_studio.internal.tlm.validation import ( + process_response_and_kwargs, + validate_tlm_options, validate_tlm_prompt, - validate_tlm_try_prompt, validate_tlm_prompt_response, + validate_tlm_try_prompt, validate_try_tlm_prompt_response, - validate_tlm_options, - process_response_and_kwargs, ) from cleanlab_studio.internal.types import TLMQualityPreset -from cleanlab_studio.errors import ValidationError -from cleanlab_studio.internal.constants import ( - _VALID_TLM_QUALITY_PRESETS, - _TLM_MAX_RETRIES, -) class TLM: @@ -699,6 +701,7 @@ class TLMOptions(TypedDict): Setting this to False disables the use of self-reflection and may produce worse TLM trustworthiness scores, but will reduce costs/runtimes. log (List[str], default = None): optionally specify additional logs or metadata to return. + For instance, include "explanation" here to get explanations of why a response is scored with low trustworthiness. """ model: NotRequired[str]