From 2f9f1badd1f1cf0144f7e2d69f013def18d0f9ab Mon Sep 17 00:00:00 2001 From: Dhruv Thakur <13575379+dhth@users.noreply.github.com> Date: Sat, 24 Aug 2024 17:14:01 +0200 Subject: [PATCH] refactor: remove redundant func --- cmd/guide.go | 2 +- cmd/import.go | 12 +- internal/persistence/queries.go | 103 +++++------------ internal/persistence/queries_test.go | 166 +++++++++++++++++++-------- 4 files changed, 159 insertions(+), 124 deletions(-) diff --git a/cmd/guide.go b/cmd/guide.go index 36ab72d..052cb6a 100644 --- a/cmd/guide.go +++ b/cmd/guide.go @@ -90,7 +90,7 @@ func insertGuideTasks(db *sql.DB) error { } } - err = pers.InsertTasks(db, tasks) + _, err = pers.InsertTasks(db, tasks, true) return err } diff --git a/cmd/import.go b/cmd/import.go index 8a591ac..63f6416 100644 --- a/cmd/import.go +++ b/cmd/import.go @@ -7,6 +7,7 @@ import ( "time" pers "github.com/dhth/omm/internal/persistence" + "github.com/dhth/omm/internal/types" ) var errWillExceedCapacity = errors.New("import will exceed capacity") @@ -35,6 +36,15 @@ func importTasks(db *sql.DB, taskSummaries []string) error { } now := time.Now() - _, err = pers.ImportTaskSummaries(db, taskSummaries, true, now, now) + tasks := make([]types.Task, len(taskSummaries)) + for i, summ := range taskSummaries { + tasks[i] = types.Task{ + Summary: summ, + Active: true, + CreatedAt: now, + UpdatedAt: now, + } + } + _, err = pers.InsertTasks(db, tasks, true) return err } diff --git a/internal/persistence/queries.go b/internal/persistence/queries.go index 479e916..47347aa 100644 --- a/internal/persistence/queries.go +++ b/internal/persistence/queries.go @@ -36,6 +36,12 @@ func fetchNumActiveTasks(db *sql.DB) (int, error) { return rowCount, err } +func fetchNumTotalTasks(db *sql.DB) (int, error) { + var rowCount int + err := db.QueryRow("SELECT count(*) from task").Scan(&rowCount) + return rowCount, err +} + func fetchTaskByID(db *sql.DB, ID int64) (types.Task, error) { var entry types.Task row := db.QueryRow(` @@ -182,7 +188,7 @@ WHERE id = 1; return lastInsertID, nil } -func ImportTaskSummaries(db *sql.DB, tasks []string, active bool, createdAt, updatedAt time.Time) (int64, error) { +func InsertTasks(db *sql.DB, tasks []types.Task, insertAtTop bool) (int64, error) { tx, err := db.Begin() if err != nil { return -1, err @@ -191,21 +197,19 @@ func ImportTaskSummaries(db *sql.DB, tasks []string, active bool, createdAt, upd _ = tx.Rollback() }() - query := `INSERT INTO task (summary, active, created_at, updated_at) + query := `INSERT INTO task (summary, context, active, created_at, updated_at) VALUES ` values := make([]interface{}, 0, len(tasks)*4) - ca := createdAt.UTC() - ua := updatedAt.UTC() - - for i, ts := range tasks { + for i, t := range tasks { if i > 0 { query += "," } - query += "(?, ?, ?, ?)" - values = append(values, ts, active, ca, ua) + query += "(?, ?, ?, ?, ?)" + values = append(values, t.Summary, t.Context, t.Active, t.CreatedAt.UTC(), t.UpdatedAt.UTC()) } + query += ";" res, err := tx.Exec(query, values...) @@ -232,13 +236,21 @@ VALUES ` return -1, err } - newTaskIDs := make([]int, len(tasks)) - counter := 0 - for i := int(lastInsertID) - len(tasks) + 1; i <= int(lastInsertID); i++ { - newTaskIDs[counter] = i - counter++ + var newTaskIDs []int + taskID := int(lastInsertID) - len(tasks) + 1 + for _, t := range tasks { + if t.Active { + newTaskIDs = append(newTaskIDs, taskID) + } + taskID++ + } + + var updatedSeqItems []int + if insertAtTop { + updatedSeqItems = append(newTaskIDs, seqItems...) + } else { + updatedSeqItems = append(seqItems, newTaskIDs...) } - updatedSeqItems := append(newTaskIDs, seqItems...) sequenceJSON, err := json.Marshal(updatedSeqItems) if err != nil { @@ -267,69 +279,6 @@ WHERE id = 1; return lastInsertID, nil } -func InsertTasks(db *sql.DB, tasks []types.Task) error { - tx, err := db.Begin() - if err != nil { - return err - } - defer func() { - _ = tx.Rollback() - }() - - query := `INSERT INTO task (summary, context, active, created_at, updated_at) -VALUES ` - - values := make([]interface{}, 0, len(tasks)*4) - - var seqItems []int - seqCounter := 1 - for i, t := range tasks { - if i > 0 { - query += "," - } - query += "(?, ?, ?, ?, ?)" - values = append(values, t.Summary, t.Context, t.Active, t.CreatedAt.UTC(), t.UpdatedAt.UTC()) - - if t.Active { - seqItems = append(seqItems, seqCounter) - } - seqCounter++ - } - - query += ";" - - _, err = tx.Exec(query, values...) - if err != nil { - return err - } - - sequenceJSON, err := json.Marshal(seqItems) - if err != nil { - return err - } - - seqUpdateStmt, err := tx.Prepare(` -UPDATE task_sequence -SET sequence = ? -WHERE id = 1; -`) - if err != nil { - return err - } - defer seqUpdateStmt.Close() - - _, err = seqUpdateStmt.Exec(sequenceJSON) - if err != nil { - return err - } - - err = tx.Commit() - if err != nil { - return err - } - return nil -} - func UpdateTaskSummary(db *sql.DB, id uint64, summary string, updatedAt time.Time) error { stmt, err := db.Prepare(` UPDATE task diff --git a/internal/persistence/queries_test.go b/internal/persistence/queries_test.go index fe49693..9d0d06a 100644 --- a/internal/persistence/queries_test.go +++ b/internal/persistence/queries_test.go @@ -14,11 +14,7 @@ import ( _ "modernc.org/sqlite" // sqlite driver ) -var ( - testDB *sql.DB - numSeedActive = 3 - numSeedInActive = 2 -) +var testDB *sql.DB func TestMain(m *testing.M) { var err error @@ -49,6 +45,10 @@ func cleanupDB(t *testing.T) { if err != nil { t.Fatalf("failed to clean up table %q: %v", tbl, err) } + _, err := testDB.Exec("DELETE FROM sqlite_sequence WHERE name=?;", tbl) + if err != nil { + t.Fatalf("failed to reset auto increment for table %q: %v", tbl, err) + } } _, err = testDB.Exec(`UPDATE task_sequence SET sequence = '[]' @@ -58,14 +58,15 @@ WHERE id = 1;`) } } -func seedDB(t *testing.T, db *sql.DB) { - t.Helper() +func getSampleTasks() ([]types.Task, int, int) { + numActive := 3 + numInactive := 2 - tasks := make([]types.Task, numSeedActive+numSeedInActive) - contexts := make([]string, numSeedActive+numSeedInActive) + tasks := make([]types.Task, numActive+numInactive) + contexts := make([]string, numActive+numInactive) now := time.Now().UTC() counter := 0 - for range numSeedActive { + for range numActive { contexts[counter] = fmt.Sprintf("context for task %d", counter) tasks[counter] = types.Task{ Summary: fmt.Sprintf("prefix: task %d", counter), @@ -76,7 +77,7 @@ func seedDB(t *testing.T, db *sql.DB) { } counter++ } - for range numSeedInActive { + for range numInactive { contexts[counter] = fmt.Sprintf("context for task %d", counter) tasks[counter] = types.Task{ Summary: fmt.Sprintf("prefix: task %d", counter), @@ -87,6 +88,15 @@ func seedDB(t *testing.T, db *sql.DB) { } counter++ } + + return tasks, numActive, numInactive +} + +func seedDB(t *testing.T, db *sql.DB) (int, int) { + t.Helper() + + tasks, na, ni := getSampleTasks() + for _, task := range tasks { _, err := db.Exec(` INSERT INTO task (summary, active, created_at, updated_at) @@ -96,8 +106,8 @@ VALUES (?, ?, ?, ?)`, task.Summary, task.Active, task.CreatedAt, task.UpdatedAt) } } - seqItems := make([]int, numSeedActive) - for i := range numSeedActive { + seqItems := make([]int, na) + for i := range na { seqItems[i] = i + 1 } sequenceJSON, err := json.Marshal(seqItems) @@ -113,15 +123,15 @@ WHERE id = 1; if err != nil { t.Fatalf("failed to insert data into table \"task_sequence\": %v", err) } + + return na, ni } func TestImportTask(t *testing.T) { t.Cleanup(func() { cleanupDB(t) }) // GIVEN - seedDB(t, testDB) - numActiveTasksBefore, err := fetchNumActiveTasks(testDB) - require.NoError(t, err) + na, _ := seedDB(t, testDB) // WHEN summary := "prefix: an imported task" @@ -132,7 +142,7 @@ func TestImportTask(t *testing.T) { // THEN numActiveTasksAfter, err := fetchNumActiveTasks(testDB) require.NoError(t, err) - assert.Equal(t, numActiveTasksAfter, numActiveTasksBefore+1, "number of active tasks didn't increase by 1") + assert.Equal(t, numActiveTasksAfter, na+1, "number of active tasks didn't increase by 1") task, err := fetchTaskByID(testDB, lastID) require.NoError(t, err) @@ -141,51 +151,117 @@ func TestImportTask(t *testing.T) { seq, err := fetchTaskSequence(testDB) require.NoError(t, err) - require.Equal(t, numActiveTasksAfter, len(seq), "number of tasks in task sequence doesn't match number of active tasks") - assert.Equal(t, seq[0], task.ID, "newly added task is not shown at the top of the list") + assert.Equal(t, seq, []uint64{6, 1, 2, 3}, "task sequence isn't correct") } -func TestImportTaskSummaries(t *testing.T) { +func TestInsertTasksWorksWithEmptyTaskList(t *testing.T) { t.Cleanup(func() { cleanupDB(t) }) // GIVEN - seedDB(t, testDB) - numActiveTasksBefore, err := fetchNumActiveTasks(testDB) + // WHEN + tasks, na, ni := getSampleTasks() + lastID, err := InsertTasks(testDB, tasks, true) + assert.Equal(t, lastID, int64(na+ni), "last ID is not correct") require.NoError(t, err) + // THEN + numActiveRes, err := fetchNumActiveTasks(testDB) + require.NoError(t, err) + assert.Equal(t, numActiveRes, na, "number of active tasks didn't increase by the correct amount") + + numTotalRes, err := fetchNumTotalTasks(testDB) + require.NoError(t, err) + assert.Equal(t, numTotalRes, na+ni, "number of total tasks didn't increase by the correct amount") + + lastTask, err := fetchTaskByID(testDB, lastID) + require.NoError(t, err) + assert.Equal(t, tasks[len(tasks)-1].Active, lastTask.Active) + assert.Equal(t, tasks[len(tasks)-1].Summary, lastTask.Summary) + assert.Equal(t, tasks[len(tasks)-1].Context, lastTask.Context) + + seq, err := fetchTaskSequence(testDB) + require.NoError(t, err) + assert.Equal(t, seq, []uint64{1, 2, 3}, "task sequence isn't correct") +} + +func TestInsertTasksAddsTasksAtTheTop(t *testing.T) { + t.Cleanup(func() { cleanupDB(t) }) + + // GIVEN + na, ni := seedDB(t, testDB) + // WHEN - newTaskSummaries := []string{ - "prefix: imported task 1", - "prefix: imported task 2", - "prefix: imported task 3", - } now := time.Now().UTC() - lastID, err := ImportTaskSummaries(testDB, newTaskSummaries, true, now, now) + tasks := []types.Task{ + { + Summary: "prefix: new task 1", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new inactive task 1", + Active: false, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new task 3", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + } + + _, err := InsertTasks(testDB, tasks, true) require.NoError(t, err) // THEN - numActiveTasksAfter, err := fetchNumActiveTasks(testDB) + numActiveRes, err := fetchNumActiveTasks(testDB) require.NoError(t, err) - assert.Equal(t, numActiveTasksAfter, numActiveTasksBefore+len(newTaskSummaries), "number of active tasks didn't increase by the correct amount") + assert.Equal(t, numActiveRes, na+2, "number of active tasks didn't increase by the correct amount") - task, err := fetchTaskByID(testDB, lastID) + numTotalRes, err := fetchNumTotalTasks(testDB) require.NoError(t, err) - assert.True(t, task.Active) - assert.Equal(t, newTaskSummaries[2], task.Summary) + assert.Equal(t, numTotalRes, na+ni+3, "number of total tasks didn't increase by the correct amount") seq, err := fetchTaskSequence(testDB) require.NoError(t, err) - require.Equal(t, numActiveTasksAfter, len(seq), "number of tasks in task sequence doesn't match number of active tasks") - - // ensure new task sequence is correct - // that is: - // imported task 1 - // imported task 2 - // imported task 3 - // ... old sequence - currentID := int(lastID) - len(newTaskSummaries) + 1 - for i := range len(newTaskSummaries) { - assert.Equal(t, currentID, int(seq[i]), "task at sequence position %d is incorrect", i+1) - currentID++ + assert.Equal(t, seq, []uint64{6, 8, 1, 2, 3}, "task sequence isn't correct") +} + +func TestInsertTasksAddsTasksAtTheEnd(t *testing.T) { + t.Cleanup(func() { cleanupDB(t) }) + + // GIVEN + na, _ := seedDB(t, testDB) + + // WHEN + now := time.Now().UTC() + tasks := []types.Task{ + { + Summary: "prefix: new task 1", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new task 2", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, } + + _, err := InsertTasks(testDB, tasks, false) + require.NoError(t, err) + + // THEN + numActiveRes, err := fetchNumActiveTasks(testDB) + require.NoError(t, err) + assert.Equal(t, numActiveRes, na+2, "number of active tasks didn't increase by the correct amount") + + seq, err := fetchTaskSequence(testDB) + require.NoError(t, err) + assert.Equal(t, seq, []uint64{1, 2, 3, 6, 7}, "task sequence isn't correct") }