Skip to content

Commit

Permalink
Optimize access to the sqlite database
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbakker committed Apr 1, 2024
1 parent cd367a4 commit 63c7c96
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 50 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
git diff --exit-code
- name: Test
run: |
nix develop -c go test -tags sqlite_foreign_keys -v ./...
nix develop -c go test -v ./...
- name: Build
run: |
nix build --print-build-logs
Expand Down
22 changes: 11 additions & 11 deletions cmd/toxstatus/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package cmd

import (
"context"
"database/sql"
"errors"
"fmt"
"log/slog"
Expand Down Expand Up @@ -38,6 +37,7 @@ var (
PprofAddr string
ToxUDPAddr string
DB string
DBCacheSize int
LogLevel string
Workers int
}{}
Expand All @@ -49,7 +49,8 @@ func init() {
Root.Flags().DurationVar(&rootFlags.HTTPClientTimeout, "http-client-timeout", 10*time.Second, "the http client timeout for requests to nodes.tox.chat")
Root.Flags().StringVar(&rootFlags.PprofAddr, "pprof-addr", "", "the network address to listen of for the pprof HTTP server")
Root.Flags().StringVar(&rootFlags.ToxUDPAddr, "tox-udp-addr", ":33450", "the UDP network address to listen on for Tox")
Root.Flags().StringVar(&rootFlags.DB, "db", "", "the sqlite database to use")
Root.Flags().StringVar(&rootFlags.DB, "db", "", "the sqlite database file to use")
Root.Flags().IntVar(&rootFlags.DBCacheSize, "db-cache-size", 100000, "the sqlite cache size to use (in KB)")
Root.Flags().StringVar(&rootFlags.LogLevel, "log-level", "info", "the log level to use")
Root.Flags().IntVar(&rootFlags.Workers, "workers", min(maxDefaultWorkers, runtime.NumCPU()), "the amount of workers to use")
Root.MarkFlagRequired("db")
Expand All @@ -71,17 +72,16 @@ func startRoot(cmd *cobra.Command, args []string) {
NoColor: !isatty.IsTerminal(os.Stderr.Fd()),
}))

dbConn, err := sql.Open("sqlite3", rootFlags.DB)
readConn, writeConn, err := db.OpenReadWrite(ctx, rootFlags.DB, db.OpenOptions{
CacheSize: rootFlags.DBCacheSize,
})
if err != nil {
logErrorAndExit(logger, "Unable to open db", slog.Any("err", err))
return
}
defer dbConn.Close()

if _, err := dbConn.ExecContext(ctx, db.Schema); err != nil {
logErrorAndExit(logger, "Unable to initialize db", slog.Any("err", err))
return
}
defer func() {
readConn.Close()
writeConn.Close()
}()

if rootFlags.PprofAddr != "" {
logger.Info("Starting pprof server")
Expand All @@ -106,7 +106,7 @@ func startRoot(cmd *cobra.Command, args []string) {
}()
}

