Skip to content

Commit

Permalink
feat(restapi) add snapshot id to /<resource>/{id}/draft PUT endpoint
Browse files Browse the repository at this point in the history
This commit adds an optional query parameter for changing a draft modifications base snapshot ID
to the PUT endpoint. If a snapshot ID is provided the value is updated, otherwise it is left
unchanged.

This feature is designed to enable the reconciliation of drafts by allowing the user to
signal that their changes are based on a more recent snapshot.

An existing test is modified to test this feature.
  • Loading branch information
keithmanville committed Dec 31, 2024
1 parent 9ce5991 commit 5c96dbf
Show file tree
Hide file tree
Showing 13 changed files with 68 additions and 64 deletions.
14 changes: 13 additions & 1 deletion src/dioptra/client/drafts.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,11 +373,17 @@ def create(self, *resource_ids: str | int, **kwargs) -> T:
json_=self._validate_fields(kwargs),
)

def modify(self, *resource_ids: str | int, **kwargs) -> T:
def modify(
self,
*resource_ids: str | int,
resource_snapshot_id: int | None = None,
**kwargs,
) -> T:
"""Modify a resource modification draft.
Args:
*resource_ids: The parent resource ids that own the sub-collection.
*resource_snapshot_id: The id of the snapshot this draft is based on.
**kwargs: The draft fields to modify.
Returns:
Expand All @@ -387,8 +393,14 @@ def modify(self, *resource_ids: str | int, **kwargs) -> T:
DraftFieldsValidationError: If one or more draft fields are invalid or
missing.
"""

params: dict[str, Any] = dict()
if resource_snapshot_id is not None:
params["resourceSnapshot"] = resource_snapshot_id

return self._session.put(
self.build_sub_collection_url(*resource_ids),
params=params,
json_=self._validate_fields(kwargs),
)

Expand Down
3 changes: 1 addition & 2 deletions src/dioptra/client/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from .tags import TagsSubCollectionClient

DRAFT_FIELDS: Final[set[str]] = {"name", "description"}
MODIFY_DRAFT_FIELDS: Final[set[str]] = DRAFT_FIELDS | {"resourceSnapshot"}

T = TypeVar("T")

