Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix trial end check #5772

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions admin/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ func (s *Service) RaiseNewOrgBillingIssues(ctx context.Context, orgID, subID, pl
return nil
}

// CleanupTrialBillingIssues removes trial related billing issues and cancel associated jobs
// CleanupTrialBillingIssues removes trial related billing issues
func (s *Service) CleanupTrialBillingIssues(ctx context.Context, orgID string) error {
bite, err := s.DB.FindBillingIssueByType(ctx, orgID, database.BillingIssueTypeTrialEnded)
bite, err := s.DB.FindBillingIssueByTypeForOrg(ctx, orgID, database.BillingIssueTypeTrialEnded)
if err != nil {
if !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("failed to find billing issue: %w", err)
Expand All @@ -277,7 +277,7 @@ func (s *Service) CleanupTrialBillingIssues(ctx context.Context, orgID string) e
}
}

biot, err := s.DB.FindBillingIssueByType(ctx, orgID, database.BillingIssueTypeOnTrial)
biot, err := s.DB.FindBillingIssueByTypeForOrg(ctx, orgID, database.BillingIssueTypeOnTrial)
if err != nil {
if !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("failed to find billing issue: %w", err)
Expand All @@ -294,9 +294,9 @@ func (s *Service) CleanupTrialBillingIssues(ctx context.Context, orgID string) e
return nil
}

