Skip to content

Commit

Permalink
Merge pull request #228 from SkynetLabs/ivo/stripe_promote_on_checkout
Browse files Browse the repository at this point in the history
Stripe promote on checkout
  • Loading branch information
ro-tex authored Jul 25, 2022
2 parents 59e344d + 6a97405 commit 74d6623
Show file tree
Hide file tree
Showing 7 changed files with 502 additions and 40 deletions.
1 change: 1 addition & 0 deletions api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func (api *API) buildHTTPRoutes() {
// `POST /stripe/billing` is deprecated. Please use `GET /stripe/billing`.
api.staticRouter.POST("/stripe/billing", api.WithDBSession(api.withAuth(api.stripeBillingHANDLER, false)))
api.staticRouter.POST("/stripe/checkout", api.WithDBSession(api.withAuth(api.stripeCheckoutPOST, false)))
api.staticRouter.GET("/stripe/checkout/:checkout_id", api.WithDBSession(api.withAuth(api.stripeCheckoutIDGET, false)))
api.staticRouter.GET("/stripe/prices", api.noAuth(api.stripePricesGET))
api.staticRouter.POST("/stripe/webhook", api.WithDBSession(api.noAuth(api.stripeWebhookPOST)))

Expand Down
192 changes: 182 additions & 10 deletions api/stripe.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/stripe/stripe-go/v72/sub"
"github.com/stripe/stripe-go/v72/webhook"
"gitlab.com/NebulousLabs/errors"
"gitlab.com/SkynetLabs/skyd/build"
)

