Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Billing/Services/IPushNotificationAdapter.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.Entities;

namespace Bit.Billing.Services;

Expand All @@ -8,4 +9,5 @@ public interface IPushNotificationAdapter
Task NotifyBankAccountVerifiedAsync(Organization organization);
Task NotifyBankAccountVerifiedAsync(Provider provider);
Task NotifyEnabledChangedAsync(Organization organization);
Task NotifyPremiumStatusChangedAsync(User user);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public class PaymentSucceededHandler(
IOrganizationRepository organizationRepository,
IStripeEventUtilityService stripeEventUtilityService,
IUserService userService,
IUserRepository userRepository,
IOrganizationEnableCommand organizationEnableCommand,
IPricingClient pricingClient,
IPushNotificationAdapter pushNotificationAdapter)
Expand Down Expand Up @@ -109,12 +110,17 @@ public async Task HandleAsync(Event parsedEvent)
}
else if (userId.HasValue)
{
if (subscription.Items.All(i => i.Plan.Id != IStripeEventUtilityService.PremiumPlanId))
if (subscription.Items.All(i => i.Price.Id is not IStripeEventUtilityService.PremiumPlanId and not IStripeEventUtilityService.PremiumPlanIdAppStore))
{
return;
}

await userService.EnablePremiumAsync(userId.Value, subscription.GetCurrentPeriodEnd());
var user = await userRepository.GetByIdAsync(userId.Value);
if (user != null)
{
await pushNotificationAdapter.NotifyPremiumStatusChangedAsync(user);
}
}
}
}
15 changes: 15 additions & 0 deletions src/Billing/Services/Implementations/PushNotificationAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Bit.Core.AdminConsole.Entities.Provider;
using Bit.Core.AdminConsole.Enums.Provider;
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models;
using Bit.Core.Platform.Push;
Expand Down Expand Up @@ -68,4 +69,18 @@ public Task NotifyEnabledChangedAsync(Organization organization) =>
},
ExcludeCurrentContext = false,
});

public Task NotifyPremiumStatusChangedAsync(User user) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❌ Related to this comment, this should follow the pattern of the methods above.

