diff --git a/backend/api/serializers/model_serializers.py b/backend/api/serializers/model_serializers.py index 5dd68a2c9..4a329fec2 100644 --- a/backend/api/serializers/model_serializers.py +++ b/backend/api/serializers/model_serializers.py @@ -6,9 +6,6 @@ from django.conf import settings from django.db.models import Max, Min, Sum from drf_spectacular.utils import extend_schema_serializer -from rest_framework import serializers -from rest_framework.exceptions import ValidationError - from metering_billing.invoice import generate_balance_adjustment_invoice from metering_billing.models import ( AddOnSpecification, @@ -56,7 +53,6 @@ ) from metering_billing.utils import convert_to_date, now_utc from metering_billing.utils.enums import ( - CATEGORICAL_FILTER_OPERATORS, CUSTOMER_BALANCE_ADJUSTMENT_STATUS, FLAT_FEE_BEHAVIOR, INVOICE_STATUS_ENUM, @@ -70,6 +66,8 @@ USAGE_BEHAVIOR, USAGE_BILLING_BEHAVIOR, ) +from rest_framework import serializers +from rest_framework.exceptions import ValidationError SVIX_CONNECTOR = settings.SVIX_CONNECTOR logger = logging.getLogger("django.server") @@ -199,36 +197,18 @@ class Meta: comparison_value = serializers.ListField(child=serializers.CharField()) -class SubscriptionCategoricalFilterSerializer( - ConvertEmptyStringToNullMixin, TimezoneFieldMixin, serializers.ModelSerializer +class SubscriptionFilterSerializer( + ConvertEmptyStringToNullMixin, TimezoneFieldMixin, serializers.Serializer ): - class Meta: - model = CategoricalFilter - fields = ("value", "property_name") - extra_kwargs = { - "property_name": { - "required": True, - }, - "value": {"required": True}, - } - value = serializers.CharField() property_name = serializers.CharField( help_text="The string name of the property to filter on. Example: 'product_id'" ) - def create(self, validated_data): - comparison_value = validated_data.pop("value") - comparison_value = [comparison_value] - validated_data["comparison_value"] = comparison_value - return CategoricalFilter.objects.get_or_create( - **validated_data, operator=CATEGORICAL_FILTER_OPERATORS.ISIN - ) - def to_representation(self, instance): data = { - "property_name": instance.property_name, - "value": instance.comparison_value[0], + "property_name": instance[0], + "value": instance[1], } return data @@ -345,9 +325,7 @@ class Meta: } subscription_id = SubscriptionUUIDField(source="subscription_record_id") - subscription_filters = SubscriptionCategoricalFilterSerializer( - many=True, source="filters" - ) + subscription_filters = SubscriptionFilterSerializer(many=True) customer = LightweightCustomerSerializer() billing_plan = LightweightPlanVersionSerializer() addons = LightweightAddOnSubscriptionRecordSerializer( @@ -445,11 +423,11 @@ def get_adjustments(self, obj) -> InvoiceLineItemAdjustmentSerializer(many=True) def get_subscription_filters( self, obj - ) -> SubscriptionCategoricalFilterSerializer(many=True, allow_null=True): + ) -> SubscriptionFilterSerializer(many=True, allow_null=True): ass_sub_record = obj.associated_subscription_record if ass_sub_record: - return SubscriptionCategoricalFilterSerializer( - ass_sub_record.filters.all(), many=True + return SubscriptionFilterSerializer( + ass_sub_record.subscription_filters, many=True ).data return None @@ -1640,7 +1618,7 @@ class Meta: help_text="Whether the subscription automatically renews. Defaults to true.", ) is_new = serializers.BooleanField(required=False) - subscription_filters = SubscriptionCategoricalFilterSerializer( + subscription_filters = SubscriptionFilterSerializer( many=True, required=False, help_text="Add filter key, value pairs that define which events will be applied to this plan subscription.", @@ -1683,22 +1661,9 @@ def create(self, validated_data): filters = validated_data.pop("subscription_filters", []) subscription_filters = [] for filter_data in filters: - sub_cat_filter_dict = { - "organization": validated_data["customer"].organization, - "property_name": filter_data["property_name"], - "operator": CATEGORICAL_FILTER_OPERATORS.ISIN, - "comparison_value": [filter_data["value"]], - } - try: - cf, _ = CategoricalFilter.objects.get_or_create(**sub_cat_filter_dict) - except CategoricalFilter.MultipleObjectsReturned: - cf = ( - CategoricalFilter.objects.filter(**sub_cat_filter_dict) - .first() - .delete() - ) - cf = CategoricalFilter.objects.filter(**sub_cat_filter_dict).first() - subscription_filters.append(cf) + subscription_filters.append( + [filter_data["property_name"], filter_data["value"]] + ) sr = SubscriptionRecord.create_subscription_record( start_date=validated_data["start_date"], end_date=validated_data.get("end_date"), @@ -1759,7 +1724,7 @@ class Meta: help_text="Whether the subscription automatically renews. Defaults to true.", ) is_new = serializers.BooleanField(required=False) - subscription_filters = SubscriptionCategoricalFilterSerializer( + subscription_filters = SubscriptionFilterSerializer( many=True, required=False, help_text="Add filter key, value pairs that define which events will be applied to this plan subscription.", @@ -1825,22 +1790,9 @@ def create(self, validated_data): filters = validated_data.pop("subscription_filters", []) subscription_filters = [] for filter_data in filters: - sub_cat_filter_dict = { - "organization": validated_data["customer"].organization, - "property_name": filter_data["property_name"], - "operator": CATEGORICAL_FILTER_OPERATORS.ISIN, - "comparison_value": [filter_data["value"]], - } - try: - cf, _ = CategoricalFilter.objects.get_or_create(**sub_cat_filter_dict) - except CategoricalFilter.MultipleObjectsReturned: - cf = ( - CategoricalFilter.objects.filter(**sub_cat_filter_dict) - .first() - .delete() - ) - cf = CategoricalFilter.objects.filter(**sub_cat_filter_dict).first() - subscription_filters.append(cf) + subscription_filters.append( + [filter_data["property_name"], filter_data["value"]] + ) sr = SubscriptionRecord.create_subscription_record( start_date=validated_data["start_date"], end_date=validated_data.get("end_date"), @@ -1869,9 +1821,7 @@ class Meta(SubscriptionRecordSerializer.Meta): plan_detail = LightweightPlanVersionSerializer( source="billing_plan", read_only=True ) - subscription_filters = SubscriptionCategoricalFilterSerializer( - source="filters", many=True, read_only=True - ) + subscription_filters = SubscriptionFilterSerializer(many=True, read_only=True) class SubscriptionInvoiceSerializer(SubscriptionRecordSerializer): @@ -2137,7 +2087,7 @@ class SubscriptionRecordFilterSerializer(serializers.Serializer): required=True, help_text="Filter to a specific plan.", ) - subscription_filters = SubscriptionCategoricalFilterSerializer( + subscription_filters = SubscriptionFilterSerializer( many=True, required=False, help_text="Filter to a specific set of subscription filters. If your billing model only allows for one subscription per customer, you very likely do not need this field. Must be formatted as a JSON-encoded + stringified list of dictionaries, where each dictionary has a key of 'property_name' and a key of 'value'.", @@ -2222,7 +2172,7 @@ class AddOnSubscriptionRecordFilterSerializer(serializers.Serializer): required=True, help_text="Filter to a specific plan.", ) - attached_subscription_filters = SubscriptionCategoricalFilterSerializer( + attached_subscription_filters = SubscriptionFilterSerializer( many=True, required=False, help_text="Filter to a specific set of subscription filters. If your billing model only allows for one subscription per customer, you very likely do not need this field. Must be formatted as a JSON-encoded + stringified list of dictionaries, where each dictionary has a key of 'property_name' and a key of 'value'.", diff --git a/backend/api/serializers/nonmodel_serializers.py b/backend/api/serializers/nonmodel_serializers.py index 85adba157..fe05c81f1 100644 --- a/backend/api/serializers/nonmodel_serializers.py +++ b/backend/api/serializers/nonmodel_serializers.py @@ -3,7 +3,7 @@ LightweightCustomerSerializer, LightweightMetricSerializer, LightweightPlanVersionSerializer, - SubscriptionCategoricalFilterSerializer, + SubscriptionFilterSerializer, ) from metering_billing.models import ( Customer, @@ -70,9 +70,7 @@ class Meta: "plan": {"required": True, "read_only": True}, } - subscription_filters = SubscriptionCategoricalFilterSerializer( - many=True, source="filters" - ) + subscription_filters = SubscriptionFilterSerializer(many=True) plan = LightweightPlanVersionSerializer(source="billing_plan") @@ -120,7 +118,7 @@ class MetricAccessRequestSerializer(serializers.Serializer): queryset=Metric.objects.all(), help_text="The metric_id of the metric you want to check access for.", ) - subscription_filters = SubscriptionCategoricalFilterSerializer( + subscription_filters = SubscriptionFilterSerializer( many=True, required=False, help_text="Used if you want to restrict the access check to only plans that fulfill certain subscription filter criteria. If your billing model does not have the ability multiple plans or subscriptions per customer, this is likely not relevant for you. ", @@ -158,7 +156,7 @@ class FeatureAccessRequestSerializer(serializers.Serializer): queryset=Feature.objects.all(), help_text="The feature_id of the feature you want to check access for.", ) - subscription_filters = SubscriptionCategoricalFilterSerializer( + subscription_filters = SubscriptionFilterSerializer( many=True, required=False, help_text="The subscription filters that are applied to this plan's relationship with the customer. If your billing model does not have the ability multiple plans or subscriptions per customer, this is likely not relevant for you. ", diff --git a/backend/api/views.py b/backend/api/views.py index 3fd443c01..432b7e1ed 100644 --- a/backend/api/views.py +++ b/backend/api/views.py @@ -31,7 +31,7 @@ ListPlanVersionsFilterSerializer, ListSubscriptionRecordFilter, PlanSerializer, - SubscriptionCategoricalFilterSerializer, + SubscriptionFilterSerializer, SubscriptionRecordCancelSerializer, SubscriptionRecordCreateSerializer, SubscriptionRecordCreateSerializerOld, @@ -94,7 +94,6 @@ from metering_billing.invoice_pdf import get_invoice_presigned_url from metering_billing.kafka.producer import Producer from metering_billing.models import ( - CategoricalFilter, ComponentChargeRecord, Customer, CustomerBalanceAdjustment, @@ -127,7 +126,6 @@ ) from metering_billing.utils import calculate_end_date, convert_to_datetime, now_utc from metering_billing.utils.enums import ( - CATEGORICAL_FILTER_OPERATORS, CUSTOMER_BALANCE_ADJUSTMENT_STATUS, INVOICING_BEHAVIOR, METRIC_STATUS, @@ -185,7 +183,6 @@ def get_queryset(self): ) .select_related("customer", "billing_plan", "billing_plan__plan") .prefetch_related( - "filters", "addon_subscription_records", "organization", ), @@ -797,19 +794,13 @@ def get_queryset(self): ) if serializer.validated_data.get("subscription_filters"): - filters = [] for filter in serializer.validated_data["subscription_filters"]: - m2m, _ = CategoricalFilter.objects.get_or_create( - organization=organization, - property_name=filter["property_name"], - comparison_value=[filter["value"]], - operator=CATEGORICAL_FILTER_OPERATORS.ISIN, + qs = qs.filter( + subscription_filters__contains=[ + [filter["property_name"], filter["value"]] + ] ) - filters.append(m2m) - query = reduce( - lambda acc, filter: acc & Q(filters=filter), filters, Q() - ) - qs = qs.filter(query) + return qs @extend_schema( @@ -1750,7 +1741,6 @@ def get(self, request, format=None): "billing_plan__plan_components__billable_metric", "billing_plan__plan_components__tiers", "billing_plan__plan", - "filters", ) return_dict = { "customer": customer, @@ -1760,7 +1750,7 @@ def get(self, request, format=None): } for sr in subscription_records.filter(billing_plan__addon_spec__isnull=True): if subscription_filters_set: - sr_filters_set = {(x.property_name, x.value) for x in sr.filters.all()} + sr_filters_set = {tuple(x) for x in sr.subscription_filters} if not subscription_filters_set.issubset(sr_filters_set): continue single_sr_dict = { @@ -1846,7 +1836,6 @@ def get(self, request, format=None): subscription_records = subscription_records.prefetch_related( "billing_plan__features", "billing_plan__plan", - "filters", ) return_dict = { "customer": customer, @@ -1856,7 +1845,7 @@ def get(self, request, format=None): } for sr in subscription_records.filter(billing_plan__addon_spec__isnull=True): if subscription_filters_set: - sr_filters_set = {(x.property_name, x.value) for x in sr.filters.all()} + sr_filters_set = {tuple(x) for x in sr.subscription_filters} if not subscription_filters_set.issubset(sr_filters_set): continue single_sr_dict = { @@ -2179,7 +2168,7 @@ class GetCustomerEventAccessRequestSerializer(serializers.Serializer): allow_null=True, help_text="The metric_id of the metric you are checking access for. Please note that you must porovide exactly one of event_name and metric_id are mutually; a validation error will be thrown if both or none are provided.", ) - subscription_filters = SubscriptionCategoricalFilterSerializer( + subscription_filters = SubscriptionFilterSerializer( many=True, required=False, help_text="The subscription filters that are applied to this plan's relationship with the customer. If your billing model does not have the ability multiple plans or subscriptions per customer, this is likely not relevant for you. This must be passed in as a stringified JSON object.", @@ -2210,7 +2199,7 @@ class GetCustomerFeatureAccessRequestSerializer(serializers.Serializer): feature_name = serializers.CharField( help_text="Name of the feature to check access for." ) - subscription_filters = SubscriptionCategoricalFilterSerializer( + subscription_filters = SubscriptionFilterSerializer( many=True, required=False, help_text="The subscription filters that are applied to this plan's relationship with the customer. If your billing model does not have the ability multiple plans or subscriptions per customer, this is likely not relevant for you. This must be passed in as a stringified JSON object.", @@ -2230,7 +2219,7 @@ class GetFeatureAccessSerializer(serializers.Serializer): plan_id = serializers.CharField( help_text="The plan_id of the plan we are checking that has access to this feature." ) - subscription_filters = SubscriptionCategoricalFilterSerializer( + subscription_filters = SubscriptionFilterSerializer( many=True, help_text="The subscription filters that are applied to this plan's relationship with the customer. If your billing model does not have the ability multiple plans or subscriptions per customer, this is likely not relevant for you.", ) @@ -2264,7 +2253,7 @@ class GetEventAccessSerializer(serializers.Serializer): plan_id = serializers.CharField( help_text="The plan_id of the plan we are checking that has access to this feature." ) - subscription_filters = SubscriptionCategoricalFilterSerializer( + subscription_filters = SubscriptionFilterSerializer( many=True, help_text="The subscription filters that are applied to this plan's relationship with the customer. If your billing model does not have the ability multiple plans or subscriptions per customer, this is likely not relevant for you.", ) @@ -2324,17 +2313,18 @@ def get(self, request, format=None): for x in serializer.validated_data.get("subscription_filters", []) } for key, value in subscription_filters.items(): - key = f"properties__{key}" - subscriptions = subscriptions.filter(**{key: value}) + subscriptions = subscriptions.filter( + subscription_filters__contains=[[key, value]] + ) features = [] subscriptions = subscriptions.prefetch_related("billing_plan__features") for sub in subscriptions: subscription_filters = [] - for filter in sub.filters.all(): + for filter_arr in sub.subscription_filters: subscription_filters.append( { - "property_name": filter.property_name, - "value": filter.comparison_value[0], + "property_name": filter_arr[0], + "value": filter_arr[1], } ) sub_dict = { @@ -2407,22 +2397,22 @@ def get(self, request, format=None): for x in serializer.validated_data.get("subscription_filters", []) } for key, value in subscription_filters.items(): - key = f"properties__{key}" - subscription_records = subscription_records.filter(**{key: value}) + subscription_records = subscription_records.filter( + subscription_filters__contains=[[key, value]] + ) metrics = [] subscription_records = subscription_records.prefetch_related( "billing_plan__plan_components", "billing_plan__plan_components__billable_metric", "billing_plan__plan_components__tiers", - "filters", ) for sr in subscription_records: subscription_filters = [] - for filter in sr.filters.all(): + for filter_arr in sr.subscription_filters: subscription_filters.append( { - "property_name": filter.property_name, - "value": filter.comparison_value[0], + "property_name": filter_arr[0], + "value": filter_arr[1], } ) single_sub_dict = { diff --git a/backend/metering_billing/aggregation/billable_metrics.py b/backend/metering_billing/aggregation/billable_metrics.py index 2c7669783..85dc96f6e 100644 --- a/backend/metering_billing/aggregation/billable_metrics.py +++ b/backend/metering_billing/aggregation/billable_metrics.py @@ -12,7 +12,6 @@ from django.conf import settings from django.db import connection from jinja2 import Template - from metering_billing.exceptions import MetricValidationFailed from metering_billing.utils import ( convert_to_date, @@ -253,10 +252,8 @@ def _prepare_injection_dict( organization.provision_subscription_filter_settings() groupby = [] injection_dict["group_by"] = groupby - for filter in billing_record.subscription.filters.all(): - injection_dict["filter_properties"][ - filter.property_name - ] = filter.comparison_value + for filter in billing_record.subscription.subscription_filters: + injection_dict["filter_properties"][filter[0]] = [filter[1]] return injection_dict @staticmethod @@ -702,10 +699,8 @@ def get_billing_record_total_billable_usage( injection_dict["start_date"] = start injection_dict["end_date"] = end injection_dict["organization_id"] = organization.id - for filter in billing_record.subscription.filters.all(): - injection_dict["filter_properties"][ - filter.property_name - ] = filter.comparison_value + for filter in billing_record.subscription.subscription_filters: + injection_dict["filter_properties"][filter[0]] = [filter[1]] results = CustomHandler._run_query(metric.custom_sql, injection_dict) if len(results) == 0: return Decimal(0) @@ -1141,10 +1136,8 @@ def get_billing_record_total_billable_usage( ], "property_name": metric.property_name, } - for filter in billing_record.subscription.filters.all(): - injection_dict["filter_properties"][ - filter.property_name - ] = filter.comparison_value + for filter in billing_record.subscription.subscription_filters: + injection_dict["filter_properties"][filter[0]] = [filter[1]] if metric.event_type == "delta": query = Template(GAUGE_DELTA_GET_TOTAL_USAGE_WITH_PRORATION).render( **injection_dict @@ -1227,10 +1220,8 @@ def get_billing_record_current_usage( ], "property_name": metric.property_name, } - for filter in billing_record.subscription.filters.all(): - injection_dict["filter_properties"][ - filter.property_name - ] = filter.comparison_value + for filter in billing_record.subscription.subscription_filters: + injection_dict["filter_properties"][filter[0]] = [filter[1]] if metric.event_type == "delta": query = Template(GAUGE_DELTA_GET_CURRENT_USAGE).render(**injection_dict) elif metric.event_type == "total": @@ -1309,10 +1300,8 @@ def get_billing_record_daily_billable_usage( ], "property_name": metric.property_name, } - for filter in billing_record.subscription.filters.all(): - injection_dict["filter_properties"][ - filter.property_name - ] = filter.comparison_value + for filter in billing_record.subscription.subscription_filters: + injection_dict["filter_properties"][filter[0]] = [filter[1]] if metric.event_type == "delta": query = Template(GAUGE_DELTA_GET_TOTAL_USAGE_WITH_PRORATION_PER_DAY).render( **injection_dict @@ -1549,10 +1538,8 @@ def _rate_cagg_total_results(metric: Metric, billing_record: BillingRecord): organization.provision_subscription_filter_settings() groupby = [] injection_dict["group_by"] = groupby - for filter in billing_record.subscription.filters.all(): - injection_dict["filter_properties"][ - filter.property_name - ] = filter.comparison_value + for filter in billing_record.subscription.subscription_filters: + injection_dict["filter_properties"][filter[0]] = [filter[1]] query = Template(RATE_CAGG_TOTAL).render(**injection_dict) with connection.cursor() as cursor: cursor.execute(query) @@ -1618,10 +1605,8 @@ def get_billing_record_current_usage( organization.provision_subscription_filter_settings() groupby = [] injection_dict["group_by"] = groupby - for filter in billing_record.subscription.filters.all(): - injection_dict["filter_properties"][ - filter.property_name - ] = filter.comparison_value + for filter in billing_record.subscription.subscription_filters: + injection_dict["filter_properties"][filter[0]] = [filter[1]] query = Template(RATE_GET_CURRENT_USAGE).render(**injection_dict) with connection.cursor() as cursor: cursor.execute(query) diff --git a/backend/metering_billing/invoice.py b/backend/metering_billing/invoice.py index 36b2a6d34..c39dbb46c 100644 --- a/backend/metering_billing/invoice.py +++ b/backend/metering_billing/invoice.py @@ -389,7 +389,7 @@ def create_next_subscription_record(subscription_record, next_bp): billing_plan=next_bp, customer=subscription_record.customer, organization=subscription_record.organization, - subscription_filters=subscription_record.filters.all(), + subscription_filters=subscription_record.subscription_filters, is_new=False, quantity=subscription_record.quantity, component_fixed_charges_initial_units=component_fixed_charges_initial_units, diff --git a/backend/metering_billing/migrations/0229_historicalsubscriptionrecord_subscription_filters_and_more.py b/backend/metering_billing/migrations/0229_historicalsubscriptionrecord_subscription_filters_and_more.py new file mode 100644 index 000000000..080847c57 --- /dev/null +++ b/backend/metering_billing/migrations/0229_historicalsubscriptionrecord_subscription_filters_and_more.py @@ -0,0 +1,29 @@ +# Generated by Django 4.0.5 on 2023-03-18 21:12 + +import django.contrib.postgres.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('metering_billing', '0228_auto_20230315_0342'), + ] + + operations = [ + migrations.AddField( + model_name='historicalsubscriptionrecord', + name='subscription_filters', + field=django.contrib.postgres.fields.ArrayField(base_field=django.contrib.postgres.fields.ArrayField(base_field=models.TextField(), size=2), default=list, size=None), + ), + migrations.AddField( + model_name='subscriptionrecord', + name='subscription_filters', + field=django.contrib.postgres.fields.ArrayField(base_field=django.contrib.postgres.fields.ArrayField(base_field=models.TextField(), size=2), default=list, size=None), + ), + migrations.AlterField( + model_name='webhooktrigger', + name='trigger_name', + field=models.CharField(choices=[('customer.created', 'customer.created'), ('invoice.created', 'invoice.created'), ('invoice.paid', 'invoice.paid'), ('invoice.past_due', 'invoice.past_due'), ('subscription.created', 'subscription.created'), ('usage_alert.triggered', 'usage_alert.triggered'), ('subscription.cancelled', 'subscription.cancelled'), ('subscription.renewed', 'subscription.renewed')], max_length=40), + ), + ] diff --git a/backend/metering_billing/migrations/0230_auto_20230318_2112.py b/backend/metering_billing/migrations/0230_auto_20230318_2112.py new file mode 100644 index 000000000..12bfa8a42 --- /dev/null +++ b/backend/metering_billing/migrations/0230_auto_20230318_2112.py @@ -0,0 +1,28 @@ +# Generated by Django 4.0.5 on 2023-03-18 21:12 + +from django.db import migrations + + +def transfer_filters_to_subscription_filters(apps, schema_editor): + SubscriptionRecord = apps.get_model("metering_billing", "SubscriptionRecord") + for subscription in SubscriptionRecord.objects.all(): + new_filters = [] + for sf in subscription.filters.all(): + property_name = sf.property_name + value = sf.comparison_value[0] + new_filters.append([property_name, value]) + subscription.subscription_filters = new_filters + subscription.save() + + +class Migration(migrations.Migration): + dependencies = [ + ( + "metering_billing", + "0229_historicalsubscriptionrecord_subscription_filters_and_more", + ), + ] + + operations = [ + migrations.RunPython(transfer_filters_to_subscription_filters), + ] diff --git a/backend/metering_billing/migrations/0231_remove_subscriptionrecord_filters.py b/backend/metering_billing/migrations/0231_remove_subscriptionrecord_filters.py new file mode 100644 index 000000000..2a5ac9e07 --- /dev/null +++ b/backend/metering_billing/migrations/0231_remove_subscriptionrecord_filters.py @@ -0,0 +1,17 @@ +# Generated by Django 4.0.5 on 2023-03-18 21:51 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('metering_billing', '0230_auto_20230318_2112'), + ] + + operations = [ + migrations.RemoveField( + model_name='subscriptionrecord', + name='filters', + ), + ] diff --git a/backend/metering_billing/models.py b/backend/metering_billing/models.py index 5a7625b46..91caae51e 100644 --- a/backend/metering_billing/models.py +++ b/backend/metering_billing/models.py @@ -22,7 +22,7 @@ MinValueValidator, ) from django.db import connection, models -from django.db.models import Count, F, FloatField, Prefetch, Q, QuerySet, Sum +from django.db.models import Count, F, FloatField, Q, Sum from django.db.models.constraints import CheckConstraint, UniqueConstraint from django.db.models.functions import Cast, Coalesce from django.utils.translation import gettext_lazy as _ @@ -1128,22 +1128,6 @@ class CategoricalFilter(models.Model): def __str__(self): return f"{self.property_name} {self.operator} {self.comparison_value}" - @staticmethod - def overlaps(filters1, filters2): - # Convert the inputs to sets of primary keys - if isinstance(filters1, (set, list)): - if all(isinstance(f, CategoricalFilter) for f in filters1): - filters1 = {f.pk for f in filters1} - if isinstance(filters2, (set, list)): - if all(isinstance(f, CategoricalFilter) for f in filters2): - filters2 = {f.pk for f in filters2} - if isinstance(filters1, QuerySet): - filters1 = set(filters1.values_list("pk", flat=True)) - if isinstance(filters2, QuerySet): - filters2 = set(filters2.values_list("pk", flat=True)) - # Check if there is an overlap between the sets - return filters1.issubset(filters2) or filters2.issubset(filters1) - class Metric(models.Model): organization = models.ForeignKey( @@ -2680,12 +2664,6 @@ class Meta: class SubscriptionRecordManager(models.Manager): - def create_with_filters(self, *args, **kwargs): - subscription_filters = kwargs.pop("subscription_filters", []) - sr = self.model(**kwargs) - sr.save(subscription_filters=subscription_filters) - return sr - def active(self, time=None): if time is None: time = now_utc() @@ -2754,10 +2732,12 @@ class SubscriptionRecord(models.Model): subscription_record_id = models.UUIDField( default=uuid.uuid4, editable=False, unique=True ) - filters = models.ManyToManyField( - CategoricalFilter, - blank=True, - help_text="Add filter key, value pairs that define which events will be applied to this plan subscription.", + subscription_filters = ArrayField( + ArrayField( + models.TextField(blank=False, null=False), + size=2, + ), + default=list, ) invoice_usage_charges = models.BooleanField(default=True) flat_fee_behavior = models.CharField( @@ -2813,7 +2793,7 @@ def create_subscription_record( assert ( billing_plan.addon_spec is None ), "Cannot create a base subscription record with an addon plan" - sr = SubscriptionRecord.objects.create_with_filters( + sr = SubscriptionRecord.objects.create( start_date=start_date, end_date=end_date, billing_plan=billing_plan, @@ -2846,14 +2826,14 @@ def create_addon_subscription_record( assert ( addon_billing_plan.addon_spec is not None ), "Cannot create an addon subscription record with a base plan" - sr = SubscriptionRecord.objects.create_with_filters( + sr = SubscriptionRecord.objects.create( parent=parent_subscription_record, start_date=now, end_date=parent_subscription_record.end_date, billing_plan=addon_billing_plan, customer=parent_subscription_record.customer, organization=parent_subscription_record.organization, - subscription_filters=parent_subscription_record.filters.all(), + subscription_filters=parent_subscription_record.subscription_filters, is_new=True, quantity=quantity, auto_renew=addon_billing_plan.addon_spec.billing_frequency @@ -2975,7 +2955,8 @@ def _create_recurring_charge_billing_records(self, recurring_charge): return brs def save(self, *args, **kwargs): - new_filters = kwargs.pop("subscription_filters", []) or [] + if self.subscription_filters is None: + self.subscription_filters = [] now = now_utc() timezone = self.customer.timezone if not self.end_date: @@ -2992,34 +2973,23 @@ def save(self, *args, **kwargs): ) new = self._state.adding is True if new: + new_filters = set(tuple(x) for x in self.subscription_filters) overlapping_subscriptions = SubscriptionRecord.objects.filter( Q(start_date__range=(self.start_date, self.end_date)) | Q(end_date__range=(self.start_date, self.end_date)), organization=self.organization, customer=self.customer, billing_plan=self.billing_plan, - ).prefetch_related( - Prefetch( - "filters", - queryset=CategoricalFilter.objects.filter( - organization=self.organization - ), - to_attr="filters_lst", - ) ) for subscription in overlapping_subscriptions: - old_filters = subscription.filters_lst - if CategoricalFilter.overlaps(old_filters, new_filters): + old_filters = set(tuple(x) for x in subscription.subscription_filters) + if old_filters.issubset(new_filters) or new_filters.issubset( + old_filters + ): raise OverlappingPlans( f"Overlapping subscriptions with the same filters are not allowed. \n Plan: {self.billing_plan} \n Customer: {self.customer}. \n New dates: ({self.start_date, self.end_date}) \n New subscription_filters: {new_filters} \n Old dates: ({self.start_date, self.end_date}) \n Old subscription_filters: {list(old_filters)}" ) super(SubscriptionRecord, self).save(*args, **kwargs) - for filter in new_filters: - self.filters.add(filter) - for filter in self.filters.all(): - if not filter.organization: - filter.organization = self.organization - filter.save() if new: alerts = UsageAlert.objects.filter( organization=self.organization, plan_version=self.billing_plan @@ -3035,9 +3005,7 @@ def save(self, *args, **kwargs): ) def get_filters_dictionary(self): - filters_dict = {} - for filter in self.filters.all(): - filters_dict[filter.property_name] = filter.comparison_value[0] + filters_dict = {f[0]: f[1] for f in self.subscription_filters} return filters_dict def amt_already_invoiced(self): diff --git a/backend/metering_billing/payment_processors.py b/backend/metering_billing/payment_processors.py index d384a2f9d..7c3f2740a 100644 --- a/backend/metering_billing/payment_processors.py +++ b/backend/metering_billing/payment_processors.py @@ -14,9 +14,6 @@ from django.conf import settings from django.core.cache import cache from django.db.models import F, Prefetch, Q -from rest_framework import serializers, status -from rest_framework.response import Response - from metering_billing.serializers.payment_processor_serializers import ( PaymentProcesorPostResponseSerializer, ) @@ -30,6 +27,8 @@ ORGANIZATION_SETTING_NAMES, PAYMENT_PROCESSORS, ) +from rest_framework import serializers, status +from rest_framework.response import Response logger = logging.getLogger("django.server") @@ -1130,10 +1129,10 @@ def create_payment_object(self, invoice) -> Optional[str]: metadata = {} if sr is not None: metadata["plan_name"] = sr.billing_plan.plan.plan_name - filters = sr.filters.all() + filters = sr.subscription_filters for f in filters: - metadata[f.property_name] = f.comparison_value[0] - name += f" - ({f.property_name} : {f.comparison_value[0]})" + metadata[f[0]] = f[1] + name += f" - ({f[0]} : {f[1]})" inv_dict = { "description": name, "amount": int(amount * 100), diff --git a/backend/metering_billing/serializers/model_serializers.py b/backend/metering_billing/serializers/model_serializers.py index ab10418ac..47102321c 100644 --- a/backend/metering_billing/serializers/model_serializers.py +++ b/backend/metering_billing/serializers/model_serializers.py @@ -765,11 +765,8 @@ class Meta(api_serializers.CategoricalFilterSerializer.Meta): fields = api_serializers.CategoricalFilterSerializer.Meta.fields -class SubscriptionCategoricalFilterDetailSerializer( - api_serializers.SubscriptionCategoricalFilterSerializer -): - class Meta(api_serializers.SubscriptionCategoricalFilterSerializer.Meta): - fields = api_serializers.SubscriptionCategoricalFilterSerializer.Meta.fields +class SubscriptionFilterDetailSerializer(api_serializers.SubscriptionFilterSerializer): + pass class NumericFilterDetailSerializer(api_serializers.NumericFilterSerializer): @@ -1949,7 +1946,7 @@ class InvoiceListFilterSerializer(api_serializers.InvoiceListFilterSerializer): class GroupedLineItemSerializer(serializers.Serializer): plan_name = serializers.CharField() - subscription_filters = SubscriptionCategoricalFilterDetailSerializer(many=True) + subscription_filters = SubscriptionFilterDetailSerializer(many=True) base = serializers.DecimalField(max_digits=10, decimal_places=2) start_date = serializers.DateTimeField() end_date = serializers.DateTimeField() @@ -1992,7 +1989,7 @@ def get_line_items(self, obj) -> GroupedLineItemSerializer(many=True): sr = line_items[0].associated_subscription_record grouped_line_item_dict = { "plan_name": sr.billing_plan.plan.plan_name, - "subscription_filters": sr.filters.all(), + "subscription_filters": sr.subscription_filters, "base": line_items.aggregate(Sum("amount"))["amount__sum"] or 0, "start_date": sr.start_date, "end_date": sr.end_date, diff --git a/backend/metering_billing/tasks.py b/backend/metering_billing/tasks.py index bf88adada..6b2c33183 100644 --- a/backend/metering_billing/tasks.py +++ b/backend/metering_billing/tasks.py @@ -92,7 +92,6 @@ def calculate_invoice(): "billing_plan__plan_components", "billing_plan__plan_components__billable_metric", "billing_plan__plan_components__tiers", - "filters", "billing_records", )