Skip to content

Commit

Permalink
[COST-4534] Support filtering with multiple values (#4846)
Browse files Browse the repository at this point in the history
* Pass parameters in as an argument rather than urlencoding
* Add tests
  • Loading branch information
samdoran committed Jan 8, 2024
1 parent ef98d21 commit e4a8c29
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 12 deletions.
4 changes: 2 additions & 2 deletions koku/api/settings/cost_groups/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
class CostGroupFilterSerializer(FilterSerializer):
"""Serializer for Cost Group Settings."""

project = serializers.CharField(required=False)
group = serializers.CharField(required=False)
project = StringOrListField(child=serializers.CharField(), required=False)
group = StringOrListField(child=serializers.CharField(), required=False)
default = serializers.BooleanField(required=False)
cluster = StringOrListField(child=serializers.CharField(), required=False)

Expand Down
41 changes: 32 additions & 9 deletions koku/api/settings/test/cost_groups/test_query_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
#
import json
from unittest.mock import patch
from urllib.parse import quote_plus
from urllib.parse import urlencode

from django.urls import reverse
from django_tenants.utils import schema_context
Expand Down Expand Up @@ -64,14 +62,13 @@ def test_get_cost_groups_invalid(self):

def test_get_cost_groups_filters(self):
"""Basic test to exercise the API endpoint"""
parameters = [{"group": self.default_cost_group}, {"default": True}, {"project": OCP_PLATFORM_NAMESPACE}, {}]
parameters = ({"group": self.default_cost_group}, {"default": True}, {"project": OCP_PLATFORM_NAMESPACE}, {})
for parameter in parameters:
with self.subTest(parameter=parameter):
for filter_option, filter_value in parameter.items():
param = {f"filter[{filter_option}]": filter_value}
url = self.url + "?" + urlencode(param, quote_via=quote_plus)
with schema_context(self.schema_name):
response = self.client.get(url, **self.headers)
response = self.client.get(self.url, param, **self.headers)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.data.get("data")
for item in data:
Expand All @@ -87,14 +84,41 @@ def test_get_cost_groups_filter_cluster(self):
for item in data:
self.assertIn(OCP_ON_GCP_CLUSTER_ID, item.get("clusters"))

def test_get_cost_groups_filters_multiple(self):
"""Test filtering with multiple values per field"""
test_matrix = (
{
"field": "group",
"value": ["Platform"],
"expected": {"Platform"},
},
{
"field": "project",
"value": [OCP_PLATFORM_NAMESPACE, "-PrOd"],
"expected": {"openshift-default", "koku-prod"},
},
)
for case in test_matrix:
with self.subTest(parameter=case["value"]):
params = {f"filter[{case['field']}]": case["value"]}
with schema_context(self.schema_name):
response = self.client.get(self.url, params, **self.headers)

data = response.data.get("data")
result = {}
if data:
result = {item[case["field"]] for item in data}

self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertTrue(result.issubset(case["expected"]))

def test_get_cost_groups_order(self):
"""Basic test to exercise the API endpoint"""

def spotcheck_first_data_element(option, value):
param = {f"order_by[{option}]": value}
url = self.url + "?" + urlencode(param, quote_via=quote_plus)
with schema_context(self.schema_name):
response = self.client.get(url, **self.headers)
response = self.client.get(self.url, param, **self.headers)

return response.status_code, response.data.get("data")[0]

Expand All @@ -119,9 +143,8 @@ def test_get_cost_groups_exclude_functionality(self):
with self.subTest(parameter=parameter):
for exclude_option, exclude_value in parameter.items():
param = {f"exclude[{exclude_option}]": exclude_value}
url = self.url + "?" + urlencode(param, quote_via=quote_plus)
with schema_context(self.schema_name):
response = self.client.get(url, **self.headers)
response = self.client.get(self.url, param, **self.headers)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.data.get("data")
for item in data:
Expand Down
2 changes: 1 addition & 1 deletion koku/api/settings/test/cost_groups/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_invalid_serializer_data(self):
self.assertFalse(serializer.is_valid())

def test_serialization(self):
instance_data = {"project": "example_project", "group": "example_group", "default": True}
instance_data = {"project": ["example_project"], "group": ["example_group"], "default": True}
serializer = CostGroupFilterSerializer(instance_data)
self.assertEqual(serializer.data, instance_data)

Expand Down

0 comments on commit e4a8c29

Please sign in to comment.