Expand Down Expand Up @@ -58,7 +57,7 @@ def __init__(self, session: DioptraSession[T]) -> None:
self._modify_resource_drafts = ModifyResourceDraftsSubCollectionClient[T](
session=session,
validate_fields_fn=make_draft_fields_validator(
draft_fields=MODIFY_DRAFT_FIELDS,
draft_fields=DRAFT_FIELDS,
resource_name=self.name,
),
root_collection=self,
Expand Down
70 changes: 35 additions & 35 deletions src/dioptra/restapi/v1/shared/drafts/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@
from flask_login import login_required
from flask_restx import Namespace, Resource
from injector import ClassAssistedBuilder, inject
from marshmallow import Schema, fields
from marshmallow import Schema
from structlog.stdlib import BoundLogger

from dioptra.restapi.db import models
from dioptra.restapi.v1 import utils
from dioptra.restapi.v1.schemas import IdStatusResponseSchema

from .schema import DraftGetQueryParameters, DraftPageSchema, DraftSchema
from .schema import (
DraftGetQueryParameters,
DraftPageSchema,
DraftSchema,
ModifyDraftPutQueryParameters,
)
from .service import (
ResourceDraftsIdService,
ResourceDraftsService,
Expand Down Expand Up @@ -230,17 +235,8 @@ def generate_resource_id_draft_endpoint(
model_name = "Draft" + "".join(
request_schema.__class__.__name__.rsplit("Schema", 1)
)
request_schema_class = Schema.from_dict(request_schema.dump_fields)
else:
model_name = "Draft" + "".join(request_schema.__name__.rsplit("Schema", 1))
request_schema_class = request_schema

class ModifyRequestSchema(request_schema_class):
resourceSnapshot = fields.Integer(
attribute="resource_snapshot_id",
data_key="resourceSnapshot",
metadata=dict(description="ID of the snapshot the draft is based on."),
)

@api.route("/<int:id>/draft")
@api.param("id", "ID for the resource.")
Expand Down Expand Up @@ -270,7 +266,7 @@ def get(self, id: int):
)

@login_required
@accepts(schema=ModifyRequestSchema, model_name=model_name, api=api)
@accepts(schema=request_schema, model_name=model_name, api=api)
@responds(schema=DraftSchema, api=api)
def post(self, id: int):
"""Creates a Draft for this resource."""
Expand All @@ -286,16 +282,29 @@ def post(self, id: int):
)

@login_required
@accepts(schema=ModifyRequestSchema, model_name=model_name, api=api)
@accepts(
schema=request_schema,
query_params_schema=ModifyDraftPutQueryParameters,
model_name=model_name,
api=api,
)
@responds(schema=DraftSchema, api=api)
def put(self, id: int):
"""Modifies the Draft for this resource."""
log = LOGGER.new(
request_id=str(uuid.uuid4()), resource="Draft", request_type="POST"
)

parsed_query_params = request.parsed_query_params # type: ignore[attr-defined]
resource_snapshot_id = parsed_query_params["resource_snapshot_id"]

parsed_obj = request.parsed_obj # type: ignore
draft, num_other_drafts = self._id_draft_service.modify(
id, payload=parsed_obj, error_if_not_found=True, log=log
id,
payload=parsed_obj,
resource_snapshot_id=resource_snapshot_id,
error_if_not_found=True,
log=log,
)
return utils.build_resource_draft(
cast(models.DraftResource, draft), request_schema, num_other_drafts
Expand Down Expand Up @@ -432,19 +441,10 @@ def generate_nested_resource_drafts_id_endpoint(
model_name = "NestedDraftsId" + "".join(
request_schema.__class__.__name__.rsplit("Schema", 1)
)
request_schema_class = Schema.from_dict(request_schema.dump_fields)
else:
model_name = "NestedDraftsId" + "".join(
request_schema.__name__.rsplit("Schema", 1)
)
request_schema_class = request_schema

class ModifyRequestSchema(request_schema_class):
resourceSnapshot = fields.Integer(
attribute="resource_snapshot_id",
data_key="resourceSnapshot",
metadata=dict(description="ID of the snapshot the draft is based on."),
)

@api.route(f"/<int:id>/{resource_route}/drafts/<int:draftId>")
@api.param("draftId", f"ID for the Draft of the {resource_name} resource.")
Expand Down Expand Up @@ -474,7 +474,7 @@ def get(self, id: int, draftId: int):
)

@login_required
@accepts(schema=ModifyRequestSchema, model_name=model_name, api=api)
@accepts(schema=request_schema, model_name=model_name, api=api)
@responds(schema=DraftSchema, api=api)
def put(self, id: int, draftId: int):
"""Modifies a Draft for the resource."""
Expand Down Expand Up @@ -526,19 +526,10 @@ def generate_nested_resource_id_draft_endpoint(
model_name = "NestedDraft" + "".join(
request_schema.__class__.__name__.rsplit("Schema", 1)
)
request_schema_class = Schema.from_dict(request_schema.dump_fields)
else:
model_name = "NestedDraft" + "".join(
request_schema.__name__.rsplit("Schema", 1)
)
request_schema_class = request_schema

class ModifyRequestSchema(request_schema_class):
resourceSnapshot = fields.Integer(
attribute="resource_snapshot_id",
data_key="resourceSnapshot",
metadata=dict(description="ID of the snapshot the draft is based on."),
)

@api.route(f"/<int:id>/{resource_route}/<int:{resource_id}>/draft")
@api.param("id", "ID for the resource.")
Expand Down Expand Up @@ -578,7 +569,7 @@ def get(self, id: int, **kwargs):
)

@login_required
@accepts(schema=ModifyRequestSchema, model_name=model_name, api=api)
@accepts(schema=request_schema, model_name=model_name, api=api)
@responds(schema=DraftSchema, api=api)
def post(self, id: int, **kwargs):
"""Creates a Draft for this resource."""
Expand All @@ -604,7 +595,12 @@ def post(self, id: int, **kwargs):
)

@login_required
@accepts(schema=ModifyRequestSchema, model_name=model_name, api=api)
@accepts(
schema=request_schema,
query_params_schema=ModifyDraftPutQueryParameters,
model_name=model_name,
api=api,
)
@responds(schema=DraftSchema, api=api)
def put(self, id: int, **kwargs):
"""Modifies the Draft for this resource."""
Expand All @@ -621,10 +617,14 @@ def put(self, id: int, **kwargs):
f"{list(unexpected_kwargs.keys())}"
)

parsed_query_params = request.parsed_query_params # type: ignore[attr-defined]
resource_snapshot_id = parsed_query_params["resource_snapshot_id"]

parsed_obj = request.parsed_obj # type: ignore
draft, num_other_drafts = self._id_draft_service.modify(
kwargs[resource_id],
payload=parsed_obj,
resource_snapshot_id=resource_snapshot_id,
error_if_not_found=True,
log=log,
)
Expand Down
4 changes: 3 additions & 1 deletion src/dioptra/restapi/v1/shared/drafts/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,12 @@ class DraftSchema(Schema):
)


class ModifyDraftBaseSchema(Schema):
class ModifyDraftPutQueryParameters(Schema):
resourceSnapshot = fields.Integer(
attribute="resource_snapshot_id",
metadata=dict(description="ID of the resource snapshot this draft modifies."),
allow_none=True,
load_default=None,
)


