diff --git a/cmd/import.go b/cmd/import.go index 63f6416..375949d 100644 --- a/cmd/import.go +++ b/cmd/import.go @@ -22,7 +22,13 @@ func importTask(db *sql.DB, taskSummary string) error { } now := time.Now() - _, err = pers.ImportTask(db, taskSummary, true, now, now) + task := types.Task{ + Summary: taskSummary, + Active: true, + CreatedAt: now, + UpdatedAt: now, + } + _, err = pers.InsertTasks(db, []types.Task{task}, true) return err } diff --git a/internal/persistence/queries.go b/internal/persistence/queries.go index 47347aa..aaf84b7 100644 --- a/internal/persistence/queries.go +++ b/internal/persistence/queries.go @@ -121,73 +121,6 @@ VALUES (?, true, ?, ?); return uint64(li), nil } -func ImportTask(db *sql.DB, summary string, active bool, createdAt, updatedAt time.Time) (int64, error) { - tx, err := db.Begin() - if err != nil { - return -1, err - } - defer func() { - _ = tx.Rollback() - }() - - query := `INSERT INTO task (summary, active, created_at, updated_at) -VALUES (?, ?, ?, ?);` - - res, err := tx.Exec(query, summary, active, createdAt.UTC(), updatedAt.UTC()) - if err != nil { - return -1, err - } - - lastInsertID, err := res.LastInsertId() - if err != nil { - return -1, err - } - - var seq []byte - seqRow := tx.QueryRow("SELECT sequence from task_sequence where id=1;") - - err = seqRow.Scan(&seq) - if err != nil { - return -1, err - } - - var seqItems []int - err = json.Unmarshal(seq, &seqItems) - if err != nil { - return -1, err - } - - newTaskID := make([]int, 1) - newTaskID[0] = int(lastInsertID) - updatedSeqItems := append(newTaskID, seqItems...) - - sequenceJSON, err := json.Marshal(updatedSeqItems) - if err != nil { - return -1, err - } - - seqUpdateStmt, err := tx.Prepare(` -UPDATE task_sequence -SET sequence = ? -WHERE id = 1; -`) - if err != nil { - return -1, err - } - defer seqUpdateStmt.Close() - - _, err = seqUpdateStmt.Exec(sequenceJSON) - if err != nil { - return -1, err - } - - err = tx.Commit() - if err != nil { - return -1, err - } - return lastInsertID, nil -} - func InsertTasks(db *sql.DB, tasks []types.Task, insertAtTop bool) (int64, error) { tx, err := db.Begin() if err != nil { diff --git a/internal/persistence/queries_test.go b/internal/persistence/queries_test.go index 9d0d06a..5241190 100644 --- a/internal/persistence/queries_test.go +++ b/internal/persistence/queries_test.go @@ -127,61 +127,54 @@ WHERE id = 1; return na, ni } -func TestImportTask(t *testing.T) { - t.Cleanup(func() { cleanupDB(t) }) - - // GIVEN - na, _ := seedDB(t, testDB) - - // WHEN - summary := "prefix: an imported task" - now := time.Now().UTC() - lastID, err := ImportTask(testDB, summary, true, now, now) - require.NoError(t, err) - - // THEN - numActiveTasksAfter, err := fetchNumActiveTasks(testDB) - require.NoError(t, err) - assert.Equal(t, numActiveTasksAfter, na+1, "number of active tasks didn't increase by 1") - - task, err := fetchTaskByID(testDB, lastID) - require.NoError(t, err) - assert.True(t, task.Active) - assert.Equal(t, summary, task.Summary) - - seq, err := fetchTaskSequence(testDB) - require.NoError(t, err) - assert.Equal(t, seq, []uint64{6, 1, 2, 3}, "task sequence isn't correct") -} - func TestInsertTasksWorksWithEmptyTaskList(t *testing.T) { t.Cleanup(func() { cleanupDB(t) }) // GIVEN // WHEN - tasks, na, ni := getSampleTasks() + now := time.Now().UTC() + tasks := []types.Task{ + { + Summary: "prefix: new task 1", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new inactive task 1", + Active: false, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new task 3", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + } lastID, err := InsertTasks(testDB, tasks, true) - assert.Equal(t, lastID, int64(na+ni), "last ID is not correct") + assert.Equal(t, lastID, int64(3), "last ID is not correct") require.NoError(t, err) // THEN numActiveRes, err := fetchNumActiveTasks(testDB) require.NoError(t, err) - assert.Equal(t, numActiveRes, na, "number of active tasks didn't increase by the correct amount") + assert.Equal(t, numActiveRes, 2, "number of active tasks didn't increase by the correct amount") numTotalRes, err := fetchNumTotalTasks(testDB) require.NoError(t, err) - assert.Equal(t, numTotalRes, na+ni, "number of total tasks didn't increase by the correct amount") + assert.Equal(t, numTotalRes, 3, "number of total tasks didn't increase by the correct amount") lastTask, err := fetchTaskByID(testDB, lastID) require.NoError(t, err) - assert.Equal(t, tasks[len(tasks)-1].Active, lastTask.Active) - assert.Equal(t, tasks[len(tasks)-1].Summary, lastTask.Summary) - assert.Equal(t, tasks[len(tasks)-1].Context, lastTask.Context) + assert.Equal(t, tasks[2].Active, lastTask.Active) + assert.Equal(t, tasks[2].Summary, lastTask.Summary) + assert.Equal(t, tasks[2].Context, lastTask.Context) seq, err := fetchTaskSequence(testDB) require.NoError(t, err) - assert.Equal(t, seq, []uint64{1, 2, 3}, "task sequence isn't correct") + assert.Equal(t, seq, []uint64{1, 3}, "task sequence isn't correct") } func TestInsertTasksAddsTasksAtTheTop(t *testing.T) {