From 0e7a9ea95a2cbcc37919f08474ba4c2655baf825 Mon Sep 17 00:00:00 2001 From: Alex Goth Date: Tue, 19 Nov 2024 13:03:04 +0100 Subject: [PATCH] feat(subscription): write first version of workflow service --- openmeter/entitlement/entitlement.go | 109 ++++ openmeter/subscription/entitlement.go | 16 +- openmeter/subscription/patch.go | 1 + openmeter/subscription/patch/additem.go | 16 + openmeter/subscription/patch/addphase.go | 16 + openmeter/subscription/patch/extendphase.go | 16 + openmeter/subscription/patch/removeitem.go | 12 + openmeter/subscription/patch/removephase.go | 12 + openmeter/subscription/plan.go | 2 +- .../subscription/repo/subscriptionitemrepo.go | 5 +- .../repo/subscriptionphaserepo.go | 5 +- openmeter/subscription/service/helpers.go | 8 +- openmeter/subscription/service/service.go | 8 - .../subscription/service/service_test.go | 52 +- openmeter/subscription/service/update.go | 38 +- openmeter/subscription/service/update_test.go | 500 ++++-------------- .../subscription/service/workflowservice.go | 132 +++++ .../service/workflowservice_test.go | 497 +++++++++++++++++ openmeter/subscription/subscriptionspec.go | 24 + openmeter/subscription/subscriptionview.go | 10 +- openmeter/subscription/testutils/patch.go | 55 ++ openmeter/subscription/testutils/plan.go | 8 +- openmeter/subscription/testutils/service.go | 76 ++- pkg/clock/clock.go | 2 +- pkg/models/model.go | 18 + 25 files changed, 1159 insertions(+), 479 deletions(-) create mode 100644 openmeter/subscription/service/workflowservice.go create mode 100644 openmeter/subscription/service/workflowservice_test.go create mode 100644 openmeter/subscription/testutils/patch.go diff --git a/openmeter/entitlement/entitlement.go b/openmeter/entitlement/entitlement.go index b6a63dea6a..684db737fc 100644 --- a/openmeter/entitlement/entitlement.go +++ b/openmeter/entitlement/entitlement.go @@ -2,6 +2,7 @@ package entitlement import ( "fmt" + "reflect" "slices" "time" @@ -38,6 +39,10 @@ type MeasureUsageFromInput struct { ts time.Time } +func (m MeasureUsageFromInput) Equal(other MeasureUsageFromInput) bool { + return m.ts.Equal(other.Get()) +} + func (m MeasureUsageFromInput) Get() time.Time { return m.ts } @@ -96,6 +101,98 @@ type CreateEntitlementInputs struct { SubscriptionManaged bool `json:"subscriptionManaged,omitempty"` } +func (c CreateEntitlementInputs) Equal(other CreateEntitlementInputs) bool { + if c.Namespace != other.Namespace { + return false + } + + if !reflect.DeepEqual(c.FeatureID, other.FeatureID) { + return false + } + + if !reflect.DeepEqual(c.FeatureKey, other.FeatureKey) { + return false + } + + if !reflect.DeepEqual(c.SubjectKey, other.SubjectKey) { + return false + } + + if c.EntitlementType != other.EntitlementType { + return false + } + + if !reflect.DeepEqual(c.Metadata, other.Metadata) { + return false + } + + if (c.ActiveFrom == nil) != (other.ActiveFrom == nil) { + return false + } + + if (c.ActiveFrom != nil && other.ActiveFrom != nil) && !c.ActiveFrom.Equal(*other.ActiveFrom) { + return false + } + + if (c.ActiveTo == nil) != (other.ActiveTo == nil) { + return false + } + + if (c.ActiveTo != nil && other.ActiveTo != nil) && !c.ActiveTo.Equal(*other.ActiveTo) { + return false + } + + if (c.MeasureUsageFrom == nil) != (other.MeasureUsageFrom == nil) { + return false + } + + if (c.MeasureUsageFrom != nil && other.MeasureUsageFrom != nil) && !c.MeasureUsageFrom.Equal(*other.MeasureUsageFrom) { + return false + } + + if !reflect.DeepEqual(c.IssueAfterReset, other.IssueAfterReset) { + return false + } + + if !reflect.DeepEqual(c.IssueAfterResetPriority, other.IssueAfterResetPriority) { + return false + } + + if !reflect.DeepEqual(c.IsSoftLimit, other.IsSoftLimit) { + return false + } + + if !reflect.DeepEqual(c.Config, other.Config) { + return false + } + + if (c.UsagePeriod == nil) != (other.UsagePeriod == nil) { + return false + } + + if (c.UsagePeriod != nil && other.UsagePeriod != nil) && !c.UsagePeriod.Equal(*other.UsagePeriod) { + return false + } + + if !reflect.DeepEqual(c.PreserveOverageAtReset, other.PreserveOverageAtReset) { + return false + } + + if c.SubscriptionManaged != other.SubscriptionManaged { + return false + } + + return true +} + +func (c CreateEntitlementInputs) Validate() error { + if c.FeatureID == nil && c.FeatureKey == nil { + return fmt.Errorf("feature id or key must be set") + } + + return nil +} + func (c CreateEntitlementInputs) GetType() EntitlementType { return c.EntitlementType } @@ -233,6 +330,18 @@ type GenericProperties struct { type UsagePeriod recurrence.Recurrence +func (u UsagePeriod) Equal(other UsagePeriod) bool { + if u.Interval != other.Interval { + return false + } + + if !u.Anchor.Equal(other.Anchor) { + return false + } + + return true +} + // The returned period is exclusive at the end end inclusive in the start func (u UsagePeriod) GetCurrentPeriodAt(at time.Time) (recurrence.Period, error) { rec := recurrence.Recurrence{ diff --git a/openmeter/subscription/entitlement.go b/openmeter/subscription/entitlement.go index ee5536c841..eaa58bf416 100644 --- a/openmeter/subscription/entitlement.go +++ b/openmeter/subscription/entitlement.go @@ -3,7 +3,6 @@ package subscription import ( "context" "fmt" - "reflect" "time" "github.com/openmeterio/openmeter/openmeter/entitlement" @@ -56,7 +55,20 @@ type ScheduleSubscriptionEntitlementInput struct { } func (s ScheduleSubscriptionEntitlementInput) Equal(other ScheduleSubscriptionEntitlementInput) bool { - return reflect.DeepEqual(s, other) + if s.SubscriptionID != other.SubscriptionID { + return false + } + if s.SubscriptionItemID != other.SubscriptionItemID { + return false + } + if !s.Cadence.Equal(other.Cadence) { + return false + } + if !s.EntitlementInputs.Equal(other.EntitlementInputs) { + return false + } + + return true } func (s ScheduleSubscriptionEntitlementInput) Validate() error { diff --git a/openmeter/subscription/patch.go b/openmeter/subscription/patch.go index a618a96409..52ef7826b3 100644 --- a/openmeter/subscription/patch.go +++ b/openmeter/subscription/patch.go @@ -146,6 +146,7 @@ func NewItemPath(phaseKey, itemKey string) PatchPath { type Patch interface { json.Marshaler Applies + Validate() error Op() PatchOperation Path() PatchPath } diff --git a/openmeter/subscription/patch/additem.go b/openmeter/subscription/patch/additem.go index bf7307f592..3dc770cb29 100644 --- a/openmeter/subscription/patch/additem.go +++ b/openmeter/subscription/patch/additem.go @@ -24,6 +24,22 @@ func (a PatchAddItem) Value() subscription.SubscriptionItemSpec { return a.CreateInput } +func (a PatchAddItem) Validate() error { + if err := a.Path().Validate(); err != nil { + return err + } + + if err := a.Op().Validate(); err != nil { + return err + } + + if err := a.CreateInput.Validate(); err != nil { + return err + } + + return nil +} + func (a PatchAddItem) ValueAsAny() any { return a.CreateInput } diff --git a/openmeter/subscription/patch/addphase.go b/openmeter/subscription/patch/addphase.go index 8d3b9e61d2..62ad8057f9 100644 --- a/openmeter/subscription/patch/addphase.go +++ b/openmeter/subscription/patch/addphase.go @@ -28,6 +28,22 @@ func (a PatchAddPhase) ValueAsAny() any { return a.CreateInput } +func (a PatchAddPhase) Validate() error { + if err := a.Path().Validate(); err != nil { + return err + } + + if err := a.Op().Validate(); err != nil { + return err + } + + if err := a.CreateInput.Validate(); err != nil { + return err + } + + return nil +} + var _ subscription.ValuePatch[subscription.CreateSubscriptionPhaseInput] = PatchAddPhase{} func (a PatchAddPhase) ApplyTo(spec *subscription.SubscriptionSpec, actx subscription.ApplyContext) error { diff --git a/openmeter/subscription/patch/extendphase.go b/openmeter/subscription/patch/extendphase.go index f4aedff939..b9facbd9ff 100644 --- a/openmeter/subscription/patch/extendphase.go +++ b/openmeter/subscription/patch/extendphase.go @@ -28,6 +28,22 @@ func (e PatchExtendPhase) ValueAsAny() any { return e.Duration } +func (e PatchExtendPhase) Validate() error { + if err := e.Path().Validate(); err != nil { + return err + } + + if err := e.Op().Validate(); err != nil { + return err + } + + if e.Duration.IsZero() { + return fmt.Errorf("duration cannot be zero") + } + + return nil +} + var _ subscription.ValuePatch[datex.Period] = PatchExtendPhase{} func (e PatchExtendPhase) ApplyTo(spec *subscription.SubscriptionSpec, actx subscription.ApplyContext) error { diff --git a/openmeter/subscription/patch/removeitem.go b/openmeter/subscription/patch/removeitem.go index 31e6eefd88..b7e00d8476 100644 --- a/openmeter/subscription/patch/removeitem.go +++ b/openmeter/subscription/patch/removeitem.go @@ -19,6 +19,18 @@ func (r PatchRemoveItem) Path() subscription.PatchPath { return subscription.NewItemPath(r.PhaseKey, r.ItemKey) } +func (r PatchRemoveItem) Validate() error { + if err := r.Path().Validate(); err != nil { + return err + } + + if err := r.Op().Validate(); err != nil { + return err + } + + return nil +} + var _ subscription.Patch = PatchRemoveItem{} func (r PatchRemoveItem) ApplyTo(spec *subscription.SubscriptionSpec, actx subscription.ApplyContext) error { diff --git a/openmeter/subscription/patch/removephase.go b/openmeter/subscription/patch/removephase.go index b29beb7ec6..910b671ef3 100644 --- a/openmeter/subscription/patch/removephase.go +++ b/openmeter/subscription/patch/removephase.go @@ -28,6 +28,18 @@ func (r PatchRemovePhase) ValueAsAny() any { return r.RemoveInput } +func (r PatchRemovePhase) Validate() error { + if err := r.Path().Validate(); err != nil { + return err + } + + if err := r.Op().Validate(); err != nil { + return err + } + + return nil +} + var _ subscription.ValuePatch[subscription.RemoveSubscriptionPhaseInput] = PatchRemovePhase{} func (r PatchRemovePhase) ApplyTo(spec *subscription.SubscriptionSpec, actx subscription.ApplyContext) error { diff --git a/openmeter/subscription/plan.go b/openmeter/subscription/plan.go index 5ecf43d396..7bf9da65ed 100644 --- a/openmeter/subscription/plan.go +++ b/openmeter/subscription/plan.go @@ -54,6 +54,6 @@ type PlanNotFoundError struct { Version int } -func (e *PlanNotFoundError) Error() string { +func (e PlanNotFoundError) Error() string { return fmt.Sprintf("plan %s@%d not found", e.Key, e.Version) } diff --git a/openmeter/subscription/repo/subscriptionitemrepo.go b/openmeter/subscription/repo/subscriptionitemrepo.go index 517d2f190c..3f27fd8beb 100644 --- a/openmeter/subscription/repo/subscriptionitemrepo.go +++ b/openmeter/subscription/repo/subscriptionitemrepo.go @@ -122,14 +122,15 @@ func (r *subscriptionItemRepo) Create(ctx context.Context, input subscription.Cr func (r *subscriptionItemRepo) Delete(ctx context.Context, input models.NamespacedID) error { _, err := entutils.TransactingRepo(ctx, r, func(ctx context.Context, repo *subscriptionItemRepo) (any, error) { + at := clock.Now() err := repo.db.SubscriptionItem.UpdateOneID(input.ID). Where( dbsubscriptionitem.Namespace(input.Namespace), dbsubscriptionitem.Or( dbsubscriptionitem.DeletedAtIsNil(), - dbsubscriptionitem.DeletedAtGT(clock.Now()), + dbsubscriptionitem.DeletedAtGT(at), ), - ).Exec(ctx) + ).SetDeletedAt(at).Exec(ctx) if db.IsNotFound(err) { return nil, &subscription.ItemNotFoundError{ID: input.ID} diff --git a/openmeter/subscription/repo/subscriptionphaserepo.go b/openmeter/subscription/repo/subscriptionphaserepo.go index a82b5a33c6..a2271e8bf6 100644 --- a/openmeter/subscription/repo/subscriptionphaserepo.go +++ b/openmeter/subscription/repo/subscriptionphaserepo.go @@ -73,14 +73,15 @@ func (r *subscriptionPhaseRepo) Create(ctx context.Context, phase subscription.C func (r *subscriptionPhaseRepo) Delete(ctx context.Context, id models.NamespacedID) error { _, err := entutils.TransactingRepo(ctx, r, func(ctx context.Context, repo *subscriptionPhaseRepo) (any, error) { + at := clock.Now() err := repo.db.SubscriptionPhase.UpdateOneID(id.ID). Where( dbsubscriptionphase.Namespace(id.Namespace), dbsubscriptionphase.Or( dbsubscriptionphase.DeletedAtIsNil(), - dbsubscriptionphase.DeletedAtGT(clock.Now()), + dbsubscriptionphase.DeletedAtGT(at), ), - ).Exec(ctx) + ).SetDeletedAt(at).Exec(ctx) if db.IsNotFound(err) { return nil, &subscription.PhaseNotFoundError{ID: id.ID} } diff --git a/openmeter/subscription/service/helpers.go b/openmeter/subscription/service/helpers.go index f1e7056c58..becca0a74b 100644 --- a/openmeter/subscription/service/helpers.go +++ b/openmeter/subscription/service/helpers.go @@ -20,7 +20,7 @@ func (s *service) createPhase( return transaction.Run(ctx, s.TransactionManager, func(ctx context.Context) (subscription.SubscriptionPhaseView, error) { res := subscription.SubscriptionPhaseView{ Spec: phaseSpec, - Items: make([]subscription.SubscriptionItemView, 0, len(phaseSpec.Items)), + Items: make(map[string]subscription.SubscriptionItemView, len(phaseSpec.Items)), } // First, let's create the phase itself @@ -43,7 +43,11 @@ func (s *service) createPhase( return res, fmt.Errorf("failed to create item: %w", err) } - res.Items = append(res.Items, item) + if _, exists := res.Items[item.SubscriptionItem.Key]; exists { + return res, fmt.Errorf("item %s already exists", item.SubscriptionItem.Key) + } + + res.Items[item.SubscriptionItem.Key] = item } return res, nil diff --git a/openmeter/subscription/service/service.go b/openmeter/subscription/service/service.go index 742d5266f3..c6f0ae0e1d 100644 --- a/openmeter/subscription/service/service.go +++ b/openmeter/subscription/service/service.go @@ -22,7 +22,6 @@ type ServiceConfig struct { // connectors CustomerService customer.Service // adapters - PlanAdapter subscription.PlanAdapter EntitlementAdapter subscription.EntitlementAdapter DiscountAdapter subscription.DiscountAdapter // framework @@ -76,13 +75,6 @@ func (s *service) Create(ctx context.Context, namespace string, spec subscriptio return def, &models.GenericConflictError{Message: "customer already has a subscription"} } - // Fetch the plan, check if it exists - // We don't actually use the plan for anything, we don't validate the spec against it, but we expect it to be a valid reference. - _, err = s.PlanAdapter.GetVersion(ctx, spec.Plan.Key, spec.Plan.Version) - if err != nil { - return def, err - } - return transaction.Run(ctx, s.TransactionManager, func(ctx context.Context) (subscription.Subscription, error) { // Create subscription entity sub, err := s.SubscriptionRepo.Create(ctx, spec.ToCreateSubscriptionEntityInput(namespace)) diff --git a/openmeter/subscription/service/service_test.go b/openmeter/subscription/service/service_test.go index 2347c3e000..51ea7c9b73 100644 --- a/openmeter/subscription/service/service_test.go +++ b/openmeter/subscription/service/service_test.go @@ -29,9 +29,9 @@ func TestCreation(t *testing.T) { dbDeps := subscriptiontestutils.SetupDBDeps(t) defer dbDeps.Cleanup() - service, deps := subscriptiontestutils.NewService(t, dbDeps) + services, deps := subscriptiontestutils.NewService(t, dbDeps) + service := services.Service - deps.PlanAdapter.AddPlan(subscriptiontestutils.GetExamplePlan()) cust := deps.CustomerAdapter.CreateExampleCustomer(t) _ = deps.FeatureConnector.CreateExampleFeature(t) @@ -64,7 +64,7 @@ func TestCreation(t *testing.T) { assert.Equal(t, sub.Currency, found.Currency) }) - t.Run("Should create subscription according to plan", func(t *testing.T) { + t.Run("Should create subscription as specced", func(t *testing.T) { found, err := service.GetView(ctx, models.NamespacedID{ID: sub.ID, Namespace: sub.Namespace}) assert.Nil(t, err) @@ -79,38 +79,38 @@ func TestCreation(t *testing.T) { // Test Phases - plan, err := deps.PlanAdapter.GetVersion(ctx, sub.Plan.Key, sub.Plan.Version) - require.Nil(t, err) - - planPhases := plan.GetPhases() foundPhases := found.Phases + specPhases := defaultSpecFromPlan.GetSortedPhases() - require.Equal(t, len(planPhases), len(foundPhases)) + require.Equal(t, len(specPhases), len(foundPhases)) - for i := range planPhases { - assert.Equal(t, planPhases[i].GetKey(), foundPhases[i].SubscriptionPhase.Key) - assert.Equal(t, planPhases[i].ToCreateSubscriptionPhasePlanInput().PhaseKey, foundPhases[i].SubscriptionPhase.Key) + for i := range specPhases { + assert.Equal(t, specPhases[i].PhaseKey, foundPhases[i].SubscriptionPhase.Key) - expectedStart, _ := planPhases[i].ToCreateSubscriptionPhasePlanInput().StartAfter.AddTo(foundSub.ActiveFrom) + expectedStart, _ := specPhases[i].StartAfter.AddTo(foundSub.ActiveFrom) assert.Equal(t, expectedStart.UTC(), foundPhases[i].ActiveFrom(foundSub.CadencedModel)) // Test Rate Cards of Phase - planPhase := planPhases[i] + specPhase := specPhases[i] foundPhase := foundPhases[i] - planRateCards := planPhase.GetRateCards() - foundRateCards := foundPhase.Items + specItems := specPhase.Items + foundItems := foundPhase.Items + + require.Equal(t, len(specItems), len(foundItems), "item count mismatch for phase %s", specPhase.PhaseKey) - require.Equal(t, len(planRateCards), len(foundRateCards), "rate card count mismatch for phase %s", planPhase.GetKey()) + for specItemKey := range specItems { + specItem, ok := specItems[specItemKey] + require.True(t, ok, "item %s not found in spec phase %s", specItemKey, specPhase.PhaseKey) + foundItem, foundTheItem := foundItems[specItemKey] + require.True(t, foundTheItem, "item %s not found in found phase %s", specItemKey, specPhase.PhaseKey) - for j := range planRateCards { - assert.Equal(t, planRateCards[j].GetKey(), foundRateCards[j].SubscriptionItem.Key) - assert.Equal(t, planRateCards[j].ToCreateSubscriptionItemPlanInput().ItemKey, foundRateCards[j].SubscriptionItem.RateCard.Key()) + assert.Equal(t, specItem.ItemKey, foundItem.SubscriptionItem.Key) - rcFeature, hasFeature := foundRateCards[j].GetFeature() + rcFeature, hasFeature := foundItem.GetFeature() - pFeatureKey := subscriptiontestutils.GetFeatureKeyOfRateCard(planRateCards[j].ToCreateSubscriptionItemPlanInput().RateCard) + pFeatureKey := subscriptiontestutils.GetFeatureKeyOfRateCard(specItem.RateCard) if hasFeature { require.NotNil(t, pFeatureKey) assert.Equal(t, *pFeatureKey, rcFeature.Key) @@ -118,10 +118,10 @@ func TestCreation(t *testing.T) { assert.Nil(t, pFeatureKey) } - rcInp := planRateCards[j].ToCreateSubscriptionItemPlanInput() + rcInp := specItem.CreateSubscriptionItemPlanInput if rcEnt := subscriptiontestutils.GetEntitlementOfRateCard(rcInp.RateCard); rcEnt != nil { - ent := foundRateCards[j].Entitlement + ent := foundItem.Entitlement exists := ent != nil require.True(t, exists) entInp := ent.ToScheduleSubscriptionEntitlementInput() @@ -140,9 +140,9 @@ func TestCreation(t *testing.T) { assert.Equal(t, foundPhase.ActiveFrom(found.Subscription.CadencedModel), *ent.Entitlement.ActiveFrom) // Validate that the entitlement is only active until the phase is scheduled to be - if i < len(planPhases)-1 { - nextPhase := planPhases[i+1] - nextPhaseStart, _ := nextPhase.ToCreateSubscriptionPhasePlanInput().StartAfter.AddTo(foundSub.ActiveFrom) + if i < len(specPhases)-1 { + nextPhase := specPhases[i+1] + nextPhaseStart, _ := nextPhase.StartAfter.AddTo(foundSub.ActiveFrom) require.NotNil(t, ent.Entitlement.ActiveTo) assert.Equal(t, nextPhaseStart.UTC(), *ent.Entitlement.ActiveTo) } diff --git a/openmeter/subscription/service/update.go b/openmeter/subscription/service/update.go index 40a6ea9683..990c075d5f 100644 --- a/openmeter/subscription/service/update.go +++ b/openmeter/subscription/service/update.go @@ -40,9 +40,11 @@ func (s *service) Update(ctx context.Context, subscriptionID models.NamespacedID } // 1. Subscription Cadence has to match - _, err := s.SubscriptionRepo.SetEndOfCadence(ctx, view.Subscription.NamespacedID, newSpec.ActiveTo) - if err != nil { - return def, fmt.Errorf("failed to set end of cadence: %w", err) + if !view.Subscription.CadencedModel.Equal(models.CadencedModel{ActiveFrom: newSpec.ActiveFrom, ActiveTo: newSpec.ActiveTo}) { + _, err := s.SubscriptionRepo.SetEndOfCadence(ctx, view.Subscription.NamespacedID, newSpec.ActiveTo) + if err != nil { + return def, fmt.Errorf("failed to set end of cadence: %w", err) + } } // 2. Anything that's changed or was removed has to be updated @@ -195,7 +197,19 @@ func (s *service) Update(ctx context.Context, subscriptionID models.NamespacedID curr := currentItemView.Entitlement.ToScheduleSubscriptionEntitlementInput() if hasCurr && currentEntitlementIntact { // We can compare the two to see if it needs changing - if !curr.Equal(new) { + // We have to be careful of feature comparison, the current will have feature ID informatino while the new will not + currToCompare := curr + if err := new.EntitlementInputs.Validate(); err != nil { + return def, fmt.Errorf("failed to validate new entitlement input: %w", err) + } + + if new.EntitlementInputs.FeatureID == nil { + currToCompare.EntitlementInputs.FeatureID = nil + } else if curr.EntitlementInputs.FeatureKey == nil { + currToCompare.EntitlementInputs.FeatureKey = nil + } + + if !currToCompare.Equal(new) { // First we need to delete the old entitlement if err := s.EntitlementAdapter.DeleteByItemID(ctx, currentItemView.SubscriptionItem.NamespacedID); err != nil { return def, fmt.Errorf("failed to delete entitlement: %w", err) @@ -205,15 +219,13 @@ func (s *service) Update(ctx context.Context, subscriptionID models.NamespacedID } } - // Sanity check - if currentEntitlementIntact { - return def, fmt.Errorf("entitlement should not be intact") + if !currentEntitlementIntact { + // Now we can create the new entitlement + if _, err := s.EntitlementAdapter.ScheduleEntitlement(ctx, new); err != nil { + return def, fmt.Errorf("failed to create entitlement: %w", err) + } } - // Now we can create the new entitlement - if _, err := s.EntitlementAdapter.ScheduleEntitlement(ctx, new); err != nil { - return def, fmt.Errorf("failed to create entitlement: %w", err) - } } } } @@ -250,9 +262,7 @@ func (s *service) Update(ctx context.Context, subscriptionID models.NamespacedID return def, fmt.Errorf("item is nil") } - _, foundMatchingItemInCurrentView := lo.Find(matchingPhaseInCurrentView.Items, func(i subscription.SubscriptionItemView) bool { - return i.SubscriptionItem.Key == item.ItemKey - }) + _, foundMatchingItemInCurrentView := matchingPhaseInCurrentView.Items[item.ItemKey] if !foundMatchingItemInCurrentView { itemCadence, error := newSpec.GetPhaseCadence(phase.PhaseKey) diff --git a/openmeter/subscription/service/update_test.go b/openmeter/subscription/service/update_test.go index 05b6bae16a..60e4346cc2 100644 --- a/openmeter/subscription/service/update_test.go +++ b/openmeter/subscription/service/update_test.go @@ -1,420 +1,112 @@ package service_test import ( - "context" "testing" + "time" "github.com/stretchr/testify/require" + customerentity "github.com/openmeterio/openmeter/openmeter/customer/entity" "github.com/openmeterio/openmeter/openmeter/subscription" subscriptiontestutils "github.com/openmeterio/openmeter/openmeter/subscription/testutils" "github.com/openmeterio/openmeter/openmeter/testutils" "github.com/openmeterio/openmeter/pkg/clock" - "github.com/openmeterio/openmeter/pkg/currencyx" ) func TestEdit(t *testing.T) { - t.Run("Should edit subscription of ExamplePlan", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - currentTime := testutils.GetRFC3339Time(t, "2021-01-01T00:00:00Z") - clock.SetTime(currentTime) - - dbDeps := subscriptiontestutils.SetupDBDeps(t) - defer dbDeps.Cleanup() - - service, deps := subscriptiontestutils.NewService(t, dbDeps) - - deps.PlanAdapter.AddPlan(subscriptiontestutils.GetExamplePlan()) - cust := deps.CustomerAdapter.CreateExampleCustomer(t) - _ = deps.FeatureConnector.CreateExampleFeature(t) - - spec, err := subscription.NewSpecFromPlan(subscriptiontestutils.GetExamplePlan(), subscription.CreateSubscriptionCustomerInput{ - CustomerId: cust.ID, - Currency: "USD", - ActiveFrom: currentTime, + type TDeps struct { + CurrentTime time.Time + Customer customerentity.Customer + ExamplePlan subscription.Plan + ServiceDeps subscriptiontestutils.ExposedServiceDeps + Service subscription.Service + } + + tt := []struct { + Name string + Handler func(t *testing.T, deps TDeps) + }{ + { + Name: "Should error if plan changes", + Handler: func(t *testing.T, deps TDeps) { + t.Skip("TODO") + }, + }, + { + Name: "Should error if customer changes", + Handler: func(t *testing.T, deps TDeps) { + t.Skip("TODO") + }, + }, + { + Name: "Should error if subscription start changes", + Handler: func(t *testing.T, deps TDeps) { + t.Skip("TODO") + }, + }, + { + Name: "Should update contents of future phase when phase end changes", + Handler: func(t *testing.T, deps TDeps) { + t.Skip("TODO") + }, + }, + { + Name: "Should update contents of future phase when phase start changes", + Handler: func(t *testing.T, deps TDeps) { + t.Skip("TODO") + }, + }, + { + Name: "Should delete item from future phase", + Handler: func(t *testing.T, deps TDeps) { + t.Skip("TODO") + }, + }, + { + Name: "Should add item to future phase", + Handler: func(t *testing.T, deps TDeps) { + t.Skip("TODO") + }, + }, + { + Name: "Should update item entitlement", + Handler: func(t *testing.T, deps TDeps) { + t.Skip("TODO") + }, + }, + { + Name: "Should update contents of current phase", + Handler: func(t *testing.T, deps TDeps) { + t.Skip("TODO") + }, + }, + } + + for _, tc := range tt { + t.Run(tc.Name, func(t *testing.T) { + currentTime := testutils.GetRFC3339Time(t, "2021-01-01T00:00:00Z") + clock.SetTime(currentTime) + + dbDeps := subscriptiontestutils.SetupDBDeps(t) + defer dbDeps.Cleanup() + + services, deps := subscriptiontestutils.NewService(t, dbDeps) + service := services.Service + + cust := deps.CustomerAdapter.CreateExampleCustomer(t) + require.NotNil(t, cust) + + _ = deps.FeatureConnector.CreateExampleFeature(t) + examplePlan := subscriptiontestutils.GetExamplePlan() + deps.PlanAdapter.AddPlan(t, examplePlan) + + tc.Handler(t, TDeps{ + CurrentTime: currentTime, + Customer: *cust, + ExamplePlan: examplePlan, + ServiceDeps: deps, + Service: service, + }) }) - require.Nil(t, err) - - sub, err := service.Create(ctx, subscriptiontestutils.ExampleNamespace, spec) - - require.Nil(t, err) - require.Equal(t, subscriptiontestutils.ExamplePlanRef, sub.Plan) - require.Equal(t, subscriptiontestutils.ExampleNamespace, sub.Namespace) - require.Equal(t, cust.ID, sub.CustomerId) - require.Equal(t, currencyx.Code("USD"), sub.Currency) - - // t.Run("Should work fine if no patches were provided", func(t *testing.T) { - // _, err := service.Update(ctx, models.NamespacedID{ - // ID: sub.ID, - // Namespace: sub.Namespace, - // }, []subscription.Patch{}) - // require.Nil(t, err) - // }) - - // t.Run("Should add new items to an existing phase", func(t *testing.T) { - // // Let's assert that the base Subscription looks as we believe - // subView, err := service.GetView(ctx, models.NamespacedID{ - // ID: sub.ID, - // Namespace: sub.Namespace, - // }) - // require.Nil(t, err) - - // // Let's add items to the last phase - - // sub := subView.Subscription - // phases := subView.Phases - // phaseKey := phases[len(phases)-1].SubscriptionPhase.Key - - // // Let's create a new featue so we can add it to the Subscription - // feat, err := deps.FeatureConnector.CreateFeature(ctx, feature.CreateFeatureInputs{ - // Name: "New Feature", - // Key: "new-feature", - // Namespace: subscriptiontestutils.ExampleNamespace, - // MeterSlug: &subscriptiontestutils.ExampleFeatureMeterSlug, - // }) - // require.Nil(t, err) - - // _, err = service.Edit(ctx, models.NamespacedID{ - // ID: sub.ID, - // Namespace: sub.Namespace, - // }, []subscription.Patch{ - // // Let's add a new Item with a Price - // patch.PatchAddItem{ - // PhaseKey: phaseKey, - // ItemKey: "new-item-1", - // CreateInput: subscription.SubscriptionItemSpec{ - // CreateSubscriptionItemInput: subscription.CreateSubscriptionItemInput{ - // CreateSubscriptionItemPlanInput: subscription.CreateSubscriptionItemPlanInput{ - // PhaseKey: phaseKey, - // ItemKey: "new-item-1", - // RateCard: plan.NewRateCardFrom(plan.FlatFeeRateCard{ - // RateCardMeta: plan.RateCardMeta{ - // Key: "new-item-1", - // Name: "New Item 1", - // }, - // Price: plan.NewPriceFrom(plan.FlatPrice{ - // Amount: alpacadecimal.NewFromInt(int64(100)), - // PaymentTerm: plan.InAdvancePaymentTerm, - // }), - // }), - // }, - // CreateSubscriptionItemCustomerInput: subscription.CreateSubscriptionItemCustomerInput{}, - // }, - // }, - // }, - // // Let's add a new Item with an Entitlement - // patch.PatchAddItem{ - // PhaseKey: phaseKey, - // ItemKey: "new-item-2", - // CreateInput: subscription.SubscriptionItemSpec{ - - // CreateSubscriptionItemPlanInput: subscription.CreateSubscriptionItemPlanInput{ - // PhaseKey: phaseKey, - // ItemKey: "new-item-2", - // FeatureKey: &feat.Key, - // CreateEntitlementInput: &subscription.CreateSubscriptionEntitlementInput{ - // EntitlementType: entitlement.EntitlementTypeMetered, - // IssueAfterReset: lo.ToPtr(100.0), - // UsagePeriodISODuration: &oneMonthISO, - // }, - // }, - // }, - // }, - // }) - // require.Nil(t, err) - - // // Let's assert that the Subscription now has the new Items - // subView, err = service.Expand(ctx, models.NamespacedID{ - // ID: sub.ID, - // Namespace: sub.Namespace, - // }) - // require.Nil(t, err) - - // lastPhase := subView.Phases[len(phases)-1] - - // items := lastPhase.Items - // assert.True(t, lo.SomeBy(items, func(i subscription.SubscriptionItemView) bool { - // if i.SubscriptionItem.RateCard.Key() == "new-item-2" { - // e := i.Entitlement - // require.NotNil(t, e) - // assert.Equal(t, entitlement.EntitlementTypeMetered, e.Entitlement.EntitlementType) - // assert.Equal(t, lo.ToPtr(100.0), e.Entitlement.IssueAfterReset) - // assert.Equal(t, lastPhase.ActiveFrom(), e.Cadence.ActiveFrom) - // assert.Nil(t, e.Cadence.ActiveTo) - // return true - // } - // return false - // })) - // }) - - // t.Run("Should add new empty phase to end of subscription", func(t *testing.T) { - // // Let's assert that the base Subscription looks as we believe - // subView, err := service.Expand(ctx, models.NamespacedID{ - // ID: sub.ID, - // Namespace: sub.Namespace, - // }) - // require.Nil(t, err) - - // sub := subView.Subscription - // phases := subView.Phases - // oldPhaseCount := len(phases) - - // lastPhase := phases[oldPhaseCount-1] - - // newPhaseKey := "new-last-phase" - // newPhaseStartAfter, err := lastPhase.AsSpec().CreateSubscriptionPhaseInput.StartAfter.Add(datex.MustParse(t, "P1M")) - // require.Nil(t, err) - - // _, err = service.Edit(ctx, models.NamespacedID{ID: sub.ID, Namespace: sub.Namespace}, []subscription.Patch{ - // patch.PatchAddPhase{ - // PhaseKey: newPhaseKey, - // CreateInput: subscription.CreateSubscriptionPhaseInput{ - // CreateSubscriptionPhasePlanInput: subscription.CreateSubscriptionPhasePlanInput{ - // PhaseKey: newPhaseKey, - // StartAfter: newPhaseStartAfter, - // }, - // CreateSubscriptionPhaseCustomerInput: subscription.CreateSubscriptionPhaseCustomerInput{ - // CreateDiscountInput: nil, // TODO: implement - // }, - // Duration: datex.MustParse(t, "P3M"), - // }, - // }, - // }) - // require.Nil(t, err) - - // // Let's re-fetch the subscription - // // Let's assert that the Subscription now has the new Items - // subView, err = service.Expand(ctx, models.NamespacedID{ - // ID: sub.ID, - // Namespace: sub.Namespace, - // }) - // require.Nil(t, err) - - // require.Equal(t, oldPhaseCount+1, len(subView.Phases)) - - // newLastPhase := subView.Phases[oldPhaseCount] - // currentOldLastPhase := subView.Phases[oldPhaseCount-1] - - // require.Equal(t, lastPhase.SubscriptionPhase.Key, subView.Phases[oldPhaseCount-1].SubscriptionPhase.Key) - - // expectedActiveFrom, _ := newPhaseStartAfter.AddTo(sub.ActiveFrom) - - // // Should add new phase - // assert.Equal(t, newPhaseKey, newLastPhase.SubscriptionPhase.Key) - // assert.Equal(t, expectedActiveFrom, newLastPhase.ActiveFrom()) - - // // Should ignore duration when adding last phase - // // TODO: validate it gets ignored - - // // Should close entitlement and price of previous phase - // for _, item := range currentOldLastPhase.Items { - // if item.Entitlement != nil { - // ent := *item.Entitlement - // assert.NotNil(t, ent.Cadence.ActiveTo) - // assert.Equal(t, lo.ToPtr(newLastPhase.ActiveFrom()), ent.Cadence.ActiveTo) - // } - // } - // }) - - // t.Run("Should add new phase with item and delay subsequent phases in subscription", func(t *testing.T) { - // // Let's assert that the base Subscription looks as we believe - // subView, err := service.Expand(ctx, models.NamespacedID{ - // ID: sub.ID, - // Namespace: sub.Namespace, - // }) - // require.Nil(t, err) - - // sub := subView.Subscription - // phases := subView.Phases - // oldPhaseCount := len(phases) - - // require.Equal(t, 4, oldPhaseCount) // 3 original in ExamplePlan + 1 new phase added above - - // expected2ndPhaseStart, _ := datex.MustParse(t, "P2M").AddTo(sub.ActiveFrom) - // expected3rdPhaseStart, _ := datex.MustParse(t, "P6M").AddTo(sub.ActiveFrom) - - // // Lets assert these two phases start when we believe they do - // require.Equal(t, expected2ndPhaseStart, phases[1].ActiveFrom()) - // require.Equal(t, expected3rdPhaseStart, phases[2].ActiveFrom()) - - // // Let's add a new phase in between them - // newPhaseKey := "in-between-phase" - // newPhaseStartAfter := datex.MustParse(t, "P4M") - // newPhaseDuration := datex.MustParse(t, "P3M") - - // _, err = service.Edit(ctx, models.NamespacedID{ID: sub.ID, Namespace: sub.Namespace}, []subscription.Patch{ - // patch.PatchAddPhase{ - // PhaseKey: newPhaseKey, - // CreateInput: subscription.CreateSubscriptionPhaseInput{ - // CreateSubscriptionPhasePlanInput: subscription.CreateSubscriptionPhasePlanInput{ - // PhaseKey: newPhaseKey, - // StartAfter: newPhaseStartAfter, - // }, - // CreateSubscriptionPhaseCustomerInput: subscription.CreateSubscriptionPhaseCustomerInput{ - // CreateDiscountInput: nil, // TODO: implement - // }, - // Duration: newPhaseDuration, - // }, - // }, - // patch.PatchAddItem{ - // PhaseKey: newPhaseKey, - // ItemKey: "new-item-1", - // CreateInput: subscription.SubscriptionItemSpec{ - // CreateSubscriptionItemPlanInput: subscription.CreateSubscriptionItemPlanInput{ - // PhaseKey: newPhaseKey, - // ItemKey: "new-item-1", - // }, - // }, - // }, - // }) - // require.Nil(t, err) - // oldPhases := phases - - // // Let's re-fetch the subscription - // subView, err = service.Expand(ctx, models.NamespacedID{ - // ID: sub.ID, - // Namespace: sub.Namespace, - // }) - // require.Nil(t, err) - - // // Lets assert that the phase starts when specified - // newPhase := subView.Phases[2] - // require.Equal(t, newPhaseKey, newPhase.SubscriptionPhase.Key) - // expectedActiveFrom, _ := newPhaseStartAfter.AddTo(sub.ActiveFrom) - // require.Equal(t, expectedActiveFrom, newPhase.ActiveFrom()) - - // // Let's assert that it has a new Item with a price that starts and ends when expected - // items := newPhase.Items - // assert.Len(t, items, 1) - - // // Lets assert that the next and all subsequent phases were delayed by on month - // // 1 month, because it used to start after 6 months, no it starts after 4 + 3 = 7 months - // for i, phase := range subView.Phases { - // if i < 3 { - // continue - // } - - // var opv subscription.SubscriptionPhaseView - // for _, op := range oldPhases { - // if op.SubscriptionPhase.Key == phase.SubscriptionPhase.Key { - // opv = op - // break - // } - // } - // require.NotNil(t, opv) - - // expectedStart, _ := datex.MustParse(t, "P1M").AddTo(opv.ActiveFrom()) - // assert.Equal(t, expectedStart, phase.ActiveFrom()) - // } - - // // Instead of checking that each item has been drifted, we can just validate the view that everything aligns - // err = subView.Validate(true) - // require.Nil(t, err) - // }) - - // t.Run("Should remove phase and all items in it", func(t *testing.T) { - // // Let's assert that the base Subscription looks as we believe - // subView, err := service.Expand(ctx, models.NamespacedID{ - // ID: sub.ID, - // Namespace: sub.Namespace, - // }) - // require.Nil(t, err) - - // sub := subView.Subscription - // phases := subView.Phases - // oldPhaseCount := len(phases) - - // require.Equal(t, 5, oldPhaseCount) // 3 original in ExamplePlan + 2 new phase added above - // phaseToDelete := phases[2] - // oldPrevPhase := phases[1] - // oldNextPhase := phases[3] - // require.Equal(t, "in-between-phase", phaseToDelete.SubscriptionPhase.Key) - // require.Equal(t, 1, len(phaseToDelete.Items)) - - // // Let's delete the phase we added above - // _, err = service.Edit(ctx, models.NamespacedID{ID: sub.ID, Namespace: sub.Namespace}, []subscription.Patch{ - // patch.PatchRemovePhase{ - // PhaseKey: phaseToDelete.SubscriptionPhase.Key, - // RemoveInput: subscription.RemoveSubscriptionPhaseInput{ - // Shift: subscription.RemoveSubscriptionPhaseShiftPrev, - // }, - // }, - // }) - // require.Nil(t, err) - - // // Let's re-fetch the subscription - // subView, err = service.Expand(ctx, models.NamespacedID{ - // ID: sub.ID, - // Namespace: sub.Namespace, - // }) - // require.Nil(t, err) - - // // Lets assert that the phase is gone - // require.Equal(t, oldPhaseCount-1, len(subView.Phases)) - - // // Let's assert that the prev phase was extended to last until the end of the deleted phase - // prevPhase := subView.Phases[1] - // assert.Equal(t, oldPrevPhase.SubscriptionPhase.Key, prevPhase.SubscriptionPhase.Key) - // assert.Equal(t, oldPrevPhase.ActiveFrom(), prevPhase.ActiveFrom()) - // nextPhase := subView.Phases[2] - // assert.Equal(t, oldNextPhase.SubscriptionPhase.Key, nextPhase.SubscriptionPhase.Key) - // assert.Equal(t, oldNextPhase.ActiveFrom(), nextPhase.ActiveFrom()) - - // // Let's assert that the items of the prev phase were extended - // for _, item := range prevPhase.Items { - // if item.Entitlement != nil { - // ent := *item.Entitlement - // assert.Equal(t, lo.ToPtr(nextPhase.ActiveFrom()), ent.Cadence.ActiveTo) - // } - // } - // }) - - // t.Run("Let's remove last phase", func(t *testing.T) { - // // Let's assert that the base Subscription looks as we believe - // subView, err := service.Expand(ctx, models.NamespacedID{ - // ID: sub.ID, - // Namespace: sub.Namespace, - // }) - // require.Nil(t, err) - - // sub := subView.Subscription - // phases := subView.Phases - // oldPhaseCount := len(phases) - - // require.Equal(t, 4, oldPhaseCount) // 3 original in ExamplePlan + 1 new phase added above - // require.Equal(t, "new-last-phase", phases[3].SubscriptionPhase.Key) - - // // Let's delete the last phase - // _, err = service.Edit(ctx, models.NamespacedID{ID: sub.ID, Namespace: sub.Namespace}, []subscription.Patch{ - // patch.PatchRemovePhase{ - // PhaseKey: "new-last-phase", - // RemoveInput: subscription.RemoveSubscriptionPhaseInput{ - // Shift: subscription.RemoveSubscriptionPhaseShiftPrev, - // }, - // }, - // }) - // require.Nil(t, err) - - // // Let's re-fetch the subscription - // subView, err = service.Expand(ctx, models.NamespacedID{ - // ID: sub.ID, - // Namespace: sub.Namespace, - // }) - // require.Nil(t, err) - - // // Lets assert that the phase is gone - // require.Equal(t, oldPhaseCount-1, len(subView.Phases)) - - // // Let's assert that the items of the prev phase were extended - // prevPhase := subView.Phases[2] - // for _, item := range prevPhase.Items { - // if item.Entitlement != nil { - // ent := *item.Entitlement - // assert.Nil(t, ent.Cadence.ActiveTo) - // } - // } - // }) - }) + } } diff --git a/openmeter/subscription/service/workflowservice.go b/openmeter/subscription/service/workflowservice.go new file mode 100644 index 0000000000..4ca5b2f264 --- /dev/null +++ b/openmeter/subscription/service/workflowservice.go @@ -0,0 +1,132 @@ +package service + +import ( + "context" + "fmt" + + "github.com/samber/lo" + + "github.com/openmeterio/openmeter/openmeter/customer" + customerentity "github.com/openmeterio/openmeter/openmeter/customer/entity" + "github.com/openmeterio/openmeter/openmeter/subscription" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/framework/transaction" + "github.com/openmeterio/openmeter/pkg/models" +) + +type WorkflowServiceConfig struct { + Service subscription.Service + // connectors + CustomerService customer.Service + // adapters + PlanAdapter subscription.PlanAdapter + // framework + TransactionManager transaction.Creator +} + +type workflowService struct { + WorkflowServiceConfig +} + +func NewWorkflowService(cfg WorkflowServiceConfig) subscription.WorkflowService { + return &workflowService{ + WorkflowServiceConfig: cfg, + } +} + +var _ subscription.WorkflowService = &workflowService{} + +func (s *workflowService) CreateFromPlan(ctx context.Context, inp subscription.CreateFromPlanInput) (subscription.SubscriptionView, error) { + var def subscription.SubscriptionView + + // Let's validate the customer exists + cust, err := s.CustomerService.GetCustomer(ctx, customerentity.GetCustomerInput{ + Namespace: inp.Namespace, + ID: inp.CustomerID, + }) + if err != nil { + return def, fmt.Errorf("failed to fetch customer: %w", err) + } + + if cust == nil { + return def, fmt.Errorf("unexpected nil customer") + } + + // Let's validate the plan exists + plan, err := s.PlanAdapter.GetVersion(ctx, inp.Plan.Key, inp.Plan.Version) + if err != nil { + return def, fmt.Errorf("failed to fetch plan: %w", err) + } + + // Let's validate the patches + for i, patch := range inp.Customization { + if err := patch.Validate(); err != nil { + return def, fmt.Errorf("invalid patch at index %d: %w", i, err) + } + } + + // Let's create the new Spec + spec, err := subscription.NewSpecFromPlan(plan, subscription.CreateSubscriptionCustomerInput{ + CustomerId: cust.ID, + Currency: inp.Currency, + ActiveFrom: inp.ActiveFrom, + }) + if err != nil { + return def, fmt.Errorf("failed to create spec from plan: %w", err) + } + + // Let's apply the customizations + err = spec.ApplyPatches(lo.Map(inp.Customization, subscription.ToApplies), subscription.ApplyContext{ + Operation: subscription.SpecOperationCreate, + CurrentTime: clock.Now(), + }) + if err != nil { + return def, fmt.Errorf("failed to apply customizations: %w", err) + } + + // Finally, let's create the subscription + return transaction.Run(ctx, s.TransactionManager, func(ctx context.Context) (subscription.SubscriptionView, error) { + sub, err := s.Service.Create(ctx, inp.Namespace, spec) + if err != nil { + return def, fmt.Errorf("failed to create subscription: %w", err) + } + + return s.Service.GetView(ctx, sub.NamespacedID) + }) +} + +func (s *workflowService) EditRunning(ctx context.Context, subscriptionID models.NamespacedID, customizations []subscription.Patch) (subscription.SubscriptionView, error) { + // First, let's fetch the current state of the Subscription + curr, err := s.Service.GetView(ctx, subscriptionID) + if err != nil { + return subscription.SubscriptionView{}, fmt.Errorf("failed to fetch subscription: %w", err) + } + + // Let's validate the patches + for i, patch := range customizations { + if err := patch.Validate(); err != nil { + return subscription.SubscriptionView{}, fmt.Errorf("invalid patch at index %d: %w", i, err) + } + } + + // Let's apply the customizations + spec := curr.AsSpec() + + err = spec.ApplyPatches(lo.Map(customizations, subscription.ToApplies), subscription.ApplyContext{ + Operation: subscription.SpecOperationEdit, + CurrentTime: clock.Now(), + }) + if err != nil { + return subscription.SubscriptionView{}, fmt.Errorf("failed to apply customizations: %w", err) + } + + // Finally, let's update the subscription + return transaction.Run(ctx, s.TransactionManager, func(ctx context.Context) (subscription.SubscriptionView, error) { + sub, err := s.Service.Update(ctx, subscriptionID, spec) + if err != nil { + return subscription.SubscriptionView{}, fmt.Errorf("failed to update subscription: %w", err) + } + + return s.Service.GetView(ctx, sub.NamespacedID) + }) +} diff --git a/openmeter/subscription/service/workflowservice_test.go b/openmeter/subscription/service/workflowservice_test.go new file mode 100644 index 0000000000..77c4d5c4e5 --- /dev/null +++ b/openmeter/subscription/service/workflowservice_test.go @@ -0,0 +1,497 @@ +package service_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + customerentity "github.com/openmeterio/openmeter/openmeter/customer/entity" + "github.com/openmeterio/openmeter/openmeter/subscription" + "github.com/openmeterio/openmeter/openmeter/subscription/service" + subscriptiontestutils "github.com/openmeterio/openmeter/openmeter/subscription/testutils" + "github.com/openmeterio/openmeter/openmeter/testutils" + "github.com/openmeterio/openmeter/pkg/clock" + "github.com/openmeterio/openmeter/pkg/models" +) + +func TestCreateFromPlan(t *testing.T) { + type testCaseDeps struct { + CurrentTime time.Time + Customer customerentity.Customer + WorkflowService subscription.WorkflowService + DBDeps *subscriptiontestutils.DBDeps + } + + testCases := []struct { + Name string + Handler func(t *testing.T, deps testCaseDeps) + }{ + { + Name: "Should error if customer is not found", + Handler: func(t *testing.T, deps testCaseDeps) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err := deps.WorkflowService.CreateFromPlan(ctx, subscription.CreateFromPlanInput{ + CustomerID: fmt.Sprintf("nonexistent-customer-%s", deps.Customer.ID), + Namespace: subscriptiontestutils.ExampleNamespace, + ActiveFrom: deps.CurrentTime, + Currency: "USD", + Plan: subscriptiontestutils.ExamplePlanRef, + }) + + assert.ErrorAs(t, err, &customerentity.NotFoundError{}, "expected customer not found error, got %T", err) + }, + }, + { + Name: "Should error if plan is not found", + Handler: func(t *testing.T, deps testCaseDeps) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err := deps.WorkflowService.CreateFromPlan(ctx, subscription.CreateFromPlanInput{ + CustomerID: deps.Customer.ID, + Namespace: subscriptiontestutils.ExampleNamespace, + ActiveFrom: deps.CurrentTime, + Currency: "USD", + Plan: subscription.PlanRef{Key: "nonexistent-plan", Version: 1}, + }) + + // assert.ErrorAs does not recognize this error + _, isErr := lo.ErrorsAs[*subscription.PlanNotFoundError](err) + assert.True(t, isErr, "expected plan not found error, got %T", err) + }, + }, + { + Name: "Should error if a patch is invalid", + Handler: func(t *testing.T, deps testCaseDeps) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + errorMsg := "this is an invalid patch" + + invalidPatch := subscriptiontestutils.TestPatch{ + ValdiateFn: func() error { + return fmt.Errorf(errorMsg) + }, + } + + _, err := deps.WorkflowService.CreateFromPlan(ctx, subscription.CreateFromPlanInput{ + CustomerID: deps.Customer.ID, + Namespace: subscriptiontestutils.ExampleNamespace, + ActiveFrom: deps.CurrentTime, + Currency: "USD", + Plan: subscriptiontestutils.ExamplePlanRef, + Customization: []subscription.Patch{&invalidPatch}, + }) + + assert.ErrorContains(t, err, errorMsg, "expected error message to contain %q, got %v", errorMsg, err) + }, + }, + { + Name: "Should apply the patch to the specs based on the plan", + Handler: func(t *testing.T, deps testCaseDeps) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // As we assert for the contextual time, we need to freeze the clock + clock.FreezeTime(deps.CurrentTime) + defer clock.UnFreeze() + + errMsg := "this custom patch apply failed" + + expectedPlanSpec, err := subscription.NewSpecFromPlan(subscriptiontestutils.GetExamplePlan(), subscription.CreateSubscriptionCustomerInput{ + CustomerId: deps.Customer.ID, + Currency: "USD", + ActiveFrom: deps.CurrentTime, + }) + require.Nil(t, err) + + patch1 := subscriptiontestutils.TestPatch{ + ApplyToFn: func(spec *subscription.SubscriptionSpec, c subscription.ApplyContext) error { + // Let's assert that the correct spec is passed to the patch + assert.Equal(t, &expectedPlanSpec, spec, "expected spec to be equal to the plan spec") + assert.Equal(t, subscription.ApplyContext{ + CurrentTime: deps.CurrentTime, + Operation: subscription.SpecOperationCreate, + }, c, "apply context is incorrect") + + // Lets modify the spec to see if its passed to the next + spec.Plan.Key = "modified-plan" + + return nil + }, + } + + patch2 := subscriptiontestutils.TestPatch{ + ApplyToFn: func(spec *subscription.SubscriptionSpec, c subscription.ApplyContext) error { + // Let's see if the modification is passed along + assert.Equal(t, "modified-plan", spec.Plan.Key, "expected plan key to be modified") + + return nil + }, + } + + patch3 := subscriptiontestutils.TestPatch{ + ApplyToFn: func(spec *subscription.SubscriptionSpec, c subscription.ApplyContext) error { + // And let's test if errors are passed correctly + return fmt.Errorf(errMsg) + }, + } + + _, err = deps.WorkflowService.CreateFromPlan(ctx, subscription.CreateFromPlanInput{ + CustomerID: deps.Customer.ID, + Namespace: subscriptiontestutils.ExampleNamespace, + ActiveFrom: deps.CurrentTime, + Currency: "USD", + Plan: subscriptiontestutils.ExamplePlanRef, + Customization: []subscription.Patch{ + &patch1, + &patch2, + &patch3, + }, + }) + + // Let's validate the error is surfaced + assert.ErrorContains(t, err, errMsg, "expected error message to contain %q, got %v", errMsg, err) + }, + }, + { + Name: "Should use the output of patches without modifications", + Handler: func(t *testing.T, deps testCaseDeps) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + returnedSpec := subscription.SubscriptionSpec{ + CreateSubscriptionPlanInput: subscription.CreateSubscriptionPlanInput{ + Plan: subscription.PlanRef{ + Key: "returned-plan", + Version: 1, + }, + }, + CreateSubscriptionCustomerInput: subscription.CreateSubscriptionCustomerInput{ + CustomerId: "new-customer-id", + }, + Phases: map[string]*subscription.SubscriptionPhaseSpec{ + "phase-1": { + CreateSubscriptionPhasePlanInput: subscription.CreateSubscriptionPhasePlanInput{ + PhaseKey: "phase-1", + StartAfter: testutils.GetISODuration(t, "P1D"), + Name: "Phase 1", + }, + Items: map[string]*subscription.SubscriptionItemSpec{ + "item-1": { + CreateSubscriptionItemInput: subscription.CreateSubscriptionItemInput{ + CreateSubscriptionItemPlanInput: subscription.CreateSubscriptionItemPlanInput{ + ItemKey: "item-1", + PhaseKey: "phase-1", + }, + }, + }, + }, + }, + }, + } + + patch1 := subscriptiontestutils.TestPatch{ + ApplyToFn: func(spec *subscription.SubscriptionSpec, c subscription.ApplyContext) error { + // Let's set the new values + + spec.CreateSubscriptionPlanInput = returnedSpec.CreateSubscriptionPlanInput + spec.CreateSubscriptionCustomerInput = returnedSpec.CreateSubscriptionCustomerInput + spec.Phases = returnedSpec.Phases + + return nil + }, + } + + sID := models.NamespacedID{ + Namespace: subscriptiontestutils.ExampleNamespace, + ID: "new-subscription-id", + } + + rView := subscription.SubscriptionView{ + Subscription: subscription.Subscription{ + CustomerId: "bogus-id", + }, + } + + mSvc := subscriptiontestutils.MockService{ + CreateFn: func(ctx context.Context, namespace string, spec subscription.SubscriptionSpec) (subscription.Subscription, error) { + // Let's validate that the spec is passed as is + assert.Equal(t, returnedSpec, spec, "expected spec to be equal to the returned spec") + + return subscription.Subscription{ + NamespacedID: sID, + }, nil + }, + GetViewFn: func(ctx context.Context, id models.NamespacedID) (subscription.SubscriptionView, error) { + assert.Equal(t, sID, id, "expected id to be equal to the returned id") + + return rView, nil + }, + } + + _, tuDeps := subscriptiontestutils.NewService(t, deps.DBDeps) + tuDeps.PlanAdapter.AddPlan(t, subscriptiontestutils.GetExamplePlan()) + + workflowService := service.NewWorkflowService(service.WorkflowServiceConfig{ + Service: &mSvc, + CustomerService: tuDeps.CustomerService, + PlanAdapter: tuDeps.PlanAdapter, + TransactionManager: tuDeps.CustomerAdapter, + }) + + res, err := workflowService.CreateFromPlan(ctx, subscription.CreateFromPlanInput{ + CustomerID: deps.Customer.ID, + Namespace: subscriptiontestutils.ExampleNamespace, + ActiveFrom: deps.CurrentTime, + Currency: "USD", + Plan: subscriptiontestutils.ExamplePlanRef, + Customization: []subscription.Patch{ + &patch1, + }, + }) + + assert.Nil(t, err) + + assert.Equal(t, rView, res) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + tcDeps := testCaseDeps{ + CurrentTime: testutils.GetRFC3339Time(t, "2021-01-01T00:00:00Z"), + } + + clock.SetTime(tcDeps.CurrentTime) + dbDeps := subscriptiontestutils.SetupDBDeps(t) + require.NotNil(t, dbDeps) + defer dbDeps.Cleanup() + + services, deps := subscriptiontestutils.NewService(t, dbDeps) + deps.PlanAdapter.AddPlan(t, subscriptiontestutils.GetExamplePlan()) + cust := deps.CustomerAdapter.CreateExampleCustomer(t) + require.NotNil(t, cust) + + tcDeps.Customer = *cust + + tcDeps.DBDeps = dbDeps + + tcDeps.WorkflowService = services.WorkflowService + + tc.Handler(t, tcDeps) + }) + } +} + +func TestEditRunning(t *testing.T) { + type testCaseDeps struct { + CurrentTime time.Time + SubView subscription.SubscriptionView + Customer customerentity.Customer + WorkflowService subscription.WorkflowService + Service subscription.Service + DBDeps *subscriptiontestutils.DBDeps + } + + testCases := []struct { + Name string + Handler func(t *testing.T, deps testCaseDeps) + }{ + { + Name: "Should error if subscription is not found", + Handler: func(t *testing.T, deps testCaseDeps) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err := deps.WorkflowService.EditRunning(ctx, models.NamespacedID{ + ID: "nonexistent-subscription", + Namespace: subscriptiontestutils.ExampleNamespace, + }, nil) + + assert.ErrorAs(t, err, lo.ToPtr(&subscription.NotFoundError{}), "expected subscription not found error, got %T", err) + }, + }, + { + Name: "Should do nothing if no patches are provided", + Handler: func(t *testing.T, deps testCaseDeps) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + subView, err := deps.WorkflowService.EditRunning(ctx, deps.SubView.Subscription.NamespacedID, nil) + assert.Nil(t, err) + + assert.Equal(t, deps.SubView, subView) + }, + }, + { + Name: "Should validate the provided patches", + Handler: func(t *testing.T, deps testCaseDeps) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errorMsg := "this is an invalid patch" + + invalidPatch := subscriptiontestutils.TestPatch{ + ValdiateFn: func() error { + return fmt.Errorf(errorMsg) + }, + } + + _, err := deps.WorkflowService.EditRunning(ctx, deps.SubView.Subscription.NamespacedID, []subscription.Patch{&invalidPatch}) + assert.ErrorContains(t, err, errorMsg, "expected error message to contain %q, got %v", errorMsg, err) + }, + }, + { + Name: "Should apply the customizations on the current spec", + Handler: func(t *testing.T, deps testCaseDeps) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Let's increment time by 1 second + deps.CurrentTime = deps.CurrentTime.Add(time.Second) + + // We have to freeze time here for the assertions + clock.FreezeTime(deps.CurrentTime) + defer clock.UnFreeze() + + errMSg := "this custom patch apply failed" + + patch1 := subscriptiontestutils.TestPatch{ + ApplyToFn: func(spec *subscription.SubscriptionSpec, c subscription.ApplyContext) error { + // Let's assert that the correct spec is passed to the patch + assert.Equal(t, lo.ToPtr(deps.SubView.AsSpec()), spec, "expected spec to be equal to the current spec") + assert.Equal(t, subscription.ApplyContext{ + CurrentTime: deps.CurrentTime, + Operation: subscription.SpecOperationEdit, + }, c, "apply context is incorrect") + + // Lets modify the spec to see if its passed to the next + spec.Plan.Key = "modified-plan" + + return nil + }, + } + + patch2 := subscriptiontestutils.TestPatch{ + ApplyToFn: func(spec *subscription.SubscriptionSpec, c subscription.ApplyContext) error { + // Let's see if the modification is passed along + assert.Equal(t, "modified-plan", spec.Plan.Key, "expected plan key to be modified") + + // Let's return an error to see if it is surfaced + return fmt.Errorf(errMSg) + }, + } + + _, err := deps.WorkflowService.EditRunning( + ctx, + deps.SubView.Subscription.NamespacedID, + []subscription.Patch{&patch1, &patch2}, + ) + assert.ErrorContains(t, err, errMSg, "expected error message to contain %q, got %v", errMSg, err) + }, + }, + { + Name: "Should use the output of patches without modifications", + Handler: func(t *testing.T, deps testCaseDeps) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + returnedSpec := deps.SubView.AsSpec() + + pKey := "test-phase-1" + + require.Contains(t, returnedSpec.Phases, pKey, "expected %s to be present in the starting spec", pKey) + returnedSpec.Phases[pKey].Name = "New Phase 1 Name" + returnedSpec.Phases[pKey].Description = lo.ToPtr("This is a new description") + + patch1 := subscriptiontestutils.TestPatch{ + ApplyToFn: func(spec *subscription.SubscriptionSpec, c subscription.ApplyContext) error { + // Let's set the new values + + spec.CreateSubscriptionPlanInput = returnedSpec.CreateSubscriptionPlanInput + spec.CreateSubscriptionCustomerInput = returnedSpec.CreateSubscriptionCustomerInput + spec.Phases = returnedSpec.Phases + + return nil + }, + } + + sID := deps.SubView.Subscription.NamespacedID + + mSvc := subscriptiontestutils.MockService{ + UpdateFn: func(ctx context.Context, id models.NamespacedID, spec subscription.SubscriptionSpec) (subscription.Subscription, error) { + // Let's validate that the spec is passed as is + assert.Equal(t, returnedSpec, spec, "expected spec to be equal to the returned spec") + + return deps.Service.Update(ctx, id, spec) + }, + GetViewFn: func(ctx context.Context, id models.NamespacedID) (subscription.SubscriptionView, error) { + assert.Equal(t, sID, id, "expected id to be equal to the returned id") + + return deps.Service.GetView(ctx, id) + }, + } + + _, tuDeps := subscriptiontestutils.NewService(t, deps.DBDeps) + tuDeps.PlanAdapter.AddPlan(t, subscriptiontestutils.GetExamplePlan()) + + workflowService := service.NewWorkflowService(service.WorkflowServiceConfig{ + Service: &mSvc, + CustomerService: tuDeps.CustomerService, + PlanAdapter: tuDeps.PlanAdapter, + TransactionManager: tuDeps.CustomerAdapter, + }) + + _, err := workflowService.EditRunning(ctx, sID, []subscription.Patch{&patch1}) + assert.Nil(t, err) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + tcDeps := testCaseDeps{ + CurrentTime: testutils.GetRFC3339Time(t, "2021-01-01T00:00:00Z"), + } + + clock.SetTime(tcDeps.CurrentTime) + + // Let's build the dependencies + dbDeps := subscriptiontestutils.SetupDBDeps(t) + require.NotNil(t, dbDeps) + defer dbDeps.Cleanup() + + services, deps := subscriptiontestutils.NewService(t, dbDeps) + deps.PlanAdapter.AddPlan(t, subscriptiontestutils.GetExamplePlan()) + deps.FeatureConnector.CreateExampleFeature(t) + cust := deps.CustomerAdapter.CreateExampleCustomer(t) + require.NotNil(t, cust) + + // Let's create an example subscription + sub, err := services.WorkflowService.CreateFromPlan(context.Background(), subscription.CreateFromPlanInput{ + CustomerID: cust.ID, + Namespace: subscriptiontestutils.ExampleNamespace, + ActiveFrom: tcDeps.CurrentTime, + Currency: "USD", + Plan: subscriptiontestutils.ExamplePlanRef, + }) + require.Nil(t, err) + + tcDeps.SubView = sub + tcDeps.Customer = *cust + tcDeps.DBDeps = dbDeps + tcDeps.Service = services.Service + tcDeps.WorkflowService = services.WorkflowService + + tc.Handler(t, tcDeps) + }) + } +} diff --git a/openmeter/subscription/subscriptionspec.go b/openmeter/subscription/subscriptionspec.go index 6c67a0ef3d..4e38167cb4 100644 --- a/openmeter/subscription/subscriptionspec.go +++ b/openmeter/subscription/subscriptionspec.go @@ -152,6 +152,16 @@ type CreateSubscriptionPhasePlanInput struct { // CreateDiscountInput *applieddiscount.CreateInput } +func (i CreateSubscriptionPhasePlanInput) Validate() error { + if i.PhaseKey == "" { + return fmt.Errorf("phase key is required") + } + if i.Name == "" { + return fmt.Errorf("name is required") + } + return nil +} + type CreateSubscriptionPhaseCustomerInput struct{} type RemoveSubscriptionPhaseShifting int @@ -179,6 +189,14 @@ type CreateSubscriptionPhaseInput struct { CreateSubscriptionPhaseCustomerInput } +func (i CreateSubscriptionPhaseInput) Validate() error { + if err := i.CreateSubscriptionPhasePlanInput.Validate(); err != nil { + return err + } + + return nil +} + type SubscriptionPhaseSpec struct { // Duration is not part of the Spec by design CreateSubscriptionPhasePlanInput @@ -327,6 +345,12 @@ func (s SubscriptionItemSpec) ToScheduleSubscriptionEntitlementInput( return def, true, fmt.Errorf("failed to get recurrence from ISO duration: %w", err) } scheduleInput.UsagePeriod = lo.ToPtr(entitlement.UsagePeriod(rec)) + mu := &entitlement.MeasureUsageFromInput{} + err = mu.FromTime(cadence.ActiveFrom) + if err != nil { + return def, true, fmt.Errorf("failed to get measure usage from time: %w", err) + } + scheduleInput.MeasureUsageFrom = mu default: return def, true, fmt.Errorf("unsupported entitlement type %s", t) } diff --git a/openmeter/subscription/subscriptionview.go b/openmeter/subscription/subscriptionview.go index 745d947bf0..ce4ff99d8b 100644 --- a/openmeter/subscription/subscriptionview.go +++ b/openmeter/subscription/subscriptionview.go @@ -54,7 +54,7 @@ func (s *SubscriptionView) Validate(includePhases bool) error { type SubscriptionPhaseView struct { SubscriptionPhase SubscriptionPhase Spec SubscriptionPhaseSpec - Items []SubscriptionItemView + Items map[string]SubscriptionItemView } func (s *SubscriptionPhaseView) ActiveFrom(subscriptionCadence models.CadencedModel) time.Time { @@ -211,7 +211,7 @@ func NewSubscriptionView( SubscriptionPhase: phase, } - itemViews := make([]SubscriptionItemView, 0, len(phaseSpec.Items)) + itemViews := make(map[string]SubscriptionItemView, len(phaseSpec.Items)) for _, itemSpec := range phaseSpec.Items { if itemSpec == nil { return nil, fmt.Errorf("item spec is nil") @@ -238,7 +238,11 @@ func NewSubscriptionView( Entitlement: subEnt, } - itemViews = append(itemViews, itemView) + if _, ok := itemViews[item.Key]; ok { + return nil, fmt.Errorf("item %s is duplicated", item.Key) + } + + itemViews[item.Key] = itemView } phaseView.Items = itemViews diff --git a/openmeter/subscription/testutils/patch.go b/openmeter/subscription/testutils/patch.go new file mode 100644 index 0000000000..dff1a15393 --- /dev/null +++ b/openmeter/subscription/testutils/patch.go @@ -0,0 +1,55 @@ +package subscriptiontestutils + +import "github.com/openmeterio/openmeter/openmeter/subscription" + +type TestPatch struct { + PatchValue any + PatchOperation subscription.PatchOperation + PatchPath subscription.PatchPath + + ApplyToFn func(s *subscription.SubscriptionSpec, c subscription.ApplyContext) error + ValdiateFn func() error +} + +var ( + _ subscription.Patch = &TestPatch{} + _ subscription.ValuePatch[any] = &TestPatch{} +) + +func (p *TestPatch) ApplyTo(s *subscription.SubscriptionSpec, c subscription.ApplyContext) error { + if p.ApplyToFn != nil { + return p.ApplyToFn(s, c) + } + return nil +} + +func (p *TestPatch) Op() subscription.PatchOperation { + return p.PatchOperation +} + +func (p *TestPatch) Path() subscription.PatchPath { + return p.PatchPath +} + +func (p *TestPatch) Validate() error { + if p.ValdiateFn != nil { + return p.ValdiateFn() + } + return nil +} + +func (p *TestPatch) MarshalJSON() ([]byte, error) { + panic("not implemented") +} + +func (p *TestPatch) UnmarshalJSON(data []byte) error { + panic("not implemented") +} + +func (p *TestPatch) Value() any { + return p.PatchValue +} + +func (p *TestPatch) ValueAsAny() any { + return p.PatchValue +} diff --git a/openmeter/subscription/testutils/plan.go b/openmeter/subscription/testutils/plan.go index 5da482f3ab..47d7fe4d67 100644 --- a/openmeter/subscription/testutils/plan.go +++ b/openmeter/subscription/testutils/plan.go @@ -117,7 +117,9 @@ func (a *planAdapter) GetVersion(ctx context.Context, k string, v int) (subscrip return version, nil } -func (a *planAdapter) AddPlan(plan *Plan) { +func (a *planAdapter) AddPlan(t *testing.T, plan *Plan) { + t.Helper() + if a.store == nil { a.store = make(map[string]map[int]*Plan) } @@ -129,7 +131,9 @@ func (a *planAdapter) AddPlan(plan *Plan) { a.store[plan.PlanInput.Plan.Key][plan.PlanInput.Plan.Version] = plan } -func (a *planAdapter) RemovePlan(ref subscription.PlanRef) { +func (a *planAdapter) RemovePlan(t *testing.T, ref subscription.PlanRef) { + t.Helper() + if _, ok := a.store[ref.Key]; !ok { return } diff --git a/openmeter/subscription/testutils/service.go b/openmeter/subscription/testutils/service.go index eef5ca50b5..4324a5576d 100644 --- a/openmeter/subscription/testutils/service.go +++ b/openmeter/subscription/testutils/service.go @@ -1,8 +1,11 @@ package subscriptiontestutils import ( + "context" "testing" + "time" + "github.com/openmeterio/openmeter/openmeter/customer" "github.com/openmeterio/openmeter/openmeter/meter" registrybuilder "github.com/openmeterio/openmeter/openmeter/registry/builder" streamingtestutils "github.com/openmeterio/openmeter/openmeter/streaming/testutils" @@ -14,14 +17,22 @@ import ( "github.com/openmeterio/openmeter/pkg/models" ) -type deps struct { - PlanAdapter *planAdapter +// TODO: we could use wire for this +type ExposedServiceDeps struct { CustomerAdapter *testCustomerRepo + CustomerService customer.Service FeatureConnector *testFeatureConnector EntitlementAdapter subscription.EntitlementAdapter + PlanAdapter *planAdapter + DBDeps *DBDeps +} + +type services struct { + Service subscription.Service + WorkflowService subscription.WorkflowService } -func NewService(t *testing.T, dbDeps *DBDeps) (subscription.Service, *deps) { +func NewService(t *testing.T, dbDeps *DBDeps) (services, ExposedServiceDeps) { t.Helper() logger := testutils.NewLogger(t) subRepo := NewSubscriptionRepo(t, dbDeps) @@ -49,25 +60,66 @@ func NewService(t *testing.T, dbDeps *DBDeps) (subscription.Service, *deps) { subscriptionEntitlementRepo, ) - planAdapter := NewMockPlanAdapter(t) - customerAdapter := NewCustomerAdapter(t, dbDeps) customer := NewCustomerService(t, dbDeps) - service := service.New(service.ServiceConfig{ + planAdapter := NewMockPlanAdapter(t) + + svc := service.New(service.ServiceConfig{ SubscriptionRepo: subRepo, SubscriptionPhaseRepo: subPhaseRepo, SubscriptionItemRepo: subItemRepo, CustomerService: customer, - PlanAdapter: planAdapter, EntitlementAdapter: entitlementAdapter, TransactionManager: subscriptionEntitlementRepo, }) - return service, &deps{ + workflowSvc := service.NewWorkflowService(service.WorkflowServiceConfig{ + Service: svc, + CustomerService: customer, PlanAdapter: planAdapter, - CustomerAdapter: customerAdapter, - FeatureConnector: NewTestFeatureConnector(entitlementRegistry.Feature), - EntitlementAdapter: entitlementAdapter, - } + TransactionManager: subscriptionEntitlementRepo, + }) + + return services{ + Service: svc, + WorkflowService: workflowSvc, + }, ExposedServiceDeps{ + CustomerAdapter: customerAdapter, + CustomerService: customer, + FeatureConnector: NewTestFeatureConnector(entitlementRegistry.Feature), + EntitlementAdapter: entitlementAdapter, + PlanAdapter: planAdapter, + DBDeps: dbDeps, + } +} + +type MockService struct { + CreateFn func(ctx context.Context, namespace string, spec subscription.SubscriptionSpec) (subscription.Subscription, error) + UpdateFn func(ctx context.Context, subscriptionID models.NamespacedID, target subscription.SubscriptionSpec) (subscription.Subscription, error) + CancelFn func(ctx context.Context, subscriptionID string, at time.Time) (subscription.Subscription, error) + GetFn func(ctx context.Context, subscriptionID models.NamespacedID) (subscription.Subscription, error) + GetViewFn func(ctx context.Context, subscriptionID models.NamespacedID) (subscription.SubscriptionView, error) +} + +var _ subscription.Service = &MockService{} + +func (s *MockService) Create(ctx context.Context, namespace string, spec subscription.SubscriptionSpec) (subscription.Subscription, error) { + return s.CreateFn(ctx, namespace, spec) +} + +func (s *MockService) Update(ctx context.Context, subscriptionID models.NamespacedID, target subscription.SubscriptionSpec) (subscription.Subscription, error) { + return s.UpdateFn(ctx, subscriptionID, target) +} + +func (s *MockService) Cancel(ctx context.Context, subscriptionID string, at time.Time) (subscription.Subscription, error) { + return s.CancelFn(ctx, subscriptionID, at) +} + +func (s *MockService) Get(ctx context.Context, subscriptionID models.NamespacedID) (subscription.Subscription, error) { + return s.GetFn(ctx, subscriptionID) +} + +func (s *MockService) GetView(ctx context.Context, subscriptionID models.NamespacedID) (subscription.SubscriptionView, error) { + return s.GetViewFn(ctx, subscriptionID) } diff --git a/pkg/clock/clock.go b/pkg/clock/clock.go index 225e2c8535..b0b2814fef 100644 --- a/pkg/clock/clock.go +++ b/pkg/clock/clock.go @@ -13,7 +13,7 @@ var ( func Now() time.Time { if atomic.LoadInt32(&frozen) == 1 { - return frozenTime.Load().(time.Time) + return frozenTime.Load().(time.Time).Round(0) } driftDuration := time.Duration(atomic.LoadInt64(&drift)) t := time.Now().Add(-driftDuration) diff --git a/pkg/models/model.go b/pkg/models/model.go index 15f640c5a6..592a8d47f0 100644 --- a/pkg/models/model.go +++ b/pkg/models/model.go @@ -201,6 +201,24 @@ type CadencedModel struct { ActiveTo *time.Time `json:"activeTo"` } +func (c CadencedModel) Equal(other CadencedModel) bool { + if !c.ActiveFrom.Equal(other.ActiveFrom) { + return false + } + + if (c.ActiveTo == nil) != (other.ActiveTo == nil) { + return false + } + + if c.ActiveTo != nil && other.ActiveTo != nil { + if !c.ActiveTo.Equal(*other.ActiveTo) { + return false + } + } + + return true +} + var _ Cadenced = CadencedModel{} func (c CadencedModel) cadenced() cadencedMarker {