Skip to content

Commit

Permalink
Merge pull request #987 from parea-ai/PAI-1352-edit-testcase-via-api
Browse files Browse the repository at this point in the history
add edit test case endpoint
  • Loading branch information
jalexanderII authored Jul 2, 2024
2 parents 803654d + 63a8b4e commit 6e8aabb
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 7 deletions.
39 changes: 39 additions & 0 deletions cookbook/async_enpoints_for_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import asyncio
import os

from dotenv import load_dotenv

from parea import Parea
from parea.schemas import TestCase, TestCaseCollection, UpdateTestCase

load_dotenv()

p = Parea(api_key=os.getenv("PAREA_API_KEY"))


data = [{"problem": "1+2", "target": 3, "tags": ["easy"]}, {"problem": "Solve the differential equation dy/dx = 3y.", "target": "y = c * e^(3x)", "tags": ["hard"]}]
new_data = [{"problem": "Evaluate the integral ∫x^2 dx from 0 to 3.", "target": 9, "tags": ["hard"]}]


async def update_test_case_example():
dataset: TestCaseCollection = await p.aget_collection("math_problems_v3")
test_cases: dict[int, TestCase] = dataset.test_cases
for test_case_id, test_case in test_cases.items():
if "easy" in test_case.tags:
# updated inputs must match the same k/v pair as original test case
await p.aupdate_test_case(
dataset_id=dataset.id,
test_case_id=test_case_id,
update_request=UpdateTestCase(inputs={"problem": "Evaluate the integral ∫x^6 dx from 0 to 9."}, target="((1/7)x^7)+C", tags=["hard"]),
)
break


async def main():
await p.acreate_test_collection(data, name="math_problems_v3")
await p.aadd_test_cases(new_data, dataset_id=182)
await update_test_case_example()


if __name__ == "__main__":
asyncio.run(main())
29 changes: 24 additions & 5 deletions cookbook/endpoints_for_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dotenv import load_dotenv

from parea import Parea
from parea.schemas import TestCase, TestCaseCollection, UpdateTestCase

load_dotenv()

Expand All @@ -11,14 +12,32 @@

data = [{"problem": "1+2", "target": 3, "tags": ["easy"]}, {"problem": "Solve the differential equation dy/dx = 3y.", "target": "y = c * e^(3x)", "tags": ["hard"]}]

# this will create a new dataset on Parea named "Math problems".
# this will create a new dataset on Parea named "math_problems_v4".
# The dataset will have one column named "problem", and two columns using the reserved names "target" and "tags".
# when using this dataset the expected prompt template should have a placeholder for the varible problem.
p.create_test_collection(data, name="Math problems")
p.create_test_collection(data, name="math_problems_v4")

new_data = [{"problem": "Evaluate the integral ∫x^2 dx from 0 to 3.", "target": 9, "tags": ["hard"]}]
# this will add the new test cases to the existing "Math problems" dataset.
# this will add the new test cases to the existing "math_problems_v4" dataset.
# New test cases must have the same columns as the existing dataset.
p.add_test_cases(new_data, name="Math problems")
p.add_test_cases(new_data, name="math_problems_v4")
# Or if you can use the dataset ID instead of the name
p.add_test_cases(new_data, dataset_id=121)
# p.add_test_cases(new_data, dataset_id=121)


def update_test_case_example():
dataset: TestCaseCollection = p.get_collection("math_problems_v4")
test_cases: dict[int, TestCase] = dataset.test_cases
for test_case_id, test_case in test_cases.items():
if "easy" in test_case.tags:
# updated inputs must match the same k/v pair as original test case
p.update_test_case(
dataset_id=dataset.id,
test_case_id=test_case_id,
update_request=UpdateTestCase(inputs={"problem": "Evaluate the integral ∫x^6 dx from 0 to 9."}, target="((1/7)x^7)+C", tags=["hard"]),
)
break


if __name__ == "__main__":
update_test_case_example()
47 changes: 47 additions & 0 deletions parea/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
TestCaseCollection,
TraceLogFilters,
TraceLogTree,
UpdateTestCase,
UseDeployedPrompt,
UseDeployedPromptResponse,
)
Expand All @@ -54,6 +55,7 @@
GET_COLLECTION_ENDPOINT = "/collection/{test_collection_identifier}"
CREATE_COLLECTION_ENDPOINT = "/collection"
ADD_TEST_CASES_ENDPOINT = "/testcases"
UPDATE_TEST_CASE_ENDPOINT = "/update_test_case/{dataset_id}/{test_case_id}"
GET_TRACE_LOG_ENDPOINT = "/trace_log/{trace_id}"
LIST_EXPERIMENTS_ENDPOINT = "/experiments"
GET_EXPERIMENT_LOGS_ENDPOINT = "/experiment/{experiment_uuid}/trace_logs"
Expand Down Expand Up @@ -343,6 +345,51 @@ def add_test_cases(
data=asdict(request),
)

async def acreate_test_collection(self, data: List[Dict[str, Any]], name: Optional[str] = None) -> None:
request: CreateTestCaseCollection = create_test_collection(data, name)
await self._client.request_async(
"POST",
CREATE_COLLECTION_ENDPOINT,
data=asdict(request),
)

async def aadd_test_cases(
self,
data: List[Dict[str, Any]],
name: Optional[str] = None,
dataset_id: Optional[int] = None,
) -> None:
request = CreateTestCases(id=dataset_id, name=name, test_cases=create_test_cases(data))
await self._client.request_async(
"POST",
ADD_TEST_CASES_ENDPOINT,
data=asdict(request),
)

def update_test_case(
self,
dataset_id: int,
test_case_id: int,
update_request: UpdateTestCase,
) -> None:
self._client.request(
"POST",
UPDATE_TEST_CASE_ENDPOINT.format(dataset_id=dataset_id, test_case_id=test_case_id),
data=asdict(update_request),
)

async def aupdate_test_case(
self,
dataset_id: int,
test_case_id: int,
update_request: UpdateTestCase,
) -> None:
await self._client.request_async(
"POST",
UPDATE_TEST_CASE_ENDPOINT.format(dataset_id=dataset_id, test_case_id=test_case_id),
data=asdict(update_request),
)

def experiment(
self,
name: str,
Expand Down
9 changes: 8 additions & 1 deletion parea/schemas/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import json
from enum import Enum
Expand Down Expand Up @@ -330,6 +330,13 @@ class CreateTestCaseCollection(CreateTestCases):
column_names: List[str] = field(factory=list)


@define
class UpdateTestCase:
inputs: Optional[Dict[str, Any]] = None
target: Optional[Union[int, float, str, bool]] = None
tags: Optional[List[str]] = None


class ExperimentStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "parea-ai"
packages = [{ include = "parea" }]
version = "0.2.183"
version = "0.2.184"
description = "Parea python sdk"
readme = "README.md"
authors = ["joel-parea-ai <[email protected]>"]
Expand Down

0 comments on commit 6e8aabb

Please sign in to comment.