Skip to content

Commit

Permalink
refactor: remove redundant function
Browse files Browse the repository at this point in the history
  • Loading branch information
dhth committed Aug 24, 2024
1 parent 2f9f1ba commit 7befdf8
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 103 deletions.
8 changes: 7 additions & 1 deletion cmd/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ func importTask(db *sql.DB, taskSummary string) error {
}

now := time.Now()
_, err = pers.ImportTask(db, taskSummary, true, now, now)
task := types.Task{
Summary: taskSummary,
Active: true,
CreatedAt: now,
UpdatedAt: now,
}
_, err = pers.InsertTasks(db, []types.Task{task}, true)
return err
}

Expand Down
67 changes: 0 additions & 67 deletions internal/persistence/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,73 +121,6 @@ VALUES (?, true, ?, ?);
return uint64(li), nil
}

func ImportTask(db *sql.DB, summary string, active bool, createdAt, updatedAt time.Time) (int64, error) {
tx, err := db.Begin()
if err != nil {
return -1, err
}
defer func() {
_ = tx.Rollback()
}()

query := `INSERT INTO task (summary, active, created_at, updated_at)
VALUES (?, ?, ?, ?);`

res, err := tx.Exec(query, summary, active, createdAt.UTC(), updatedAt.UTC())
if err != nil {
return -1, err
}

lastInsertID, err := res.LastInsertId()
if err != nil {
return -1, err
}

var seq []byte
seqRow := tx.QueryRow("SELECT sequence from task_sequence where id=1;")

err = seqRow.Scan(&seq)
if err != nil {
return -1, err
}

var seqItems []int
err = json.Unmarshal(seq, &seqItems)
if err != nil {
return -1, err
}

newTaskID := make([]int, 1)
newTaskID[0] = int(lastInsertID)
updatedSeqItems := append(newTaskID, seqItems...)

sequenceJSON, err := json.Marshal(updatedSeqItems)
if err != nil {
return -1, err
}

seqUpdateStmt, err := tx.Prepare(`
UPDATE task_sequence
SET sequence = ?
WHERE id = 1;
`)
if err != nil {
return -1, err
}
defer seqUpdateStmt.Close()

_, err = seqUpdateStmt.Exec(sequenceJSON)
if err != nil {
return -1, err
}

err = tx.Commit()
if err != nil {
return -1, err
}
return lastInsertID, nil
}

func InsertTasks(db *sql.DB, tasks []types.Task, insertAtTop bool) (int64, error) {
tx, err := db.Begin()
if err != nil {
Expand Down
63 changes: 28 additions & 35 deletions internal/persistence/queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,61 +127,54 @@ WHERE id = 1;
return na, ni
}

func TestImportTask(t *testing.T) {
t.Cleanup(func() { cleanupDB(t) })

// GIVEN
na, _ := seedDB(t, testDB)

// WHEN
summary := "prefix: an imported task"
now := time.Now().UTC()
lastID, err := ImportTask(testDB, summary, true, now, now)
require.NoError(t, err)

// THEN
numActiveTasksAfter, err := fetchNumActiveTasks(testDB)
require.NoError(t, err)
assert.Equal(t, numActiveTasksAfter, na+1, "number of active tasks didn't increase by 1")

task, err := fetchTaskByID(testDB, lastID)
require.NoError(t, err)
assert.True(t, task.Active)
assert.Equal(t, summary, task.Summary)

seq, err := fetchTaskSequence(testDB)
require.NoError(t, err)
assert.Equal(t, seq, []uint64{6, 1, 2, 3}, "task sequence isn't correct")
}

func TestInsertTasksWorksWithEmptyTaskList(t *testing.T) {
t.Cleanup(func() { cleanupDB(t) })

// GIVEN
// WHEN
tasks, na, ni := getSampleTasks()
now := time.Now().UTC()
tasks := []types.Task{
{
Summary: "prefix: new task 1",
Active: true,
CreatedAt: now,
UpdatedAt: now,
},
{
Summary: "prefix: new inactive task 1",
Active: false,
CreatedAt: now,
UpdatedAt: now,
},
{
Summary: "prefix: new task 3",
Active: true,
CreatedAt: now,
UpdatedAt: now,
},
}
lastID, err := InsertTasks(testDB, tasks, true)
assert.Equal(t, lastID, int64(na+ni), "last ID is not correct")
assert.Equal(t, lastID, int64(3), "last ID is not correct")
require.NoError(t, err)

// THEN
numActiveRes, err := fetchNumActiveTasks(testDB)
require.NoError(t, err)
assert.Equal(t, numActiveRes, na, "number of active tasks didn't increase by the correct amount")
assert.Equal(t, numActiveRes, 2, "number of active tasks didn't increase by the correct amount")

numTotalRes, err := fetchNumTotalTasks(testDB)
require.NoError(t, err)
assert.Equal(t, numTotalRes, na+ni, "number of total tasks didn't increase by the correct amount")
assert.Equal(t, numTotalRes, 3, "number of total tasks didn't increase by the correct amount")

lastTask, err := fetchTaskByID(testDB, lastID)
require.NoError(t, err)
assert.Equal(t, tasks[len(tasks)-1].Active, lastTask.Active)
assert.Equal(t, tasks[len(tasks)-1].Summary, lastTask.Summary)
assert.Equal(t, tasks[len(tasks)-1].Context, lastTask.Context)
assert.Equal(t, tasks[2].Active, lastTask.Active)
assert.Equal(t, tasks[2].Summary, lastTask.Summary)
assert.Equal(t, tasks[2].Context, lastTask.Context)

seq, err := fetchTaskSequence(testDB)
require.NoError(t, err)
assert.Equal(t, seq, []uint64{1, 2, 3}, "task sequence isn't correct")
assert.Equal(t, seq, []uint64{1, 3}, "task sequence isn't correct")
}

func TestInsertTasksAddsTasksAtTheTop(t *testing.T) {
Expand Down

0 comments on commit 7befdf8

Please sign in to comment.