diff --git a/query/components.go b/query/components.go index fecf346e..a5d7db42 100644 --- a/query/components.go +++ b/query/components.go @@ -10,7 +10,7 @@ import ( func GetComponentsByIDs(ctx context.Context, ids []uuid.UUID) ([]models.Component, error) { var components []models.Component for i := range ids { - c, err := ComponentFromCache(ctx, ids[i].String()) + c, err := ComponentFromCache(ctx, ids[i].String(), false) if err != nil { return nil, err } diff --git a/query/components_cache.go b/query/components_cache.go index 7f59a84e..db73878f 100644 --- a/query/components_cache.go +++ b/query/components_cache.go @@ -53,7 +53,7 @@ func SyncComponentCache(ctx context.Context) error { return nil } -func ComponentFromCache(ctx context.Context, id string) (models.Component, error) { +func ComponentFromCache(ctx context.Context, id string, queryDeleted bool) (models.Component, error) { c, err := componentCache.Get(ctx, componentCacheKey(id)) if err != nil { var cacheErr *store.NotFound @@ -62,7 +62,11 @@ func ComponentFromCache(ctx context.Context, id string) (models.Component, error } var component models.Component - if err := ctx.DB().Where("id = ?", id).Where("deleted_at IS NULL").First(&component).Error; err != nil { + q := ctx.DB().Where("id = ?", id) + if !queryDeleted { + q = q.Where("deleted_at IS NULL") + } + if err := q.First(&component).Error; err != nil { return component, err } diff --git a/topology/save.go b/topology/save.go index 31af63e3..5837d4d4 100644 --- a/topology/save.go +++ b/topology/save.go @@ -7,10 +7,18 @@ import ( "github.com/flanksource/duty/db" "github.com/flanksource/duty/models" "github.com/flanksource/duty/query" + "github.com/google/uuid" "gorm.io/gorm/clause" ) -func SaveComponent(ctx context.Context, c *models.Component) error { +// Save the component and its children returing the ids that were inserted/updated +func SaveComponent(ctx context.Context, c *models.Component) ([]string, error) { + var ids []string + return saveComponentsRecursively(ctx, c, ids) +} + +// We keep a list of ids to track all the insert/updated ids +func saveComponentsRecursively(ctx context.Context, c *models.Component, ids []string) ([]string, error) { if c.ParentId != nil && !strings.Contains(c.Path, c.ParentId.String()) { if c.Path == "" { c.Path = c.ParentId.String() @@ -19,37 +27,43 @@ func SaveComponent(ctx context.Context, c *models.Component) error { } } - if existing, err := query.ComponentFromCache(ctx, c.ID.String()); err == nil { + if existing, err := query.ComponentFromCache(ctx, c.ID.String(), true); err == nil { // Update component if it exists if err := ctx.DB().UpdateColumns(c).Error; err != nil { - return db.ErrorDetails(err) + return nil, db.ErrorDetails(err) } // Unset deleted_at if it was non nil if existing.DeletedAt != nil && c.DeletedAt == nil { if err := ctx.DB().Update("deleted_at", nil).Error; err != nil { - return db.ErrorDetails(err) + return nil, db.ErrorDetails(err) } } + ids = append(ids, c.ID.String()) } else { + // We set this to nil so that the conflict clause returns correct ID + c.ID = uuid.Nil // Create new component handling conflicts if err := ctx.DB().Clauses( clause.OnConflict{ Columns: []clause.Column{{Name: "topology_id"}, {Name: "type"}, {Name: "name"}, {Name: "parent_id"}}, UpdateAll: true, - }).Create(c).Error; err != nil { - return db.ErrorDetails(err) + }, clause.Returning{Columns: []clause.Column{{Name: "id"}}}).Create(c).Error; err != nil { + return nil, db.ErrorDetails(err) } + ids = append(ids, c.ID.String()) } if len(c.Components) > 0 { for _, child := range c.Components { child.TopologyID = c.TopologyID child.ParentId = &c.ID - if err := SaveComponent(ctx, child); err != nil { - return err + returnedIDs, err := saveComponentsRecursively(ctx, child, ids) + if err != nil { + return nil, err } + ids = append(ids, returnedIDs...) } } - return nil + return ids, nil }