diff --git a/sequel.go b/sequel.go index 8cf9c45..92473d7 100644 --- a/sequel.go +++ b/sequel.go @@ -455,6 +455,11 @@ func (t *Tx) Query(query string, args ...any) (*sql.Rows, error) { return t.tx.Query(query, args...) } +// QueryContext works like Query but with context. +func (t *Tx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + return t.tx.QueryContext(ctx, query, args...) +} + // QueryRow executes a query that is expected to return at most one row. // QueryRowContext always returns a non-nil value. Errors are deferred until // Row's Scan method is called. @@ -466,12 +471,22 @@ func (t *Tx) QueryRow(query string, args ...any) *sql.Row { return t.tx.QueryRow(query, args...) } +// QueryRowContext works like QueryRow but with context. +func (t *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { + return t.tx.QueryRowContext(ctx, query, args...) +} + // Exec executes a query without returning any rows. The args are for any // placeholder parameters in the query. func (t *Tx) Exec(query string, args ...any) (sql.Result, error) { return t.tx.Exec(query, args...) } +// ExecContext works like Exec but with context. +func (t *Tx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + return t.tx.ExecContext(ctx, query, args...) +} + // Query executes a query that returns rows, typically a SELECT. The query is // rebound from `?` to the DB driver's bind type. The args are for any // placeholder parameters in the query. diff --git a/sequel_test.go b/sequel_test.go index 06a9dd7..f1d912f 100644 --- a/sequel_test.go +++ b/sequel_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/go-sqlx/sqlx" + "github.com/google/uuid" "github.com/jackc/pgx/v5/pgconn" _ "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" @@ -500,6 +501,10 @@ func TestTxQueries(t *testing.T) { Email: NullString("jolly@example.com"), }, } + p3 := &personModel{ + Name: "Kelly Klimber", + Email: NullString("kelly@example.com"), + } t.Run("rebind", func(t *testing.T) { tx, err := db.Begin(ctx) @@ -543,6 +548,21 @@ func TestTxQueries(t *testing.T) { assert.NoError(t, tx.Commit()) }) + t.Run("queryContext", func(t *testing.T) { + tx, err := db.Begin(ctx) + require.NoError(t, err) + rows, err := tx.QueryContext(ctx, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) + assert.NoError(t, err) + for rows.Next() { + var p personModel + assert.NoError(t, rows.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + assertEqualPerson(t, p1, &p) + } + assert.NoError(t, rows.Err()) + assert.NoError(t, rows.Close()) //nolint:sqlclosecheck // no defer for testing purposes + assert.NoError(t, tx.Commit()) + }) + t.Run("queryRow", func(t *testing.T) { var p personModel tx, err := db.Begin(ctx) @@ -554,6 +574,17 @@ func TestTxQueries(t *testing.T) { assert.NoError(t, tx.Commit()) }) + t.Run("queryRowContext", func(t *testing.T) { + var p personModel + tx, err := db.Begin(ctx) + require.NoError(t, err) + row := tx.QueryRowContext(ctx, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) + assert.NoError(t, row.Err()) + assert.NoError(t, row.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + assertEqualPerson(t, p1, &p) + assert.NoError(t, tx.Commit()) + }) + t.Run("rebindQuery", func(t *testing.T) { tx, err := db.Begin(ctx) require.NoError(t, err) @@ -742,6 +773,21 @@ func TestTxQueries(t *testing.T) { assert.NoError(t, tx.Commit()) }) + t.Run("execContext", func(t *testing.T) { + tx, err := db.Begin(ctx) + require.NoError(t, err) + defer func() { + assert.Error(t, tx.Rollback()) + }() + + res, err := tx.Exec(personExecQ, uuid.NewString(), p3.CreatedAt, p3.UpdatedAt, nil, p3.Name, p3.Email) + require.NoError(t, err) + n, err := res.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, int64(1), n) + assert.NoError(t, tx.Commit()) + }) + t.Run("exec (clear table)", func(t *testing.T) { _, err := db.Exec(ctx, "DELETE FROM person_test") assert.NoError(t, err)