const (
Expand All @@ -34,8 +35,22 @@ var (
// `https://account.` prepended to it).
DashboardURL = "https://account.siasky.net"

// True is a helper for when we need to pass a *bool to Stripe.
True = true
// ErrCheckoutWithoutCustomer is the error returned when a checkout session
// doesn't have an associated customer
ErrCheckoutWithoutCustomer = errors.New("this checkout session does not have an associated customer")
// ErrCheckoutWithoutSub is the error returned when a checkout session doesn't
// have an associated subscription
ErrCheckoutWithoutSub = errors.New("this checkout session does not have an associated subscription")
// ErrCheckoutDoesNotBelongToUser is returned when the given checkout
// session does not belong to the current user. This might be a mistake or
// might be an attempt for fraud.
ErrCheckoutDoesNotBelongToUser = errors.New("checkout session does not belong to current user")
// ErrSubNotActive is returned when the given subscription is not active, so
// we cannot do anything based on it.
ErrSubNotActive = errors.New("subscription not active")
// ErrSubWithoutPrice is returned when the subscription doesn't have a
// price, so we cannot determine the user's tier based on it.
ErrSubWithoutPrice = errors.New("subscription does not have a price")

// stripePageSize defines the number of records we are going to request from
// endpoints that support pagination.
Expand All @@ -62,7 +77,7 @@ var (
)

type (
// StripePrice ...
// StripePrice describes a Stripe price item.
StripePrice struct {
ID string `json:"id"`
Name string `json:"name"`
Expand All @@ -74,6 +89,42 @@ type (
ProductID string `json:"productId"`
LiveMode bool `json:"livemode"`
}
// SubscriptionGET describes a Stripe subscription for our front end needs.
SubscriptionGET struct {
Created int64 `json:"created"`
CurrentPeriodStart int64 `json:"currentPeriodStart"`
Discount *SubscriptionDiscountGET `json:"discount"`
ID string `json:"id"`
Plan *SubscriptionPlanGET `json:"plan"`
StartDate int64 `json:"startDate"`
Status string `json:"status"`
}
// SubscriptionDiscountGET describes a Stripe subscription discount for our
// front end needs.
SubscriptionDiscountGET struct {
AmountOff int64 `json:"amountOff"`
Currency string `json:"currency"`
Duration string `json:"duration"`
DurationInMonths int64 `json:"durationInMonths"`
Name string `json:"name"`
PercentOff float64 `json:"percentOff"`
}
// SubscriptionPlanGET describes a Stripe subscription plan for our front
// end needs.
SubscriptionPlanGET struct {
Amount int64 `json:"amount"`
Currency string `json:"currency"`
Interval string `json:"interval"`
IntervalCount int64 `json:"intervalCount"`
Price string `json:"price"`
Product *SubscriptionProductGET `json:"product"`
}
// SubscriptionProductGET describes a Stripe subscription product for our
// front end needs.
SubscriptionProductGET struct {
Description string `json:"description"`
Name string `json:"name"`
}
)

// processStripeSub reads the information about the user's subscription and
Expand Down Expand Up @@ -120,8 +171,8 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er
}
// Cancel all subs aside from the latest one.
p := stripe.SubscriptionCancelParams{
InvoiceNow: &True,
Prorate: &True,
InvoiceNow: stripe.Bool(true),
Prorate: stripe.Bool(true),
}
for _, subsc := range subs {
if subsc == nil || (mostRecentSub != nil && subsc.ID == mostRecentSub.ID) {
Expand Down Expand Up @@ -199,7 +250,7 @@ func (api *API) stripeCheckoutPOST(u *database.User, w http.ResponseWriter, req
cancelURL := DashboardURL + "/payments"
successURL := DashboardURL + "/payments?session_id={CHECKOUT_SESSION_ID}"
params := stripe.CheckoutSessionParams{
AllowPromotionCodes: &True,
AllowPromotionCodes: stripe.Bool(true),
CancelURL: &cancelURL,
ClientReferenceID: &u.Sub,
Customer: &u.StripeID,
Expand All @@ -226,26 +277,147 @@ func (api *API) stripeCheckoutPOST(u *database.User, w http.ResponseWriter, req
api.WriteJSON(w, response)
}

// stripeCheckoutIDGET checks the status of a checkout session. If the checkout
// is successful and results in a higher tier sub than the current one, we
// upgrade the user to the new tier.
func (api *API) stripeCheckoutIDGET(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
checkoutSessionID := ps.ByName("checkout_id")
subStr := "subscription"
subDiscountStr := "subscription.discount"
subPlanProductStr := "subscription.plan.product"
params := &stripe.CheckoutSessionParams{
Params: stripe.Params{
Expand: []*string{&subStr, &subDiscountStr, &subPlanProductStr},
},
}
cos, err := cosession.Get(checkoutSessionID, params)
if err != nil {
api.WriteError(w, err, http.StatusInternalServerError)
return
}
if cos.Customer == nil {
api.WriteError(w, ErrCheckoutWithoutCustomer, http.StatusBadRequest)
return
}
if cos.Customer.ID != u.StripeID {
api.WriteError(w, ErrCheckoutDoesNotBelongToUser, http.StatusForbidden)
return
}
coSub := cos.Subscription
if coSub == nil {
api.WriteError(w, ErrCheckoutWithoutSub, http.StatusBadRequest)
return
}
if coSub.Status != stripe.SubscriptionStatusActive {
api.WriteError(w, ErrSubNotActive, http.StatusBadRequest)
return
}
// Get the subscription price.
if coSub.Items == nil || len(coSub.Items.Data) == 0 || coSub.Items.Data[0].Price == nil {
api.WriteError(w, ErrSubWithoutPrice, http.StatusBadRequest)
return
}
coSubPrice := coSub.Items.Data[0].Price
tier, exists := StripePrices()[coSubPrice.ID]
if !exists {
err = fmt.Errorf("invalid price id '%s'", coSubPrice.ID)
api.WriteError(w, err, http.StatusInternalServerError)
build.Critical(errors.AddContext(err, "We somehow received an invalid price ID from Stripe. This might be caused by mismatched test/prod tokens or a breakdown in our Stripe setup."))
return
}
// Promote the user, if needed.
if tier > u.Tier {
err = api.staticDB.UserSetTier(req.Context(), u, tier)
if err != nil {
api.WriteError(w, errors.AddContext(err, "failed to promote user"), http.StatusInternalServerError)
return
}
}
// Build the response DTO.
var discountInfo *SubscriptionDiscountGET
if coSub.Discount != nil {
var coupon *stripe.Coupon
// We can potentially fetch the discount coupon from two places - the
// discount itself or its promotional code. We'll check them in order.
if coSub.Discount.Coupon != nil {
coupon = coSub.Discount.Coupon
} else if coSub.Discount.PromotionCode != nil && coSub.Discount.PromotionCode.Coupon != nil {
coupon = coSub.Discount.PromotionCode.Coupon
}
if coupon != nil {
discountInfo = &SubscriptionDiscountGET{
AmountOff: coupon.AmountOff,
Currency: string(coupon.Currency),
Duration: string(coupon.Duration),
DurationInMonths: coupon.DurationInMonths,
Name: coupon.Name,
PercentOff: coupon.PercentOff,
}
}
}
var planInfo *SubscriptionPlanGET
if coSub.Plan != nil {
var productInfo *SubscriptionProductGET
if coSub.Plan.Product != nil {
productInfo = &SubscriptionProductGET{
Description: coSub.Plan.Product.Description,
Name: coSub.Plan.Product.Name,
}
}
planInfo = &SubscriptionPlanGET{
Amount: coSub.Plan.Amount,
Currency: string(coSub.Plan.Currency),
Interval: string(coSub.Plan.Interval),
IntervalCount: coSub.Plan.IntervalCount,
Price: coSub.Plan.ID, // plan ID and price ID are the same
Product: productInfo,
}
}

subInfo := SubscriptionGET{
Created: coSub.Created,
CurrentPeriodStart: coSub.CurrentPeriodStart,
Discount: discountInfo,
ID: coSub.ID,
Plan: planInfo,
StartDate: coSub.StartDate,
Status: string(coSub.Status),
}
api.WriteJSON(w, subInfo)
}

// stripeCreateCustomer creates a Stripe customer record for this user and
// updates the user in the database.
func (api *API) stripeCreateCustomer(ctx context.Context, u *database.User) (string, error) {
cus, err := customer.New(&stripe.CustomerParams{})
if err != nil {
return "", errors.AddContext(err, "failed to create Stripe customer")
}
stripeID := cus.ID
err = api.staticDB.UserSetStripeID(ctx, u, stripeID)
// We'll try to update the customer with the user's email and sub. We only
// do this as an optional step, so we can match Stripe customers to local
// users more easily. We do not care if this step fails - it's entirely
// optional. It requires an additional round-trip to Stripe and we don't
// need to wait for it to finish, so we'll do it in a separate goroutine.
go func() {
email := u.Email.String()
updateParams := stripe.CustomerParams{
Description: &u.Sub,
Email: &email,
}
_, _ = customer.Update(cus.ID, &updateParams)
}()
err = api.staticDB.UserSetStripeID(ctx, u, cus.ID)
if err != nil {
return "", errors.AddContext(err, "failed to save user's StripeID")
}
return stripeID, nil
return cus.ID, nil
}

// stripePricesGET returns a list of plans and prices.
func (api *API) stripePricesGET(_ *database.User, w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
var sPrices []StripePrice
params := &stripe.PriceListParams{
Active: &True,
Active: stripe.Bool(true),
ListParams: stripe.ListParams{
Limit: &stripePageSize,
},
Expand Down
18 changes: 10 additions & 8 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ require (
github.com/julienschmidt/httprouter v1.3.0
github.com/lestrrat-go/jwx v1.2.25
github.com/sirupsen/logrus v1.8.1
github.com/stripe/stripe-go/v72 v72.115.0
github.com/stripe/stripe-go/v72 v72.117.0
gitlab.com/NebulousLabs/errors v0.0.0-20200929122200-06c536cf6975
gitlab.com/NebulousLabs/fastrand v0.0.0-20181126182046-603482d69e40
gitlab.com/SkynetLabs/skyd v1.5.10
gitlab.com/SkynetLabs/skyd v1.6.0
go.mongodb.org/mongo-driver v1.9.1
go.sia.tech/siad v1.5.8
go.sia.tech/siad v1.5.9-rc1
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d
gopkg.in/h2non/gock.v1 v1.1.2
gopkg.in/mail.v2 v2.3.1
)

Expand All @@ -24,10 +25,11 @@ require (
github.com/dchest/threefish v0.0.0-20120919164726-3ecf4c494abf // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect
github.com/go-stack/stack v1.8.1 // indirect
github.com/goccy/go-json v0.9.7 // indirect
github.com/goccy/go-json v0.9.8 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
github.com/klauspost/compress v1.15.6 // indirect
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
github.com/klauspost/compress v1.15.7 // indirect
github.com/klauspost/cpuid/v2 v2.0.14 // indirect
github.com/klauspost/reedsolomon v1.10.0 // indirect
github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect
Expand All @@ -50,11 +52,11 @@ require (
gitlab.com/NebulousLabs/merkletree v0.0.0-20200118113624-07fbf710afc4 // indirect
gitlab.com/NebulousLabs/persist v0.0.0-20200605115618-007e5e23d877 // indirect
gitlab.com/NebulousLabs/ratelimit v0.0.0-20200811080431-99b8f0768b2e // indirect
gitlab.com/NebulousLabs/siamux v0.0.0-20220616144115-9831ef867730 // indirect
gitlab.com/NebulousLabs/siamux v0.0.2-0.20220630142132-142a1443a259 // indirect
gitlab.com/NebulousLabs/threadgroup v0.0.0-20200608151952-38921fbef213 // indirect
golang.org/x/net v0.0.0-20220622184535-263ec571b305 // indirect
golang.org/x/net v0.0.0-20220706163947-c90051bbdb60 // indirect
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664 // indirect
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e // indirect
golang.org/x/text v0.3.7 // indirect
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
Expand Down
Loading

0 comments on commit 74d6623

Please sign in to comment.