diff --git a/admin/billing.go b/admin/billing.go index 1d5f3e37752..d7f99d5d74c 100644 --- a/admin/billing.go +++ b/admin/billing.go @@ -261,9 +261,9 @@ func (s *Service) RaiseNewOrgBillingIssues(ctx context.Context, orgID, subID, pl return nil } -// CleanupTrialBillingIssues removes trial related billing issues and cancel associated jobs +// CleanupTrialBillingIssues removes trial related billing issues func (s *Service) CleanupTrialBillingIssues(ctx context.Context, orgID string) error { - bite, err := s.DB.FindBillingIssueByType(ctx, orgID, database.BillingIssueTypeTrialEnded) + bite, err := s.DB.FindBillingIssueByTypeForOrg(ctx, orgID, database.BillingIssueTypeTrialEnded) if err != nil { if !errors.Is(err, database.ErrNotFound) { return fmt.Errorf("failed to find billing issue: %w", err) @@ -277,7 +277,7 @@ func (s *Service) CleanupTrialBillingIssues(ctx context.Context, orgID string) e } } - biot, err := s.DB.FindBillingIssueByType(ctx, orgID, database.BillingIssueTypeOnTrial) + biot, err := s.DB.FindBillingIssueByTypeForOrg(ctx, orgID, database.BillingIssueTypeOnTrial) if err != nil { if !errors.Is(err, database.ErrNotFound) { return fmt.Errorf("failed to find billing issue: %w", err) @@ -294,9 +294,9 @@ func (s *Service) CleanupTrialBillingIssues(ctx context.Context, orgID string) e return nil } -// CleanupBillingErrorSubCancellation removes subscription cancellation related billing error and cancel associated job +// CleanupBillingErrorSubCancellation removes subscription cancellation related billing error func (s *Service) CleanupBillingErrorSubCancellation(ctx context.Context, orgID string) error { - bisc, err := s.DB.FindBillingIssueByType(ctx, orgID, database.BillingIssueTypeSubscriptionCancelled) + bisc, err := s.DB.FindBillingIssueByTypeForOrg(ctx, orgID, database.BillingIssueTypeSubscriptionCancelled) if err != nil { if !errors.Is(err, database.ErrNotFound) { return fmt.Errorf("failed to find billing errors: %w", err) @@ -314,7 +314,7 @@ func (s *Service) CleanupBillingErrorSubCancellation(ctx context.Context, orgID } func (s *Service) CheckBillingErrors(ctx context.Context, orgID string) error { - be, err := s.DB.FindBillingIssueByType(ctx, orgID, database.BillingIssueTypeTrialEnded) + be, err := s.DB.FindBillingIssueByTypeForOrg(ctx, orgID, database.BillingIssueTypeTrialEnded) if err != nil { if !errors.Is(err, database.ErrNotFound) { return err @@ -325,7 +325,7 @@ func (s *Service) CheckBillingErrors(ctx context.Context, orgID string) error { return fmt.Errorf("trial has ended") } - be, err = s.DB.FindBillingIssueByType(ctx, orgID, database.BillingIssueTypePaymentFailed) + be, err = s.DB.FindBillingIssueByTypeForOrg(ctx, orgID, database.BillingIssueTypePaymentFailed) if err != nil { if !errors.Is(err, database.ErrNotFound) { return err @@ -336,7 +336,7 @@ func (s *Service) CheckBillingErrors(ctx context.Context, orgID string) error { return fmt.Errorf("invoice payment failed") } - be, err = s.DB.FindBillingIssueByType(ctx, orgID, database.BillingIssueTypeSubscriptionCancelled) + be, err = s.DB.FindBillingIssueByTypeForOrg(ctx, orgID, database.BillingIssueTypeSubscriptionCancelled) if err != nil { if !errors.Is(err, database.ErrNotFound) { return err diff --git a/admin/database/database.go b/admin/database/database.go index 79ab2d71721..2daad2062be 100644 --- a/admin/database/database.go +++ b/admin/database/database.go @@ -266,14 +266,14 @@ type DB interface { FindOrganizationForPaymentCustomerID(ctx context.Context, customerID string) (*Organization, error) FindOrganizationForBillingCustomerID(ctx context.Context, customerID string) (*Organization, error) - FindBillingIssues(ctx context.Context, orgID string) ([]*BillingIssue, error) - FindBillingIssueByType(ctx context.Context, orgID string, errorType BillingIssueType) (*BillingIssue, error) - FindBillingIssueByTypeNotOverdueProcessed(ctx context.Context, errorType BillingIssueType) ([]*BillingIssue, error) - FindBillingIssueByTypeNotOverdueProcessedForOrg(ctx context.Context, orgID string, errorType BillingIssueType) (*BillingIssue, error) + FindBillingIssuesForOrg(ctx context.Context, orgID string) ([]*BillingIssue, error) + FindBillingIssueByTypeForOrg(ctx context.Context, orgID string, errorType BillingIssueType) (*BillingIssue, error) + FindBillingIssueByType(ctx context.Context, errorType BillingIssueType) ([]*BillingIssue, error) + FindBillingIssueByTypeAndOverdueProcessed(ctx context.Context, errorType BillingIssueType, overdueProcessed bool) ([]*BillingIssue, error) UpsertBillingIssue(ctx context.Context, opts *UpsertBillingIssueOptions) (*BillingIssue, error) UpdateBillingIssueOverdueAsProcessed(ctx context.Context, id string) error DeleteBillingIssue(ctx context.Context, id string) error - DeleteBillingIssueByType(ctx context.Context, orgID string, errorType BillingIssueType) error + DeleteBillingIssueByTypeForOrg(ctx context.Context, orgID string, errorType BillingIssueType) error } // Tx represents a database transaction. It can only be used to commit and rollback transactions. diff --git a/admin/database/postgres/postgres.go b/admin/database/postgres/postgres.go index 46ab81ba3bc..0fbcbf15cf6 100644 --- a/admin/database/postgres/postgres.go +++ b/admin/database/postgres/postgres.go @@ -1948,7 +1948,7 @@ func (c *connection) FindOrganizationForBillingCustomerID(ctx context.Context, c return res, nil } -func (c *connection) FindBillingIssues(ctx context.Context, orgID string) ([]*database.BillingIssue, error) { +func (c *connection) FindBillingIssuesForOrg(ctx context.Context, orgID string) ([]*database.BillingIssue, error) { var res []*billingIssueDTO err := c.db.SelectContext(ctx, &res, `SELECT * FROM billing_issues WHERE org_id = $1`, orgID) if err != nil { @@ -1962,7 +1962,7 @@ func (c *connection) FindBillingIssues(ctx context.Context, orgID string) ([]*da return billingErrors, nil } -func (c *connection) FindBillingIssueByType(ctx context.Context, orgID string, errorType database.BillingIssueType) (*database.BillingIssue, error) { +func (c *connection) FindBillingIssueByTypeForOrg(ctx context.Context, orgID string, errorType database.BillingIssueType) (*database.BillingIssue, error) { res := &billingIssueDTO{} err := c.db.GetContext(ctx, res, `SELECT * FROM billing_issues WHERE org_id = $1 AND type = $2`, orgID, errorType) if err != nil { @@ -1971,9 +1971,9 @@ func (c *connection) FindBillingIssueByType(ctx context.Context, orgID string, e return res.AsModel(), nil } -func (c *connection) FindBillingIssueByTypeNotOverdueProcessed(ctx context.Context, errorType database.BillingIssueType) ([]*database.BillingIssue, error) { +func (c *connection) FindBillingIssueByType(ctx context.Context, errorType database.BillingIssueType) ([]*database.BillingIssue, error) { var res []*billingIssueDTO - err := c.db.SelectContext(ctx, &res, `SELECT * FROM billing_issues WHERE type = $1 AND overdue_processed = false`, errorType) + err := c.db.SelectContext(ctx, &res, `SELECT * FROM billing_issues WHERE type = $1`, errorType) if err != nil { return nil, parseErr("billing issues", err) } @@ -1985,13 +1985,18 @@ func (c *connection) FindBillingIssueByTypeNotOverdueProcessed(ctx context.Conte return billingErrors, nil } -func (c *connection) FindBillingIssueByTypeNotOverdueProcessedForOrg(ctx context.Context, orgID string, errorType database.BillingIssueType) (*database.BillingIssue, error) { - res := &billingIssueDTO{} - err := c.db.GetContext(ctx, res, `SELECT * FROM billing_issues WHERE org_id = $1 AND type = $2 AND overdue_processed = false`, orgID, errorType) +func (c *connection) FindBillingIssueByTypeAndOverdueProcessed(ctx context.Context, errorType database.BillingIssueType, overdueProcessed bool) ([]*database.BillingIssue, error) { + var res []*billingIssueDTO + err := c.db.SelectContext(ctx, &res, `SELECT * FROM billing_issues WHERE type = $1 AND overdue_processed = $2`, errorType, overdueProcessed) if err != nil { - return nil, parseErr("billing issue", err) + return nil, parseErr("billing issues", err) } - return res.AsModel(), nil + + var billingErrors []*database.BillingIssue + for _, dto := range res { + billingErrors = append(billingErrors, dto.AsModel()) + } + return billingErrors, nil } func (c *connection) UpsertBillingIssue(ctx context.Context, opts *database.UpsertBillingIssueOptions) (*database.BillingIssue, error) { @@ -2032,7 +2037,7 @@ func (c *connection) DeleteBillingIssue(ctx context.Context, id string) error { return checkDeleteRow("billing issue", res, err) } -func (c *connection) DeleteBillingIssueByType(ctx context.Context, orgID string, errorType database.BillingIssueType) error { +func (c *connection) DeleteBillingIssueByTypeForOrg(ctx context.Context, orgID string, errorType database.BillingIssueType) error { res, err := c.getDB(ctx).ExecContext(ctx, "DELETE FROM billing_issues WHERE org_id = $1 AND type = $2", orgID, errorType) return checkDeleteRow("billing issue", res, err) } diff --git a/admin/jobs/river/biller_event_handlers.go b/admin/jobs/river/biller_event_handlers.go index b8688df90ec..16f45b188ef 100644 --- a/admin/jobs/river/biller_event_handlers.go +++ b/admin/jobs/river/biller_event_handlers.go @@ -42,7 +42,7 @@ func (w *PaymentFailedWorker) Work(ctx context.Context, job *river.Job[PaymentFa return fmt.Errorf("failed to find organization of billing customer id %q: %w", job.Args.BillingCustomerID, err) } - be, err := w.admin.DB.FindBillingIssueByType(ctx, org.ID, database.BillingIssueTypePaymentFailed) + be, err := w.admin.DB.FindBillingIssueByTypeForOrg(ctx, org.ID, database.BillingIssueTypePaymentFailed) if err != nil { if !errors.Is(err, database.ErrNotFound) { return fmt.Errorf("failed to find billing errors: %w", err) @@ -120,7 +120,7 @@ func (w *PaymentSuccessWorker) Work(ctx context.Context, job *river.Job[PaymentS } // check for existing billing error and delete it - be, err := w.admin.DB.FindBillingIssueByType(ctx, org.ID, database.BillingIssueTypePaymentFailed) + be, err := w.admin.DB.FindBillingIssueByTypeForOrg(ctx, org.ID, database.BillingIssueTypePaymentFailed) if err != nil { if errors.Is(err, database.ErrNotFound) { // no billing error, ignore @@ -191,7 +191,7 @@ func (w *PaymentFailedGracePeriodCheckWorker) Work(ctx context.Context, job *riv } func (w *PaymentFailedGracePeriodCheckWorker) paymentFailedGracePeriodCheck(ctx context.Context) error { - failures, err := w.admin.DB.FindBillingIssueByTypeNotOverdueProcessed(ctx, database.BillingIssueTypePaymentFailed) + failures, err := w.admin.DB.FindBillingIssueByTypeAndOverdueProcessed(ctx, database.BillingIssueTypePaymentFailed, false) if err != nil { if errors.Is(err, database.ErrNotFound) { // no orgs have this billing error diff --git a/admin/jobs/river/payment_provider_event_handlers.go b/admin/jobs/river/payment_provider_event_handlers.go index a41a9655fcc..feb5b29e273 100644 --- a/admin/jobs/river/payment_provider_event_handlers.go +++ b/admin/jobs/river/payment_provider_event_handlers.go @@ -36,7 +36,7 @@ func (w *PaymentMethodAddedWorker) Work(ctx context.Context, job *river.Job[Paym } // check for no payment method billing error - be, err := w.admin.DB.FindBillingIssueByType(ctx, org.ID, database.BillingIssueTypeNoPaymentMethod) + be, err := w.admin.DB.FindBillingIssueByTypeForOrg(ctx, org.ID, database.BillingIssueTypeNoPaymentMethod) if err != nil { if !errors.Is(err, database.ErrNotFound) { return fmt.Errorf("failed to find billing errors: %w", err) @@ -128,7 +128,7 @@ func (w *CustomerAddressUpdatedWorker) Work(ctx context.Context, job *river.Job[ } // look for no billable address billing error and remove it - be, err := w.admin.DB.FindBillingIssueByType(ctx, org.ID, database.BillingIssueTypeNoBillableAddress) + be, err := w.admin.DB.FindBillingIssueByTypeForOrg(ctx, org.ID, database.BillingIssueTypeNoBillableAddress) if err != nil { if !errors.Is(err, database.ErrNotFound) { return fmt.Errorf("failed to find billing errors: %w", err) diff --git a/admin/jobs/river/subscription_handlers.go b/admin/jobs/river/subscription_handlers.go index c8e15f38593..b8a8d079feb 100644 --- a/admin/jobs/river/subscription_handlers.go +++ b/admin/jobs/river/subscription_handlers.go @@ -127,7 +127,7 @@ func (w *SubscriptionCancellationCheckWorker) Work(ctx context.Context, job *riv } func (w *SubscriptionCancellationCheckWorker) subscriptionCancellationCheck(ctx context.Context) error { - cancelled, err := w.admin.DB.FindBillingIssueByTypeNotOverdueProcessed(ctx, database.BillingIssueTypeSubscriptionCancelled) + cancelled, err := w.admin.DB.FindBillingIssueByTypeAndOverdueProcessed(ctx, database.BillingIssueTypeSubscriptionCancelled, false) if err != nil { if errors.Is(err, database.ErrNotFound) { // no orgs have this billing issue diff --git a/admin/jobs/river/trial_checks.go b/admin/jobs/river/trial_checks.go index b211ea6353a..1a226b086a5 100644 --- a/admin/jobs/river/trial_checks.go +++ b/admin/jobs/river/trial_checks.go @@ -30,7 +30,7 @@ func (w *TrialEndingSoonWorker) Work(ctx context.Context, job *river.Job[TrialEn } func (w *TrialEndingSoonWorker) trialEndingSoon(ctx context.Context) error { - onTrialOrgs, err := w.admin.DB.FindBillingIssueByTypeNotOverdueProcessed(ctx, database.BillingIssueTypeOnTrial) + onTrialOrgs, err := w.admin.DB.FindBillingIssueByTypeAndOverdueProcessed(ctx, database.BillingIssueTypeOnTrial, false) if err != nil { if errors.Is(err, database.ErrNotFound) { // no orgs have this billing issue @@ -89,7 +89,7 @@ func (w *TrialEndCheckWorker) Work(ctx context.Context, job *river.Job[TrialEndC } func (w *TrialEndCheckWorker) trialEndCheck(ctx context.Context) error { - onTrialOrgs, err := w.admin.DB.FindBillingIssueByTypeNotOverdueProcessed(ctx, database.BillingIssueTypeOnTrial) + onTrialOrgs, err := w.admin.DB.FindBillingIssueByTypeAndOverdueProcessed(ctx, database.BillingIssueTypeOnTrial, true) if err != nil { if errors.Is(err, database.ErrNotFound) { // no orgs have this billing issue @@ -114,7 +114,13 @@ func (w *TrialEndCheckWorker) trialEndCheck(ctx context.Context) error { w.logger.Warn("trial period has ended", zap.String("org_id", org.ID), zap.String("org_name", org.Name)) gracePeriodEndDate := m.EndDate.AddDate(0, 0, gracePeriodDays) - _, err = w.admin.DB.UpsertBillingIssue(ctx, &database.UpsertBillingIssueOptions{ + + cctx, tx, err := w.admin.DB.NewTx(ctx) + if err != nil { + return fmt.Errorf("failed to start transaction: %w", err) + } + + _, err = w.admin.DB.UpsertBillingIssue(cctx, &database.UpsertBillingIssueOptions{ OrgID: org.ID, Type: database.BillingIssueTypeTrialEnded, Metadata: &database.BillingIssueMetadataTrialEnded{ @@ -123,9 +129,23 @@ func (w *TrialEndCheckWorker) trialEndCheck(ctx context.Context) error { EventTime: m.EndDate.AddDate(0, 0, 1), }) if err != nil { + err = tx.Rollback() + if err != nil { + return fmt.Errorf("failed to rollback transaction: %w", err) + } return fmt.Errorf("failed to add billing error: %w", err) } + // delete the on-trial billing issue + err = w.admin.DB.DeleteBillingIssue(cctx, o.ID) + if err != nil { + err = tx.Rollback() + if err != nil { + return fmt.Errorf("failed to rollback transaction: %w", err) + } + return fmt.Errorf("failed to delete billing issue: %w", err) + } + // send email err = w.admin.Email.SendTrialEnded(&email.TrialEnded{ ToEmail: org.BillingEmail, @@ -134,13 +154,16 @@ func (w *TrialEndCheckWorker) trialEndCheck(ctx context.Context) error { GracePeriodEndDate: gracePeriodEndDate, }) if err != nil { + err = tx.Rollback() + if err != nil { + return fmt.Errorf("failed to rollback transaction: %w", err) + } return fmt.Errorf("failed to send trial period ended email for org %q: %w", org.Name, err) } - // mark the billing issue as processed - err = w.admin.DB.UpdateBillingIssueOverdueAsProcessed(ctx, o.ID) + err = tx.Commit() if err != nil { - return fmt.Errorf("failed to update billing issue as processed: %w", err) + return fmt.Errorf("failed to commit transaction: %w", err) } } @@ -162,7 +185,7 @@ func (w *TrialGracePeriodCheckWorker) Work(ctx context.Context, job *river.Job[T } func (w *TrialGracePeriodCheckWorker) trialGracePeriodCheck(ctx context.Context) error { - trailEndedOrgs, err := w.admin.DB.FindBillingIssueByTypeNotOverdueProcessed(ctx, database.BillingIssueTypeTrialEnded) + trailEndedOrgs, err := w.admin.DB.FindBillingIssueByTypeAndOverdueProcessed(ctx, database.BillingIssueTypeTrialEnded, false) if err != nil { if errors.Is(err, database.ErrNotFound) { // no orgs have this billing issue diff --git a/admin/server/billing.go b/admin/server/billing.go index 3c4c7cb1817..208a1bd7eb0 100644 --- a/admin/server/billing.go +++ b/admin/server/billing.go @@ -114,7 +114,7 @@ func (s *Server) UpdateBillingSubscription(ctx context.Context, req *adminv1.Upd validationErrs = append(validationErrs, "no billing address found, click on update information to add billing address") } - be, err := s.admin.DB.FindBillingIssueByType(ctx, org.ID, database.BillingIssueTypePaymentFailed) + be, err := s.admin.DB.FindBillingIssueByTypeForOrg(ctx, org.ID, database.BillingIssueTypePaymentFailed) if err != nil { if !errors.Is(err, database.ErrNotFound) { return nil, status.Error(codes.Internal, err.Error()) @@ -407,7 +407,7 @@ func (s *Server) ListOrganizationBillingIssues(ctx context.Context, req *adminv1 return nil, status.Error(codes.PermissionDenied, "not allowed to read org billing errors") } - issues, err := s.admin.DB.FindBillingIssues(ctx, org.ID) + issues, err := s.admin.DB.FindBillingIssuesForOrg(ctx, org.ID) if err != nil { return nil, status.Error(codes.Internal, err.Error()) } @@ -447,7 +447,7 @@ func (s *Server) SudoDeleteOrganizationBillingIssue(ctx context.Context, req *ad return nil, err } - err = s.admin.DB.DeleteBillingIssueByType(ctx, org.ID, t) + err = s.admin.DB.DeleteBillingIssueByTypeForOrg(ctx, org.ID, t) if err != nil { return nil, status.Error(codes.Internal, err.Error()) }