// CleanupBillingErrorSubCancellation removes subscription cancellation related billing error and cancel associated job
// CleanupBillingErrorSubCancellation removes subscription cancellation related billing error
func (s *Service) CleanupBillingErrorSubCancellation(ctx context.Context, orgID string) error {
bisc, err := s.DB.FindBillingIssueByType(ctx, orgID, database.BillingIssueTypeSubscriptionCancelled)
bisc, err := s.DB.FindBillingIssueByTypeForOrg(ctx, orgID, database.BillingIssueTypeSubscriptionCancelled)
if err != nil {
if !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("failed to find billing errors: %w", err)
Expand All @@ -314,7 +314,7 @@ func (s *Service) CleanupBillingErrorSubCancellation(ctx context.Context, orgID
}

func (s *Service) CheckBillingErrors(ctx context.Context, orgID string) error {
be, err := s.DB.FindBillingIssueByType(ctx, orgID, database.BillingIssueTypeTrialEnded)
be, err := s.DB.FindBillingIssueByTypeForOrg(ctx, orgID, database.BillingIssueTypeTrialEnded)
if err != nil {
if !errors.Is(err, database.ErrNotFound) {
return err
Expand All @@ -325,7 +325,7 @@ func (s *Service) CheckBillingErrors(ctx context.Context, orgID string) error {
return fmt.Errorf("trial has ended")
}

be, err = s.DB.FindBillingIssueByType(ctx, orgID, database.BillingIssueTypePaymentFailed)
be, err = s.DB.FindBillingIssueByTypeForOrg(ctx, orgID, database.BillingIssueTypePaymentFailed)
if err != nil {
if !errors.Is(err, database.ErrNotFound) {
return err
Expand All @@ -336,7 +336,7 @@ func (s *Service) CheckBillingErrors(ctx context.Context, orgID string) error {
return fmt.Errorf("invoice payment failed")
}

be, err = s.DB.FindBillingIssueByType(ctx, orgID, database.BillingIssueTypeSubscriptionCancelled)
be, err = s.DB.FindBillingIssueByTypeForOrg(ctx, orgID, database.BillingIssueTypeSubscriptionCancelled)
if err != nil {
if !errors.Is(err, database.ErrNotFound) {
return err
Expand Down
10 changes: 5 additions & 5 deletions admin/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,14 @@ type DB interface {
FindOrganizationForPaymentCustomerID(ctx context.Context, customerID string) (*Organization, error)
FindOrganizationForBillingCustomerID(ctx context.Context, customerID string) (*Organization, error)

FindBillingIssues(ctx context.Context, orgID string) ([]*BillingIssue, error)
FindBillingIssueByType(ctx context.Context, orgID string, errorType BillingIssueType) (*BillingIssue, error)
FindBillingIssueByTypeNotOverdueProcessed(ctx context.Context, errorType BillingIssueType) ([]*BillingIssue, error)
FindBillingIssueByTypeNotOverdueProcessedForOrg(ctx context.Context, orgID string, errorType BillingIssueType) (*BillingIssue, error)
FindBillingIssuesForOrg(ctx context.Context, orgID string) ([]*BillingIssue, error)
FindBillingIssueByTypeForOrg(ctx context.Context, orgID string, errorType BillingIssueType) (*BillingIssue, error)
FindBillingIssueByType(ctx context.Context, errorType BillingIssueType) ([]*BillingIssue, error)
FindBillingIssueByTypeAndOverdueProcessed(ctx context.Context, errorType BillingIssueType, overdueProcessed bool) ([]*BillingIssue, error)
UpsertBillingIssue(ctx context.Context, opts *UpsertBillingIssueOptions) (*BillingIssue, error)
UpdateBillingIssueOverdueAsProcessed(ctx context.Context, id string) error
DeleteBillingIssue(ctx context.Context, id string) error
DeleteBillingIssueByType(ctx context.Context, orgID string, errorType BillingIssueType) error
DeleteBillingIssueByTypeForOrg(ctx context.Context, orgID string, errorType BillingIssueType) error
}

// Tx represents a database transaction. It can only be used to commit and rollback transactions.
Expand Down
25 changes: 15 additions & 10 deletions admin/database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -1948,7 +1948,7 @@ func (c *connection) FindOrganizationForBillingCustomerID(ctx context.Context, c
return res, nil
}

func (c *connection) FindBillingIssues(ctx context.Context, orgID string) ([]*database.BillingIssue, error) {
func (c *connection) FindBillingIssuesForOrg(ctx context.Context, orgID string) ([]*database.BillingIssue, error) {
var res []*billingIssueDTO
err := c.db.SelectContext(ctx, &res, `SELECT * FROM billing_issues WHERE org_id = $1`, orgID)
if err != nil {
Expand All @@ -1962,7 +1962,7 @@ func (c *connection) FindBillingIssues(ctx context.Context, orgID string) ([]*da
return billingErrors, nil
}

func (c *connection) FindBillingIssueByType(ctx context.Context, orgID string, errorType database.BillingIssueType) (*database.BillingIssue, error) {
func (c *connection) FindBillingIssueByTypeForOrg(ctx context.Context, orgID string, errorType database.BillingIssueType) (*database.BillingIssue, error) {
res := &billingIssueDTO{}
err := c.db.GetContext(ctx, res, `SELECT * FROM billing_issues WHERE org_id = $1 AND type = $2`, orgID, errorType)
if err != nil {
Expand All @@ -1971,9 +1971,9 @@ func (c *connection) FindBillingIssueByType(ctx context.Context, orgID string, e
return res.AsModel(), nil
}

func (c *connection) FindBillingIssueByTypeNotOverdueProcessed(ctx context.Context, errorType database.BillingIssueType) ([]*database.BillingIssue, error) {
func (c *connection) FindBillingIssueByType(ctx context.Context, errorType database.BillingIssueType) ([]*database.BillingIssue, error) {
var res []*billingIssueDTO
err := c.db.SelectContext(ctx, &res, `SELECT * FROM billing_issues WHERE type = $1 AND overdue_processed = false`, errorType)
err := c.db.SelectContext(ctx, &res, `SELECT * FROM billing_issues WHERE type = $1`, errorType)
if err != nil {
return nil, parseErr("billing issues", err)
}
Expand All @@ -1985,13 +1985,18 @@ func (c *connection) FindBillingIssueByTypeNotOverdueProcessed(ctx context.Conte
return billingErrors, nil
}

func (c *connection) FindBillingIssueByTypeNotOverdueProcessedForOrg(ctx context.Context, orgID string, errorType database.BillingIssueType) (*database.BillingIssue, error) {
res := &billingIssueDTO{}
err := c.db.GetContext(ctx, res, `SELECT * FROM billing_issues WHERE org_id = $1 AND type = $2 AND overdue_processed = false`, orgID, errorType)
func (c *connection) FindBillingIssueByTypeAndOverdueProcessed(ctx context.Context, errorType database.BillingIssueType, overdueProcessed bool) ([]*database.BillingIssue, error) {
var res []*billingIssueDTO
err := c.db.SelectContext(ctx, &res, `SELECT * FROM billing_issues WHERE type = $1 AND overdue_processed = $2`, errorType, overdueProcessed)
if err != nil {
return nil, parseErr("billing issue", err)
return nil, parseErr("billing issues", err)
}
return res.AsModel(), nil

var billingErrors []*database.BillingIssue
for _, dto := range res {
billingErrors = append(billingErrors, dto.AsModel())
}
return billingErrors, nil
}

func (c *connection) UpsertBillingIssue(ctx context.Context, opts *database.UpsertBillingIssueOptions) (*database.BillingIssue, error) {
Expand Down Expand Up @@ -2032,7 +2037,7 @@ func (c *connection) DeleteBillingIssue(ctx context.Context, id string) error {
return checkDeleteRow("billing issue", res, err)
}

func (c *connection) DeleteBillingIssueByType(ctx context.Context, orgID string, errorType database.BillingIssueType) error {
func (c *connection) DeleteBillingIssueByTypeForOrg(ctx context.Context, orgID string, errorType database.BillingIssueType) error {
res, err := c.getDB(ctx).ExecContext(ctx, "DELETE FROM billing_issues WHERE org_id = $1 AND type = $2", orgID, errorType)
return checkDeleteRow("billing issue", res, err)
}
Expand Down
6 changes: 3 additions & 3 deletions admin/jobs/river/biller_event_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (w *PaymentFailedWorker) Work(ctx context.Context, job *river.Job[PaymentFa
return fmt.Errorf("failed to find organization of billing customer id %q: %w", job.Args.BillingCustomerID, err)
}

be, err := w.admin.DB.FindBillingIssueByType(ctx, org.ID, database.BillingIssueTypePaymentFailed)
be, err := w.admin.DB.FindBillingIssueByTypeForOrg(ctx, org.ID, database.BillingIssueTypePaymentFailed)
if err != nil {
if !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("failed to find billing errors: %w", err)
Expand Down Expand Up @@ -120,7 +120,7 @@ func (w *PaymentSuccessWorker) Work(ctx context.Context, job *river.Job[PaymentS
}

// check for existing billing error and delete it
be, err := w.admin.DB.FindBillingIssueByType(ctx, org.ID, database.BillingIssueTypePaymentFailed)
be, err := w.admin.DB.FindBillingIssueByTypeForOrg(ctx, org.ID, database.BillingIssueTypePaymentFailed)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
// no billing error, ignore
Expand Down Expand Up @@ -191,7 +191,7 @@ func (w *PaymentFailedGracePeriodCheckWorker) Work(ctx context.Context, job *riv
}

func (w *PaymentFailedGracePeriodCheckWorker) paymentFailedGracePeriodCheck(ctx context.Context) error {
failures, err := w.admin.DB.FindBillingIssueByTypeNotOverdueProcessed(ctx, database.BillingIssueTypePaymentFailed)
failures, err := w.admin.DB.FindBillingIssueByTypeAndOverdueProcessed(ctx, database.BillingIssueTypePaymentFailed, false)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
// no orgs have this billing error
Expand Down
4 changes: 2 additions & 2 deletions admin/jobs/river/payment_provider_event_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (w *PaymentMethodAddedWorker) Work(ctx context.Context, job *river.Job[Paym
}

// check for no payment method billing error
be, err := w.admin.DB.FindBillingIssueByType(ctx, org.ID, database.BillingIssueTypeNoPaymentMethod)
be, err := w.admin.DB.FindBillingIssueByTypeForOrg(ctx, org.ID, database.BillingIssueTypeNoPaymentMethod)
if err != nil {
if !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("failed to find billing errors: %w", err)
Expand Down Expand Up @@ -128,7 +128,7 @@ func (w *CustomerAddressUpdatedWorker) Work(ctx context.Context, job *river.Job[
}

// look for no billable address billing error and remove it
be, err := w.admin.DB.FindBillingIssueByType(ctx, org.ID, database.BillingIssueTypeNoBillableAddress)
be, err := w.admin.DB.FindBillingIssueByTypeForOrg(ctx, org.ID, database.BillingIssueTypeNoBillableAddress)
if err != nil {
if !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("failed to find billing errors: %w", err)
Expand Down
2 changes: 1 addition & 1 deletion admin/jobs/river/subscription_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (w *SubscriptionCancellationCheckWorker) Work(ctx context.Context, job *riv
}

func (w *SubscriptionCancellationCheckWorker) subscriptionCancellationCheck(ctx context.Context) error {
cancelled, err := w.admin.DB.FindBillingIssueByTypeNotOverdueProcessed(ctx, database.BillingIssueTypeSubscriptionCancelled)
cancelled, err := w.admin.DB.FindBillingIssueByTypeAndOverdueProcessed(ctx, database.BillingIssueTypeSubscriptionCancelled, false)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
// no orgs have this billing issue
Expand Down
37 changes: 30 additions & 7 deletions admin/jobs/river/trial_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (w *TrialEndingSoonWorker) Work(ctx context.Context, job *river.Job[TrialEn
}

func (w *TrialEndingSoonWorker) trialEndingSoon(ctx context.Context) error {
onTrialOrgs, err := w.admin.DB.FindBillingIssueByTypeNotOverdueProcessed(ctx, database.BillingIssueTypeOnTrial)
onTrialOrgs, err := w.admin.DB.FindBillingIssueByTypeAndOverdueProcessed(ctx, database.BillingIssueTypeOnTrial, false)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
// no orgs have this billing issue
Expand Down Expand Up @@ -89,7 +89,7 @@ func (w *TrialEndCheckWorker) Work(ctx context.Context, job *river.Job[TrialEndC
}

func (w *TrialEndCheckWorker) trialEndCheck(ctx context.Context) error {
onTrialOrgs, err := w.admin.DB.FindBillingIssueByTypeNotOverdueProcessed(ctx, database.BillingIssueTypeOnTrial)
onTrialOrgs, err := w.admin.DB.FindBillingIssueByTypeAndOverdueProcessed(ctx, database.BillingIssueTypeOnTrial, true)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
// no orgs have this billing issue
Expand All @@ -114,7 +114,13 @@ func (w *TrialEndCheckWorker) trialEndCheck(ctx context.Context) error {
w.logger.Warn("trial period has ended", zap.String("org_id", org.ID), zap.String("org_name", org.Name))

gracePeriodEndDate := m.EndDate.AddDate(0, 0, gracePeriodDays)
_, err = w.admin.DB.UpsertBillingIssue(ctx, &database.UpsertBillingIssueOptions{

cctx, tx, err := w.admin.DB.NewTx(ctx)
if err != nil {
return fmt.Errorf("failed to start transaction: %w", err)
}

_, err = w.admin.DB.UpsertBillingIssue(cctx, &database.UpsertBillingIssueOptions{
OrgID: org.ID,
Type: database.BillingIssueTypeTrialEnded,
Metadata: &database.BillingIssueMetadataTrialEnded{
Expand All @@ -123,9 +129,23 @@ func (w *TrialEndCheckWorker) trialEndCheck(ctx context.Context) error {
EventTime: m.EndDate.AddDate(0, 0, 1),
})
if err != nil {
err = tx.Rollback()
if err != nil {
return fmt.Errorf("failed to rollback transaction: %w", err)
}
return fmt.Errorf("failed to add billing error: %w", err)
}

// delete the on-trial billing issue
err = w.admin.DB.DeleteBillingIssue(cctx, o.ID)
if err != nil {
err = tx.Rollback()
if err != nil {
return fmt.Errorf("failed to rollback transaction: %w", err)
}
return fmt.Errorf("failed to delete billing issue: %w", err)
}

// send email
err = w.admin.Email.SendTrialEnded(&email.TrialEnded{
ToEmail: org.BillingEmail,
Expand All @@ -134,13 +154,16 @@ func (w *TrialEndCheckWorker) trialEndCheck(ctx context.Context) error {
GracePeriodEndDate: gracePeriodEndDate,
})
if err != nil {
err = tx.Rollback()
if err != nil {
return fmt.Errorf("failed to rollback transaction: %w", err)
}
return fmt.Errorf("failed to send trial period ended email for org %q: %w", org.Name, err)
}

// mark the billing issue as processed
err = w.admin.DB.UpdateBillingIssueOverdueAsProcessed(ctx, o.ID)
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to update billing issue as processed: %w", err)
return fmt.Errorf("failed to commit transaction: %w", err)
}
}

Expand All @@ -162,7 +185,7 @@ func (w *TrialGracePeriodCheckWorker) Work(ctx context.Context, job *river.Job[T
}

func (w *TrialGracePeriodCheckWorker) trialGracePeriodCheck(ctx context.Context) error {
trailEndedOrgs, err := w.admin.DB.FindBillingIssueByTypeNotOverdueProcessed(ctx, database.BillingIssueTypeTrialEnded)
trailEndedOrgs, err := w.admin.DB.FindBillingIssueByTypeAndOverdueProcessed(ctx, database.BillingIssueTypeTrialEnded, false)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
// no orgs have this billing issue
Expand Down
6 changes: 3 additions & 3 deletions admin/server/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (s *Server) UpdateBillingSubscription(ctx context.Context, req *adminv1.Upd
validationErrs = append(validationErrs, "no billing address found, click on update information to add billing address")
}

be, err := s.admin.DB.FindBillingIssueByType(ctx, org.ID, database.BillingIssueTypePaymentFailed)
be, err := s.admin.DB.FindBillingIssueByTypeForOrg(ctx, org.ID, database.BillingIssueTypePaymentFailed)
if err != nil {
if !errors.Is(err, database.ErrNotFound) {
return nil, status.Error(codes.Internal, err.Error())
Expand Down Expand Up @@ -407,7 +407,7 @@ func (s *Server) ListOrganizationBillingIssues(ctx context.Context, req *adminv1
return nil, status.Error(codes.PermissionDenied, "not allowed to read org billing errors")
}

issues, err := s.admin.DB.FindBillingIssues(ctx, org.ID)
issues, err := s.admin.DB.FindBillingIssuesForOrg(ctx, org.ID)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
Expand Down Expand Up @@ -447,7 +447,7 @@ func (s *Server) SudoDeleteOrganizationBillingIssue(ctx context.Context, req *ad
return nil, err
}

err = s.admin.DB.DeleteBillingIssueByType(ctx, org.ID, t)
err = s.admin.DB.DeleteBillingIssueByTypeForOrg(ctx, org.ID, t)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
Expand Down
Loading