From b8e51e456559328e08352d49f81912d979883e7c Mon Sep 17 00:00:00 2001 From: Alexander Bakker Date: Mon, 1 Apr 2024 11:38:02 +0200 Subject: [PATCH] Optimize access to the sqlite database --- .github/workflows/build.yaml | 2 +- cmd/toxstatus/cmd/root.go | 22 +++++------ flake.nix | 2 - internal/crawler/crawler.go | 2 +- internal/db/db.go | 2 +- internal/db/models.go | 2 +- internal/db/open.go | 74 ++++++++++++++++++++++++++++++++++++ internal/db/queries.sql.go | 2 +- internal/db/types.go | 5 --- internal/repo/repo.go | 44 +++++++++++---------- internal/repo/repo_test.go | 25 +++++++----- 11 files changed, 129 insertions(+), 53 deletions(-) create mode 100644 internal/db/open.go diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 57db377..8a1f725 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -19,7 +19,7 @@ jobs: git diff --exit-code - name: Test run: | - nix develop -c go test -tags sqlite_foreign_keys -v ./... + nix develop -c go test -v ./... - name: Build run: | nix build --print-build-logs diff --git a/cmd/toxstatus/cmd/root.go b/cmd/toxstatus/cmd/root.go index 08722e8..e94153c 100644 --- a/cmd/toxstatus/cmd/root.go +++ b/cmd/toxstatus/cmd/root.go @@ -2,7 +2,6 @@ package cmd import ( "context" - "database/sql" "errors" "fmt" "log/slog" @@ -38,6 +37,7 @@ var ( PprofAddr string ToxUDPAddr string DB string + DBCacheSize int LogLevel string Workers int }{} @@ -49,7 +49,8 @@ func init() { Root.Flags().DurationVar(&rootFlags.HTTPClientTimeout, "http-client-timeout", 10*time.Second, "the http client timeout for requests to nodes.tox.chat") Root.Flags().StringVar(&rootFlags.PprofAddr, "pprof-addr", "", "the network address to listen of for the pprof HTTP server") Root.Flags().StringVar(&rootFlags.ToxUDPAddr, "tox-udp-addr", ":33450", "the UDP network address to listen on for Tox") - Root.Flags().StringVar(&rootFlags.DB, "db", "", "the sqlite database to use") + Root.Flags().StringVar(&rootFlags.DB, "db", "", "the sqlite database file to use") + Root.Flags().IntVar(&rootFlags.DBCacheSize, "db-cache-size", 100000, "the sqlite cache size to use (in KB)") Root.Flags().StringVar(&rootFlags.LogLevel, "log-level", "info", "the log level to use") Root.Flags().IntVar(&rootFlags.Workers, "workers", min(maxDefaultWorkers, runtime.NumCPU()), "the amount of workers to use") Root.MarkFlagRequired("db") @@ -71,17 +72,16 @@ func startRoot(cmd *cobra.Command, args []string) { NoColor: !isatty.IsTerminal(os.Stderr.Fd()), })) - dbConn, err := sql.Open("sqlite3", rootFlags.DB) + readConn, writeConn, err := db.OpenReadWrite(ctx, rootFlags.DB, db.OpenOptions{ + CacheSize: rootFlags.DBCacheSize, + }) if err != nil { logErrorAndExit(logger, "Unable to open db", slog.Any("err", err)) - return - } - defer dbConn.Close() - - if _, err := dbConn.ExecContext(ctx, db.Schema); err != nil { - logErrorAndExit(logger, "Unable to initialize db", slog.Any("err", err)) - return } + defer func() { + readConn.Close() + writeConn.Close() + }() if rootFlags.PprofAddr != "" { logger.Info("Starting pprof server") @@ -106,7 +106,7 @@ func startRoot(cmd *cobra.Command, args []string) { }() } - nodesRepo := repo.New(dbConn) + nodesRepo := repo.New(readConn, writeConn) cr, err := crawler.New(nodesRepo, crawler.CrawlerOptions{ Logger: logger, HTTPAddr: rootFlags.HTTPAddr, diff --git a/flake.nix b/flake.nix index 8d861a3..ae44558 100644 --- a/flake.nix +++ b/flake.nix @@ -18,8 +18,6 @@ subPackages = [ "cmd/toxstatus" ]; vendorHash = "sha256-5cVWDVroDrC32xq5p0DkeRBgxHGfA178JdfgiPvnAbw="; - tags = ["sqlite_foreign_keys"]; - ldflags = let pkgPath = "github.com/Tox/ToxStatus/internal/version"; in [ diff --git a/internal/crawler/crawler.go b/internal/crawler/crawler.go index 9b2d4e7..4626599 100644 --- a/internal/crawler/crawler.go +++ b/internal/crawler/crawler.go @@ -595,7 +595,7 @@ func (c *Crawler) receivePacket(ctx context.Context, data []byte, addr *net.UDPA c.handleInfoChan <- &infoPacket{Addr: addr, Packet: bsPacket} return nil } - if err != nil && !errors.Is(err, bootstrap.ErrUnknownPacketType) { + if !errors.Is(err, bootstrap.ErrUnknownPacketType) { return fmt.Errorf("bootstrap info packet check: %w", err) } diff --git a/internal/db/db.go b/internal/db/db.go index bdb151c..17d86e9 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package db diff --git a/internal/db/models.go b/internal/db/models.go index 541e5f1..212e572 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package db diff --git a/internal/db/open.go b/internal/db/open.go new file mode 100644 index 0000000..532d063 --- /dev/null +++ b/internal/db/open.go @@ -0,0 +1,74 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "net/url" + "runtime" +) + +type OpenOptions struct { + CacheSize int + Params map[string]string +} + +func OpenReadWrite(ctx context.Context, dbFile string, opts OpenOptions) (rdb *sql.DB, wdb *sql.DB, err error) { + uri := &url.URL{ + Scheme: "file", + Opaque: dbFile, + } + query := uri.Query() + if opts.Params != nil { + for k, v := range opts.Params { + query.Set(k, v) + } + } + query.Set("_txlock", "immediate") + uri.RawQuery = query.Encode() + + pragmas := fmt.Sprintf(` + PRAGMA journal_mode = WAL; + PRAGMA busy_timeout = 5000; + PRAGMA synchronous = NORMAL; + PRAGMA cache_size = -%d; + PRAGMA foreign_keys = true; + PRAGMA temp_store = memory; + `, opts.CacheSize) + + readConn, err := sql.Open("sqlite3", uri.String()) + if err != nil { + return nil, nil, err + } + defer func() { + if err != nil { + readConn.Close() + } + }() + readConn.SetMaxOpenConns(max(4, runtime.NumCPU())) + + if _, err = readConn.ExecContext(ctx, pragmas); err != nil { + return nil, nil, fmt.Errorf("configure db conn: %w", err) + } + + writeConn, err := sql.Open("sqlite3", uri.String()) + if err != nil { + return nil, nil, err + } + defer func() { + if err != nil { + writeConn.Close() + } + }() + writeConn.SetMaxOpenConns(1) + + if _, err = writeConn.ExecContext(ctx, pragmas); err != nil { + return nil, nil, fmt.Errorf("configure db conn: %w", err) + } + + if _, err = writeConn.ExecContext(ctx, Schema); err != nil { + return nil, nil, fmt.Errorf("init db: %w", err) + } + + return readConn, writeConn, nil +} diff --git a/internal/db/queries.sql.go b/internal/db/queries.sql.go index 80237ff..7a1e016 100644 --- a/internal/db/queries.sql.go +++ b/internal/db/queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 // source: queries.sql package db diff --git a/internal/db/types.go b/internal/db/types.go index 797fa18..174682b 100644 --- a/internal/db/types.go +++ b/internal/db/types.go @@ -1,8 +1,3 @@ -//go:build sqlite_foreign_keys - -// By specifying our go-sqlite3 build tags above, the build will fail if we -// forget to specify it in the go build/test command. - package db import ( diff --git a/internal/repo/repo.go b/internal/repo/repo.go index 0f1debe..4edea6e 100644 --- a/internal/repo/repo.go +++ b/internal/repo/repo.go @@ -17,8 +17,9 @@ import ( var ErrNotFound = fmt.Errorf("not found: %w", sql.ErrNoRows) type NodesRepo struct { - db *sql.DB - q *db.Queries + wdb *sql.DB + rq *db.Queries + wq *db.Queries } type nodeAddressCombo struct { @@ -26,15 +27,16 @@ type nodeAddressCombo struct { NodeAddress db.NodeAddress } -func New(sqldb *sql.DB) *NodesRepo { +func New(rdb *sql.DB, wdb *sql.DB) *NodesRepo { return &NodesRepo{ - db: sqldb, - q: db.New(sqldb), + wdb: wdb, + rq: db.New(rdb), + wq: db.New(wdb), } } func (r *NodesRepo) GetNodeByPublicKey(ctx context.Context, pk *dht.PublicKey) (*models.Node, error) { - rows, err := r.q.GetNodeByPublicKey(ctx, (*db.PublicKey)(pk)) + rows, err := r.rq.GetNodeByPublicKey(ctx, (*db.PublicKey)(pk)) if err != nil { return nil, err } @@ -53,7 +55,7 @@ func (r *NodesRepo) GetNodeByPublicKey(ctx context.Context, pk *dht.PublicKey) ( } func (r *NodesRepo) HasNodeByPublicKey(ctx context.Context, pk *dht.PublicKey) (bool, error) { - res, err := r.q.HasNodeByPublicKey(ctx, (*db.PublicKey)(pk)) + res, err := r.rq.HasNodeByPublicKey(ctx, (*db.PublicKey)(pk)) if err != nil { return false, err } @@ -62,17 +64,17 @@ func (r *NodesRepo) HasNodeByPublicKey(ctx context.Context, pk *dht.PublicKey) ( } func (r *NodesRepo) GetNodeCount(ctx context.Context) (int64, error) { - return r.q.GetNodeCount(ctx) + return r.rq.GetNodeCount(ctx) } func (r *NodesRepo) TrackDHTNode(ctx context.Context, node *dht.Node) (*models.Node, error) { - tx, err := r.db.Begin() + tx, err := r.wdb.Begin() if err != nil { return nil, err } defer tx.Rollback() - q := r.q.WithTx(tx) + q := r.wq.WithTx(tx) dbNode, err := q.UpsertNode(ctx, (*db.PublicKey)(node.PublicKey)) if err != nil { return nil, fmt.Errorf("upsert node: %w", err) @@ -99,7 +101,7 @@ func (r *NodesRepo) TrackDHTNode(ctx context.Context, node *dht.Node) (*models.N } func (r *NodesRepo) getDHTNodeAddressID(ctx context.Context, node *dht.Node) (int64, error) { - return r.q.GetNodeAddress(ctx, &db.GetNodeAddressParams{ + return r.rq.GetNodeAddress(ctx, &db.GetNodeAddressParams{ PublicKey: (*db.PublicKey)(node.PublicKey), Net: node.Type.Net(), Ip: node.IP.String(), @@ -116,7 +118,7 @@ func (r *NodesRepo) PingDHTNode(ctx context.Context, node *dht.Node) error { return err } - return r.q.PingNodeAddress(ctx, id) + return r.rq.PingNodeAddress(ctx, id) } func (r *NodesRepo) PongDHTNode(ctx context.Context, node *dht.Node) error { @@ -128,11 +130,11 @@ func (r *NodesRepo) PongDHTNode(ctx context.Context, node *dht.Node) error { return err } - return r.q.PongNodeAddress(ctx, id) + return r.rq.PongNodeAddress(ctx, id) } func (r *NodesRepo) GetNodesWithStaleBootstrapInfo(ctx context.Context) ([]*models.Node, error) { - rows, err := r.q.GetNodesWithStaleBootstrapInfo(ctx, &db.GetNodesWithStaleBootstrapInfoParams{ + rows, err := r.rq.GetNodesWithStaleBootstrapInfo(ctx, &db.GetNodesWithStaleBootstrapInfoParams{ NodeTimeout: (5 * time.Minute).Seconds(), InfoInterval: (1 * time.Minute).Seconds(), }) @@ -156,13 +158,13 @@ func (r *NodesRepo) GetNodesWithStaleBootstrapInfo(ctx context.Context) ([]*mode } func (r *NodesRepo) UpdateNodeInfoRequestTime(ctx context.Context, addrReqTimes map[int64]time.Time) error { - tx, err := r.db.Begin() + tx, err := r.wdb.Begin() if err != nil { return err } defer tx.Rollback() - q := r.q.WithTx(tx) + q := r.wq.WithTx(tx) for id, reqTime := range addrReqTimes { if err := q.UpdateNodeInfoRequestTime(ctx, &db.UpdateNodeInfoRequestTimeParams{ ID: id, @@ -176,7 +178,7 @@ func (r *NodesRepo) UpdateNodeInfoRequestTime(ctx context.Context, addrReqTimes } func (r *NodesRepo) UpdateNodeInfo(ctx context.Context, addr *net.UDPAddr, motd string, version uint32) error { - tx, err := r.db.Begin() + tx, err := r.wdb.Begin() if err != nil { return err } @@ -189,8 +191,8 @@ func (r *NodesRepo) UpdateNodeInfo(ctx context.Context, addr *net.UDPAddr, motd nodeType = dht.NodeTypeUDPIP6 } - q := r.q.WithTx(tx) - node, err := r.q.GetNodeByInfoResponseAddress(ctx, &db.GetNodeByInfoResponseAddressParams{ + q := r.wq.WithTx(tx) + node, err := q.GetNodeByInfoResponseAddress(ctx, &db.GetNodeByInfoResponseAddressParams{ InfoReqTimeout: (10 * time.Second).Seconds(), Net: nodeType.Net(), Ip: addr.IP.String(), @@ -212,7 +214,7 @@ func (r *NodesRepo) UpdateNodeInfo(ctx context.Context, addr *net.UDPAddr, motd } func (r *NodesRepo) GetResponsiveDHTNodes(ctx context.Context) ([]*dht.Node, error) { - rows, err := r.q.GetResponsiveNodes(ctx) + rows, err := r.rq.GetResponsiveNodes(ctx) if err != nil { return nil, err } @@ -229,7 +231,7 @@ func (r *NodesRepo) GetResponsiveDHTNodes(ctx context.Context) ([]*dht.Node, err } func (r *NodesRepo) GetUnresponsiveDHTNodes(ctx context.Context, retryDelay time.Duration) ([]*dht.Node, error) { - rows, err := r.q.GetUnresponsiveNodes(ctx, retryDelay.Seconds()) + rows, err := r.rq.GetUnresponsiveNodes(ctx, retryDelay.Seconds()) if err != nil { return nil, err } diff --git a/internal/repo/repo_test.go b/internal/repo/repo_test.go index 6e63c8c..568ae1d 100644 --- a/internal/repo/repo_test.go +++ b/internal/repo/repo_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "crypto/rand" - "database/sql" "errors" "net" "testing" @@ -18,16 +17,24 @@ import ( var ctx = context.Background() func initRepo(t *testing.T) (repo *NodesRepo, close func() error) { - dbConn, err := sql.Open("sqlite3", ":memory:") + readConn, writeConn, err := db.OpenReadWrite(ctx, ":memory:", db.OpenOptions{ + CacheSize: 2000, + Params: map[string]string{"cache": "shared"}, + }) if err != nil { t.Fatal(err) } - if _, err := dbConn.ExecContext(ctx, db.Schema); err != nil { - t.Fatal(err) + return New(readConn, writeConn), func() error { + var errs []error + if err := readConn.Close(); err != nil { + errs = append(errs, err) + } + if err := writeConn.Close(); err != nil { + errs = append(errs, err) + } + return errors.Join(errs...) } - - return New(dbConn), dbConn.Close } func generateNode(t *testing.T) *models.Node { @@ -68,7 +75,7 @@ func TestAddNode(t *testing.T) { defer close() node := generateNode(t) - dbNode, err := repo.q.UpsertNode(ctx, (*db.PublicKey)(node.PublicKey)) + dbNode, err := repo.wq.UpsertNode(ctx, (*db.PublicKey)(node.PublicKey)) if err != nil { t.Fatal(err) } @@ -93,7 +100,7 @@ func TestHasNodeByPublicKey(t *testing.T) { defer close() node := generateNode(t) - _, err := repo.q.UpsertNode(ctx, (*db.PublicKey)(node.PublicKey)) + _, err := repo.wq.UpsertNode(ctx, (*db.PublicKey)(node.PublicKey)) if err != nil { t.Fatal(err) } @@ -120,7 +127,7 @@ func TestPongNonExistentNode(t *testing.T) { defer close() pk := generatePublicKey(t) - _, err := repo.q.UpsertNode(ctx, (*db.PublicKey)(pk)) + _, err := repo.wq.UpsertNode(ctx, (*db.PublicKey)(pk)) if err != nil { t.Fatal(err) }