Skip to content

Commit

Permalink
Merge pull request #114 from SAP/create_job_with_bb_template_id
Browse files Browse the repository at this point in the history
Create job with bb template
  • Loading branch information
rguru89 authored Oct 12, 2021
2 parents 5f9a231 + 5306647 commit e89b22e
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 10 deletions.
15 changes: 14 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]


## [0.9.0]

### Added

* Support for Business blueprint id in `create_job` and `create_job_and_wait` methods. [#114]
* This feature is **not supported** in DAR service yet, it is added for internal testing purposes.
* Either model_template_id or business_blueprint_id has to be specified in `create_job` method.
* Both model_template_id and business_blueprint_id are not allowed.

[#114]: https://github.com/SAP/data-attribute-recommendation-python-sdk/pull/114

## [0.8.2]

### Fixed
Expand Down Expand Up @@ -203,7 +215,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

* First public release

[Unreleased]: https://github.com/SAP/data-attribute-recommendation-python-sdk/compare/rel/0.8.2...HEAD
[Unreleased]: https://github.com/SAP/data-attribute-recommendation-python-sdk/compare/rel/0.9.0...HEAD
[0.9.0]: https://github.com/SAP/data-attribute-recommendation-python-sdk/compare/rel/0.8.2...rel/0.9.0
[0.8.2]: https://github.com/SAP/data-attribute-recommendation-python-sdk/compare/rel/0.8.1...rel/0.8.2
[0.8.1]: https://github.com/SAP/data-attribute-recommendation-python-sdk/compare/rel/0.8.0...rel/0.8.1
[0.8.0]: https://github.com/SAP/data-attribute-recommendation-python-sdk/compare/rel/0.7.1...rel/0.8.0
Expand Down
8 changes: 8 additions & 0 deletions sap/aibus/dar/client/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ class DeploymentFailed(DARException):
pass


class CreateTrainingJobFailed(DARException):
"""
Create training job failed.
"""

pass


class ModelAlreadyExists(DARException):
"""
Model already exists and must be deleted first.
Expand Down
40 changes: 32 additions & 8 deletions sap/aibus/dar/client/model_manager_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
TrainingJobTimeOut,
DeploymentTimeOut,
DeploymentFailed,
CreateTrainingJobFailed,
)
from sap.aibus.dar.client.model_manager_constants import (
JobStatus,
Expand Down Expand Up @@ -134,7 +135,8 @@ def create_job(
self,
model_name: str,
dataset_id: str,
model_template_id: str,
model_template_id: str = None,
business_blueprint_id: str = None,
) -> dict:
"""
Creates a training Job.
Expand All @@ -152,6 +154,9 @@ def create_job(
:param model_name: Name of the model to train
:param dataset_id: Id of previously uploaded, valid dataset
:param model_template_id: Model template ID for training
:param business_blueprint_id: Business Blueprint template ID for training
:raises CreateTrainingJobFailed: When business_blueprint_id
and model_template_id are provided or when both are not provided
:return: newly created Job as dict
"""
self.log.info(
Expand All @@ -160,12 +165,28 @@ def create_job(
dataset_id,
model_template_id,
)

payload = {
"modelName": model_name,
"datasetId": dataset_id,
"modelTemplateId": model_template_id,
}
if business_blueprint_id and model_template_id:
raise CreateTrainingJobFailed(
"Either model_template_id or business_blueprint_id"
" have to be specified, not both."
)
if not business_blueprint_id and not model_template_id:
raise CreateTrainingJobFailed(
"Either model_template_id or business_blueprint_id"
" have to be specified."
)
if business_blueprint_id:
payload = {
"modelName": model_name,
"datasetId": dataset_id,
"businessBlueprintId": business_blueprint_id,
}
elif model_template_id:
payload = {
"modelName": model_name,
"datasetId": dataset_id,
"modelTemplateId": model_template_id,
}
response = self.session.post_to_endpoint(
ModelManagerPaths.ENDPOINT_JOB_COLLECTION, payload=payload
)
Expand All @@ -178,7 +199,8 @@ def create_job_and_wait(
self,
model_name: str,
dataset_id: str,
model_template_id: str,
model_template_id: str = None,
business_blueprint_id: str = None,
):
"""
Starts a job and waits for the job to finish.
Expand All @@ -189,6 +211,7 @@ def create_job_and_wait(
:param model_name: Name of the model to train
:param dataset_id: Id of previously uploaded, valid dataset
:param model_template_id: Model template ID for training
:param business_blueprint_id: Business Blueprint ID for training
:raises TrainingJobFailed: When training job has status FAILED
:raises TrainingJobTimeOut: When training job takes too long
:return: API response as dict
Expand All @@ -197,6 +220,7 @@ def create_job_and_wait(
model_name=model_name,
dataset_id=dataset_id,
model_template_id=model_template_id,
business_blueprint_id=business_blueprint_id,
)
return self.wait_for_job(job_resource["id"])

Expand Down
157 changes: 157 additions & 0 deletions tests/sap/aibus/dar/client/test_model_manager_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
TrainingJobTimeOut,
DeploymentTimeOut,
DeploymentFailed,
CreateTrainingJobFailed,
)
from sap.aibus.dar.client.model_manager_client import ModelManagerClient
from sap.aibus.dar.client.util.polling import Polling, PollingTimeoutException
Expand Down Expand Up @@ -148,6 +149,113 @@ def test_create_job(self, model_manager_client):
== response
)

def test_create_job_with_business_blueprint_id(self, model_manager_client):
"""
Tests start_job correctly with business_blueprint_id
"""
business_blueprint_id = "d7810207-ca31-4d4d-9b5a-841a644fd81f"
dataset_id = "a2058037-2ae4-465e-8110-65381d47f3d4"
model_name = "my_test_model"

model_manager_client.session.post_to_endpoint.return_value.json.return_value = {
"id": "abcd"
}

response = model_manager_client.create_job(
model_template_id=None,
business_blueprint_id=business_blueprint_id,
dataset_id=dataset_id,
model_name=model_name,
)

expected_url = "/model-manager/api/v3/jobs"
expected_payload = {
"datasetId": dataset_id,
"businessBlueprintId": business_blueprint_id,
"modelName": model_name,
}

expected_post_call = [call(expected_url, payload=expected_payload)]

assert (
expected_post_call
== model_manager_client.session.post_to_endpoint.call_args_list
)

assert (
model_manager_client.session.post_to_endpoint.return_value.json.return_value
== response
)

def test_create_job_with_business_blueprint_id_and_model_template_id(
self, model_manager_client
):
"""
Tests it throws the exception if both model_template_id
and model_template_id is provided
"""
business_blueprint_id = "d7810207-ca31-4d4d-9b5a-841a644fd81f"
model_template_id = "d7810207-ca31-4d4d-9b5a-841a644fd81f"
dataset_id = "a2058037-2ae4-465e-8110-65381d47f3d4"
model_name = "my_test_model"

with pytest.raises(CreateTrainingJobFailed) as exception:
model_manager_client.create_job(
model_template_id=model_template_id,
business_blueprint_id=business_blueprint_id,
dataset_id=dataset_id,
model_name=model_name,
)
expected_message = (
"Either model_template_id or business_blueprint_id"
" have to be specified, not both."
)
assert str(exception.value) == expected_message

def test_create_job_without_business_blueprint_id_and_model_template_id(
self, model_manager_client
):
"""
Tests it throws the exception if both model_template_id and
model_template_id is provided
"""
dataset_id = "a2058037-2ae4-465e-8110-65381d47f3d4"
model_name = "my_test_model"

with pytest.raises(CreateTrainingJobFailed) as exception:
model_manager_client.create_job(
model_template_id=None,
business_blueprint_id=None,
dataset_id=dataset_id,
model_name=model_name,
)
expected_message = (
"Either model_template_id or business_blueprint_id have to be specified."
)
assert str(exception.value) == expected_message

def test_create_job_with_empty_business_blueprint_id_and_model_template_id(
self, model_manager_client
):
"""
Tests it throws the exception if both model_template_id and
model_template_id is provided
"""
dataset_id = "a2058037-2ae4-465e-8110-65381d47f3d4"
model_name = "my_test_model"

with pytest.raises(CreateTrainingJobFailed) as exception:
model_manager_client.create_job(
model_template_id="",
business_blueprint_id="",
dataset_id=dataset_id,
model_name=model_name,
)
expected_message = (
"Either model_template_id or business_blueprint_id have to be specified."
)
assert str(exception.value) == expected_message

def test_wait_for_job_uses_polling(self, model_manager_client: ModelManagerClient):
"""
Tests the interaction of the `wait_for_job` method with the Polling class.
Expand Down Expand Up @@ -271,6 +379,55 @@ def test_create_job_and_wait(self, model_manager_client: ModelManagerClient):
model_template_id=model_template_id,
dataset_id=dataset_id,
model_name=model_name,
business_blueprint_id=None,
)

assert model_manager_client.create_job.call_args_list == [
expected_create_job_call_args
]

expected_wait_for_job_call_args = call(job_resource["id"])

assert model_manager_client.wait_for_job.call_args_list == [
expected_wait_for_job_call_args
]

def test_create_job_and_wait_with_business_blueprint_id(
self, model_manager_client: ModelManagerClient
):
"""
Tests if start_job_and_wait correctly with business_blueprint_id
orchestrates start_job() and wait_for_job().
"""
business_blueprint_id = "d7810207-ca31-4d4d-9b5a-841a644fd81f"
dataset_id = "a2058037-2ae4-465e-8110-65381d47f3d4"
model_name = "my_test_model"

job_resource = self._make_job_resource("SUCCEEDED")

model_manager_client.create_job = create_autospec(
model_manager_client.create_job
)
model_manager_client.create_job.return_value = job_resource

model_manager_client.wait_for_job = create_autospec(
model_manager_client.wait_for_job
)

ret_val = model_manager_client.create_job_and_wait(
model_name=model_name,
dataset_id=dataset_id,
model_template_id="",
business_blueprint_id=business_blueprint_id,
)

assert ret_val == model_manager_client.wait_for_job.return_value

expected_create_job_call_args = call(
model_template_id="",
dataset_id=dataset_id,
model_name=model_name,
business_blueprint_id=business_blueprint_id,
)

assert model_manager_client.create_job.call_args_list == [
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.8.2
0.9.0

0 comments on commit e89b22e

Please sign in to comment.