diff --git a/manager/sql/manager_sql.go b/manager/sql/manager_sql.go index 2aac6cd..293b4e9 100644 --- a/manager/sql/manager_sql.go +++ b/manager/sql/manager_sql.go @@ -190,15 +190,8 @@ func (s *SQLManager) Create(policy Policy) (err error) { } switch s.db.DriverName() { - case "postgres", "pgx": - if _, err = tx.Exec(s.db.Rebind("INSERT INTO ladon_policy (id, description, effect, conditions) VALUES (?, ?, ?, ?) ON CONFLICT DO NOTHING"), policy.GetID(), policy.GetDescription(), policy.GetEffect(), conditions); err != nil { - if err := tx.Rollback(); err != nil { - return errors.WithStack(err) - } - return errors.WithStack(err) - } - case "mysql": - if _, err = tx.Exec(s.db.Rebind("INSERT IGNORE INTO ladon_policy (id, description, effect, conditions) VALUES (?, ?, ?, ?)"), policy.GetID(), policy.GetDescription(), policy.GetEffect(), conditions); err != nil { + case "postgres", "pgx", "mysql": + if _, err = tx.Exec(s.db.Rebind("INSERT INTO ladon_policy (id, description, effect, conditions) SELECT ?, ?, ?, ? WHERE NOT EXISTS (SELECT 1 FROM ladon_policy WHERE id = ?)"), policy.GetID(), policy.GetDescription(), policy.GetEffect(), conditions, policy.GetID()); err != nil { if err := tx.Rollback(); err != nil { return errors.WithStack(err) } @@ -232,30 +225,15 @@ func (s *SQLManager) Create(policy Policy) (err error) { } switch s.db.DriverName() { - case "postgres", "pgx": - if _, err := tx.Exec(s.db.Rebind(fmt.Sprintf("INSERT INTO ladon_%s (id, template, compiled, has_regex) VALUES (?, ?, ?, ?) ON CONFLICT DO NOTHING", v.t)), id, template, compiled.String(), strings.Index(template, string(policy.GetStartDelimiter())) > -1); err != nil { - if err := tx.Rollback(); err != nil { - return errors.WithStack(err) - } - return errors.WithStack(err) - } - - if _, err := tx.Exec(s.db.Rebind(fmt.Sprintf("INSERT INTO ladon_policy_%s_rel (policy, %s) VALUES (?, ?) ON CONFLICT DO NOTHING", v.t, v.t)), policy.GetID(), id); err != nil { - if err := tx.Rollback(); err != nil { - return errors.WithStack(err) - } - return errors.WithStack(err) - } - break - case "mysql": - if _, err := tx.Exec(s.db.Rebind(fmt.Sprintf("INSERT IGNORE INTO ladon_%s (id, template, compiled, has_regex) VALUES (?, ?, ?, ?)", v.t)), id, template, compiled.String(), strings.Index(template, string(policy.GetStartDelimiter())) > -1); err != nil { + case "postgres", "pgx", "mysql": + if _, err := tx.Exec(s.db.Rebind(fmt.Sprintf("INSERT INTO ladon_%s (id, template, compiled, has_regex) SELECT ?, ?, ?, ? WHERE NOT EXISTS (SELECT 1 FROM ladon_%[1]s WHERE id = ?)", v.t)), id, template, compiled.String(), strings.Index(template, string(policy.GetStartDelimiter())) > -1, id); err != nil { if err := tx.Rollback(); err != nil { return errors.WithStack(err) } return errors.WithStack(err) } - if _, err := tx.Exec(s.db.Rebind(fmt.Sprintf("INSERT IGNORE INTO ladon_policy_%s_rel (policy, %s) VALUES (?, ?)", v.t, v.t)), policy.GetID(), id); err != nil { + if _, err := tx.Exec(s.db.Rebind(fmt.Sprintf("INSERT INTO ladon_policy_%s_rel (policy, %[1]s) SELECT ?, ? WHERE NOT EXISTS (SELECT 1 FROM ladon_policy_%[1]s_rel WHERE policy = ? AND %[1]s = ?)", v.t)), policy.GetID(), id, policy.GetID(), id); err != nil { if err := tx.Rollback(); err != nil { return errors.WithStack(err) }