Skip to content

Commit

Permalink
Add model run name for model.run (#435)
Browse files Browse the repository at this point in the history
* Add model run name for model.run

* Fix test

* Remove epoch

* Reduce assert payload size
  • Loading branch information
ntamas92 authored Apr 2, 2024
1 parent 4139951 commit f0d8022
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 41 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.17.4](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.17.4) - 2024-03-25

### Modified
- In `Model.run`, added the `model_run_name` parameter. This allows the creation of multiple model runs for datasets.


## [0.17.3] - 2024-02-29

### Added
Expand Down
15 changes: 12 additions & 3 deletions nucleus/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def evaluate(self, scenario_test_names: List[str]) -> AsyncJob:
)
return AsyncJob.from_json(response, self._client)

def run(self, dataset_id: str, slice_id: Optional[str]) -> str:
def run(
self, dataset_id: str, model_run_name: str, slice_id: Optional[str]
) -> str:
"""Runs inference on the bundle associated with the model on the dataset. ::
import nucleus
Expand All @@ -222,11 +224,18 @@ def run(self, dataset_id: str, slice_id: Optional[str]) -> str:
Args:
dataset_id: The ID of the dataset to run inference on.
job_id: The ID of the :class:`AsyncJob` used to track job progress.
model_run_name: The name of the model run.
slice_id: The ID of the slice of the dataset to run inference on.
Returns:
job_id: The ID of the :class:`AsyncJob` used to track job progress.
"""
response = self._client.make_request(
{"dataset_id": dataset_id, "slice_id": slice_id},
{
"dataset_id": dataset_id,
"slice_id": slice_id,
"model_run_name": model_run_name,
},
f"model/run/{self.id}/",
requests_command=requests.post,
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ignore = ["E501", "E741", "E731", "F401"] # Easy ignore for getting it running

[tool.poetry]
name = "scale-nucleus"
version = "0.17.3"
version = "0.17.4"
description = "The official Python client library for Nucleus, the Data Platform for AI"
license = "MIT"
authors = ["Scale AI Nucleus Team <[email protected]>"]
Expand Down
1 change: 0 additions & 1 deletion tests/test_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,6 @@ def test_default_category_gt_upload_async(dataset):
"status": "Completed",
"message": {
"annotation_upload": {
"epoch": 1,
"total": 1,
"errored": 0,
"ignored": 0,
Expand Down
36 changes: 0 additions & 36 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,24 +380,6 @@ def test_annotate_async(dataset: Dataset):
expected = {
"job_id": job.job_id,
"status": "Completed",
"message": {
"annotation_upload": {
"epoch": 1,
"total": 4,
"errored": 0,
"ignored": 0,
"datasetId": dataset.id,
"processed": 4,
},
"segmentation_upload": {
"ignored": 0,
"n_errors": 0,
"processed": 1,
},
},
"job_progress": "1.00",
"completed_steps": 5,
"total_steps": 5,
}
assert_partial_equality(expected, status)

Expand All @@ -423,24 +405,6 @@ def test_annotate_async_with_error(dataset: Dataset):
expected = {
"job_id": job.job_id,
"status": "Completed",
"message": {
"annotation_upload": {
"epoch": 1,
"total": 4,
"errored": 1,
"ignored": 0,
"datasetId": dataset.id,
"processed": 3,
},
"segmentation_upload": {
"ignored": 0,
"n_errors": 0,
"processed": 1,
},
},
"job_progress": "1.00",
"completed_steps": 5,
"total_steps": 5,
}
assert_partial_equality(expected, status)

Expand Down

0 comments on commit f0d8022

Please sign in to comment.