diff --git a/api/subscriptions/views.py b/api/subscriptions/views.py index 2029972b57b..8f15677d643 100644 --- a/api/subscriptions/views.py +++ b/api/subscriptions/views.py @@ -24,8 +24,6 @@ RegistrationProvider, AbstractProvider, AbstractNode, - Preprint, - OSFUser, ) from osf.models.notification_type import NotificationType from osf.models.notification_subscription import NotificationSubscription @@ -48,10 +46,6 @@ def get_queryset(self): user_guid = self.request.user._id provider_ct = ContentType.objects.get(app_label='osf', model='abstractprovider') - provider_subquery = AbstractProvider.objects.filter( - id=Cast(OuterRef('object_id'), IntegerField()), - ).values('_id')[:1] - node_subquery = AbstractNode.objects.filter( id=Cast(OuterRef('object_id'), IntegerField()), ).values('guids___id')[:1] @@ -66,17 +60,17 @@ def get_queryset(self): ).annotate( event_name=Case( When( - notification_type=NotificationType.Type.USER_FILE_UPDATED.instance, + notification_type=NotificationType.Type.NODE_FILE_UPDATED.instance, then=Value('files_updated'), ), When( notification_type=NotificationType.Type.USER_FILE_UPDATED.instance, - then=Value(f'{user_guid}_global_file_updated'), + then=Value('global_file_updated'), ), When( Q(notification_type=NotificationType.Type.PROVIDER_NEW_PENDING_SUBMISSIONS.instance) & Q(content_type=provider_ct), - then=Value('new_pending_submissions'), + then=Value('global_reviews'), ), ), legacy_id=Case( @@ -91,7 +85,7 @@ def get_queryset(self): When( Q(notification_type=NotificationType.Type.PROVIDER_NEW_PENDING_SUBMISSIONS.instance) & Q(content_type=provider_ct), - then=Concat(Subquery(provider_subquery), Value('_new_pending_submissions')), + then=Value(f'{user_guid}_global_reviews'), ), ), ) @@ -133,41 +127,17 @@ def get_object(self): provider_ct = ContentType.objects.get(app_label='osf', model='abstractprovider') node_ct = ContentType.objects.get(app_label='osf', model='abstractnode') - provider_subquery = AbstractProvider.objects.filter( - id=Cast(OuterRef('object_id'), IntegerField()), - ).values('_id')[:1] - node_subquery = AbstractNode.objects.filter( id=Cast(OuterRef('object_id'), IntegerField()), ).values('guids___id')[:1] - guid_id, *event_parts = subscription_id.split('_') - event = '_'.join(event_parts) if event_parts else '' - - subscription_obj = AbstractNode.load(guid_id) or Preprint.load(guid_id) or OSFUser.load(guid_id) - - if event != 'global': - if subscription_obj is None: - subscription_obj = PreprintProvider.objects.get(_id=guid_id) - obj_filter = Q( - object_id=getattr(subscription_obj, 'id', None), - content_type=ContentType.objects.get_for_model(subscription_obj.__class__), - notification_type__in=[ - NotificationType.Type.USER_FILE_UPDATED.instance, - NotificationType.Type.NODE_FILE_UPDATED.instance, - NotificationType.Type.PROVIDER_NEW_PENDING_SUBMISSIONS.instance, - ], - ) - else: - obj_filter = Q() - try: - obj = NotificationSubscription.objects.annotate( + annotated_obj_qs = NotificationSubscription.objects.filter(user=self.request.user).annotate( legacy_id=Case( When( notification_type__name=NotificationType.Type.NODE_FILE_UPDATED.value, content_type=node_ct, - then=Concat(Subquery(node_subquery), Value('_file_updated')), + then=Concat(Subquery(node_subquery), Value('_files_updated')), ), When( notification_type__name=NotificationType.Type.USER_FILE_UPDATED.value, @@ -176,12 +146,13 @@ def get_object(self): When( notification_type__name=NotificationType.Type.PROVIDER_NEW_PENDING_SUBMISSIONS.value, content_type=provider_ct, - then=Concat(Subquery(provider_subquery), Value('_new_pending_submissions')), + then=Value(f'{user_guid}_global_reviews'), ), default=Value(f'{user_guid}_global'), output_field=CharField(), ), - ).filter(obj_filter) + ) + obj = annotated_obj_qs.filter(legacy_id=subscription_id) except ObjectDoesNotExist: raise NotFound @@ -194,6 +165,28 @@ def get_object(self): self.check_object_permissions(self.request, obj) return obj + def update(self, request, *args, **kwargs): + """ + Update a notification subscription + """ + ret = super().update(request, *args, **kwargs) + # Copy global_reviews subscription changes to new_pending_submissions subscriptions [ENG-9666] + if self.get_object().notification_type.name == NotificationType.Type.PROVIDER_NEW_PENDING_SUBMISSIONS.value: + qs = NotificationSubscription.objects.filter( + user=self.request.user, + notification_type__name__in=[ + NotificationType.Type.PROVIDER_REVIEWS_SUBMISSION_CONFIRMATION.value, + NotificationType.Type.PROVIDER_REVIEWS_RESUBMISSION_CONFIRMATION.value, + NotificationType.Type.PROVIDER_NEW_PENDING_WITHDRAW_REQUESTS.value, + NotificationType.Type.REVIEWS_SUBMISSION_STATUS.value, + ], + ) + for instance in qs: + serializer = self.get_serializer(instance=instance, data=request.data, partial=True) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) + return ret + class AbstractProviderSubscriptionDetail(SubscriptionDetail): view_name = 'provider-notification-subscription-detail' diff --git a/api_tests/subscriptions/views/test_subscriptions_detail.py b/api_tests/subscriptions/views/test_subscriptions_detail.py index 889519b05a6..f14ca4e2522 100644 --- a/api_tests/subscriptions/views/test_subscriptions_detail.py +++ b/api_tests/subscriptions/views/test_subscriptions_detail.py @@ -29,8 +29,8 @@ def notification(self, user): ) @pytest.fixture() - def url(self, notification): - return f'/{API_BASE}subscriptions/{notification._id}/' + def url(self, user): + return f'/{API_BASE}subscriptions/{user._id}_global_file_updated/' @pytest.fixture() def url_invalid(self): @@ -119,8 +119,9 @@ def test_subscription_detail_invalid_payload_400( url, payload_invalid, auth=user.auth, - expect_errors=True + expect_errors=True, ) + assert res.status_code == 400 assert res.json['errors'][0]['detail'] == ('"invalid-frequency" is not a valid choice.') @@ -151,6 +152,7 @@ def test_subscription_detail_patch_no_user( def test_subscription_detail_patch( self, app, user, user_no_auth, notification, url, url_invalid, payload, payload_invalid ): + res = app.patch_json_api(url, payload, auth=user.auth) assert res.status_code == 200 assert res.json['data']['attributes']['frequency'] == 'none' diff --git a/api_tests/subscriptions/views/test_subscriptions_list.py b/api_tests/subscriptions/views/test_subscriptions_list.py index cd5e1699614..c376e260014 100644 --- a/api_tests/subscriptions/views/test_subscriptions_list.py +++ b/api_tests/subscriptions/views/test_subscriptions_list.py @@ -73,7 +73,7 @@ def test_list_complete( # There should only be 3 notifications: users' global, node's file updates and provider's preprint added. assert len(notification_ids) == 3 assert f'{user._id}_global_file_updated' in notification_ids - assert f'{provider._id}_new_pending_submissions' in notification_ids + assert f'{user._id}_global_reviews' in notification_ids assert f'{node._id}_file_updated' in notification_ids def test_unauthenticated(self, app, url): @@ -122,5 +122,5 @@ def test_value_filter_id( # Confirm it’s the expected subscription object attributes = data[0]['attributes'] - assert attributes['event_name'] is None # event names are legacy + assert attributes['event_name'] == 'files_updated' # event names are legacy assert attributes['frequency'] in ['instantly', 'daily', 'none']