pushNotificationService.PushAsync(new PushNotification<PremiumStatusPushNotification>
{
Type = PushType.PremiumStatusChanged,
Target = NotificationTarget.User,
TargetId = user.Id,
Payload = new PremiumStatusPushNotification
{
UserId = user.Id,
Premium = user.Premium
},
ExcludeCurrentContext = false,
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Extensions;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Quartz;
using Event = Stripe.Event;
Expand All @@ -13,28 +14,34 @@ public class SubscriptionDeletedHandler : ISubscriptionDeletedHandler
{
private readonly IStripeEventService _stripeEventService;
private readonly IUserService _userService;
private readonly IUserRepository _userRepository;
private readonly IStripeEventUtilityService _stripeEventUtilityService;
private readonly IOrganizationDisableCommand _organizationDisableCommand;
private readonly IProviderRepository _providerRepository;
private readonly IProviderService _providerService;
private readonly ISchedulerFactory _schedulerFactory;
private readonly IPushNotificationAdapter _pushNotificationAdapter;

public SubscriptionDeletedHandler(
IStripeEventService stripeEventService,
IUserService userService,
IUserRepository userRepository,
IStripeEventUtilityService stripeEventUtilityService,
IOrganizationDisableCommand organizationDisableCommand,
IProviderRepository providerRepository,
IProviderService providerService,
ISchedulerFactory schedulerFactory)
ISchedulerFactory schedulerFactory,
IPushNotificationAdapter pushNotificationAdapter)
{
_stripeEventService = stripeEventService;
_userService = userService;
_userRepository = userRepository;
_stripeEventUtilityService = stripeEventUtilityService;
_organizationDisableCommand = organizationDisableCommand;
_providerRepository = providerRepository;
_providerService = providerService;
_schedulerFactory = schedulerFactory;
_pushNotificationAdapter = pushNotificationAdapter;
}

/// <summary>
Expand Down Expand Up @@ -80,6 +87,11 @@ public async Task HandleAsync(Event parsedEvent)
else if (userId.HasValue)
{
await _userService.DisablePremiumAsync(userId.Value, subscription.GetCurrentPeriodEnd());
var user = await _userRepository.GetByIdAsync(userId.Value);
if (user != null)
{
await _pushNotificationAdapter.NotifyPremiumStatusChangedAsync(user!);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler
private readonly IStripeFacade _stripeFacade;
private readonly IOrganizationSponsorshipRenewCommand _organizationSponsorshipRenewCommand;
private readonly IUserService _userService;
private readonly IUserRepository _userRepository;
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationEnableCommand _organizationEnableCommand;
private readonly IOrganizationDisableCommand _organizationDisableCommand;
Expand All @@ -37,6 +38,7 @@ public SubscriptionUpdatedHandler(
IStripeFacade stripeFacade,
IOrganizationSponsorshipRenewCommand organizationSponsorshipRenewCommand,
IUserService userService,
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationEnableCommand organizationEnableCommand,
IOrganizationDisableCommand organizationDisableCommand,
Expand All @@ -52,6 +54,7 @@ public SubscriptionUpdatedHandler(
_stripeFacade = stripeFacade;
_organizationSponsorshipRenewCommand = organizationSponsorshipRenewCommand;
_userService = userService;
_userRepository = userRepository;
_organizationRepository = organizationRepository;
_providerRepository = providerRepository;
_organizationEnableCommand = organizationEnableCommand;
Expand Down Expand Up @@ -140,7 +143,15 @@ SubscriptionStatus.Incomplete or

private Task DisableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd) =>
subscriberId.Match(
userId => _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd),
async userId =>
{
await _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd);
var user = await _userRepository.GetByIdAsync(userId.Value);
if (user != null)
{
await _pushNotificationAdapter.NotifyPremiumStatusChangedAsync(user);
}
},
async organizationId =>
{
await _organizationDisableCommand.DisableAsync(organizationId.Value, currentPeriodEnd);
Expand All @@ -162,7 +173,15 @@ private Task DisableSubscriberAsync(SubscriberId subscriberId, DateTime? current

private Task EnableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd) =>
subscriberId.Match(
userId => _userService.EnablePremiumAsync(userId.Value, currentPeriodEnd),
async userId =>
{
await _userService.EnablePremiumAsync(userId.Value, currentPeriodEnd);
var user = await _userRepository.GetByIdAsync(userId.Value);
if (user != null)
{
await _pushNotificationAdapter.NotifyPremiumStatusChangedAsync(user!);
}
},
async organizationId =>
{
await _organizationEnableCommand.EnableAsync(organizationId.Value, currentPeriodEnd);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
using Bit.Core.Billing.Services;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models;
using Bit.Core.Models.Data;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Utilities;
Expand Down Expand Up @@ -55,7 +57,8 @@ public class UpgradePremiumToOrganizationCommand(
IOrganizationUserRepository organizationUserRepository,
IOrganizationApiKeyRepository organizationApiKeyRepository,
ICollectionRepository collectionRepository,
IApplicationCacheService applicationCacheService)
IApplicationCacheService applicationCacheService,
IPushNotificationService pushNotificationService)
: BaseBillingCommand<UpgradePremiumToOrganizationCommand>(logger), IUpgradePremiumToOrganizationCommand
{
private readonly ILogger<UpgradePremiumToOrganizationCommand> _logger = logger;
Expand Down Expand Up @@ -279,6 +282,19 @@ await organizationApiKeyRepository.CreateAsync(new OrganizationApiKey
user.RevisionDate = DateTime.UtcNow;
await userService.SaveUserAsync(user);

await pushNotificationService.PushAsync(new PushNotification<PremiumStatusPushNotification>
{
Type = PushType.PremiumStatusChanged,
Target = NotificationTarget.User,
TargetId = user.Id,
Payload = new PremiumStatusPushNotification
{
UserId = user.Id,
Premium = user.Premium,
},
ExcludeCurrentContext = false,
});

return organization.Id;
});

Expand Down
21 changes: 20 additions & 1 deletion src/Core/Billing/Services/Implementations/LicensingService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
using Bit.Core.Billing.Models.Business;
using Bit.Core.Billing.Organizations.Models;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models;
using Bit.Core.Models.Business;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
Expand All @@ -36,6 +39,7 @@ public class LicensingService : ILicensingService
private readonly ILogger<LicensingService> _logger;
private readonly ILicenseClaimsFactory<Organization> _organizationLicenseClaimsFactory;
private readonly ILicenseClaimsFactory<User> _userLicenseClaimsFactory;
private readonly IPushNotificationService _pushNotificationService;

private IDictionary<Guid, DateTime> _userCheckCache = new Dictionary<Guid, DateTime>();

Expand All @@ -47,7 +51,8 @@ public LicensingService(
ILogger<LicensingService> logger,
IGlobalSettings globalSettings,
ILicenseClaimsFactory<Organization> organizationLicenseClaimsFactory,
ILicenseClaimsFactory<User> userLicenseClaimsFactory)
ILicenseClaimsFactory<User> userLicenseClaimsFactory,
IPushNotificationService pushNotificationService)
{
_userRepository = userRepository;
_organizationRepository = organizationRepository;
Expand All @@ -56,6 +61,7 @@ public LicensingService(
_globalSettings = globalSettings;
_organizationLicenseClaimsFactory = organizationLicenseClaimsFactory;
_userLicenseClaimsFactory = userLicenseClaimsFactory;
_pushNotificationService = pushNotificationService;

var certThumbprint = environment.IsDevelopment() ?
"207E64A231E8AA32AAF68A61037C075EBEBD553F" :
Expand Down Expand Up @@ -246,6 +252,19 @@ private async Task DisablePremiumAsync(User user, ILicense license, string reaso
await _userRepository.ReplaceAsync(user);

await _mailService.SendLicenseExpiredAsync(new List<string> { user.Email });

await _pushNotificationService.PushAsync(new PushNotification<PremiumStatusPushNotification>
{
Type = PushType.PremiumStatusChanged,
Target = NotificationTarget.User,
TargetId = user.Id,
Payload = new PremiumStatusPushNotification
{
UserId = user.Id,
Premium = user.Premium,
},
ExcludeCurrentContext = false,
});
}

public bool VerifyLicense(ILicense license)
Expand Down
6 changes: 6 additions & 0 deletions src/Core/Models/PushNotification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,9 @@ public class AutoConfirmPushNotification
/// </summary>
public Guid TargetOrganizationUserId { get; set; }
}

public class PremiumStatusPushNotification
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❌ Please review the comments on lines 7 and 8 of this file regarding the location this file should live in.

{
public Guid UserId { get; set; }
public bool Premium { get; set; }
}
3 changes: 3 additions & 0 deletions src/Core/Platform/Push/PushType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,7 @@ public enum PushType : byte

[NotificationInfo("@bitwarden/team-admin-console-dev", typeof(Models.AutoConfirmPushNotification))]
AutoConfirm = 26,

[NotificationInfo("@bitwarden/team-billing-dev", typeof(Models.PremiumStatusPushNotification))]
PremiumStatusChanged = 27,
}
12 changes: 12 additions & 0 deletions src/Notifications/HubHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,18 @@ await _hubContext.Clients.User(pendingTasksData.Payload.UserId.ToString())
await _hubContext.Clients.User(autoConfirmNotification.Payload.UserId.ToString())
.SendAsync(_receiveMessageMethod, autoConfirmNotification, cancellationToken);
break;
case PushType.PremiumStatusChanged:
var premiumStatusNotification =
JsonSerializer.Deserialize<PushNotificationData<PremiumStatusPushNotification>>(
notificationJson, _deserializerOptions);
if (premiumStatusNotification is null)
{
break;
}

await _hubContext.Clients.User(premiumStatusNotification.Payload.UserId.ToString())
.SendAsync(_receiveMessageMethod, premiumStatusNotification, cancellationToken);
break;
default:
_logger.LogWarning("Notification type '{NotificationType}' has not been registered in HubHelpers and will not be pushed as as result", notification.Type);
break;
Expand Down
16 changes: 15 additions & 1 deletion test/Billing.Test/Services/SubscriptionDeletedHandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,54 @@
using Bit.Core.AdminConsole.Repositories;
using Bit.Core.AdminConsole.Services;
using Bit.Core.Billing.Extensions;
using Bit.Core.Entities;
using Bit.Core.Repositories;
using Bit.Core.Services;
using NSubstitute;
using Quartz;
using Stripe;
using Xunit;
using Event = Stripe.Event;

namespace Bit.Billing.Test.Services;

public class SubscriptionDeletedHandlerTests
{
private readonly IStripeEventService _stripeEventService;
private readonly IUserService _userService;
private readonly IUserRepository _userRepository;
private readonly IStripeEventUtilityService _stripeEventUtilityService;
private readonly IOrganizationDisableCommand _organizationDisableCommand;
private readonly IProviderRepository _providerRepository;
private readonly IProviderService _providerService;
private readonly ISchedulerFactory _schedulerFactory;
private readonly IPushNotificationAdapter _pushNotificationAdapter;
private readonly IScheduler _scheduler;
private readonly SubscriptionDeletedHandler _sut;

public SubscriptionDeletedHandlerTests()
{
_stripeEventService = Substitute.For<IStripeEventService>();
_userService = Substitute.For<IUserService>();
_userRepository = Substitute.For<IUserRepository>();
_stripeEventUtilityService = Substitute.For<IStripeEventUtilityService>();
_organizationDisableCommand = Substitute.For<IOrganizationDisableCommand>();
_providerRepository = Substitute.For<IProviderRepository>();
_providerService = Substitute.For<IProviderService>();
_schedulerFactory = Substitute.For<ISchedulerFactory>();
_pushNotificationAdapter = Substitute.For<IPushNotificationAdapter>();
_scheduler = Substitute.For<IScheduler>();
_schedulerFactory.GetScheduler().Returns(_scheduler);
_sut = new SubscriptionDeletedHandler(
_stripeEventService,
_userService,
_userRepository,
_stripeEventUtilityService,
_organizationDisableCommand,
_providerRepository,
_providerService,
_schedulerFactory);
_schedulerFactory,
_pushNotificationAdapter);
}

[Fact]
Expand Down Expand Up @@ -129,16 +138,21 @@ public async Task HandleAsync_UserSubscriptionCanceled_DisablesUserPremium()
Metadata = new Dictionary<string, string> { { "userId", userId.ToString() } }
};

var user = new User { Id = userId, Premium = false, PremiumExpirationDate = subscription.GetCurrentPeriodEnd() };

_stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription);
_stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata)
.Returns(Tuple.Create<Guid?, Guid?, Guid?>(null, userId, null));
_userRepository.GetByIdAsync(userId).Returns(user);

// Act
await _sut.HandleAsync(stripeEvent);

// Assert
await _userService.Received(1)
.DisablePremiumAsync(userId, subscription.GetCurrentPeriodEnd());
await _userRepository.Received(1).GetByIdAsync(userId);
await _pushNotificationAdapter.Received(1).NotifyPremiumStatusChangedAsync(user);
}

[Fact]
Expand Down
Loading
Loading