diff --git a/koku/api/settings/tags/mapping/serializers.py b/koku/api/settings/tags/mapping/serializers.py index 669f57b77d..47ca49ee57 100644 --- a/koku/api/settings/tags/mapping/serializers.py +++ b/koku/api/settings/tags/mapping/serializers.py @@ -12,14 +12,6 @@ from reporting.provider.all.models import TagMapping -class ParentSerializer(serializers.ModelSerializer): - cost_model_id = serializers.UUIDField() - - class Meta: - model = EnabledTagKeys - fields = ["uuid", "key", "cost_model_id"] - - class ViewOptionsSerializer(serializers.ModelSerializer): """Intended to be used in conjuntion with the CostModelAnnotationMixin.""" @@ -55,22 +47,8 @@ class AddChildSerializer(serializers.Serializer): parent = serializers.UUIDField() children = serializers.ListField(child=serializers.UUIDField()) - def _unify_parent_key(self, data): - """Unifies duplicate parents keys under a single uuid.""" - enabled_row = EnabledTagKeys.objects.filter(uuid=data["parent"]).first() - if not enabled_row: - return data - tag_map = TagMapping.objects.filter(parent__key=enabled_row.key).first() - if not tag_map: - return data - if tag_map.parent_id == data["parent"]: - return data - data["parent"] = tag_map.parent_id - return data - def validate(self, data): """This function validates the options and returns the enabled tag rows.""" - data = self._unify_parent_key(data) children_list = data["children"] combined_list = [data["parent"]] + children_list enabled_rows = EnabledTagKeys.objects.filter(uuid__in=combined_list, enabled=True) diff --git a/koku/api/settings/tags/mapping/view.py b/koku/api/settings/tags/mapping/view.py index 237f02620a..96dd94f592 100644 --- a/koku/api/settings/tags/mapping/view.py +++ b/koku/api/settings/tags/mapping/view.py @@ -21,7 +21,6 @@ from api.common.permissions.settings_access import SettingsAccessPermission from api.settings.tags.mapping.query_handler import Relationship from api.settings.tags.mapping.serializers import AddChildSerializer -from api.settings.tags.mapping.serializers import ParentSerializer from api.settings.tags.mapping.serializers import TagMappingSerializer from api.settings.tags.mapping.serializers import ViewOptionsSerializer from api.settings.tags.mapping.utils import resummarize_current_month_by_tag_keys @@ -52,6 +51,16 @@ class Meta: default_ordering = ["parent"] +class SettingsEnabledTagKeysFilter(TagMappingFilters): + key = NonValidatedMultipleChoiceFilter(method="filter_by_key") + source_type = NonValidatedMultipleChoiceFilter(field_name="provider_type", method="filter_by_source_type") + + class Meta: + model = EnabledTagKeys + fields = ("key", "source_type") + default_ordering = ["key", "-enabled"] + + class SettingsTagMappingView(generics.GenericAPIView): queryset = TagMapping.objects.all() serializer_class = TagMappingSerializer @@ -71,16 +80,6 @@ def get(self, request: Request, **kwargs): return response -class ChildViewFilter(TagMappingFilters): - key = NonValidatedMultipleChoiceFilter(method="filter_by_key") - source_type = NonValidatedMultipleChoiceFilter(field_name="provider_type", method="filter_by_source_type") - - class Meta: - model = EnabledTagKeys - fields = ("key", "source_type") - default_ordering = ["key", "-enabled"] - - class SettingsTagMappingChildView(CostModelAnnotationMixin, generics.GenericAPIView): queryset = ( EnabledTagKeys.objects.exclude(parent__isnull=False).exclude(child__parent__isnull=False).filter(enabled=True) @@ -88,7 +87,7 @@ class SettingsTagMappingChildView(CostModelAnnotationMixin, generics.GenericAPIV serializer_class = ViewOptionsSerializer permission_classes = (SettingsAccessPermission,) filter_backends = (DjangoFilterBackend,) - filterset_class = ChildViewFilter + filterset_class = SettingsEnabledTagKeysFilter @method_decorator(never_cache) def get(self, request: Request, **kwargs): @@ -100,21 +99,12 @@ def get(self, request: Request, **kwargs): return response -class ParentViewFilter(TagMappingFilters): - key = NonValidatedMultipleChoiceFilter(method="filter_by_key") - - class Meta: - model = EnabledTagKeys - fields = ("key",) - default_ordering = ["key", "child__parent", "provider_type"] - - class SettingsTagMappingParentView(CostModelAnnotationMixin, generics.GenericAPIView): - queryset = EnabledTagKeys.objects.exclude(child__parent__isnull=False).filter(enabled=True).distinct("key") - serializer_class = ParentSerializer + queryset = EnabledTagKeys.objects.exclude(child__parent__isnull=False).filter(enabled=True) + serializer_class = ViewOptionsSerializer permission_classes = (SettingsAccessPermission,) filter_backends = (DjangoFilterBackend,) - filterset_class = ParentViewFilter + filterset_class = SettingsEnabledTagKeysFilter @method_decorator(never_cache) def get(self, request: Request, **kwargs): diff --git a/koku/api/settings/test/tags/mappings/test_view.py b/koku/api/settings/test/tags/mappings/test_view.py index a2a54dae8f..14c060e947 100644 --- a/koku/api/settings/test/tags/mappings/test_view.py +++ b/koku/api/settings/test/tags/mappings/test_view.py @@ -2,7 +2,6 @@ # Copyright 2024 Red Hat Inc. # SPDX-License-Identifier: Apache-2.0 # -import logging from collections import defaultdict from unittest.mock import patch @@ -19,8 +18,6 @@ from reporting.provider.all.models import EnabledTagKeys from reporting.provider.all.models import TagMapping -LOG = logging.getLogger(__name__) - class TestSettingsTagMappingView(MasuTestCase): def setUp(self): @@ -30,38 +27,52 @@ def setUp(self): with tenant_context(self.tenant): self.enabled_uuid_list = list(EnabledTagKeys.objects.filter(enabled=True).values_list("uuid", flat=True)) - self.tag_mapping_url = reverse("tags-mapping") - self.parent_get_url = reverse("tags-mapping-parent") - self.parent_remove_url = reverse("tags-mapping-parent-remove") - self.child_get_url = reverse("tags-mapping-child") - self.child_add_url = reverse("tags-mapping-child-add") - self.child_remove_url = reverse("tags-mapping-child-remove") def test_get_method(self): """Test the get method for the tag mapping view""" - for url in [self.tag_mapping_url, self.parent_get_url, self.child_get_url]: - with self.subTest(url=url): - response = self.client.get(url, **self.headers) - self.assertEqual(response.status_code, status.HTTP_200_OK) + url = reverse("tags-mapping") + response = self.client.get(url, **self.headers) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_get_method_with_filter(self): """Test the get method for the tag mapping view with a filter""" # Check that the response data is filtered correctly (with AWS example) - url = self.tag_mapping_url + "?source_type=aWs" # also testing case sensitivity + url = reverse("tags-mapping") + "?source_type=aWs" # also testing case sensitivity response = self.client.get(url, **self.headers) self.assertEqual(response.status_code, status.HTTP_200_OK) for item in response.data["data"]: self.assertEqual(item["source_type"], "AWS") # Check that the response data is filtered correctly (with OCP example) - url = self.tag_mapping_url + "?source_type=ocP" # also testing case sensitivity + url = reverse("tags-mapping") + "?source_type=ocP" # also testing case sensitivity response = self.client.get(url, **self.headers) self.assertEqual(response.status_code, status.HTTP_200_OK) for item in response.data["data"]: self.assertEqual(item["source_type"], "OCP") + def test_get_child_tag_key(self): + """Test the get method for the tag mapping Child view""" + url = reverse("tags-mapping-child") + response = self.client.get(url, **self.headers) + self.assertEqual(response.status_code, status.HTTP_200_OK) + def test_get_child_with_filter(self): """Test the get method for the tag mapping Child view with a filter""" - url = self.child_get_url + "?source_type=aWs" + url = reverse("tags-mapping-child") + "?source_type=aWs" + response = self.client.get(url, **self.headers) + self.assertEqual(response.status_code, status.HTTP_200_OK) + # Check that the response data is filtered correctly + for item in response.data["data"]: + self.assertEqual(item["source_type"], "AWS") + + def test_get_parent(self): + """Test the get method for the tag mapping Parent view""" + url = reverse("tags-mapping-parent") + response = self.client.get(url, **self.headers) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_get_parent_with_filter(self): + """Test the get method for the tag mapping Parent view with a filter""" + url = reverse("tags-mapping-parent") + "?source_type=aWs" response = self.client.get(url, **self.headers) self.assertEqual(response.status_code, status.HTTP_200_OK) # Check that the response data is filtered correctly @@ -70,87 +81,96 @@ def test_get_child_with_filter(self): def test_put_method_invalid_uuid(self): """Test the put method for the tag mapping view with an invalid uuid""" + url = reverse("tags-mapping-child-add") data = {"parent": "29f738e4-38f4-4ed8-a9f4-beed48165220", "children": ["29f738e4-38f4-4ed8-a9f4-beed48165229"]} - response = self.client.put(self.child_add_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_put_method_validate_parent(self): """Test if a parent can be added as a child.""" + url = reverse("tags-mapping-child-add") data = { "parent": self.enabled_uuid_list[0], "children": [self.enabled_uuid_list[1], self.enabled_uuid_list[2], self.enabled_uuid_list[3]], } - response = self.client.put(self.child_add_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) # Adding a parent as child data = {"parent": self.enabled_uuid_list[4], "children": [self.enabled_uuid_list[0]]} - response = self.client.put(self.child_add_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_put_method_validate_child(self): """Test if a child can be added as a parent.""" + url = reverse("tags-mapping-child-add") data = { "parent": self.enabled_uuid_list[0], "children": [self.enabled_uuid_list[1], self.enabled_uuid_list[2]], } - response = self.client.put(self.child_add_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) # Adding a child as parent data = {"parent": self.enabled_uuid_list[2], "children": [self.enabled_uuid_list[4]]} - response = self.client.put(self.child_add_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) # Add one more additional child data = { "parent": self.enabled_uuid_list[0], "children": [self.enabled_uuid_list[1], self.enabled_uuid_list[2], self.enabled_uuid_list[3]], } - response = self.client.put(self.child_add_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_put_method_add_multiple_children(self): """Test adding multiple children (list).""" + url = reverse("tags-mapping-child-add") data = { "parent": self.enabled_uuid_list[0], "children": [self.enabled_uuid_list[1], self.enabled_uuid_list[2]], } - response = self.client.put(self.child_add_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) def test_put_method_remove_children(self): """Test removing children.""" + url = reverse("tags-mapping-child-add") data = { "parent": self.enabled_uuid_list[0], "children": [self.enabled_uuid_list[1], self.enabled_uuid_list[2], self.enabled_uuid_list[3]], } - response = self.client.put(self.child_add_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) # Removing children + url = reverse("tags-mapping-child-remove") data = {"ids": [self.enabled_uuid_list[1], self.enabled_uuid_list[3]]} - response = self.client.put(self.child_remove_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) def test_put_method_remove_parent(self): """Test removing parent.""" + url = reverse("tags-mapping-child-add") data = { "parent": self.enabled_uuid_list[0], "children": [self.enabled_uuid_list[1], self.enabled_uuid_list[2], self.enabled_uuid_list[3]], } - response = self.client.put(self.child_add_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) # Removing parent + url = reverse("tags-mapping-parent-remove") data = {"ids": [self.enabled_uuid_list[0]]} - response = self.client.put(self.parent_remove_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) def test_filter_by_source_type(self): """Test the filter by source_type.""" # Get an already inserted provider type to check if the filter is working with tenant_context(self.tenant): + url = reverse("tags-mapping-child-add") data = { "parent": self.enabled_uuid_list[0], "children": [self.enabled_uuid_list[1], self.enabled_uuid_list[2], self.enabled_uuid_list[3]], } - response = self.client.put(self.child_add_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) parent_provider_types = TagMapping.objects.values_list("parent__provider_type", flat=True).distinct() test_filter = parent_provider_types[0] @@ -161,7 +181,7 @@ def test_filter_by_source_type(self): self.assertNotEqual(len(result), 0) filter = "?filter[source_type]=random" - url = self.tag_mapping_url + filter + url = reverse("tags-mapping") + filter response = self.client.get(url, **self.headers) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["data"]), 0) @@ -257,8 +277,9 @@ def test_cached_tag_rate_mapping(self, mock_get): def test_removal_of_unmapped_key(self): """Test that we error when we try to remove an unmapped key.""" - for url in [self.parent_remove_url, self.child_remove_url]: - with self.subTest(url=url): + for url_key in ["tags-mapping-parent-remove", "tags-mapping-child-remove"]: + with self.subTest(url_key=url_key): + url = reverse(url_key) data = {"ids": [self.enabled_uuid_list[0]]} response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -283,39 +304,42 @@ def test_adding_a_child_connected_to_a_cost_model(self, mock_tag_rates): "cost_model_id": "91417cce-4b66-4b41-8...137bdb1620", }, } + url = reverse("tags-mapping-child-add") data = { "parent": self.enabled_uuid_list[0], "children": [str(self.enabled_uuid_list[1])], } - response = self.client.put(self.child_add_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_empty_child_returns_400(self): """Test empty child returns 400""" data = {"parent": self.enabled_uuid_list[0], "children": []} - response = self.client.put(self.child_add_url, data, format="json", **self.headers) + response = self.client.put(reverse("tags-mapping-child-add"), data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_filter_by_parent_and_child(self): """Test that you can filter by parent & child.""" with tenant_context(self.tenant): child, parent = EnabledTagKeys.objects.all()[:2] + url = reverse("tags-mapping-child-add") data = { "parent": str(parent.uuid), "children": [str(child.uuid)], } - response = self.client.put(self.child_add_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - url = self.tag_mapping_url + f"?parent={parent.key}" + url = reverse("tags-mapping") + f"?parent={parent.key}" response = self.client.get(url, **self.headers) self.assertEqual(response.status_code, status.HTTP_200_OK) - url = self.tag_mapping_url + f"?child={child.key}" + url = reverse("tags-mapping") + f"?child={child.key}" response = self.client.get(url, **self.headers) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_order_by_fake_value(self): """Test the get method for the tag mapping view""" - url = self.tag_mapping_url + "?order_by[parent]=FAKE" + url = reverse("tags-mapping") + url = url + "?order_by[parent]=FAKE" response = self.client.get(url, **self.headers) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -333,18 +357,19 @@ def test_multi_source_type_filter(self): {"parent": aws_uuids[1], "children": [azure_uuids[1]]}, {"parent": azure_uuids[1], "children": [azure_uuids[3]]}, ] + url = reverse("tags-mapping-child-add") for data in body_metadata: - response = self.client.put(self.child_add_url, data, format="json", **self.headers) + response = self.client.put(url, data, format="json", **self.headers) # Test multiple source_type filters test_matrix = [ f"?filter[source_type]={Provider.PROVIDER_AWS}&filter[source_type]={Provider.PROVIDER_AZURE}", f"?filter[source_type]={Provider.PROVIDER_AWS}&filter[source_type]={Provider.PROVIDER_OCP}", ] for multi_filter in test_matrix: - for url in [self.child_get_url, self.tag_mapping_url]: - with self.subTest(multi_filter=multi_filter, url=url): - filtered_url = url + multi_filter - response = self.client.get(filtered_url, **self.headers) + for endpoint in ["tags-mapping-parent", "tags-mapping-child", "tags-mapping"]: + with self.subTest(multi_filter=multi_filter, endpoint=endpoint): + url = reverse(endpoint) + multi_filter + response = self.client.get(url, **self.headers) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotEqual(len(response.data["data"]), 0) @@ -356,22 +381,24 @@ def test_multi_key_filter(self): f"?filter[key]={enabled_keys[1].key}&filter[key]={enabled_keys[3].key}", ] for multi_filter in test_matrix: - for url in [self.parent_get_url, self.child_get_url]: - with self.subTest(multi_filter=multi_filter, url=url): - filtered_url = url + multi_filter - response = self.client.get(filtered_url, **self.headers) + for endpoint in ["tags-mapping-parent", "tags-mapping-child"]: + with self.subTest(multi_filter=multi_filter, endpoint=endpoint): + url = reverse(endpoint) + multi_filter + response = self.client.get(url, **self.headers) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotEqual(len(response.data["data"]), 0) def test_multi_key_parent_and_child_filter(self): """Test that you can filter by parent & child keys.""" + endpoint = "tags-mapping" enabled_keys = EnabledTagKeys.objects.filter(enabled=True) test_populate = [ {"parent": enabled_keys[0].uuid, "children": [enabled_keys[1].uuid, enabled_keys[2].uuid]}, {"parent": enabled_keys[3].uuid, "children": [enabled_keys[4].uuid, enabled_keys[5].uuid]}, ] + url = reverse("tags-mapping-child-add") for populate in test_populate: - self.client.put(self.child_add_url, populate, format="json", **self.headers) + self.client.put(url, populate, format="json", **self.headers) # test parent filter test_matrix = [ ("parent", enabled_keys[0].key, enabled_keys[3].key), @@ -380,33 +407,8 @@ def test_multi_key_parent_and_child_filter(self): for test in test_matrix: filter_key, key_one, key_two = test filter = f"?filter[{filter_key}]={key_one}&filter[{filter_key}]={key_two}" - url = self.tag_mapping_url + filter + url = reverse(endpoint) + filter with self.subTest(url=url): response = self.client.get(url, **self.headers) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotEqual(len(response.data["data"]), 0) - - def test_deduplicate_and_unify_parent_keys(self): - """Test that we have deduplicated app keys.""" - key = "test_deduplication" - providers = [Provider.PROVIDER_AWS, Provider.PROVIDER_AZURE, Provider.PROVIDER_GCP] - parent_keys = [EnabledTagKeys(key=key, enabled=True, provider_type=provider) for provider in providers] - with tenant_context(self.tenant): - EnabledTagKeys.objects.bulk_create(parent_keys) - url = self.parent_get_url + f"?filter[key]={key}" - response = self.client.get(url, **self.headers) - self.assertEqual(len(response.data.get("data")), 1) - # Test unification under one parent key - test_populate = [ - {"parent": parent_keys[0].uuid, "children": [self.enabled_uuid_list[0]]}, - {"parent": parent_keys[1].uuid, "children": [self.enabled_uuid_list[1]]}, - {"parent": parent_keys[2].uuid, "children": [self.enabled_uuid_list[2]]}, - ] - for populate in test_populate: - self.client.put(self.child_add_url, populate, format="json", **self.headers) - url = self.tag_mapping_url + f"?filter[parent]={key}" - response = self.client.get(url, **self.headers) - data = response.data.get("data") - self.assertEqual(len(data), 1) - data_dict = data[0] - self.assertTrue(len(data_dict.get("children")), len(test_populate))