nodesRepo := repo.New(dbConn)
nodesRepo := repo.New(readConn, writeConn)
cr, err := crawler.New(nodesRepo, crawler.CrawlerOptions{
Logger: logger,
HTTPAddr: rootFlags.HTTPAddr,
Expand Down
2 changes: 0 additions & 2 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
subPackages = [ "cmd/toxstatus" ];
vendorHash = "sha256-5cVWDVroDrC32xq5p0DkeRBgxHGfA178JdfgiPvnAbw=";

tags = ["sqlite_foreign_keys"];

ldflags = let
pkgPath = "github.com/Tox/ToxStatus/internal/version";
in [
Expand Down
2 changes: 1 addition & 1 deletion internal/crawler/crawler.go
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ func (c *Crawler) receivePacket(ctx context.Context, data []byte, addr *net.UDPA
c.handleInfoChan <- &infoPacket{Addr: addr, Packet: bsPacket}
return nil
}
if err != nil && !errors.Is(err, bootstrap.ErrUnknownPacketType) {
if !errors.Is(err, bootstrap.ErrUnknownPacketType) {
return fmt.Errorf("bootstrap info packet check: %w", err)
}

Expand Down
68 changes: 68 additions & 0 deletions internal/db/open.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package db

import (
"context"
"database/sql"
"fmt"
"net/url"
"runtime"
)

type OpenOptions struct {
CacheSize int
Params map[string]string
}

func OpenReadWrite(ctx context.Context, dbFile string, opts OpenOptions) (rdb *sql.DB, wdb *sql.DB, err error) {
uri := &url.URL{
Scheme: "file",
Opaque: dbFile,
}
query := uri.Query()
if opts.Params != nil {
for k, v := range opts.Params {
query.Set(k, v)
}
}
query.Set("_txlock", "immediate")
uri.RawQuery = query.Encode()

readConn, err := sql.Open("sqlite3", uri.String())
if err != nil {
return nil, nil, err
}
defer func() {
if err != nil {
readConn.Close()
}
}()
readConn.SetMaxOpenConns(max(4, runtime.NumCPU()))

writeConn, err := sql.Open("sqlite3", uri.String())
if err != nil {
return nil, nil, err
}
defer func() {
if err != nil {
writeConn.Close()
}
}()
writeConn.SetMaxOpenConns(1)

if _, err = writeConn.ExecContext(ctx, fmt.Sprintf(`
PRAGMA journal_mode = WAL;
PRAGMA busy_timeout = 5000;
PRAGMA synchronous = NORMAL;
PRAGMA cache_size = -%d;
PRAGMA foreign_keys = true;
PRAGMA temp_store = memory;
`, opts.CacheSize)); err != nil {
return nil, nil, fmt.Errorf("configure db: %w", err)
}

if _, err = writeConn.ExecContext(ctx, Schema); err != nil {
return nil, nil, fmt.Errorf("init db: %w", err)
}

return readConn, writeConn, nil
}
5 changes: 0 additions & 5 deletions internal/db/types.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
//go:build sqlite_foreign_keys

// By specifying our go-sqlite3 build tags above, the build will fail if we
// forget to specify it in the go build/test command.

package db

import (
Expand Down
44 changes: 23 additions & 21 deletions internal/repo/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,26 @@ import (
var ErrNotFound = fmt.Errorf("not found: %w", sql.ErrNoRows)

type NodesRepo struct {
db *sql.DB
q *db.Queries
wdb *sql.DB
rq *db.Queries
wq *db.Queries
}

type nodeAddressCombo struct {
Node db.Node
NodeAddress db.NodeAddress
}

func New(sqldb *sql.DB) *NodesRepo {
func New(rdb *sql.DB, wdb *sql.DB) *NodesRepo {
return &NodesRepo{
db: sqldb,
q: db.New(sqldb),
wdb: wdb,
rq: db.New(rdb),
wq: db.New(wdb),
}
}

func (r *NodesRepo) GetNodeByPublicKey(ctx context.Context, pk *dht.PublicKey) (*models.Node, error) {
rows, err := r.q.GetNodeByPublicKey(ctx, (*db.PublicKey)(pk))
rows, err := r.rq.GetNodeByPublicKey(ctx, (*db.PublicKey)(pk))
if err != nil {
return nil, err
}
Expand All @@ -53,7 +55,7 @@ func (r *NodesRepo) GetNodeByPublicKey(ctx context.Context, pk *dht.PublicKey) (
}

func (r *NodesRepo) HasNodeByPublicKey(ctx context.Context, pk *dht.PublicKey) (bool, error) {
res, err := r.q.HasNodeByPublicKey(ctx, (*db.PublicKey)(pk))
res, err := r.rq.HasNodeByPublicKey(ctx, (*db.PublicKey)(pk))
if err != nil {
return false, err
}
Expand All @@ -62,17 +64,17 @@ func (r *NodesRepo) HasNodeByPublicKey(ctx context.Context, pk *dht.PublicKey) (
}

func (r *NodesRepo) GetNodeCount(ctx context.Context) (int64, error) {
return r.q.GetNodeCount(ctx)
return r.rq.GetNodeCount(ctx)
}

func (r *NodesRepo) TrackDHTNode(ctx context.Context, node *dht.Node) (*models.Node, error) {
tx, err := r.db.Begin()
tx, err := r.wdb.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()

q := r.q.WithTx(tx)
q := r.wq.WithTx(tx)
dbNode, err := q.UpsertNode(ctx, (*db.PublicKey)(node.PublicKey))
if err != nil {
return nil, fmt.Errorf("upsert node: %w", err)
Expand All @@ -99,7 +101,7 @@ func (r *NodesRepo) TrackDHTNode(ctx context.Context, node *dht.Node) (*models.N
}

func (r *NodesRepo) getDHTNodeAddressID(ctx context.Context, node *dht.Node) (int64, error) {
return r.q.GetNodeAddress(ctx, &db.GetNodeAddressParams{
return r.rq.GetNodeAddress(ctx, &db.GetNodeAddressParams{
PublicKey: (*db.PublicKey)(node.PublicKey),
Net: node.Type.Net(),
Ip: node.IP.String(),
Expand All @@ -116,7 +118,7 @@ func (r *NodesRepo) PingDHTNode(ctx context.Context, node *dht.Node) error {
return err
}

return r.q.PingNodeAddress(ctx, id)
return r.rq.PingNodeAddress(ctx, id)
}

func (r *NodesRepo) PongDHTNode(ctx context.Context, node *dht.Node) error {
Expand All @@ -128,11 +130,11 @@ func (r *NodesRepo) PongDHTNode(ctx context.Context, node *dht.Node) error {
return err
}

return r.q.PongNodeAddress(ctx, id)
return r.rq.PongNodeAddress(ctx, id)
}

func (r *NodesRepo) GetNodesWithStaleBootstrapInfo(ctx context.Context) ([]*models.Node, error) {
rows, err := r.q.GetNodesWithStaleBootstrapInfo(ctx, &db.GetNodesWithStaleBootstrapInfoParams{
rows, err := r.rq.GetNodesWithStaleBootstrapInfo(ctx, &db.GetNodesWithStaleBootstrapInfoParams{
NodeTimeout: (5 * time.Minute).Seconds(),
InfoInterval: (1 * time.Minute).Seconds(),
})
Expand All @@ -156,13 +158,13 @@ func (r *NodesRepo) GetNodesWithStaleBootstrapInfo(ctx context.Context) ([]*mode
}

func (r *NodesRepo) UpdateNodeInfoRequestTime(ctx context.Context, addrReqTimes map[int64]time.Time) error {
tx, err := r.db.Begin()
tx, err := r.wdb.Begin()
if err != nil {
return err
}
defer tx.Rollback()

q := r.q.WithTx(tx)
q := r.wq.WithTx(tx)
for id, reqTime := range addrReqTimes {
if err := q.UpdateNodeInfoRequestTime(ctx, &db.UpdateNodeInfoRequestTimeParams{
ID: id,
Expand All @@ -176,7 +178,7 @@ func (r *NodesRepo) UpdateNodeInfoRequestTime(ctx context.Context, addrReqTimes
}

func (r *NodesRepo) UpdateNodeInfo(ctx context.Context, addr *net.UDPAddr, motd string, version uint32) error {
tx, err := r.db.Begin()
tx, err := r.wdb.Begin()
if err != nil {
return err
}
Expand All @@ -189,8 +191,8 @@ func (r *NodesRepo) UpdateNodeInfo(ctx context.Context, addr *net.UDPAddr, motd
nodeType = dht.NodeTypeUDPIP6
}

q := r.q.WithTx(tx)
node, err := r.q.GetNodeByInfoResponseAddress(ctx, &db.GetNodeByInfoResponseAddressParams{
q := r.wq.WithTx(tx)
node, err := q.GetNodeByInfoResponseAddress(ctx, &db.GetNodeByInfoResponseAddressParams{
InfoReqTimeout: (10 * time.Second).Seconds(),
Net: nodeType.Net(),
Ip: addr.IP.String(),
Expand All @@ -212,7 +214,7 @@ func (r *NodesRepo) UpdateNodeInfo(ctx context.Context, addr *net.UDPAddr, motd
}

func (r *NodesRepo) GetResponsiveDHTNodes(ctx context.Context) ([]*dht.Node, error) {
rows, err := r.q.GetResponsiveNodes(ctx)
rows, err := r.rq.GetResponsiveNodes(ctx)
if err != nil {
return nil, err
}
Expand All @@ -229,7 +231,7 @@ func (r *NodesRepo) GetResponsiveDHTNodes(ctx context.Context) ([]*dht.Node, err
}

func (r *NodesRepo) GetUnresponsiveDHTNodes(ctx context.Context, retryDelay time.Duration) ([]*dht.Node, error) {
rows, err := r.q.GetUnresponsiveNodes(ctx, retryDelay.Seconds())
rows, err := r.rq.GetUnresponsiveNodes(ctx, retryDelay.Seconds())
if err != nil {
return nil, err
}
Expand Down
25 changes: 16 additions & 9 deletions internal/repo/repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"context"
"crypto/rand"
"database/sql"
"errors"
"net"
"testing"
Expand All @@ -18,16 +17,24 @@ import (
var ctx = context.Background()

func initRepo(t *testing.T) (repo *NodesRepo, close func() error) {
dbConn, err := sql.Open("sqlite3", ":memory:")
readConn, writeConn, err := db.OpenReadWrite(ctx, ":memory:", db.OpenOptions{
CacheSize: 2000,
Params: map[string]string{"cache": "shared"},
})
if err != nil {
t.Fatal(err)
}

if _, err := dbConn.ExecContext(ctx, db.Schema); err != nil {
t.Fatal(err)
return New(readConn, writeConn), func() error {
var errs []error
if err := readConn.Close(); err != nil {
errs = append(errs, err)
}
if err := writeConn.Close(); err != nil {
errs = append(errs, err)
}
return errors.Join(errs...)
}

return New(dbConn), dbConn.Close
}

func generateNode(t *testing.T) *models.Node {
Expand Down Expand Up @@ -68,7 +75,7 @@ func TestAddNode(t *testing.T) {
defer close()

node := generateNode(t)
dbNode, err := repo.q.UpsertNode(ctx, (*db.PublicKey)(node.PublicKey))
dbNode, err := repo.wq.UpsertNode(ctx, (*db.PublicKey)(node.PublicKey))
if err != nil {
t.Fatal(err)
}
Expand All @@ -93,7 +100,7 @@ func TestHasNodeByPublicKey(t *testing.T) {
defer close()

node := generateNode(t)
_, err := repo.q.UpsertNode(ctx, (*db.PublicKey)(node.PublicKey))
_, err := repo.wq.UpsertNode(ctx, (*db.PublicKey)(node.PublicKey))
if err != nil {
t.Fatal(err)
}
Expand All @@ -120,7 +127,7 @@ func TestPongNonExistentNode(t *testing.T) {
defer close()

pk := generatePublicKey(t)
_, err := repo.q.UpsertNode(ctx, (*db.PublicKey)(pk))
_, err := repo.wq.UpsertNode(ctx, (*db.PublicKey)(pk))
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit 63c7c96

Please sign in to comment.