-
Notifications
You must be signed in to change notification settings - Fork 0
/
tx_manager.go
63 lines (49 loc) · 1.32 KB
/
tx_manager.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
package pgxtransactor
import (
"context"
"fmt"
"github.com/jackc/pgx/v4"
)
type txKey string
const key = txKey("tx")
type transactionManager struct {
db DB
}
func NewTransactionManager(db DB) *transactionManager {
return &transactionManager{db: db}
}
func (tm *transactionManager) GetQueryEngine(ctx context.Context) QueryEngine {
tx, ok := ctx.Value(key).(QueryEngine)
if ok && tx != nil {
return tx
}
return tm.db
}
func rollback(ctx context.Context, tx pgx.Tx, err error) error {
if rerr := tx.Rollback(ctx); rerr != nil {
err = fmt.Errorf("%w: %v", err, rerr)
}
return err
}
func (tm *transactionManager) runTx(ctx context.Context, txLevel pgx.TxIsoLevel, fx func(ctxTX context.Context) error) error {
tx, err := tm.db.BeginTx(ctx,
pgx.TxOptions{
IsoLevel: txLevel,
})
if err != nil {
return err
}
if err := fx(context.WithValue(ctx, key, tx)); err != nil {
return rollback(ctx, tx, err)
}
if err := tx.Commit(ctx); err != nil {
return rollback(ctx, tx, err)
}
return nil
}
func (tm *transactionManager) RunReadCommitted(ctx context.Context, fx func(dbCtx context.Context) error) error {
return tm.runTx(ctx, pgx.ReadCommitted, fx)
}
func (tm *transactionManager) RunRepeatableRead(ctx context.Context, fx func(ctxTX context.Context) error) error {
return tm.runTx(ctx, pgx.RepeatableRead, fx)
}