From 7823eb823950717fff1372e9b228ece95592ff42 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 23 Jan 2024 18:04:54 -0800 Subject: [PATCH] sqlite: remove txn interface --- persist/sqlite/consensus.go | 30 ++++----- persist/sqlite/init.go | 6 +- persist/sqlite/migrations.go | 2 +- persist/sqlite/peers.go | 16 ++--- persist/sqlite/sql.go | 123 +++++++++++++++-------------------- persist/sqlite/store.go | 17 +++-- persist/sqlite/wallet.go | 26 ++++---- 7 files changed, 104 insertions(+), 116 deletions(-) diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index ff9c1b4..9079fc8 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -17,12 +17,12 @@ type proofUpdater interface { UpdateElementProof(*types.StateElement) } -func insertChainIndex(tx txn, index types.ChainIndex) (id int64, err error) { +func insertChainIndex(tx *txn, index types.ChainIndex) (id int64, err error) { err = tx.QueryRow(`INSERT INTO chain_indices (height, block_id) VALUES ($1, $2) ON CONFLICT (block_id) DO UPDATE SET height=EXCLUDED.height RETURNING id`, index.Height, encode(index.ID)).Scan(&id) return } -func applyEvents(tx txn, events []wallet.Event) error { +func applyEvents(tx *txn, events []wallet.Event) error { stmt, err := tx.Prepare(`INSERT INTO events (date_created, index_id, event_type, event_data) VALUES ($1, $2, $3, $4) RETURNING id`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) @@ -64,7 +64,7 @@ func applyEvents(tx txn, events []wallet.Event) error { return nil } -func deleteSiacoinOutputs(tx txn, spent []types.SiacoinElement) error { +func deleteSiacoinOutputs(tx *txn, spent []types.SiacoinElement) error { addrStmt, err := tx.Prepare(`SELECT id, siacoin_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { return fmt.Errorf("failed to prepare lookup statement: %w", err) @@ -108,7 +108,7 @@ func deleteSiacoinOutputs(tx txn, spent []types.SiacoinElement) error { return nil } -func applySiacoinOutputs(tx txn, added map[types.Hash256]types.SiacoinElement) error { +func applySiacoinOutputs(tx *txn, added map[types.Hash256]types.SiacoinElement) error { addrStmt, err := tx.Prepare(`SELECT id, siacoin_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { return fmt.Errorf("failed to prepare lookup statement: %w", err) @@ -152,7 +152,7 @@ func applySiacoinOutputs(tx txn, added map[types.Hash256]types.SiacoinElement) e return nil } -func deleteSiafundOutputs(tx txn, spent []types.SiafundElement) error { +func deleteSiafundOutputs(tx *txn, spent []types.SiafundElement) error { addrStmt, err := tx.Prepare(`SELECT id, siafund_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { return fmt.Errorf("failed to prepare lookup statement: %w", err) @@ -199,7 +199,7 @@ func deleteSiafundOutputs(tx txn, spent []types.SiafundElement) error { return nil } -func applySiafundOutputs(tx txn, added map[types.Hash256]types.SiafundElement) error { +func applySiafundOutputs(tx *txn, added map[types.Hash256]types.SiafundElement) error { addrStmt, err := tx.Prepare(`SELECT id, siafund_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { return fmt.Errorf("failed to prepare lookup statement: %w", err) @@ -245,13 +245,13 @@ func applySiafundOutputs(tx txn, added map[types.Hash256]types.SiafundElement) e return nil } -func updateLastIndexedTip(tx txn, tip types.ChainIndex) error { +func updateLastIndexedTip(tx *txn, tip types.ChainIndex) error { _, err := tx.Exec(`UPDATE global_settings SET last_indexed_tip=$1`, encode(tip)) return err } -func getStateElementBatch(stmt *loggedStmt, offset, limit int) ([]types.StateElement, error) { - rows, err := stmt.Query(limit, offset) +func getStateElementBatch(s *stmt, offset, limit int) ([]types.StateElement, error) { + rows, err := s.Query(limit, offset) if err != nil { return nil, fmt.Errorf("failed to query siacoin elements: %w", err) } @@ -269,8 +269,8 @@ func getStateElementBatch(stmt *loggedStmt, offset, limit int) ([]types.StateEle return updated, nil } -func updateStateElement(stmt *loggedStmt, se types.StateElement) error { - res, err := stmt.Exec(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)) +func updateStateElement(s *stmt, se types.StateElement) error { + res, err := s.Exec(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)) if err != nil { return fmt.Errorf("failed to update siacoin element %q: %w", se.ID, err) } else if n, err := res.RowsAffected(); err != nil { @@ -282,7 +282,7 @@ func updateStateElement(stmt *loggedStmt, se types.StateElement) error { } // how slow is this going to be 😬? -func updateElementProofs(tx txn, table string, updater proofUpdater) error { +func updateElementProofs(tx *txn, table string, updater proofUpdater) error { stmt, err := tx.Prepare(`SELECT id, merkle_proof, leaf_index FROM ` + table + ` LIMIT $1 OFFSET $2`) if err != nil { return fmt.Errorf("failed to prepare batch statement: %w", err) @@ -314,7 +314,7 @@ func updateElementProofs(tx txn, table string, updater proofUpdater) error { } // applyChainUpdates applies the given chain updates to the database. -func applyChainUpdates(tx txn, updates []*chain.ApplyUpdate) error { +func applyChainUpdates(tx *txn, updates []*chain.ApplyUpdate) error { stmt, err := tx.Prepare(`SELECT id FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) @@ -404,7 +404,7 @@ func (s *Store) ProcessChainApplyUpdate(cau *chain.ApplyUpdate, mayCommit bool) s.updates = append(s.updates, cau) if mayCommit { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { if err := applyChainUpdates(tx, s.updates); err != nil { return err } @@ -424,7 +424,7 @@ func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { } // update has been committed, revert it - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { stmt, err := tx.Prepare(`SELECT id FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) diff --git a/persist/sqlite/init.go b/persist/sqlite/init.go index bc6c1a4..24ae378 100644 --- a/persist/sqlite/init.go +++ b/persist/sqlite/init.go @@ -17,13 +17,13 @@ import ( //go:embed init.sql var initDatabase string -func initializeSettings(tx txn, target int64) error { +func initializeSettings(tx *txn, target int64) error { _, err := tx.Exec(`INSERT INTO global_settings (id, db_version, last_indexed_tip) VALUES (0, ?, ?)`, target, encode(types.ChainIndex{})) return err } func (s *Store) initNewDatabase(target int64) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { if _, err := tx.Exec(initDatabase); err != nil { return fmt.Errorf("failed to initialize database: %w", err) } else if err := initializeSettings(tx, target); err != nil { @@ -48,7 +48,7 @@ func (s *Store) upgradeDatabase(current, target int64) error { } }() - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { for _, fn := range migrations[current-1:] { current++ start := time.Now() diff --git a/persist/sqlite/migrations.go b/persist/sqlite/migrations.go index 7aa3692..99d01e1 100644 --- a/persist/sqlite/migrations.go +++ b/persist/sqlite/migrations.go @@ -7,4 +7,4 @@ import ( // migrations is a list of functions that are run to migrate the database from // one version to the next. Migrations are used to update existing databases to // match the schema in init.sql. -var migrations = []func(tx txn, log *zap.Logger) error{} +var migrations = []func(tx *txn, log *zap.Logger) error{} diff --git a/persist/sqlite/peers.go b/persist/sqlite/peers.go index 3b59936..4d8de8d 100644 --- a/persist/sqlite/peers.go +++ b/persist/sqlite/peers.go @@ -13,14 +13,14 @@ import ( "go.uber.org/zap" ) -func getPeerInfo(tx txn, peer string) (syncer.PeerInfo, error) { +func getPeerInfo(tx *txn, peer string) (syncer.PeerInfo, error) { const query = `SELECT first_seen, last_connect, synced_blocks, sync_duration FROM syncer_peers WHERE peer_address=$1` var info syncer.PeerInfo err := tx.QueryRow(query, peer).Scan(decode(&info.FirstSeen), decode(&info.LastConnect), &info.SyncedBlocks, &info.SyncDuration) return info, err } -func (s *Store) updatePeerInfo(tx txn, peer string, info syncer.PeerInfo) error { +func (s *Store) updatePeerInfo(tx *txn, peer string, info syncer.PeerInfo) error { const query = `UPDATE syncer_peers SET first_seen=$1, last_connect=$2, synced_blocks=$3, sync_duration=$4 WHERE peer_address=$5 RETURNING peer_address` err := tx.QueryRow(query, encode(info.FirstSeen), encode(info.LastConnect), info.SyncedBlocks, info.SyncDuration, peer).Scan(&peer) return err @@ -28,7 +28,7 @@ func (s *Store) updatePeerInfo(tx txn, peer string, info syncer.PeerInfo) error // AddPeer adds the given peer to the store. func (s *Store) AddPeer(peer string) { - err := s.transaction(func(tx txn) error { + err := s.transaction(func(tx *txn) error { const query = `INSERT INTO syncer_peers (peer_address, first_seen, last_connect, synced_blocks, sync_duration) VALUES ($1, $2, 0, 0, 0) ON CONFLICT (peer_address) DO NOTHING` _, err := tx.Exec(query, peer, encode(time.Now())) return err @@ -40,7 +40,7 @@ func (s *Store) AddPeer(peer string) { // Peers returns the addresses of all known peers. func (s *Store) Peers() (peers []string) { - err := s.transaction(func(tx txn) error { + err := s.transaction(func(tx *txn) error { const query = `SELECT peer_address FROM syncer_peers` rows, err := tx.Query(query) if err != nil { @@ -64,7 +64,7 @@ func (s *Store) Peers() (peers []string) { // UpdatePeerInfo updates the info for the given peer. func (s *Store) UpdatePeerInfo(peer string, fn func(*syncer.PeerInfo)) { - err := s.transaction(func(tx txn) error { + err := s.transaction(func(tx *txn) error { info, err := getPeerInfo(tx, peer) if err != nil { return fmt.Errorf("failed to get peer info: %w", err) @@ -81,7 +81,7 @@ func (s *Store) UpdatePeerInfo(peer string, fn func(*syncer.PeerInfo)) { func (s *Store) PeerInfo(peer string) (syncer.PeerInfo, bool) { var info syncer.PeerInfo var err error - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { info, err = getPeerInfo(tx, peer) return err }) @@ -134,7 +134,7 @@ func (s *Store) Ban(peer string, duration time.Duration, reason string) { s.log.Error("failed to normalize peer", zap.Error(err)) return } - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `INSERT INTO syncer_bans (net_cidr, expiration, reason) VALUES ($1, $2, $3) ON CONFLICT (net_cidr) DO UPDATE SET expiration=EXCLUDED.expiration, reason=EXCLUDED.reason` _, err := tx.Exec(query, address, encode(time.Now().Add(duration)), reason) return err @@ -176,7 +176,7 @@ func (s *Store) Banned(peer string) (banned bool) { checkSubnets = append(checkSubnets, subnet.String()) } - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { query := `SELECT net_cidr, expiration FROM syncer_bans WHERE net_cidr IN (` + queryPlaceHolders(len(checkSubnets)) + `) ORDER BY expiration DESC LIMIT 1` var subnet string diff --git a/persist/sqlite/sql.go b/persist/sqlite/sql.go index 6ea715f..fb253d1 100644 --- a/persist/sqlite/sql.go +++ b/persist/sqlite/sql.go @@ -22,122 +22,107 @@ type ( Scan(dest ...any) error } - // A txn is an interface for executing queries within a transaction. - txn interface { - // Exec executes a query without returning any rows. The args are for - // any placeholder parameters in the query. - Exec(query string, args ...any) (sql.Result, error) - // Prepare creates a prepared statement for later queries or executions. - // Multiple queries or executions may be run concurrently from the - // returned statement. The caller must call the statement's Close method - // when the statement is no longer needed. - Prepare(query string) (*loggedStmt, error) - // Query executes a query that returns rows, typically a SELECT. The - // args are for any placeholder parameters in the query. - Query(query string, args ...any) (*loggedRows, error) - // QueryRow executes a query that is expected to return at most one row. - // QueryRow always returns a non-nil value. Errors are deferred until - // Row's Scan method is called. If the query selects no rows, the *Row's - // Scan will return ErrNoRows. Otherwise, the *Row's Scan scans the - // first selected row and discards the rest. - QueryRow(query string, args ...any) *loggedRow - } - - loggedStmt struct { + // A stmt wraps a *sql.Stmt, logging slow queries. + stmt struct { *sql.Stmt query string - log *zap.Logger + + log *zap.Logger } - loggedTxn struct { + // A txn wraps a *sql.Tx, logging slow queries. + txn struct { *sql.Tx log *zap.Logger } - loggedRow struct { + // A row wraps a *sql.Row, logging slow queries. + row struct { *sql.Row log *zap.Logger } - loggedRows struct { + // rows wraps a *sql.Rows, logging slow queries. + rows struct { *sql.Rows + log *zap.Logger } ) -func (lr *loggedRows) Next() bool { +func (r *rows) Next() bool { start := time.Now() - next := lr.Rows.Next() + next := r.Rows.Next() if dur := time.Since(start); dur > longQueryDuration { - lr.log.Debug("slow next", zap.Duration("elapsed", dur), zap.Stack("stack")) + r.log.Debug("slow next", zap.Duration("elapsed", dur), zap.Stack("stack")) } return next } -func (lr *loggedRows) Scan(dest ...any) error { +func (r *rows) Scan(dest ...any) error { start := time.Now() - err := lr.Rows.Scan(dest...) + err := r.Rows.Scan(dest...) if dur := time.Since(start); dur > longQueryDuration { - lr.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) + r.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) } return err } -func (lr *loggedRow) Scan(dest ...any) error { +func (r *row) Scan(dest ...any) error { start := time.Now() - err := lr.Row.Scan(dest...) + err := r.Row.Scan(dest...) if dur := time.Since(start); dur > longQueryDuration { - lr.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) + r.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) } return err } -func (ls *loggedStmt) Exec(args ...any) (sql.Result, error) { - return ls.ExecContext(context.Background(), args...) +func (s *stmt) Exec(args ...any) (sql.Result, error) { + return s.ExecContext(context.Background(), args...) } -func (ls *loggedStmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) { +func (s *stmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) { start := time.Now() - result, err := ls.Stmt.ExecContext(ctx, args...) + result, err := s.Stmt.ExecContext(ctx, args...) if dur := time.Since(start); dur > longQueryDuration { - ls.log.Debug("slow exec", zap.String("query", ls.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + s.log.Debug("slow exec", zap.String("query", s.query), zap.Duration("elapsed", dur), zap.Stack("stack")) } return result, err } -func (ls *loggedStmt) Query(args ...any) (*sql.Rows, error) { - return ls.QueryContext(context.Background(), args...) +func (s *stmt) Query(args ...any) (*sql.Rows, error) { + return s.QueryContext(context.Background(), args...) } -func (ls *loggedStmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) { +func (s *stmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) { start := time.Now() - rows, err := ls.Stmt.QueryContext(ctx, args...) + rows, err := s.Stmt.QueryContext(ctx, args...) if dur := time.Since(start); dur > longQueryDuration { - ls.log.Debug("slow query", zap.String("query", ls.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + s.log.Debug("slow query", zap.String("query", s.query), zap.Duration("elapsed", dur), zap.Stack("stack")) } return rows, err } -func (ls *loggedStmt) QueryRow(args ...any) *loggedRow { - return ls.QueryRowContext(context.Background(), args...) +func (s *stmt) QueryRow(args ...any) *row { + return s.QueryRowContext(context.Background(), args...) } -func (ls *loggedStmt) QueryRowContext(ctx context.Context, args ...any) *loggedRow { +func (s *stmt) QueryRowContext(ctx context.Context, args ...any) *row { start := time.Now() - row := ls.Stmt.QueryRowContext(ctx, args...) + r := s.Stmt.QueryRowContext(ctx, args...) if dur := time.Since(start); dur > longQueryDuration { - ls.log.Debug("slow query row", zap.String("query", ls.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + s.log.Debug("slow query row", zap.String("query", s.query), zap.Duration("elapsed", dur), zap.Stack("stack")) } - return &loggedRow{row, ls.log.Named("row")} + return &row{r, s.log.Named("row")} } // Exec executes a query without returning any rows. The args are for // any placeholder parameters in the query. -func (lt *loggedTxn) Exec(query string, args ...any) (sql.Result, error) { +func (tx *txn) Exec(query string, args ...any) (sql.Result, error) { start := time.Now() - result, err := lt.Tx.Exec(query, args...) + result, err := tx.Tx.Exec(query, args...) if dur := time.Since(start); dur > longQueryDuration { - lt.log.Debug("slow exec", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + tx.log.Debug("slow exec", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) } return result, err } @@ -146,30 +131,30 @@ func (lt *loggedTxn) Exec(query string, args ...any) (sql.Result, error) { // Multiple queries or executions may be run concurrently from the // returned statement. The caller must call the statement's Close method // when the statement is no longer needed. -func (lt *loggedTxn) Prepare(query string) (*loggedStmt, error) { +func (tx *txn) Prepare(query string) (*stmt, error) { start := time.Now() - stmt, err := lt.Tx.Prepare(query) + s, err := tx.Tx.Prepare(query) if dur := time.Since(start); dur > longQueryDuration { - lt.log.Debug("slow prepare", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + tx.log.Debug("slow prepare", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) } else if err != nil { return nil, err } - return &loggedStmt{ - Stmt: stmt, + return &stmt{ + Stmt: s, query: query, - log: lt.log.Named("statement"), + log: tx.log.Named("statement"), }, nil } // Query executes a query that returns rows, typically a SELECT. The // args are for any placeholder parameters in the query. -func (lt *loggedTxn) Query(query string, args ...any) (*loggedRows, error) { +func (tx *txn) Query(query string, args ...any) (*rows, error) { start := time.Now() - rows, err := lt.Tx.Query(query, args...) + r, err := tx.Tx.Query(query, args...) if dur := time.Since(start); dur > longQueryDuration { - lt.log.Debug("slow query", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + tx.log.Debug("slow query", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) } - return &loggedRows{rows, lt.log.Named("rows")}, err + return &rows{r, tx.log.Named("rows")}, err } // QueryRow executes a query that is expected to return at most one row. @@ -177,13 +162,13 @@ func (lt *loggedTxn) Query(query string, args ...any) (*loggedRows, error) { // Row's Scan method is called. If the query selects no rows, the *Row's // Scan will return ErrNoRows. Otherwise, the *Row's Scan scans the // first selected row and discards the rest. -func (lt *loggedTxn) QueryRow(query string, args ...any) *loggedRow { +func (tx *txn) QueryRow(query string, args ...any) *row { start := time.Now() - row := lt.Tx.QueryRow(query, args...) + r := tx.Tx.QueryRow(query, args...) if dur := time.Since(start); dur > longQueryDuration { - lt.log.Debug("slow query row", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + tx.log.Debug("slow query row", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) } - return &loggedRow{row, lt.log.Named("row")} + return &row{r, tx.log.Named("row")} } func queryPlaceHolders(n int) string { @@ -220,7 +205,7 @@ func getDBVersion(db *sql.DB) (version int64) { } // setDBVersion sets the current version of the database. -func setDBVersion(tx txn, version int64) error { +func setDBVersion(tx *txn, version int64) error { const query = `UPDATE global_settings SET db_version=$1 RETURNING id;` var dbID int64 return tx.QueryRow(query, version).Scan(&dbID) diff --git a/persist/sqlite/store.go b/persist/sqlite/store.go index 0db301e..e50fda5 100644 --- a/persist/sqlite/store.go +++ b/persist/sqlite/store.go @@ -3,6 +3,7 @@ package sqlite import ( "database/sql" "encoding/hex" + "errors" "fmt" "math" "strings" @@ -27,7 +28,7 @@ type ( // function returns an error, the transaction is rolled back. Otherwise, the // transaction is committed. If the transaction fails due to a busy error, it is // retried up to 10 times before returning. -func (s *Store) transaction(fn func(txn) error) error { +func (s *Store) transaction(fn func(*txn) error) error { var err error txnID := hex.EncodeToString(frand.Bytes(4)) log := s.log.Named("transaction").With(zap.String("id", txnID)) @@ -76,25 +77,27 @@ func sqliteFilepath(fp string) string { // doTransaction is a helper function to execute a function within a transaction. If fn returns // an error, the transaction is rolled back. Otherwise, the transaction is // committed. -func doTransaction(db *sql.DB, log *zap.Logger, fn func(tx txn) error) error { +func doTransaction(db *sql.DB, log *zap.Logger, fn func(tx *txn) error) error { start := time.Now() - tx, err := db.Begin() + dbtx, err := db.Begin() if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } - defer tx.Rollback() defer func() { + if err := dbtx.Rollback(); err != nil && !errors.Is(err, sql.ErrTxDone) { + log.Error("failed to rollback transaction", zap.Error(err)) + } // log the transaction if it took longer than txn duration if time.Since(start) > longTxnDuration { log.Debug("long transaction", zap.Duration("elapsed", time.Since(start)), zap.Stack("stack"), zap.Bool("failed", err != nil)) } }() - ltx := &loggedTxn{ - Tx: tx, + tx := &txn{ + Tx: dbtx, log: log, } - if err = fn(ltx); err != nil { + if err = fn(tx); err != nil { return err } else if err = tx.Commit(); err != nil { return fmt.Errorf("failed to commit transaction: %w", err) diff --git a/persist/sqlite/wallet.go b/persist/sqlite/wallet.go index 0fc0a05..f5f081d 100644 --- a/persist/sqlite/wallet.go +++ b/persist/sqlite/wallet.go @@ -10,7 +10,7 @@ import ( "go.sia.tech/walletd/wallet" ) -func insertAddress(tx txn, addr types.Address) (id int64, err error) { +func insertAddress(tx *txn, addr types.Address) (id int64, err error) { const query = `INSERT INTO sia_addresses (sia_address, siacoin_balance, siafund_balance) VALUES ($1, $2, 0) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id` @@ -21,7 +21,7 @@ RETURNING id` // WalletEvents returns the events relevant to a wallet, sorted by height descending. func (s *Store) WalletEvents(walletID string, offset, limit int) (events []wallet.Event, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `SELECT ev.id, ev.date_created, ci.height, ci.block_id, ev.event_type, ev.event_data FROM events ev INNER JOIN chain_indices ci ON (ev.index_id = ci.id) @@ -79,7 +79,7 @@ LIMIT $2 OFFSET $3` // AddWallet adds a wallet to the database. func (s *Store) AddWallet(name string, info json.RawMessage) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { const query = `INSERT INTO wallets (id, extra_data) VALUES ($1, $2)` _, err := tx.Exec(query, name, info) @@ -93,7 +93,7 @@ func (s *Store) AddWallet(name string, info json.RawMessage) error { // DeleteWallet deletes a wallet from the database. This does not stop tracking // addresses that were previously associated with the wallet. func (s *Store) DeleteWallet(name string) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { _, err := tx.Exec(`DELETE FROM wallets WHERE id=$1`, name) return err }) @@ -102,7 +102,7 @@ func (s *Store) DeleteWallet(name string) error { // Wallets returns a map of wallet names to wallet extra data. func (s *Store) Wallets() (map[string]json.RawMessage, error) { wallets := make(map[string]json.RawMessage) - err := s.transaction(func(tx txn) error { + err := s.transaction(func(tx *txn) error { const query = `SELECT id, extra_data FROM wallets` rows, err := tx.Query(query) @@ -126,7 +126,7 @@ func (s *Store) Wallets() (map[string]json.RawMessage, error) { // AddAddress adds an address to a wallet. func (s *Store) AddAddress(walletID string, address types.Address, info json.RawMessage) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { addressID, err := insertAddress(tx, address) if err != nil { return fmt.Errorf("failed to insert address: %w", err) @@ -139,7 +139,7 @@ func (s *Store) AddAddress(walletID string, address types.Address, info json.Raw // RemoveAddress removes an address from a wallet. This does not stop tracking // the address. func (s *Store) RemoveAddress(walletID string, address types.Address) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { const query = `DELETE FROM wallet_addresses WHERE wallet_id=$1 AND address_id=(SELECT id FROM sia_addresses WHERE sia_address=$2)` _, err := tx.Exec(query, walletID, encode(address)) return err @@ -149,7 +149,7 @@ func (s *Store) RemoveAddress(walletID string, address types.Address) error { // Addresses returns a map of addresses to their extra data for a wallet. func (s *Store) Addresses(walletID string) (map[types.Address]json.RawMessage, error) { addresses := make(map[types.Address]json.RawMessage) - err := s.transaction(func(tx txn) error { + err := s.transaction(func(tx *txn) error { const query = `SELECT sa.sia_address, wa.extra_data FROM wallet_addresses wa INNER JOIN sia_addresses sa ON (sa.id = wa.address_id) @@ -176,7 +176,7 @@ WHERE wa.wallet_id=$1` // UnspentSiacoinOutputs returns the unspent siacoin outputs for a wallet. func (s *Store) UnspentSiacoinOutputs(walletID string) (siacoins []types.SiacoinElement, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `SELECT se.id, se.leaf_index, se.merkle_proof, se.siacoin_value, sa.sia_address, se.maturity_height FROM siacoin_elements se INNER JOIN sia_addresses sa ON (se.address_id = sa.id) @@ -204,7 +204,7 @@ func (s *Store) UnspentSiacoinOutputs(walletID string) (siacoins []types.Siacoin // UnspentSiafundOutputs returns the unspent siafund outputs for a wallet. func (s *Store) UnspentSiafundOutputs(walletID string) (siafunds []types.SiafundElement, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `SELECT se.id, se.leaf_index, se.merkle_proof, se.siafund_value, se.claim_start, sa.sia_address FROM siafund_elements se INNER JOIN sia_addresses sa ON (se.address_id = sa.id) @@ -231,7 +231,7 @@ func (s *Store) UnspentSiafundOutputs(walletID string) (siafunds []types.Siafund // WalletBalance returns the total balance of a wallet. func (s *Store) WalletBalance(walletID string) (sc types.Currency, sf uint64, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `SELECT siacoin_balance, siafund_balance FROM sia_addresses sa INNER JOIN wallet_addresses wa ON (sa.id = wa.address_id) WHERE wa.wallet_id=$1` @@ -258,7 +258,7 @@ func (s *Store) WalletBalance(walletID string) (sc types.Currency, sf uint64, er // AddressBalance returns the balance of a single address. func (s *Store) AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `SELECT siacoin_balance, siafund_balance FROM address_balance WHERE sia_address=$1` return tx.QueryRow(query, encode(address)).Scan(decode(&sc), &sf) }) @@ -267,7 +267,7 @@ func (s *Store) AddressBalance(address types.Address) (sc types.Currency, sf uin // Annotate annotates a list of transactions using the wallet's addresses. func (s *Store) Annotate(walletID string, txns []types.Transaction) (annotated []wallet.PoolTransaction, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `SELECT sa.id FROM sia_addresses sa INNER JOIN wallet_addresses wa ON (sa.id = wa.address_id) WHERE wa.wallet_id=$1 AND sa.sia_address=$2 LIMIT 1`