diff --git a/src/Billing/Services/IPushNotificationAdapter.cs b/src/Billing/Services/IPushNotificationAdapter.cs index 2f74f35eecdc..dcab88596552 100644 --- a/src/Billing/Services/IPushNotificationAdapter.cs +++ b/src/Billing/Services/IPushNotificationAdapter.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Entities; namespace Bit.Billing.Services; @@ -8,4 +9,5 @@ public interface IPushNotificationAdapter Task NotifyBankAccountVerifiedAsync(Organization organization); Task NotifyBankAccountVerifiedAsync(Provider provider); Task NotifyEnabledChangedAsync(Organization organization); + Task NotifyPremiumStatusChangedAsync(User user); } diff --git a/src/Billing/Services/Implementations/PaymentSucceededHandler.cs b/src/Billing/Services/Implementations/PaymentSucceededHandler.cs index 443227f7bfe3..c37def04332f 100644 --- a/src/Billing/Services/Implementations/PaymentSucceededHandler.cs +++ b/src/Billing/Services/Implementations/PaymentSucceededHandler.cs @@ -19,6 +19,7 @@ public class PaymentSucceededHandler( IOrganizationRepository organizationRepository, IStripeEventUtilityService stripeEventUtilityService, IUserService userService, + IUserRepository userRepository, IOrganizationEnableCommand organizationEnableCommand, IPricingClient pricingClient, IPushNotificationAdapter pushNotificationAdapter) @@ -109,12 +110,17 @@ public async Task HandleAsync(Event parsedEvent) } else if (userId.HasValue) { - if (subscription.Items.All(i => i.Plan.Id != IStripeEventUtilityService.PremiumPlanId)) + if (subscription.Items.All(i => i.Price.Id is not IStripeEventUtilityService.PremiumPlanId and not IStripeEventUtilityService.PremiumPlanIdAppStore)) { return; } await userService.EnablePremiumAsync(userId.Value, subscription.GetCurrentPeriodEnd()); + var user = await userRepository.GetByIdAsync(userId.Value); + if (user != null) + { + await pushNotificationAdapter.NotifyPremiumStatusChangedAsync(user); + } } } } diff --git a/src/Billing/Services/Implementations/PushNotificationAdapter.cs b/src/Billing/Services/Implementations/PushNotificationAdapter.cs index 673ae1415eee..a47c1753529d 100644 --- a/src/Billing/Services/Implementations/PushNotificationAdapter.cs +++ b/src/Billing/Services/Implementations/PushNotificationAdapter.cs @@ -2,6 +2,8 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Models; +using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models; using Bit.Core.Platform.Push; @@ -68,4 +70,18 @@ public Task NotifyEnabledChangedAsync(Organization organization) => }, ExcludeCurrentContext = false, }); + + public Task NotifyPremiumStatusChangedAsync(User user) => + pushNotificationService.PushAsync(new PushNotification + { + Type = PushType.PremiumStatusChanged, + Target = NotificationTarget.User, + TargetId = user.Id, + Payload = new PremiumStatusPushNotification + { + UserId = user.Id, + Premium = user.Premium + }, + ExcludeCurrentContext = false, + }); } diff --git a/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs b/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs index c204cc502610..8a52a1cc2b86 100644 --- a/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Extensions; +using Bit.Core.Repositories; using Bit.Core.Services; using Quartz; using Event = Stripe.Event; @@ -13,28 +14,34 @@ public class SubscriptionDeletedHandler : ISubscriptionDeletedHandler { private readonly IStripeEventService _stripeEventService; private readonly IUserService _userService; + private readonly IUserRepository _userRepository; private readonly IStripeEventUtilityService _stripeEventUtilityService; private readonly IOrganizationDisableCommand _organizationDisableCommand; private readonly IProviderRepository _providerRepository; private readonly IProviderService _providerService; private readonly ISchedulerFactory _schedulerFactory; + private readonly IPushNotificationAdapter _pushNotificationAdapter; public SubscriptionDeletedHandler( IStripeEventService stripeEventService, IUserService userService, + IUserRepository userRepository, IStripeEventUtilityService stripeEventUtilityService, IOrganizationDisableCommand organizationDisableCommand, IProviderRepository providerRepository, IProviderService providerService, - ISchedulerFactory schedulerFactory) + ISchedulerFactory schedulerFactory, + IPushNotificationAdapter pushNotificationAdapter) { _stripeEventService = stripeEventService; _userService = userService; + _userRepository = userRepository; _stripeEventUtilityService = stripeEventUtilityService; _organizationDisableCommand = organizationDisableCommand; _providerRepository = providerRepository; _providerService = providerService; _schedulerFactory = schedulerFactory; + _pushNotificationAdapter = pushNotificationAdapter; } /// @@ -80,6 +87,11 @@ public async Task HandleAsync(Event parsedEvent) else if (userId.HasValue) { await _userService.DisablePremiumAsync(userId.Value, subscription.GetCurrentPeriodEnd()); + var user = await _userRepository.GetByIdAsync(userId.Value); + if (user != null) + { + await _pushNotificationAdapter.NotifyPremiumStatusChangedAsync(user!); + } } } diff --git a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs index e4710f7dcead..eeded8291a66 100644 --- a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs @@ -22,6 +22,7 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler private readonly IStripeFacade _stripeFacade; private readonly IOrganizationSponsorshipRenewCommand _organizationSponsorshipRenewCommand; private readonly IUserService _userService; + private readonly IUserRepository _userRepository; private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationEnableCommand _organizationEnableCommand; private readonly IOrganizationDisableCommand _organizationDisableCommand; @@ -37,6 +38,7 @@ public SubscriptionUpdatedHandler( IStripeFacade stripeFacade, IOrganizationSponsorshipRenewCommand organizationSponsorshipRenewCommand, IUserService userService, + IUserRepository userRepository, IOrganizationRepository organizationRepository, IOrganizationEnableCommand organizationEnableCommand, IOrganizationDisableCommand organizationDisableCommand, @@ -52,6 +54,7 @@ public SubscriptionUpdatedHandler( _stripeFacade = stripeFacade; _organizationSponsorshipRenewCommand = organizationSponsorshipRenewCommand; _userService = userService; + _userRepository = userRepository; _organizationRepository = organizationRepository; _providerRepository = providerRepository; _organizationEnableCommand = organizationEnableCommand; @@ -140,7 +143,15 @@ SubscriptionStatus.Incomplete or private Task DisableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd) => subscriberId.Match( - userId => _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd), + async userId => + { + await _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd); + var user = await _userRepository.GetByIdAsync(userId.Value); + if (user != null) + { + await _pushNotificationAdapter.NotifyPremiumStatusChangedAsync(user); + } + }, async organizationId => { await _organizationDisableCommand.DisableAsync(organizationId.Value, currentPeriodEnd); @@ -162,7 +173,15 @@ private Task DisableSubscriberAsync(SubscriberId subscriberId, DateTime? current private Task EnableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd) => subscriberId.Match( - userId => _userService.EnablePremiumAsync(userId.Value, currentPeriodEnd), + async userId => + { + await _userService.EnablePremiumAsync(userId.Value, currentPeriodEnd); + var user = await _userRepository.GetByIdAsync(userId.Value); + if (user != null) + { + await _pushNotificationAdapter.NotifyPremiumStatusChangedAsync(user!); + } + }, async organizationId => { await _organizationEnableCommand.EnableAsync(organizationId.Value, currentPeriodEnd); diff --git a/src/Core/Billing/Models/PremiumStatusPushNotification.cs b/src/Core/Billing/Models/PremiumStatusPushNotification.cs new file mode 100644 index 000000000000..133bc4aef8c0 --- /dev/null +++ b/src/Core/Billing/Models/PremiumStatusPushNotification.cs @@ -0,0 +1,7 @@ +namespace Bit.Core.Billing.Models; + +public class PremiumStatusPushNotification +{ + public Guid UserId { get; set; } + public bool Premium { get; set; } +} diff --git a/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs b/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs index 72373d0da3c7..4be3ab7b39f0 100644 --- a/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs +++ b/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs @@ -2,6 +2,7 @@ using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Models; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; @@ -9,6 +10,7 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Data; +using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Utilities; @@ -54,7 +56,8 @@ public class UpgradePremiumToOrganizationCommand( IOrganizationUserRepository organizationUserRepository, IOrganizationApiKeyRepository organizationApiKeyRepository, ICollectionRepository collectionRepository, - IApplicationCacheService applicationCacheService) + IApplicationCacheService applicationCacheService, + IPushNotificationService pushNotificationService) : BaseBillingCommand(logger), IUpgradePremiumToOrganizationCommand { private readonly ILogger _logger = logger; @@ -278,6 +281,19 @@ await organizationApiKeyRepository.CreateAsync(new OrganizationApiKey user.RevisionDate = DateTime.UtcNow; await userService.SaveUserAsync(user); + await pushNotificationService.PushAsync(new PushNotification + { + Type = PushType.PremiumStatusChanged, + Target = NotificationTarget.User, + TargetId = user.Id, + Payload = new PremiumStatusPushNotification + { + UserId = user.Id, + Premium = user.Premium, + }, + ExcludeCurrentContext = false, + }); + return organization.Id; }); diff --git a/src/Core/Billing/Services/Implementations/LicensingService.cs b/src/Core/Billing/Services/Implementations/LicensingService.cs index 6f0cdec8f55b..77d1687ab6c9 100644 --- a/src/Core/Billing/Services/Implementations/LicensingService.cs +++ b/src/Core/Billing/Services/Implementations/LicensingService.cs @@ -9,11 +9,14 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Licenses.Models; using Bit.Core.Billing.Licenses.Services; +using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Entities; +using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; +using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; @@ -36,6 +39,7 @@ public class LicensingService : ILicensingService private readonly ILogger _logger; private readonly ILicenseClaimsFactory _organizationLicenseClaimsFactory; private readonly ILicenseClaimsFactory _userLicenseClaimsFactory; + private readonly IPushNotificationService _pushNotificationService; private IDictionary _userCheckCache = new Dictionary(); @@ -47,7 +51,8 @@ public LicensingService( ILogger logger, IGlobalSettings globalSettings, ILicenseClaimsFactory organizationLicenseClaimsFactory, - ILicenseClaimsFactory userLicenseClaimsFactory) + ILicenseClaimsFactory userLicenseClaimsFactory, + IPushNotificationService pushNotificationService) { _userRepository = userRepository; _organizationRepository = organizationRepository; @@ -56,6 +61,7 @@ public LicensingService( _globalSettings = globalSettings; _organizationLicenseClaimsFactory = organizationLicenseClaimsFactory; _userLicenseClaimsFactory = userLicenseClaimsFactory; + _pushNotificationService = pushNotificationService; var certThumbprint = environment.IsDevelopment() ? "207E64A231E8AA32AAF68A61037C075EBEBD553F" : @@ -246,6 +252,19 @@ private async Task DisablePremiumAsync(User user, ILicense license, string reaso await _userRepository.ReplaceAsync(user); await _mailService.SendLicenseExpiredAsync(new List { user.Email }); + + await _pushNotificationService.PushAsync(new PushNotification + { + Type = PushType.PremiumStatusChanged, + Target = NotificationTarget.User, + TargetId = user.Id, + Payload = new PremiumStatusPushNotification + { + UserId = user.Id, + Premium = user.Premium, + }, + ExcludeCurrentContext = false, + }); } public bool VerifyLicense(ILicense license) diff --git a/src/Core/Platform/Push/PushType.cs b/src/Core/Platform/Push/PushType.cs index b08619530433..c1569d5108fd 100644 --- a/src/Core/Platform/Push/PushType.cs +++ b/src/Core/Platform/Push/PushType.cs @@ -102,4 +102,7 @@ public enum PushType : byte [NotificationInfo("@bitwarden/team-admin-console-dev", typeof(Models.AutoConfirmPushNotification))] AutoConfirm = 26, + + [NotificationInfo("@bitwarden/team-billing-dev", typeof(Billing.Models.PremiumStatusPushNotification))] + PremiumStatusChanged = 27, } diff --git a/src/Notifications/HubHelpers.cs b/src/Notifications/HubHelpers.cs index f9956a5de41c..aaa64cb80c67 100644 --- a/src/Notifications/HubHelpers.cs +++ b/src/Notifications/HubHelpers.cs @@ -1,4 +1,5 @@ using System.Text.Json; +using Bit.Core.Billing.Models; using Bit.Core.Enums; using Bit.Core.Models; using Microsoft.AspNetCore.SignalR; @@ -246,6 +247,18 @@ await _hubContext.Clients.User(pendingTasksData.Payload.UserId.ToString()) await _hubContext.Clients.User(autoConfirmNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, autoConfirmNotification, cancellationToken); break; + case PushType.PremiumStatusChanged: + var premiumStatusNotification = + JsonSerializer.Deserialize>( + notificationJson, _deserializerOptions); + if (premiumStatusNotification is null) + { + break; + } + + await _hubContext.Clients.User(premiumStatusNotification.Payload.UserId.ToString()) + .SendAsync(_receiveMessageMethod, premiumStatusNotification, cancellationToken); + break; default: _logger.LogWarning("Notification type '{NotificationType}' has not been registered in HubHelpers and will not be pushed as as result", notification.Type); break; diff --git a/test/Billing.Test/Services/SubscriptionDeletedHandlerTests.cs b/test/Billing.Test/Services/SubscriptionDeletedHandlerTests.cs index de2d3ec0ed80..370a3ec99be7 100644 --- a/test/Billing.Test/Services/SubscriptionDeletedHandlerTests.cs +++ b/test/Billing.Test/Services/SubscriptionDeletedHandlerTests.cs @@ -7,11 +7,14 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Extensions; +using Bit.Core.Entities; +using Bit.Core.Repositories; using Bit.Core.Services; using NSubstitute; using Quartz; using Stripe; using Xunit; +using Event = Stripe.Event; namespace Bit.Billing.Test.Services; @@ -19,11 +22,13 @@ public class SubscriptionDeletedHandlerTests { private readonly IStripeEventService _stripeEventService; private readonly IUserService _userService; + private readonly IUserRepository _userRepository; private readonly IStripeEventUtilityService _stripeEventUtilityService; private readonly IOrganizationDisableCommand _organizationDisableCommand; private readonly IProviderRepository _providerRepository; private readonly IProviderService _providerService; private readonly ISchedulerFactory _schedulerFactory; + private readonly IPushNotificationAdapter _pushNotificationAdapter; private readonly IScheduler _scheduler; private readonly SubscriptionDeletedHandler _sut; @@ -31,21 +36,25 @@ public SubscriptionDeletedHandlerTests() { _stripeEventService = Substitute.For(); _userService = Substitute.For(); + _userRepository = Substitute.For(); _stripeEventUtilityService = Substitute.For(); _organizationDisableCommand = Substitute.For(); _providerRepository = Substitute.For(); _providerService = Substitute.For(); _schedulerFactory = Substitute.For(); + _pushNotificationAdapter = Substitute.For(); _scheduler = Substitute.For(); _schedulerFactory.GetScheduler().Returns(_scheduler); _sut = new SubscriptionDeletedHandler( _stripeEventService, _userService, + _userRepository, _stripeEventUtilityService, _organizationDisableCommand, _providerRepository, _providerService, - _schedulerFactory); + _schedulerFactory, + _pushNotificationAdapter); } [Fact] @@ -129,9 +138,12 @@ public async Task HandleAsync_UserSubscriptionCanceled_DisablesUserPremium() Metadata = new Dictionary { { "userId", userId.ToString() } } }; + var user = new User { Id = userId, Premium = false, PremiumExpirationDate = subscription.GetCurrentPeriodEnd() }; + _stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription); _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) .Returns(Tuple.Create(null, userId, null)); + _userRepository.GetByIdAsync(userId).Returns(user); // Act await _sut.HandleAsync(stripeEvent); @@ -139,6 +151,8 @@ public async Task HandleAsync_UserSubscriptionCanceled_DisablesUserPremium() // Assert await _userService.Received(1) .DisablePremiumAsync(userId, subscription.GetCurrentPeriodEnd()); + await _userRepository.Received(1).GetByIdAsync(userId); + await _pushNotificationAdapter.Received(1).NotifyPremiumStatusChangedAsync(user); } [Fact] diff --git a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs index 9517802f1338..85479cf91583 100644 --- a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs +++ b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs @@ -7,6 +7,7 @@ using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; +using Bit.Core.Entities; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; @@ -30,6 +31,7 @@ public class SubscriptionUpdatedHandlerTests private readonly IStripeFacade _stripeFacade; private readonly IOrganizationSponsorshipRenewCommand _organizationSponsorshipRenewCommand; private readonly IUserService _userService; + private readonly IUserRepository _userRepository; private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationEnableCommand _organizationEnableCommand; private readonly IOrganizationDisableCommand _organizationDisableCommand; @@ -47,6 +49,7 @@ public SubscriptionUpdatedHandlerTests() _stripeFacade = Substitute.For(); _organizationSponsorshipRenewCommand = Substitute.For(); _userService = Substitute.For(); + _userRepository = Substitute.For(); _providerService = Substitute.For(); _organizationRepository = Substitute.For(); _organizationEnableCommand = Substitute.For(); @@ -63,6 +66,7 @@ public SubscriptionUpdatedHandlerTests() _stripeFacade, _organizationSponsorshipRenewCommand, _userService, + _userRepository, _organizationRepository, _organizationEnableCommand, _organizationDisableCommand, @@ -599,18 +603,24 @@ public async Task HandleAsync_UnpaidUserSubscription_DisablesPremiumAndSetsCance } }; + var user = new User { Id = userId, Premium = false, PremiumExpirationDate = currentPeriodEnd }; + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) .Returns(Tuple.Create(null, userId, null)); + _userRepository.GetByIdAsync(userId).Returns(user); + // Act await _sut.HandleAsync(parsedEvent); // Assert await _userService.Received(1) .DisablePremiumAsync(userId, currentPeriodEnd); + await _userRepository.Received(1).GetByIdAsync(userId); + await _pushNotificationAdapter.Received(1).NotifyPremiumStatusChangedAsync(user); await _stripeFacade.Received(1).UpdateSubscription( subscriptionId, Arg.Is(options => @@ -792,8 +802,11 @@ public async Task HandleAsync_ActiveUserSubscription_EnablesPremiumAndUpdatesExp } }; + var user = new User { Id = userId, Premium = true, PremiumExpirationDate = currentPeriodEnd }; + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); + _userRepository.GetByIdAsync(userId).Returns(user); // Act await _sut.HandleAsync(parsedEvent); @@ -803,6 +816,8 @@ await _userService.Received(1) .EnablePremiumAsync(userId, currentPeriodEnd); await _userService.Received(1) .UpdatePremiumExpirationAsync(userId, currentPeriodEnd); + await _userRepository.Received(1).GetByIdAsync(userId); + await _pushNotificationAdapter.Received(1).NotifyPremiumStatusChangedAsync(user); await _stripeFacade.Received(1).UpdateSubscription( subscriptionId, Arg.Is(options => diff --git a/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs index e4eb8f24e93d..9179d841bef4 100644 --- a/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs @@ -250,7 +250,7 @@ public async Task Run_ValidPaymentMethodTypes_PayPal_Success( var mockSubscription = Substitute.For(); mockSubscription.Id = "sub_123"; - mockSubscription.Status = "active"; + mockSubscription.Status = "incomplete"; mockSubscription.LatestInvoiceId = "in_123"; var mockInvoice = Substitute.For(); diff --git a/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs index bb0a3eccc171..e6f8b72ff47a 100644 --- a/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs @@ -7,6 +7,7 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Data; +using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Test.Common.AutoFixture.Attributes; @@ -136,6 +137,7 @@ private static List CreateTestPremiumPlansList() private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository = Substitute.For(); private readonly ICollectionRepository _collectionRepository = Substitute.For(); private readonly IApplicationCacheService _applicationCacheService = Substitute.For(); + private readonly IPushNotificationService _pushNotificationService = Substitute.For(); private readonly ILogger _logger = Substitute.For>(); private readonly UpgradePremiumToOrganizationCommand _command; @@ -150,7 +152,8 @@ public UpgradePremiumToOrganizationCommandTests() _organizationUserRepository, _organizationApiKeyRepository, _collectionRepository, - _applicationCacheService); + _applicationCacheService, + _pushNotificationService); } private static Core.Billing.Payment.Models.BillingAddress CreateTestBillingAddress() => @@ -277,6 +280,7 @@ await _userService.Received(1).SaveUserAsync(Arg.Is(u => u.Premium == false && u.GatewaySubscriptionId == null && u.GatewayCustomerId == null)); + } [Theory, BitAutoData] diff --git a/test/Notifications.Test/HubHelpersTest.cs b/test/Notifications.Test/HubHelpersTest.cs index 2cd20858f3b7..b519b5617a5a 100644 --- a/test/Notifications.Test/HubHelpersTest.cs +++ b/test/Notifications.Test/HubHelpersTest.cs @@ -1,5 +1,6 @@ #nullable enable using System.Text.Json; +using Bit.Core.Billing.Models; using Bit.Core.Enums; using Bit.Core.Models; using Bit.Core.Test.NotificationCenter.AutoFixture; @@ -249,6 +250,30 @@ await sutProvider.GetDependency>().Clients.Receive .Group(Arg.Any()); } + [Theory] + [BitAutoData] + public async Task SendNotificationToHubAsync_PremiumStatusChanged_SentToUser( + SutProvider sutProvider, + PremiumStatusPushNotification notification, + string contextId, + CancellationToken cancellationToken) + { + var json = ToNotificationJson(notification, PushType.PremiumStatusChanged, contextId); + await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken); + + await sutProvider.GetDependency>().Clients.Received(1) + .User(notification.UserId.ToString()) + .Received(1) + .SendCoreAsync("ReceiveMessage", Arg.Is(objects => + objects.Length == 1 && AssertPremiumStatusPushNotification(notification, objects[0], + PushType.PremiumStatusChanged, contextId)), + cancellationToken); + sutProvider.GetDependency>().Clients.Received(0).Group(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0) + .Group(Arg.Any()); + } + private static string ToNotificationJson(object payload, PushType type, string contextId) { var notification = new PushNotificationData(type, payload, contextId); @@ -287,4 +312,18 @@ private static bool AssertSyncPolicyPushNotification(SyncPolicyPushNotification expected.Policy.Type == pushNotificationData.Payload.Policy.Type && expected.Policy.Enabled == pushNotificationData.Payload.Policy.Enabled; } + + private static bool AssertPremiumStatusPushNotification(PremiumStatusPushNotification expected, object? actual, + PushType type, string contextId) + { + if (actual is not PushNotificationData pushNotificationData) + { + return false; + } + + return pushNotificationData.Type == type && + pushNotificationData.ContextId == contextId && + expected.UserId == pushNotificationData.Payload.UserId && + expected.Premium == pushNotificationData.Payload.Premium; + } }