-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fa5b8ea
commit 1d7e949
Showing
2 changed files
with
129 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package db_test | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"fmt" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
"github.com/xataio/pgroll/pkg/db" | ||
"github.com/xataio/pgroll/pkg/testutils" | ||
) | ||
|
||
func TestMain(m *testing.M) { | ||
testutils.SharedTestMain(m) | ||
} | ||
|
||
func TestExecContext(t *testing.T) { | ||
t.Parallel() | ||
|
||
testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) { | ||
ctx := context.Background() | ||
// create a table on which an exclusive lock is held for 2 seconds | ||
setupTableLock(t, connStr, 2*time.Second) | ||
|
||
// set the lock timeout to 100ms | ||
ensureLockTimeout(t, conn, 100) | ||
|
||
// execute a query that should retry until the lock is released | ||
rdb := &db.RDB{DB: conn} | ||
_, err := rdb.ExecContext(ctx, "INSERT INTO test(id) VALUES (1)") | ||
require.NoError(t, err) | ||
}) | ||
} | ||
|
||
func TestWithRetryableTransaction(t *testing.T) { | ||
t.Parallel() | ||
|
||
testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) { | ||
ctx := context.Background() | ||
|
||
// create a table on which an exclusive lock is held for 2 seconds | ||
setupTableLock(t, connStr, 2*time.Second) | ||
|
||
// set the lock timeout to 100ms | ||
ensureLockTimeout(t, conn, 100) | ||
|
||
// run a transaction that should retry until the lock is released | ||
rdb := &db.RDB{DB: conn} | ||
err := rdb.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error { | ||
return tx.QueryRowContext(ctx, "SELECT 1 FROM test").Err() | ||
}) | ||
require.NoError(t, err) | ||
}) | ||
} | ||
|
||
// setupTableLock: | ||
// * connects to the database | ||
// * creates a table in the database | ||
// * starts a transaction that temporarily locks the table | ||
func setupTableLock(t *testing.T, connStr string, d time.Duration) { | ||
t.Helper() | ||
ctx := context.Background() | ||
|
||
// connect to the database | ||
conn2, err := sql.Open("postgres", connStr) | ||
require.NoError(t, err) | ||
|
||
// create a table in the database | ||
_, err = conn2.ExecContext(ctx, "CREATE TABLE test (id INT PRIMARY KEY)") | ||
require.NoError(t, err) | ||
|
||
// start a transaction that takes a temporary lock on the table | ||
errCh := make(chan error) | ||
go func() { | ||
// begin a transaction | ||
tx, err := conn2.Begin() | ||
if err != nil { | ||
errCh <- err | ||
return | ||
} | ||
|
||
// lock the table | ||
_, err = tx.ExecContext(ctx, "LOCK TABLE test IN ACCESS EXCLUSIVE MODE") | ||
if err != nil { | ||
errCh <- err | ||
return | ||
} | ||
|
||
// signal that the lock is obtained | ||
errCh <- nil | ||
|
||
// temporarily hold the lock | ||
time.Sleep(d) | ||
|
||
// commit the transaction | ||
tx.Commit() | ||
}() | ||
|
||
// wait for the lock to be obtained | ||
err = <-errCh | ||
require.NoError(t, err) | ||
} | ||
|
||
func ensureLockTimeout(t *testing.T, conn *sql.DB, ms int) { | ||
t.Helper() | ||
|
||
// Set the lock timeout | ||
query := fmt.Sprintf("SET lock_timeout = '%dms'", ms) | ||
_, err := conn.ExecContext(context.Background(), query) | ||
require.NoError(t, err) | ||
|
||
// Ensure the lock timeout is set | ||
var lockTimeout string | ||
err = conn.QueryRowContext(context.Background(), "SHOW lock_timeout").Scan(&lockTimeout) | ||
require.NoError(t, err) | ||
require.Equal(t, fmt.Sprintf("%dms", ms), lockTimeout) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters