diff --git a/cmd/toxstatus/cmd/root.go b/cmd/toxstatus/cmd/root.go index e94153c..b8a301b 100644 --- a/cmd/toxstatus/cmd/root.go +++ b/cmd/toxstatus/cmd/root.go @@ -72,9 +72,8 @@ func startRoot(cmd *cobra.Command, args []string) { NoColor: !isatty.IsTerminal(os.Stderr.Fd()), })) - readConn, writeConn, err := db.OpenReadWrite(ctx, rootFlags.DB, db.OpenOptions{ - CacheSize: rootFlags.DBCacheSize, - }) + db.RegisterPragmaHook(rootFlags.DBCacheSize) + readConn, writeConn, err := db.OpenReadWrite(ctx, rootFlags.DB, db.OpenOptions{}) if err != nil { logErrorAndExit(logger, "Unable to open db", slog.Any("err", err)) } diff --git a/internal/db/open.go b/internal/db/open.go index 532d063..9b626c5 100644 --- a/internal/db/open.go +++ b/internal/db/open.go @@ -6,11 +6,30 @@ import ( "fmt" "net/url" "runtime" + + "github.com/mattn/go-sqlite3" ) type OpenOptions struct { - CacheSize int - Params map[string]string + Params map[string]string +} + +func RegisterPragmaHook(cacheSize int) { + sql.Register("toxstatus_sqlite3", &sqlite3.SQLiteDriver{ + ConnectHook: func(c *sqlite3.SQLiteConn) error { + fmt.Println("Executing pragmas") + pragmas := fmt.Sprintf(` + PRAGMA journal_mode = WAL; + PRAGMA busy_timeout = 5000; + PRAGMA synchronous = NORMAL; + PRAGMA cache_size = -%d; + PRAGMA foreign_keys = true; + PRAGMA temp_store = memory; + `, cacheSize) + _, err := c.Exec(pragmas, nil) + return err + }, + }) } func OpenReadWrite(ctx context.Context, dbFile string, opts OpenOptions) (rdb *sql.DB, wdb *sql.DB, err error) { @@ -27,16 +46,7 @@ func OpenReadWrite(ctx context.Context, dbFile string, opts OpenOptions) (rdb *s query.Set("_txlock", "immediate") uri.RawQuery = query.Encode() - pragmas := fmt.Sprintf(` - PRAGMA journal_mode = WAL; - PRAGMA busy_timeout = 5000; - PRAGMA synchronous = NORMAL; - PRAGMA cache_size = -%d; - PRAGMA foreign_keys = true; - PRAGMA temp_store = memory; - `, opts.CacheSize) - - readConn, err := sql.Open("sqlite3", uri.String()) + readConn, err := sql.Open("toxstatus_sqlite3", uri.String()) if err != nil { return nil, nil, err } @@ -47,11 +57,7 @@ func OpenReadWrite(ctx context.Context, dbFile string, opts OpenOptions) (rdb *s }() readConn.SetMaxOpenConns(max(4, runtime.NumCPU())) - if _, err = readConn.ExecContext(ctx, pragmas); err != nil { - return nil, nil, fmt.Errorf("configure db conn: %w", err) - } - - writeConn, err := sql.Open("sqlite3", uri.String()) + writeConn, err := sql.Open("toxstatus_sqlite3", uri.String()) if err != nil { return nil, nil, err } @@ -62,10 +68,6 @@ func OpenReadWrite(ctx context.Context, dbFile string, opts OpenOptions) (rdb *s }() writeConn.SetMaxOpenConns(1) - if _, err = writeConn.ExecContext(ctx, pragmas); err != nil { - return nil, nil, fmt.Errorf("configure db conn: %w", err) - } - if _, err = writeConn.ExecContext(ctx, Schema); err != nil { return nil, nil, fmt.Errorf("init db: %w", err) } diff --git a/internal/repo/repo_test.go b/internal/repo/repo_test.go index 568ae1d..a546fcd 100644 --- a/internal/repo/repo_test.go +++ b/internal/repo/repo_test.go @@ -16,10 +16,13 @@ import ( var ctx = context.Background() +func init() { + db.RegisterPragmaHook(2000) +} + func initRepo(t *testing.T) (repo *NodesRepo, close func() error) { readConn, writeConn, err := db.OpenReadWrite(ctx, ":memory:", db.OpenOptions{ - CacheSize: 2000, - Params: map[string]string{"cache": "shared"}, + Params: map[string]string{"cache": "shared"}, }) if err != nil { t.Fatal(err)