diff --git a/api/features/serializers.py b/api/features/serializers.py index d922dd11ae47..d08296cb19b9 100644 --- a/api/features/serializers.py +++ b/api/features/serializers.py @@ -23,7 +23,10 @@ ) from integrations.github.constants import GitHubEventType from integrations.github.github import call_github_task -from metadata.serializers import MetadataSerializer, MetadataSerializerMixin +from metadata.serializers import ( + MetadataSerializer, + with_metadata, +) from projects.code_references.serializers import ( FeatureFlagCodeReferencesRepositoryCountSerializer, ) @@ -344,7 +347,12 @@ def get_last_modified_in_current_environment( return getattr(instance, "last_modified_in_current_environment", None) -class FeatureSerializerWithMetadata(MetadataSerializerMixin, CreateFeatureSerializer): +@with_metadata( + lambda self, attrs: ( + self.instance.project if self.instance else self.context["project"] + ).organisation +) +class FeatureSerializerWithMetadata(CreateFeatureSerializer): metadata = MetadataSerializer(required=False, many=True) code_references_counts = FeatureFlagCodeReferencesRepositoryCountSerializer( @@ -358,25 +366,6 @@ class Meta(CreateFeatureSerializer.Meta): "code_references_counts", ) - def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: - attrs = super().validate(attrs) - project = self.instance.project if self.instance else self.context["project"] # type: ignore[union-attr] - organisation = project.organisation - self._validate_required_metadata(organisation, attrs.get("metadata", [])) - return attrs - - def create(self, validated_data: dict[str, Any]) -> Feature: - metadata_data = validated_data.pop("metadata", []) - feature = super().create(validated_data) - self._update_metadata(feature, metadata_data) - return feature - - def update(self, feature: Feature, validated_data: dict[str, Any]) -> Feature: - metadata = validated_data.pop("metadata", []) - feature = super().update(feature, validated_data) - self._update_metadata(feature, metadata) - return feature - class UpdateFeatureSerializerWithMetadata(FeatureSerializerWithMetadata): """prevent users from changing certain values after creation""" diff --git a/api/metadata/serializers.py b/api/metadata/serializers.py index 8cfca2adb5b4..e971a37d7bcb 100644 --- a/api/metadata/serializers.py +++ b/api/metadata/serializers.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Callable from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ObjectDoesNotExist @@ -150,3 +150,126 @@ def _update_metadata( with transaction.atomic(): existing_metadata.delete() Metadata.objects.bulk_create(new_metadata) + + +# Helper functions for the decorator +def _validate_required_metadata_for_model( + model_class: type[Model], organisation: Organisation, metadata: list[dict[str, Any]] +) -> None: + """Validate that all required metadata fields are present.""" + content_type = ContentType.objects.get_for_model(model_class) + requirements = MetadataModelFieldRequirement.objects.filter( + model_field__content_type=content_type, + model_field__field__organisation=organisation, + ).select_related("model_field__field") + + metadata_fields = {field["model_field"] for field in metadata} + for requirement in requirements: + if requirement.model_field not in metadata_fields: + field_name = requirement.model_field.field.name + raise serializers.ValidationError( + {"metadata": f"Missing required metadata field: {field_name}"} + ) + + +def _update_metadata_for_instance( + instance: Model, metadata_data: list[dict[str, Any]] +) -> None: + """Update metadata for a model instance.""" + content_type = ContentType.objects.get_for_model(type(instance)) + existing_metadata = Metadata.objects.filter( + object_id=instance.pk, + content_type=content_type, + ) + + new_metadata = [ + Metadata( + model_field=model_field, + content_type=content_type, + object_id=instance.pk, + field_value=metadata_item["field_value"], + ) + for metadata_item in metadata_data + if (model_field := metadata_item["model_field"]) and model_field.pk + ] + + with transaction.atomic(): + existing_metadata.delete() + Metadata.objects.bulk_create(new_metadata) + + +def with_metadata( + get_organisation_fn: Callable[[Any, dict[str, Any]], Organisation], +) -> Callable[[type], type]: + """ + Decorator that adds metadata handling to a serializer. + + Automatically handles metadata in create(), update(), and validate() methods. + Eliminates the need to manually override these methods and coordinate metadata handling. + + Args: + get_organisation_fn: A callable that takes (self, attrs) and returns the Organisation + for validation context. + + Usage: + @with_metadata(lambda self, attrs: self.context['project'].organisation) + class FeatureSerializer(CreateFeatureSerializer): + metadata = MetadataSerializer(required=False, many=True) + + The decorator will: + - Pop metadata from validated_data in create() and update() + - Call the appropriate metadata update function + - Validate required metadata fields in validate() + """ + + def decorator(cls: type) -> type: + # Store original methods + original_create = cls.create if hasattr(cls, "create") else None + original_update = cls.update if hasattr(cls, "update") else None + original_validate = cls.validate if hasattr(cls, "validate") else None + + def create(self: Any, validated_data: dict[str, Any]) -> Model: + """Override create to handle metadata.""" + metadata_data = validated_data.pop("metadata", []) + if original_create: + instance = original_create(self, validated_data) + else: + # Fallback to parent class + instance = super(cls, self).create(validated_data) + _update_metadata_for_instance(instance, metadata_data) + return instance # type: ignore[no-any-return] + + def update(self: Any, instance: Model, validated_data: dict[str, Any]) -> Model: + """Override update to handle metadata.""" + metadata_data = validated_data.pop("metadata", []) + if original_update: + instance = original_update(self, instance, validated_data) + else: + # Fallback to parent class + instance = super(cls, self).update(instance, validated_data) + _update_metadata_for_instance(instance, metadata_data) + return instance # type: ignore[no-any-return] + + def validate(self: Any, attrs: dict[str, Any]) -> dict[str, Any]: + """Override validate to check required metadata.""" + if original_validate: + attrs = original_validate(self, attrs) + else: + # Fallback to parent class + attrs = super(cls, self).validate(attrs) + + # Validate required metadata + organisation = get_organisation_fn(self, attrs) + _validate_required_metadata_for_model( + self.Meta.model, organisation, attrs.get("metadata", []) + ) + return attrs + + # Replace methods on the class + cls.create = create # type: ignore[assignment] + cls.update = update # type: ignore[assignment] + cls.validate = validate # type: ignore[assignment] + + return cls + + return decorator