diff --git a/README.md b/README.md index 6d9c4ba0..a55563bc 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,8 @@ Parea python sdk +[Python SDK Docs](https://docs.parea.ai/sdk/python) + ## Installation ```bash @@ -47,11 +49,10 @@ Alternatively, you can add the following code to your codebase to get started: ```python import os -from parea import init, InMemoryCache +from parea import Parea, InMemoryCache, trace from parea.schemas.log import Log -from parea.utils.trace_utils import trace -init(api_key=os.getenv("PAREA_API_KEY"), cache=InMemoryCache()) # use InMemoryCache if you don't have a Parea API key +Parea(api_key=os.getenv("PAREA_API_KEY"), cache=InMemoryCache()) # use InMemoryCache if you don't have a Parea API key def locally_defined_eval_function(log: Log) -> float: @@ -63,6 +64,32 @@ def function_to_evaluate(*args, **kwargs) -> ...: ... ``` +### Run Experiments + +You can run an experiment for your LLM application by defining the `Experiment` class and passing it the name, the data and the +function you want to run. You need annotate the function with the `trace` decorator to trace its inputs, outputs, latency, etc. +as well as to specify which evaluation functions should be applied to it (as shown above). + +```python +from parea import Experiment + +Experiment( + name="Experiment Name", # Name of the experiment (str) + data=[{"n": "10"}], # Data to run the experiment on (list of dicts) + func=function_to_evaluate, # Function to run (callable) +) +``` + +Then you can run the experiment by using the `experiment` command and give it the path to the python file. +This will run your experiment with the specified inputs and create a report with the results which can be viewed under +the [Experiments tab](https://app.parea.ai/experiments). + +```bash +parea experiment +``` + +Full working example in our [docs](https://docs.parea.ai/testing/run-experiments). + ## Debugging Chains & Agents You can iterate on your chains & agents much faster by using a local cache. This will allow you to make changes to your @@ -71,26 +98,26 @@ code and start [a local redis cache](https://redis.io/docs/getting-started/install-stack/): ```python -from parea import init, RedisCache +from parea import Parea, RedisCache -init(cache=RedisCache()) +Parea(cache=RedisCache()) ``` Above will use the default redis cache at `localhost:6379` with no password. You can also specify your redis database by: ```python -from parea import init, RedisCache +from parea import Parea, RedisCache cache = RedisCache( host=os.getenv("REDIS_HOST", "localhost"), # default value port=int(os.getenv("REDIS_PORT", 6379)), # default value password=os.getenv("REDIS_PASSWORT", None) # default value ) -init(cache=cache) +Parea(cache=cache) ``` -If you set `cache = None` for `init`, no cache will be used. +If you set `cache = None` for `Parea`, no cache will be used. ### Benchmark your LLM app across many inputs @@ -109,15 +136,15 @@ redis cache running. Please, raise a GitHub issue if you would like to use this ### Automatically log all your LLM call traces You can automatically log all your LLM traces to the Parea dashboard by setting the `PAREA_API_KEY` environment variable -or specifying it in the `init` function. +or specifying it in the `Parea` initialization. This will help you debug issues your customers are facing by stepping through the LLM call traces and recreating the issue in your local setup & code. ```python -from parea import init +from parea import Parea -init( +Parea( api_key=os.getenv("PAREA_API_KEY"), # default value cache=... ) diff --git a/parea/__init__.py b/parea/__init__.py index cfc4dc40..6c4cd08b 100644 --- a/parea/__init__.py +++ b/parea/__init__.py @@ -12,9 +12,10 @@ from importlib import metadata as importlib_metadata from parea.cache import InMemoryCache, RedisCache -from parea.client import Parea, init -from parea.experiment.cli import experiment as experiment_cli +from parea.client import Parea +from parea.experiment.cli import experiment as _experiment_cli from parea.experiment.experiment import Experiment +from parea.utils.trace_utils import trace def get_version() -> str: @@ -30,7 +31,7 @@ def get_version() -> str: def main(): args = sys.argv[1:] if args[0] == "experiment": - experiment_cli(args[1:]) + _experiment_cli(args[1:]) else: print(f"Unknown command: '{args[0]}'") diff --git a/parea/client.py b/parea/client.py index 49db8d85..a7dce144 100644 --- a/parea/client.py +++ b/parea/client.py @@ -34,9 +34,9 @@ @define class Parea: - api_key: str = field(init=True, default="") - _client: HTTPClient = field(init=False, default=HTTPClient()) + api_key: str = field(init=True, default=os.getenv("PAREA_API_KEY")) cache: Cache = field(init=True, default=None) + _client: HTTPClient = field(init=False, default=HTTPClient()) def __attrs_post_init__(self): self._client.set_api_key(self.api_key) @@ -144,10 +144,6 @@ async def aget_experiment_stats(self, experiment_uuid: str) -> ExperimentStatsSc _initialized_parea_wrapper = False -def init(api_key: str = os.getenv("PAREA_API_KEY"), cache: Cache = None) -> None: - Parea(api_key=api_key, cache=cache) - - def _init_parea_wrapper(log: Callable = None, cache: Cache = None): global _initialized_parea_wrapper if _initialized_parea_wrapper: diff --git a/parea/cookbook/tracing_and_evaluating_openai_endpoint.py b/parea/cookbook/tracing_and_evaluating_openai_endpoint.py index 831085f7..ef207a28 100644 --- a/parea/cookbook/tracing_and_evaluating_openai_endpoint.py +++ b/parea/cookbook/tracing_and_evaluating_openai_endpoint.py @@ -8,7 +8,7 @@ from attr import asdict from dotenv import load_dotenv -from parea import InMemoryCache, init +from parea import InMemoryCache, Parea from parea.evals.chat import goal_success_ratio_factory from parea.evals.utils import call_openai from parea.helpers import write_trace_logs_to_csv @@ -22,7 +22,7 @@ use_cache = True # by using the in memory cache, you don't need a Parea API key cache = InMemoryCache() if use_cache else None -init(api_key=os.getenv("PAREA_API_KEY"), cache=cache) +Parea(api_key=os.getenv("PAREA_API_KEY"), cache=cache) def friendliness(log: Log) -> float: diff --git a/parea/evals/chat/goal_success_ratio.py b/parea/evals/chat/goal_success_ratio.py index d44025cd..e800ecdc 100644 --- a/parea/evals/chat/goal_success_ratio.py +++ b/parea/evals/chat/goal_success_ratio.py @@ -7,7 +7,16 @@ def goal_success_ratio_factory(use_output: Optional[bool] = False, message_field: Optional[str] = None) -> Callable[[Log], float]: - """Factory function that returns a function that calculates the goal success ratio of a log. + """ + This factory creates an evaluation function that measures the success ratio of a goal-oriented conversation. + Typically, a user interacts with a chatbot or AI assistant to achieve specific goals. + This motivates to measure the quality of a chatbot by counting how many messages a user has to send before they reach their goal. + One can further break this down by successful and unsuccessful goals to analyze user & LLM behavior. + + Concretely: + 1. Delineate the conversation into segments by splitting them by the goals the user wants to achieve. + 2. Assess if every goal has been reached. + 3. Calculate the average number of messages sent per segment. Args: use_output (Optional[bool], optional): Whether to use the output of the log to access the messages. Defaults to False. diff --git a/parea/evals/general/answer_relevancy.py b/parea/evals/general/answer_relevancy.py index 6d0e0ffb..62e83675 100644 --- a/parea/evals/general/answer_relevancy.py +++ b/parea/evals/general/answer_relevancy.py @@ -5,7 +5,23 @@ def answer_relevancy_factory(question_field: str = "question", n_generations: int = 3) -> Callable[[Log], float]: - """Quantifies how much the generated answer relates to the query.""" + """ + This factory creates an evaluation function that measures how relevant the generated response is to the given question. + It is based on the paper [RAGAS: Automated Evaluation of Retrieval Augmented Generation](https://arxiv.org/abs/2309.15217) + which suggests using an LLM to generate multiple questions that fit the generated answer and measure the cosine + similarity of the generated questions with the original one. + + Args: + question_field: The key name/field used for the question/query of the user. Defaults to "question". + n_generations: The number of questions which should be generated. Defaults to 3. + + Returns: + Callable[[Log], float]: A function that takes a log as input and returns a score between 0 and 1 indicating + if the generated response is relevant to the query. + + Raises: + ImportError: If numpy is not installed. + """ try: import numpy as np except ImportError: diff --git a/parea/evals/general/llm_grader.py b/parea/evals/general/llm_grader.py index aec5c3bd..49525fcd 100644 --- a/parea/evals/general/llm_grader.py +++ b/parea/evals/general/llm_grader.py @@ -11,7 +11,23 @@ def llm_grader_factory(model: str, question_field: str = "question") -> Callable[[Log], float]: - """Measures the generated response quality by using a LLM on a scale of 1 to 10.""" + """ + This factory creates an evaluation function that uses an LLM to grade the response of an LLM to a given question. + It is based on the paper [Judging LLM-as-a-Judge with MT-Bench and Chatbot Arena](https://arxiv.org/abs/2306.05685) + which intorduces general-purpose zero-shot prompt to rate responses from an LLM to a given question on a scale from 1-10. + They find that GPT-4's ratings agree as much with a human rater as a human annotator agrees with another one (>80%). + Further, they observe that the agreement with a human annotator increases as the response rating gets clearer. + Additionally, they investigated how much the evaluating LLM overestimated its responses and found that GPT-4 and + Claude-1 were the only models that didn't overestimate themselves. + + Args: + model: The model which should be used for grading. Currently, only supports OpenAI chat models. + question_field: The key name/field used for the question/query of the user. Defaults to "question". + + Returns: + Callable[[Log], float]: A function that takes a log as input and returns a score between 0 and 1 which is the + rating of the response on a scale from 1-10 divided by 10. + """ def llm_grader(log: Log) -> float: question = log.inputs[question_field] diff --git a/parea/evals/general/lm_vs_lm.py b/parea/evals/general/lm_vs_lm.py index 394f6dc7..efb3799c 100644 --- a/parea/evals/general/lm_vs_lm.py +++ b/parea/evals/general/lm_vs_lm.py @@ -5,8 +5,21 @@ def lm_vs_lm_factuality_factory(examiner_model: str = "gpt-3.5-turbo") -> Callable[[Log], float]: - """Using an examining LLM, measures the factuality of a claim. Examining LLM asks follow-up questions to the other - LLM until it reaches a conclusion.""" + """ + This factory creates an evaluation function that measures the factuality of an LLM's response to a given question. + It is based on the paper [LM vs LM: Detecting Factual Errors via Cross Examination](https://arxiv.org/abs/2305.13281) which proposes using + another LLM to assess an LLM response's factuality. To do this, the examining LLM generates follow-up questions to the + original response until it can confidently determine the factuality of the response. + This method outperforms prompting techniques such as asking the original model, "Are you sure?" or instructing the + model to say, "I don't know," if it is uncertain. + + Args: + examiner_model: The model which will examine the original model. Currently, only supports OpenAI chat models. + + Returns: + Callable[[Log], float]: A function that takes a log as input and returns a score between 0 and 1 indicating + the factuality of the response. + """ def lm_vs_lm_factuality(log: Log) -> float: output = log.output diff --git a/parea/evals/general/self_check.py b/parea/evals/general/self_check.py index 492b7911..a6203f8d 100644 --- a/parea/evals/general/self_check.py +++ b/parea/evals/general/self_check.py @@ -3,7 +3,23 @@ def self_check(log: Log) -> float: - """Measures how consistent is the output of a model under resampling the response.""" + """ + Given that many API-based LLMs don't reliably give access to the log probabilities of the generated tokens, assessing + the certainty of LLM predictions via perplexity isn't possible. + The [SelfCheckGPT: Zero-Resource Black-Box Hallucination Detection for Generative Large Language Models](https://arxiv.org/abs/2303.08896) paper + suggests measuring the average factuality of every sentence in a generated response. They generate additional responses + from the LLM at a high temperature and check how much every sentence in the original answer is supported by the other generations. + The intuition behind this is that if the LLM knows a fact, it's more likely to sample it. The authors find that this + works well in detecting non-factual and factual sentences and ranking passages in terms of factuality. + The authors noted that correlation with human judgment doesn't increase after 4-6 additional + generations when using `gpt-3.5-turbo` to evaluate biography generations. + + Args: + log (Log): The log object to of the trace evaluate. + + Returns: + float: A score between 0 and 1 indicating the factuality of the response. + """ if log.configuration is None or log.configuration.messages is None: return 0.0 diff --git a/parea/evals/rag/answer_context_faithfulness_binary.py b/parea/evals/rag/answer_context_faithfulness_binary.py index e4127792..db612524 100644 --- a/parea/evals/rag/answer_context_faithfulness_binary.py +++ b/parea/evals/rag/answer_context_faithfulness_binary.py @@ -9,7 +9,21 @@ def answer_context_faithfulness_binary_factory( context_field: Optional[str] = "context", model: Optional[str] = "gpt-3.5-turbo-16k", ) -> Callable[[Log], float]: - """Quantifies how much the generated answer can be inferred from the retrieved context.""" + """ + This factory creates an evaluation function that classifies if the generated answer was faithful to the given context. + It is based on the paper [Evaluating Correctness and Faithfulness of Instruction-Following Models for Question Answering](https://arxiv.org/abs/2307.16877) + which suggests using an LLM to flag any information in the generated answer that cannot be deduced from the given context. + They find that GPT-4 is the best model for this analysis as measured by correlation with human judgment. + + Args: + question_field: The key name/field used for the question/query of the user. Defaults to "question". + context_field: The key name/field used for the retrieved context. Defaults to "context". + model: The model which should be used for grading. Currently, only supports OpenAI chat models. Defaults to "gpt-4". + + Returns: + Callable[[Log], float]: A function that takes a log as input and returns a score between 0 and 1 indicating + if the generated answer was faithful to the given context. + """ def answer_context_faithfulness_binary(log: Log) -> float: question = log.inputs[question_field] diff --git a/parea/evals/rag/answer_context_faithfulness_precision.py b/parea/evals/rag/answer_context_faithfulness_precision.py index 3b2a5e41..bfc3c4be 100644 --- a/parea/evals/rag/answer_context_faithfulness_precision.py +++ b/parea/evals/rag/answer_context_faithfulness_precision.py @@ -6,7 +6,18 @@ def answer_context_faithfulness_precision_factory(context_field: Optional[str] = "context") -> Callable[[Log], float]: - """Prop. of tokens in model generation which are also present in the retrieved context.""" + """ + This factory creates an evaluation function that calculates the how many tokens in the generated answer are also present in the retrieved context. + It is based on the paper [Evaluating Correctness and Faithfulness of Instruction-Following Models for Question Answering](https://arxiv.org/abs/2307.16877) + which finds that this method only slightly lags behind GPT-4 and outperforms GPT-3.5-turbo (see Table 4 from the above paper). + + Args: + context_field: The key name/field used for the retrieved context. Defaults to "context". + + Returns: + Callable[[Log], float]: A function that takes a log as input and returns a score between 0 and 1 indicating + how many tokens in the generated answer are also present in the retrieved context. + """ def answer_context_faithfulness_precision(log: Log) -> float: """Prop. of tokens in model generation which are also present in the retrieved context.""" diff --git a/parea/evals/rag/answer_context_faithfulness_statement_level.py b/parea/evals/rag/answer_context_faithfulness_statement_level.py index ac38986f..d35d5f51 100644 --- a/parea/evals/rag/answer_context_faithfulness_statement_level.py +++ b/parea/evals/rag/answer_context_faithfulness_statement_level.py @@ -5,7 +5,20 @@ def answer_context_faithfulness_statement_level_factory(question_field: str = "question", context_fields: List[str] = ["context"]) -> Callable[[Log], float]: - """Quantifies how much the generated answer can be inferred from the retrieved context.""" + """ + This factory creates an evaluation function that measures the faithfulness of the generated answer to the given context + by measuring how many statements from the generated answer can be inferred from the given context. It is based on the paper + [RAGAS: Automated Evaluation of Retrieval Augmented Generation](https://arxiv.org/abs/2309.15217) which suggests using an LLM + to create a list of all statements in the generated answer and assessing whether the given context supports each statement. + + Args: + question_field: The key name/field used for the question/query of the user. Defaults to "question". + context_fields: A list of key names/fields used for the retrieved contexts. Defaults to ["context"]. + + Returns: + Callable[[Log], float]: A function that takes a log as input and returns a score between 0 and 1 indicating + if the retrieved context is relevant to the query. + """ def answer_context_faithfulness_statement_level(log: Log) -> float: """Quantifies how much the generated answer can be inferred from the retrieved context.""" diff --git a/parea/evals/rag/context_query_relevancy.py b/parea/evals/rag/context_query_relevancy.py index ac07aaae..0448c32c 100644 --- a/parea/evals/rag/context_query_relevancy.py +++ b/parea/evals/rag/context_query_relevancy.py @@ -5,7 +5,20 @@ def context_query_relevancy_factory(question_field: str = "question", context_fields: List[str] = ["context"]) -> Callable[[Log], float]: - """Quantifies how much the retrieved context relates to the query.""" + """ + This factory creates an evaluation function that measures how relevant the retrieved context is to the given question. + It is based on the paper [RAGAS: Automated Evaluation of Retrieval Augmented Generation](https://arxiv.org/abs/2309.15217) + which suggests using an LLM to extract any sentence from the retrieved context relevant to the query. Then, calculate + the ratio of relevant sentences to the total number of sentences in the retrieved context. + + Args: + question_field: The key name/field used for the question/query of the user. Defaults to "question". + context_fields: A list of key names/fields used for the retrieved contexts. Defaults to ["context"]. + + Returns: + Callable[[Log], float]: A function that takes a log as input and returns a score between 0 and 1 indicating + if the retrieved context is relevant to the query. + """ def context_query_relevancy(log: Log) -> float: """Quantifies how much the retrieved context relates to the query.""" diff --git a/parea/evals/rag/context_ranking_listwise.py b/parea/evals/rag/context_ranking_listwise.py index 9290e989..9d191a77 100644 --- a/parea/evals/rag/context_ranking_listwise.py +++ b/parea/evals/rag/context_ranking_listwise.py @@ -10,9 +10,12 @@ def context_ranking_listwise_factory( ranking_measurement="ndcg", n_contexts_to_rank=10, ) -> Callable[[Log], float]: - """Quantifies if the retrieved context is ranked by their relevancy by re-ranking the contexts. - - Paper: https://arxiv.org/abs/2305.02156 + """ + This factory creates an evaluation function that measures how well the retrieved contexts are ranked by relevancy to the given query + by listwise estimation of the relevancy of every context to the query. It is based on the paper + [Zero-Shot Listwise Document Reranking with a Large Language Model](https://arxiv.org/abs/2305.02156) which suggests using an LLM + to rerank a list of contexts and use that to evaluate how well the contexts are ranked by relevancy to the given query. + The authors used a progressive listwise reordering if the retrieved contexts don't fit into the context window of the LLM. Args: question_field (str): The name of the field in the log that contains the question. Defaults to "question". diff --git a/parea/evals/rag/context_ranking_pointwise.py b/parea/evals/rag/context_ranking_pointwise.py index 62a6c84c..6ae0e291 100644 --- a/parea/evals/rag/context_ranking_pointwise.py +++ b/parea/evals/rag/context_ranking_pointwise.py @@ -5,7 +5,25 @@ def context_ranking_pointwise_factory(question_field: str = "question", context_fields: List[str] = ["context"], ranking_measurement="average_precision") -> Callable[[Log], float]: - """Quantifies if the retrieved context is ranked by their relevancy""" + """ + This factory creates an evaluation function that measures how well the retrieved contexts are ranked by relevancy to the given query + by pointwise estimation of the relevancy of every context to the query. It is based on the paper + [RAGAS: Automated Evaluation of Retrieval Augmented Generation](https://arxiv.org/abs/2309.15217) which suggests using an LLM + to check if every extracted context is relevant. Then, they measure how well the contexts are ranked by calculating the + mean average precision. Note that this approach considers any two relevant contexts equally important/relevant to the query. + + Args: + question_field: The key name/field used for the question/query of the user. Defaults to "question". + context_fields: A list of key names/fields used for the retrieved contexts. Defaults to ["context"]. + ranking_measurement: Method to calculate ranking. Currently, only supports "average_precision". + + Returns: + Callable[[Log], float]: A function that takes a log as input and returns a score between 0 and 1 indicating + how well the retrieved context is ranked by their relevancy. + + Raises: + ImportError: If numpy is not installed. + """ try: import numpy as np except ImportError: diff --git a/parea/evals/summary/factual_inconsistency_binary.py b/parea/evals/summary/factual_inconsistency_binary.py index c23d9849..95073f7d 100644 --- a/parea/evals/summary/factual_inconsistency_binary.py +++ b/parea/evals/summary/factual_inconsistency_binary.py @@ -8,6 +8,22 @@ def factual_inconsistency_binary_factory( article_field: Optional[str] = "article", model: Optional[str] = "gpt-4", ) -> Callable[[Log], float]: + """ + This factory creates an evaluation function that classifies if a summary is factually inconsistent with the original text. + It is based on the paper [ChatGPT as a Factual Inconsistency Evaluator for Text Summarization](https://arxiv.org/abs/2303.15621) + which suggests using an LLM to assess the factuality of a summary by measuring how consistent the summary is with + the original text, posed as a binary classification. They find that `gpt-3.5-turbo-0301` outperforms + baseline methods such as SummaC and QuestEval when identifying factually inconsistent summaries. + + Args: + article_field: The key name/field used for the content which should be summarized. Defaults to "article". + model: The model which should be used for grading. Currently, only supports OpenAI chat models. Defaults to "gpt-4". + + Returns: + Callable[[Log], float]: A function that takes a log as input and returns a score between 0 and 1 indicating + if the generated summary is factually consistent with the original text. + """ + def factual_inconsistency_binary(log: Log) -> float: article = log.inputs[article_field] output = log.output diff --git a/parea/evals/summary/factual_inconsistency_scale.py b/parea/evals/summary/factual_inconsistency_scale.py index a8a73e43..b7b24016 100644 --- a/parea/evals/summary/factual_inconsistency_scale.py +++ b/parea/evals/summary/factual_inconsistency_scale.py @@ -10,6 +10,21 @@ def factual_inconsistency_scale_factory( article_field: Optional[str] = "article", model: Optional[str] = "gpt-4", ) -> Callable[[Log], float]: + """ + This factory creates an evaluation function that grades the factual consistency of a summary with the article on a scale from 1 to 10. + It is based on the paper [ChatGPT as a Factual Inconsistency Evaluator for Text Summarization](https://arxiv.org/abs/2303.15621) + which finds that using `gpt-3.5-turbo-0301` leads to a higher correlation with human expert judgment when grading + the factuality of summaries on a scale from 1 to 10 than baseline methods such as SummaC and QuestEval. + + Args: + article_field: The key name/field used for the content which should be summarized. Defaults to "article". + model: The model which should be used for grading. Currently, only supports OpenAI chat models. Defaults to "gpt-4". + + Returns: + Callable[[Log], float]: A function that takes a log as input and returns a score between 0 and 1 indicating + if the generated summary is factually consistent with the original text. + """ + def factual_inconsistency_scale(log: Log) -> float: article = log.inputs[article_field] output = log.output diff --git a/parea/evals/summary/likert_scale.py b/parea/evals/summary/likert_scale.py index e705f3f9..852daa4b 100644 --- a/parea/evals/summary/likert_scale.py +++ b/parea/evals/summary/likert_scale.py @@ -10,6 +10,22 @@ def likert_scale_factory( article_field: Optional[str] = "article", model: Optional[str] = "gpt-4", ) -> Callable[[Log], float]: + """ + This factory creates an evaluation function that grades the quality of a summary on a Likert scale from 1-5 along + the dimensions of relevance, consistency, fluency, and coherence. It is based on the paper + [Human-like Summarization Evaluation with ChatGPT](https://arxiv.org/abs/2304.02554) which finds that using `gpt-3.5-0301` + leads to a highest correlation with human expert judgment when grading summaries on a Likert scale from 1-5 than baseline + methods. Noteworthy is that [BARTScore](https://arxiv.org/abs/2106.11520) was very competitive to `gpt-3.5-0301`. + + Args: + article_field: The key name/field used for the content which should be summarized. Defaults to "article". + model: The model which should be used for grading. Currently, only supports OpenAI chat models. Defaults to "gpt-4". + + Returns: + Callable[[Log], float]: A function that takes a log as input and returns a score between 0 and 1 indicating + the quality of the summary on a Likert scale from 1-5 along the dimensions of relevance, consistency, fluency, and coherence. + """ + def likert_scale(log: Log) -> float: article = log.inputs[article_field] output = log.output diff --git a/parea/experiment/experiment.py b/parea/experiment/experiment.py index 4e2ee4e6..62784f8e 100644 --- a/parea/experiment/experiment.py +++ b/parea/experiment/experiment.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Iterator, List +from typing import Callable, Dict, Iterable, List import asyncio import inspect @@ -50,7 +50,7 @@ def async_wrapper(fn, **kwargs): return asyncio.run(fn(**kwargs)) -def experiment(name: str, data: Iterator, func: Callable) -> ExperimentStatsSchema: +def experiment(name: str, data: Iterable[Dict], func: Callable) -> ExperimentStatsSchema: """Creates an experiment and runs the function on the data iterator.""" load_dotenv() @@ -81,7 +81,7 @@ def experiment(name: str, data: Iterator, func: Callable) -> ExperimentStatsSche @define class Experiment: name: str = field(init=True) - data: Iterator[Dict] = field(init=True) + data: Iterable[Dict] = field(init=True) func: Callable = field(init=True) experiment_stats: ExperimentStatsSchema = field(init=False, default=None) diff --git a/parea/schemas/log.py b/parea/schemas/log.py index 0a26cf15..bc860096 100644 --- a/parea/schemas/log.py +++ b/parea/schemas/log.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from enum import Enum diff --git a/pyproject.toml b/pyproject.toml index 29202c2b..f4e9f158 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "parea-ai" packages = [{ include = "parea" }] -version = "0.2.26" +version = "0.2.27" description = "Parea python sdk" readme = "README.md" authors = ["joel-parea-ai "]