diff --git a/api/api.go b/api/api.go index 3cf68fb..f3213cb 100644 --- a/api/api.go +++ b/api/api.go @@ -188,7 +188,10 @@ func (a *API) orderRoutes(r *router) { r.With(addGetBody).Post("/", a.PaymentCreate) }) - r.Get("/downloads", a.DownloadList) + r.Route("/downloads", func(r *router) { + r.Get("/", a.DownloadList) + r.Post("/refresh", a.DownloadRefresh) + }) r.Get("/receipt", a.ReceiptView) r.Post("/receipt", a.ResendOrderReceipt) }) diff --git a/api/download.go b/api/download.go index 564a4cc..c19ebd3 100644 --- a/api/download.go +++ b/api/download.go @@ -83,7 +83,6 @@ func (a *API) DownloadList(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() orderID := gcontext.GetOrderID(ctx) log := getLogEntry(r) - claims := gcontext.GetClaims(ctx) order := &models.Order{} if orderID != "" { @@ -114,6 +113,7 @@ func (a *API) DownloadList(w http.ResponseWriter, r *http.Request) error { if order != nil { query = query.Where(orderTable+".id = ?", order.ID) } else { + claims := gcontext.GetClaims(ctx) query = query.Where(orderTable+".user_id = ?", claims.Subject) } @@ -130,3 +130,44 @@ func (a *API) DownloadList(w http.ResponseWriter, r *http.Request) error { log.WithField("download_count", len(downloads)).Debugf("Successfully retrieved %d downloads", len(downloads)) return sendJSON(w, http.StatusOK, downloads) } + +// DownloadRefresh makes sure downloads are up to date +func (a *API) DownloadRefresh(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + orderID := gcontext.GetOrderID(ctx) + config := gcontext.GetConfig(ctx) + log := getLogEntry(r) + + order := &models.Order{} + if orderID == "" { + return badRequestError("Order id missing") + } + + query := a.db.Where("id = ?", orderID). + Preload("LineItems"). + Preload("Downloads") + if result := query.First(order); result.Error != nil { + if result.RecordNotFound() { + return notFoundError("Download order not found") + } + return internalServerError("Error during database query").WithInternalError(result.Error) + } + + if !hasOrderAccess(ctx, order) { + return unauthorizedError("You don't have permission to access this order") + } + + if order.PaymentState != models.PaidState { + return unauthorizedError("This order has not been completed yet") + } + + if err := order.UpdateDownloads(config, log); err != nil { + return internalServerError("Error during updating downloads").WithInternalError(err) + } + + if result := a.db.Save(order); result.Error != nil { + return internalServerError("Error during saving order").WithInternalError(result.Error) + } + + return sendJSON(w, http.StatusOK, map[string]string{}) +} diff --git a/api/download_test.go b/api/download_test.go index 96c4ed9..c3267af 100644 --- a/api/download_test.go +++ b/api/download_test.go @@ -1,7 +1,11 @@ package api import ( + "encoding/json" + "fmt" + "io/ioutil" "net/http" + "net/http/httptest" "testing" "github.com/netlify/gocommerce/models" @@ -19,3 +23,70 @@ func TestDownloadList(t *testing.T) { assert.Len(t, downloads, 1) }) } + +func currentDownloads(test *RouteTest) []models.Download { + recorder := test.TestEndpoint(http.MethodGet, "/downloads", nil, test.Data.testUserToken) + + downloads := []models.Download{} + extractPayload(test.T, http.StatusOK, recorder, &downloads) + return downloads +} + +type DownloadMeta struct { + Title string `json:"title"` + URL string `json:"url"` +} + +func startTestSiteWithDownloads(t *testing.T, downloads []*DownloadMeta) *httptest.Server { + downloadsList, err := json.Marshal(downloads) + assert.NoError(t, err) + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/i/believe/i/can/fly": + fmt.Fprintf(w, productMetaFrame(` + {"sku": "123-i-can-fly-456", "downloads": %s}`), + string(downloadsList), + ) + } + })) +} + +func TestDownloadRefresh(t *testing.T) { + test := NewRouteTest(t) + downloadsBefore := currentDownloads(test) + + testSite := startTestSiteWithDownloads(t, []*DownloadMeta{ + &DownloadMeta{ + Title: "Updated Download", + URL: "/my/special/new/url", + }, + }) + defer testSite.Close() + test.Config.SiteURL = testSite.URL + + url := fmt.Sprintf("/orders/%s/downloads/refresh", test.Data.firstOrder.ID) + recorder := test.TestEndpoint(http.MethodPost, url, nil, test.Data.testUserToken) + body, err := ioutil.ReadAll(recorder.Body) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, recorder.Code, "Failure: %s", string(body)) + + downloadsAfter := currentDownloads(test) + + assert.Equal(t, len(downloadsBefore)+1, len(downloadsAfter)) + exists := false + for _, download := range downloadsAfter { + found := false + for _, prev := range downloadsBefore { + if download.ID == prev.ID { + found = true + break + } + } + if !found { + assert.Equal(t, "/my/special/new/url", download.URL) + assert.Equal(t, "123-i-can-fly-456", download.Sku) + exists = true + } + } + assert.True(t, exists) +} diff --git a/api/order.go b/api/order.go index b873a88..63654ae 100644 --- a/api/order.go +++ b/api/order.go @@ -7,7 +7,6 @@ import ( "net/http" "sync" - "github.com/PuerkitoBio/goquery" "github.com/go-chi/chi" "github.com/jinzhu/gorm" "github.com/mattes/vat" @@ -641,6 +640,13 @@ func (a *API) createLineItems(ctx context.Context, tx *gorm.DB, order *models.Or Path: orderItem.Path, OrderID: order.ID, } + + for _, addon := range orderItem.Addons { + lineItem.AddonItems = append(lineItem.AddonItems, &models.AddonItem{ + Sku: addon.Sku, + }) + } + order.LineItems = append(order.LineItems, lineItem) sem <- 1 wg.Add(1) @@ -654,7 +660,7 @@ func (a *API) createLineItems(ctx context.Context, tx *gorm.DB, order *models.Or return } - if err := a.processLineItem(ctx, order, item, orderItem); err != nil { + if err := a.processLineItem(ctx, order, item); err != nil { sharedErr.setError(err) } }(lineItem, orderItem) @@ -735,56 +741,11 @@ func (a *API) processAddress(tx *gorm.DB, order *models.Order, name string, addr return address, nil } -func (a *API) processLineItem(ctx context.Context, order *models.Order, item *models.LineItem, orderItem *orderLineItem) error { +func (a *API) processLineItem(ctx context.Context, order *models.Order, item *models.LineItem) error { config := gcontext.GetConfig(ctx) jwtClaims := gcontext.GetClaimsAsMap(ctx) - resp, err := a.httpClient.Get(config.SiteURL + item.Path) - if err != nil { - return err - } - defer resp.Body.Close() - - doc, err := goquery.NewDocumentFromResponse(resp) - if err != nil { - return err - } - - metaTag := doc.Find(".gocommerce-product") - if metaTag.Length() == 0 { - return fmt.Errorf("No script tag with class gocommerce-product tag found for '%v'", item.Title) - } - metaProducts := []*models.LineItemMetadata{} - var parsingErr error - metaTag.EachWithBreak(func(_ int, tag *goquery.Selection) bool { - meta := &models.LineItemMetadata{} - parsingErr = json.Unmarshal([]byte(tag.Text()), meta) - if parsingErr != nil { - return false - } - metaProducts = append(metaProducts, meta) - return true - }) - if parsingErr != nil { - return fmt.Errorf("Error parsing product metadata: %v", parsingErr) - } - - if len(metaProducts) == 1 && item.Sku == "" { - item.Sku = metaProducts[0].Sku - } - - for _, meta := range metaProducts { - if meta.Sku == item.Sku { - for _, addon := range orderItem.Addons { - item.AddonItems = append(item.AddonItems, &models.AddonItem{ - Sku: addon.Sku, - }) - } - - return item.Process(jwtClaims, order, meta) - } - } - return fmt.Errorf("No product Sku from path matched: %v", item.Sku) + return item.Process(config, jwtClaims, order) } func orderQuery(db *gorm.DB) *gorm.DB { diff --git a/api/utils_test.go b/api/utils_test.go index 55f0681..1b68d06 100644 --- a/api/utils_test.go +++ b/api/utils_test.go @@ -418,35 +418,34 @@ func signInstanceRequest(req *http.Request, instanceID string, jwtSecret string) // TEST SITE // ------------------------------------------------------------------------------------------------ +func productMetaFrame(meta string) string { + return fmt.Sprintf(` + +Test Product + + + +`, + meta) +} + func handleTestProducts(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/simple-product": - fmt.Fprintln(w, ` - - Test Product - - - - `) + fmt.Fprintln(w, productMetaFrame(` + {"sku": "product-1", "title": "Product 1", "type": "Book", "prices": [ + {"amount": "9.99", "currency": "USD"} + ]}`)) case "/bundle-product": - fmt.Fprintln(w, ` - - Test Product - - - - `) + ]}`)) default: w.WriteHeader(http.StatusNotFound) } diff --git a/models/line_item.go b/models/line_item.go index 54f5ded..50ed44b 100644 --- a/models/line_item.go +++ b/models/line_item.go @@ -4,12 +4,15 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strconv" "time" + "github.com/PuerkitoBio/goquery" "github.com/jinzhu/gorm" "github.com/netlify/gocommerce/calculator" "github.com/netlify/gocommerce/claims" + "github.com/netlify/gocommerce/conf" "github.com/pborman/uuid" ) @@ -248,7 +251,12 @@ func (i *LineItem) GetQuantity() uint64 { } // Process calculates the price of a LineItem. -func (i *LineItem) Process(userClaims map[string]interface{}, order *Order, meta *LineItemMetadata) error { +func (i *LineItem) Process(config *conf.Configuration, userClaims map[string]interface{}, order *Order) error { + meta, err := i.FetchMeta(config.SiteURL) + if err != nil { + return err + } + i.Sku = meta.Sku i.Title = meta.Title i.Description = meta.Description @@ -277,6 +285,59 @@ func (i *LineItem) Process(userClaims map[string]interface{}, order *Order, meta i.AddonItems[index].Price = lowestPrice.cents } + order.Downloads = i.MissingDownloads(order, meta) + + return i.calculatePrice(userClaims, meta.Prices, order.Currency) +} + +// FetchMeta determines the product metadata for the item based on its path +func (i *LineItem) FetchMeta(siteURL string) (*LineItemMetadata, error) { + resp, err := http.Get(siteURL + i.Path) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + doc, err := goquery.NewDocumentFromResponse(resp) + if err != nil { + return nil, err + } + + metaTag := doc.Find(".gocommerce-product") + if metaTag.Length() == 0 { + return nil, fmt.Errorf("No script tag with class gocommerce-product tag found for '%v'", i.Title) + } + metaProducts := []*LineItemMetadata{} + var parsingErr error + metaTag.EachWithBreak(func(_ int, tag *goquery.Selection) bool { + meta := &LineItemMetadata{} + parsingErr = json.Unmarshal([]byte(tag.Text()), meta) + if parsingErr != nil { + return false + } + metaProducts = append(metaProducts, meta) + return true + }) + if parsingErr != nil { + return nil, fmt.Errorf("Error parsing product metadata: %v", parsingErr) + } + + if len(metaProducts) == 1 && i.Sku == "" { + i.Sku = metaProducts[0].Sku + } + + for _, meta := range metaProducts { + if meta.Sku == i.Sku { + return meta, nil + } + } + + return nil, fmt.Errorf("No product Sku from path matched: %v", i.Sku) +} + +// MissingDownloads returns all downloads that are not yet listed in the order +func (i *LineItem) MissingDownloads(order *Order, meta *LineItemMetadata) []Download { + downloads := []Download{} for _, download := range meta.Downloads { alreadyCreated := false for _, d := range order.Downloads { @@ -292,10 +353,9 @@ func (i *LineItem) Process(userClaims map[string]interface{}, order *Order, meta download.OrderID = order.ID download.Title = i.Title download.Sku = i.Sku - order.Downloads = append(order.Downloads, download) + downloads = append(downloads, download) } - - return i.calculatePrice(userClaims, meta.Prices, order.Currency) + return downloads } func (i *LineItem) calculatePrice(userClaims map[string]interface{}, prices []PriceMetadata, currency string) error { diff --git a/models/order.go b/models/order.go index 69a6558..03c412f 100644 --- a/models/order.go +++ b/models/order.go @@ -7,6 +7,7 @@ import ( "github.com/jinzhu/gorm" "github.com/netlify/gocommerce/calculator" + "github.com/netlify/gocommerce/conf" "github.com/pborman/uuid" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -202,6 +203,18 @@ func (o *Order) CalculateTotal(settings *calculator.Settings, claims map[string] } } +// UpdateDownloads will refetch downloads for all line items in the order and +// update the downloads in the order +func (o *Order) UpdateDownloads(config *conf.Configuration, log logrus.FieldLogger) error { + updateMap := downloadRefreshItemSet{} + for _, item := range o.LineItems { + updateMap.Add(item, o) + } + updates, err := updateMap.Update(nil, config, log) + log.Debugf("Updated downloads of %d orders", len(updates)) + return err +} + func (o *Order) BeforeDelete(tx *gorm.DB) error { cascadeModels := map[string]interface{}{ "line item": &[]LineItem{}, @@ -224,3 +237,84 @@ func (o *Order) BeforeDelete(tx *gorm.DB) error { } return nil } + +type downloadRefreshItemSetEntry struct { + item *LineItem + orders []*Order +} +type downloadRefreshInstanceItems map[string]*downloadRefreshItemSetEntry +type downloadRefreshItemSet map[string]downloadRefreshInstanceItems + +// Add will take a line item and an order to persist in +// the list of orders to update +func (m downloadRefreshItemSet) Add(item *LineItem, order *Order) { + instance, ok := m[order.InstanceID] + if !ok { + instance = make(map[string]*downloadRefreshItemSetEntry) + m[order.InstanceID] = instance + } + + mapping, ok := instance[item.Sku] + if !ok { + mapping = &downloadRefreshItemSetEntry{ + item: item, + orders: []*Order{}, + } + instance[item.Sku] = mapping + } + + mapping.orders = append(mapping.orders, order) +} + +// UpdateDownloads fetches downloads for all line items and updates orders with new downloads +func (m downloadRefreshItemSet) Update(db *gorm.DB, config *conf.Configuration, log logrus.FieldLogger) (updates []*Order, err error) { + // @todo: run in parallel with goroutines, lock orders with mutexes + for instanceID, items := range m { + if config == nil { + if db == nil { + err = errors.New("Instance config or database connection missing") + return + } + instance := Instance{} + if queryErr := db.First(&instance, Instance{ID: instanceID}).Error; queryErr != nil { + err = errors.Wrap(queryErr, "Failed fetching instance for order") + return + } + config = instance.BaseConfig + } + + for _, entry := range items { + if entry.item.Sku == "" { + log.Warningf( + "Tried updating a line item without SKU at %s. Skipped to avoid memory update in FetchMeta", + entry.item.Path, + ) + continue + } + log.Debugf("Updating downloads for item with sku '%s'", entry.item.Sku) + meta, fetchErr := entry.item.FetchMeta(config.SiteURL) + if fetchErr != nil { + // item might not be offered anymore, preserve downloads + log.WithError(fetchErr). + WithFields(map[string]interface{}{ + "path": entry.item.Path, + "sku": entry.item.Sku, + }). + Warning("Fetching product metadata failed. Skipping item.") + continue + } + for _, order := range entry.orders { + downloads := entry.item.MissingDownloads(order, meta) + if len(downloads) == 0 { + continue + } + // @todo: Lock order mutex if run in goroutines + order.Downloads = append(order.Downloads, downloads...) + + updates = append(updates, order) + } + } + } + + return +}