diff --git a/docs/specs/openapi.json b/docs/specs/openapi.json index 66c6f4d554..61c7d48225 100644 --- a/docs/specs/openapi.json +++ b/docs/specs/openapi.json @@ -5593,7 +5593,9 @@ } } } - }, + } + }, + "/settings/cost/groups/add/": { "put": { "tags": [ "Settings", @@ -5619,14 +5621,19 @@ } } } - }, - "delete": { + } + }, + "/settings/cost-groups/remove/": { + "put": { "tags": [ "Settings", "Cost Groups" ], - "summary": "Remove projects from a coust group", - "operationId": "deleteSettingsCostGroups", + "summary": "Remove projects from a cost group", + "operationId": "putSettingsCostGroupsRemove", + "requestBody": { + "$ref": "#/components/requestBodies/CostGroupsBody" + }, "responses": { "204": { "description": "Cost groups updated" diff --git a/koku/api/settings/cost_groups/query_handler.py b/koku/api/settings/cost_groups/query_handler.py index c3787cdf02..fea6d37344 100644 --- a/koku/api/settings/cost_groups/query_handler.py +++ b/koku/api/settings/cost_groups/query_handler.py @@ -55,24 +55,25 @@ def _remove_default_projects(projects: list[dict[str, str]]) -> list[dict[str, s def put_openshift_namespaces(projects: list[dict[str, str]]) -> list[dict[str, str]]: projects = _remove_default_projects(projects) - # Build mapping of cost groups to cost category IDs in order to easiy get + # Build mapping of cost groups to cost category IDs in order to easily get # the ID of the cost group to update cost_groups = {item["name"]: item["id"] for item in OpenshiftCostCategory.objects.values("name", "id")} - namespaces_to_create = [ - OpenshiftCostCategoryNamespace( - namespace=new_project["project"], - system_default=False, - cost_category_id=cost_groups[new_project["group"]], - ) - for new_project in projects - ] - try: - # Perform bulk create - OpenshiftCostCategoryNamespace.objects.bulk_create(namespaces_to_create) - except IntegrityError as e: - # Handle IntegrityError (e.g., if a unique constraint is violated) - LOG.warning(f"IntegrityError: {e}") + # TODO: With Django 4.2, we can move back to using bulk_updates() since it allows updating conflicts + # https://docs.djangoproject.com/en/4.2/ref/models/querysets/#bulk-create + for new_project in projects: + try: + OpenshiftCostCategoryNamespace.objects.update_or_create( + namespace=new_project["project"], + system_default=False, + cost_category_id=cost_groups[new_project["group"]], + ) + except IntegrityError as e: + # The project already exists. Move it to a different cost group. + LOG.warning(f"IntegrityError: {e}") + OpenshiftCostCategoryNamespace.objects.filter(namespace=new_project["project"]).update( + cost_category_id=cost_groups[new_project["group"]] + ) return projects @@ -80,7 +81,7 @@ def put_openshift_namespaces(projects: list[dict[str, str]]) -> list[dict[str, s def delete_openshift_namespaces(projects: list[dict[str, str]]) -> list[dict[str, str]]: projects = _remove_default_projects(projects) projects_to_delete = [item["project"] for item in projects] - deleted_count, _ = ( + deleted_count, deletions = ( OpenshiftCostCategoryNamespace.objects.filter(namespace__in=projects_to_delete) .exclude(system_default=True) .delete() diff --git a/koku/api/settings/cost_groups/view.py b/koku/api/settings/cost_groups/view.py index ce71301b37..148d720652 100644 --- a/koku/api/settings/cost_groups/view.py +++ b/koku/api/settings/cost_groups/view.py @@ -70,24 +70,6 @@ def get(self, request: Request, **kwargs) -> Response: return paginator.paginated_response - def put(self, request: Request) -> Response: - serializer = CostGroupProjectSerializer(data=request.data, many=True) - serializer.is_valid(raise_exception=True) - - projects = put_openshift_namespaces(serializer.validated_data) - self._summarize_current_month(request.user.customer.schema_name, projects) - - return Response(status=status.HTTP_204_NO_CONTENT) - - def delete(self, request: Request) -> Response: - serializer = CostGroupProjectSerializer(data=request.data, many=True) - serializer.is_valid(raise_exception=True) - - projects = delete_openshift_namespaces(serializer.validated_data) - self._summarize_current_month(request.user.customer.schema_name, projects) - - return Response(status=status.HTTP_204_NO_CONTENT) - def _summarize_current_month(self, schema_name: str, projects: list[dict[str, str]]) -> list[str]: """Resummarize OCP data for the current month.""" projects_to_summarize = [proj["project"] for proj in projects] @@ -111,3 +93,25 @@ def _summarize_current_month(self, schema_name: str, projects: list[dict[str, st async_ids.append(str(async_result)) return async_ids + + +class CostGroupsAddView(CostGroupsView): + def put(self, request: Request) -> Response: + serializer = CostGroupProjectSerializer(data=request.data, many=True) + serializer.is_valid(raise_exception=True) + + projects = put_openshift_namespaces(serializer.validated_data) + self._summarize_current_month(request.user.customer.schema_name, projects) + + return Response(status=status.HTTP_204_NO_CONTENT) + + +class CostGroupsRemoveView(CostGroupsView): + def put(self, request: Request) -> Response: + serializer = CostGroupProjectSerializer(data=request.data, many=True) + serializer.is_valid(raise_exception=True) + + projects = delete_openshift_namespaces(serializer.validated_data) + self._summarize_current_month(request.user.customer.schema_name, projects) + + return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/koku/api/settings/test/cost_groups/test_query_handler.py b/koku/api/settings/test/cost_groups/test_query_handler.py index b1dede6ac5..06a8d02c5b 100644 --- a/koku/api/settings/test/cost_groups/test_query_handler.py +++ b/koku/api/settings/test/cost_groups/test_query_handler.py @@ -20,6 +20,7 @@ from masu.processor.tasks import OCP_QUEUE from masu.processor.tasks import OCP_QUEUE_XL from reporting.provider.ocp.models import OCPProject +from reporting.provider.ocp.models import OpenshiftCostCategory from reporting.provider.ocp.models import OpenshiftCostCategoryNamespace @@ -40,6 +41,8 @@ def setUp(self): .first() ) + self.custom_cost_group = OpenshiftCostCategory.objects.create(name="Overhead", label=[]) + self.project = project_to_insert.get("project") self.provider_uuid = project_to_insert.get("cluster__provider__uuid") self.body_format = [{"project": self.project, "group": self.default_cost_group}] @@ -48,6 +51,14 @@ def setUp(self): def url(self): return reverse("settings-cost-groups") + @property + def add_url(self): + return reverse("settings-cost-groups-add") + + @property + def remove_url(self): + return reverse("settings-cost-groups-remove") + def test_get_cost_groups(self): """Basic test to exercise the API endpoint""" with schema_context(self.schema_name): @@ -163,13 +174,14 @@ def _add_additional_projects(schema_name): expected_count = current_rows + 1 put_openshift_namespaces(self.body_format) current_count = OpenshiftCostCategoryNamespace.objects.count() + if current_count != expected_count: raise FailedToPopulateDummyProjects("Failed to populate dummy data for deletion testing.") _add_additional_projects(self.schema_name) body = json.dumps(self.body_format) with schema_context(self.schema_name): - response = self.client.delete(self.url, body, content_type="application/json", **self.headers) + response = self.client.put(self.remove_url, body, content_type="application/json", **self.headers) current_count = OpenshiftCostCategoryNamespace.objects.filter(namespace=self.project).count() self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) @@ -192,21 +204,28 @@ def test_query_handler_remove_default_projects(self): self.assertEqual(_remove_default_projects(body_format), []) def test_put_catch_integrity_error(self): - """Test that we catch integrity errors on put.""" - self.body_format.append({"project": self.project, "group": self.default_cost_group}) + """Test that we catch integrity errors when moving a project to a + different Cost Group.""" + + with schema_context(self.schema_name): + # Create an entry with a project in the default Cost Group + put_openshift_namespaces(self.body_format) + with self.assertLogs(logger="api.settings.cost_groups.query_handler", level="WARNING") as log_warning: with schema_context(self.schema_name): - put_openshift_namespaces(self.body_format) + # Move a project to another Cost Group + put_openshift_namespaces([{"project": self.project, "group": self.custom_cost_group.name}]) + self.assertEqual(len(log_warning.records), 1) # Check that a warning log was generated self.assertIn("IntegrityError", log_warning.records[0].getMessage()) @patch("api.settings.cost_groups.view.update_summary_tables.s") @patch("api.settings.cost_groups.view.is_customer_large") - def test_put_new_records(self, mock_is_customer_large, mock_update_schedule): + def test_add_new_records(self, mock_is_customer_large, mock_update_schedule): mock_is_customer_large.return_value = False with schema_context(self.schema_name): body = json.dumps(self.body_format) - response = self.client.put(self.url, body, content_type="application/json", **self.headers) + response = self.client.put(self.add_url, body, content_type="application/json", **self.headers) current_count = OpenshiftCostCategoryNamespace.objects.filter(namespace=self.project).count() self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) @@ -220,13 +239,34 @@ def test_put_new_records(self, mock_is_customer_large, mock_update_schedule): ) mock_update_schedule.return_value.apply_async.assert_called_with(queue=OCP_QUEUE) + @patch("api.settings.cost_groups.view.update_summary_tables.s") + @patch("api.settings.cost_groups.view.is_customer_large", return_value=False) + def test_move_project_to_different_cost_group(self, mock_is_customer_large, mock_update_schedule): + """Test moving an existing project to a different Cost Group""" + + with schema_context(self.schema_name): + # Add a project to the default Cost Group + body = json.dumps(self.body_format) + self.client.put(self.add_url, body, content_type="application/json", **self.headers) + OpenshiftCostCategoryNamespace.objects.filter(namespace=self.project).count() + + # Move the project to a custom Cost Group + body = json.dumps([{"project": self.project, "group": self.custom_cost_group.name}]) + response = self.client.put(self.add_url, body, content_type="application/json", **self.headers) + current_count = OpenshiftCostCategoryNamespace.objects.filter( + cost_category_id=self.custom_cost_group.id + ).count() + + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertEqual(current_count, 1) + @patch("api.settings.cost_groups.view.update_summary_tables.s") @patch("api.settings.cost_groups.view.is_customer_large") - def test_put_new_records_large(self, mock_is_customer_large, mock_update_schedule): + def test_add_new_records_large(self, mock_is_customer_large, mock_update_schedule): mock_is_customer_large.return_value = True with schema_context(self.schema_name): body = json.dumps(self.body_format) - response = self.client.put(self.url, body, content_type="application/json", **self.headers) + response = self.client.put(self.add_url, body, content_type="application/json", **self.headers) current_count = OpenshiftCostCategoryNamespace.objects.filter(namespace=self.project).count() self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) diff --git a/koku/api/urls.py b/koku/api/urls.py index 0d5d66410b..cf58f723f0 100644 --- a/koku/api/urls.py +++ b/koku/api/urls.py @@ -30,6 +30,8 @@ from api.views import AzureSubscriptionGuidView from api.views import AzureTagView from api.views import cloud_accounts +from api.views import CostGroupsAddView +from api.views import CostGroupsRemoveView from api.views import CostGroupsView from api.views import CostModelResourceTypesView from api.views import DataExportRequestViewSet @@ -350,6 +352,8 @@ path("settings/", deprecate_view(SettingsView.as_view()), name="settings"), path("settings/aws_category_keys/", SettingsAWSCategoryKeyView.as_view(), name="settings-aws-category-keys"), path("settings/cost-groups/", CostGroupsView.as_view(), name="settings-cost-groups"), + path("settings/cost-groups/add/", CostGroupsAddView.as_view(), name="settings-cost-groups-add"), + path("settings/cost-groups/remove/", CostGroupsRemoveView.as_view(), name="settings-cost-groups-remove"), path( "settings/aws_category_keys/enable/", SettingsEnableAWSCategoryKeyView.as_view(), diff --git a/koku/api/views.py b/koku/api/views.py index a9a975b12a..98d7aeafef 100644 --- a/koku/api/views.py +++ b/koku/api/views.py @@ -74,6 +74,8 @@ from api.settings.aws_category_keys.view import SettingsAWSCategoryKeyView from api.settings.aws_category_keys.view import SettingsDisableAWSCategoryKeyView from api.settings.aws_category_keys.view import SettingsEnableAWSCategoryKeyView +from api.settings.cost_groups.view import CostGroupsAddView +from api.settings.cost_groups.view import CostGroupsRemoveView from api.settings.cost_groups.view import CostGroupsView from api.settings.tags.view import SettingsDisableTagView from api.settings.tags.view import SettingsEnableTagView