Expand Down
11 changes: 6 additions & 5 deletions src/dioptra/restapi/v1/shared/drafts/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,8 @@ def modify(
if draft is None:
return None

current_timestamp = datetime.datetime.now(tz=datetime.timezone.utc)
draft.payload["resource_data"] = payload
draft.last_modified_on = current_timestamp
draft.last_modified_on = datetime.datetime.now(tz=datetime.timezone.utc)

if commit:
db.session.commit()
Expand Down Expand Up @@ -437,6 +436,7 @@ def modify(
self,
resource_id: int,
payload: dict[str, Any],
resource_snapshot_id: int | None = None,
error_if_not_found: bool = False,
commit: bool = True,
**kwargs,
Expand All @@ -446,6 +446,7 @@ def modify(
Args:
resource_id: The unique id of the resource.
payload: The contents of the draft.
resource_snapshot_id: The id of the snapshot this draft is based on.
error_if_not_found: If True, raise an error if the group is not found.
Defaults to False.
commit: If True, commit the transaction. Defaults to True.
Expand All @@ -466,10 +467,10 @@ def modify(
if draft is None:
return None, num_other_drafts

current_timestamp = datetime.datetime.now(tz=datetime.timezone.utc)
draft.payload["resource_snapshot_id"] = payload.pop("resource_snapshot_id")
draft.payload["resource_data"] = payload
draft.last_modified_on = current_timestamp
if resource_snapshot_id is not None:
draft.payload["resource_snapshot_id"] = resource_snapshot_id
draft.last_modified_on = datetime.datetime.now(tz=datetime.timezone.utc)

if commit:
db.session.commit()
Expand Down
1 change: 0 additions & 1 deletion tests/unit/restapi/lib/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def assert_draft_response_contents_matches_expectations(
Args:
response: The actual response from the API.
expected_contents: The expected response from the API.
existing_draft: If the draft is of an existing resource or not.
Raises:
AssertionError: If the API response does not match the expected response
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/restapi/lib/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def run_existing_resource_drafts_tests(
)

# Modify operation tests
draft_mod["resourceSnapshotId"] = response["resourceSnapshot"]
response = client.modify(*resource_ids, **draft_mod).json()
response = client.modify(*resource_ids, resource_snapshot_id=99, **draft_mod).json()
asserts.assert_draft_response_contents_matches_expectations(
response, draft_mod_expected
)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/restapi/v1/test_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def test_manage_existing_entrypoint_draft(
"user_id": auth_account["id"],
"group_id": entrypoint["group"]["id"],
"resource_id": entrypoint["id"],
"resource_snapshot_id": entrypoint["snapshot"],
"resource_snapshot_id": 99,
"num_other_drafts": 0,
"payload": draft_mod,
}
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/restapi/v1/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def test_manage_existing_experiment_draft(
"user_id": auth_account["id"],
"group_id": experiment["group"]["id"],
"resource_id": experiment["id"],
"resource_snapshot_id": experiment["snapshot"],
"resource_snapshot_id": 99,
"num_other_drafts": 0,
"payload": draft_mod,
}
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/restapi/v1/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def test_manage_existing_model_draft(
"user_id": auth_account["id"],
"group_id": model["group"]["id"],
"resource_id": model["id"],
"resource_snapshot_id": model["snapshot"],
"resource_snapshot_id": 99,
"num_other_drafts": 0,
"payload": draft_mod,
}
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/restapi/v1/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,7 @@ def test_manage_existing_plugin_draft(
"user_id": auth_account["id"],
"group_id": plugin["group"]["id"],
"resource_id": plugin["id"],
"resource_snapshot_id": plugin["snapshot"],
"resource_snapshot_id": 99,
"num_other_drafts": 0,
"payload": draft_mod,
}
Expand Down Expand Up @@ -1375,7 +1375,7 @@ def hello_world(name: str) -> str:
"user_id": auth_account["id"],
"group_id": plugin_file["group"]["id"],
"resource_id": plugin_file["id"],
"resource_snapshot_id": plugin_file["snapshot"],
"resource_snapshot_id": 99,
"num_other_drafts": 0,
"payload": draft_mod,
}
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/restapi/v1/test_plugin_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ def test_manage_existing_plugin_parameter_type_draft(
"user_id": auth_account["id"],
"group_id": plugin_param_type["group"]["id"],
"resource_id": plugin_param_type["id"],
"resource_snapshot_id": plugin_param_type["snapshot"],
"resource_snapshot_id": 99,
"num_other_drafts": 0,
"payload": draft_mod,
}
Expand Down
14 changes: 3 additions & 11 deletions tests/unit/restapi/v1/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,16 +587,8 @@ def test_manage_existing_queue_draft(
description = "description"

# test creation
draft = {
"name": name,
"description": description,
"resourceSnapshot": queue["snapshot"],
}
draft_mod = {
"name": new_name,
"description": description,
"resourceSnapshot": queue["snapshot"],
}
draft = {"name": name, "description": description}
draft_mod = {"name": new_name, "description": description}

# Expected responses
draft_expected = {
Expand All @@ -611,7 +603,7 @@ def test_manage_existing_queue_draft(
"user_id": auth_account["id"],
"group_id": queue["group"]["id"],
"resource_id": queue["id"],
"resource_snapshot_id": queue["snapshot"],
"resource_snapshot_id": 99,
"num_other_drafts": 0,
"payload": draft_mod,
}
Expand Down

0 comments on commit 5c96dbf

Please sign in to comment.