From 4895bc723267d94200202e66b70b75f305133a36 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 2 Jan 2024 12:09:01 -0800 Subject: [PATCH 01/24] deps: update go version, add sqlite --- cmd/walletd/node.go | 5 +- go.mod | 3 +- go.sum | 2 + persist/sqlite/consensus.go | 510 +++++++++++++++++++++++++++++++ persist/sqlite/consts_default.go | 12 + persist/sqlite/consts_testing.go | 12 + persist/sqlite/init.go | 89 ++++++ persist/sqlite/init.sql | 70 +++++ persist/sqlite/migrations.go | 10 + persist/sqlite/sql.go | 232 ++++++++++++++ persist/sqlite/store.go | 120 ++++++++ persist/sqlite/types.go | 135 ++++++++ persist/sqlite/wallet.go | 292 ++++++++++++++++++ wallet/state.go | 83 +++++ wallet/wallet.go | 25 +- 15 files changed, 1587 insertions(+), 13 deletions(-) create mode 100644 persist/sqlite/consensus.go create mode 100644 persist/sqlite/consts_default.go create mode 100644 persist/sqlite/consts_testing.go create mode 100644 persist/sqlite/init.go create mode 100644 persist/sqlite/init.sql create mode 100644 persist/sqlite/migrations.go create mode 100644 persist/sqlite/sql.go create mode 100644 persist/sqlite/store.go create mode 100644 persist/sqlite/types.go create mode 100644 persist/sqlite/wallet.go create mode 100644 wallet/state.go diff --git a/cmd/walletd/node.go b/cmd/walletd/node.go index eaf0d5d..3249fe6 100644 --- a/cmd/walletd/node.go +++ b/cmd/walletd/node.go @@ -17,6 +17,7 @@ import ( "go.sia.tech/coreutils/syncer" "go.sia.tech/walletd/internal/syncerutil" "go.sia.tech/walletd/internal/walletutil" + "go.uber.org/zap" "lukechampine.com/upnp" ) @@ -163,9 +164,7 @@ func newNode(addr, dir string, chainNetwork string, useUPNP bool) (*node, error) UniqueID: gateway.GenerateUniqueID(), NetAddress: syncerAddr, } - - s := syncer.New(l, cm, ps, header) - + s := syncer.New(l, cm, ps, header, syncer.WithLogger(zap.NewNop())) wm, err := walletutil.NewJSONWalletManager(dir, cm) if err != nil { return nil, err diff --git a/go.mod b/go.mod index e6df28b..b50ae10 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,12 @@ module go.sia.tech/walletd go 1.21 require ( + github.com/mattn/go-sqlite3 v1.14.21 go.sia.tech/core v0.2.1-0.20240130145801-8067f34b2ecc go.sia.tech/coreutils v0.0.0-20240130201319-8303550528d7 go.sia.tech/jape v0.9.0 go.sia.tech/web/walletd v0.16.0 + go.uber.org/zap v1.26.0 golang.org/x/term v0.6.0 lukechampine.com/flagg v1.1.1 lukechampine.com/frand v1.4.2 @@ -20,7 +22,6 @@ require ( go.sia.tech/mux v1.2.0 // indirect go.sia.tech/web v0.0.0-20230628194305-c6e1696bad89 // indirect go.uber.org/multierr v1.10.0 // indirect - go.uber.org/zap v1.26.0 // indirect golang.org/x/crypto v0.0.0-20220507011949-2cf3adece122 // indirect golang.org/x/sys v0.6.0 // indirect golang.org/x/tools v0.7.0 // indirect diff --git a/go.sum b/go.sum index f3a3ac8..e25fb79 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/mattn/go-sqlite3 v1.14.21 h1:IXocQLOykluc3xPE0Lvy8FtggMz1G+U3mEjg+0zGizc= +github.com/mattn/go-sqlite3 v1.14.21/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go new file mode 100644 index 0000000..03f8f0a --- /dev/null +++ b/persist/sqlite/consensus.go @@ -0,0 +1,510 @@ +package sqlite + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "log" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" + "go.sia.tech/walletd/wallet" +) + +const updateProofBatchSize = 1000 + +type proofUpdater interface { + UpdateElementProof(*types.StateElement) +} + +func insertChainIndex(tx txn, index types.ChainIndex) (id int64, err error) { + err = tx.QueryRow(`INSERT INTO chain_indices (height, block_id) VALUES ($1, $2) ON CONFLICT (block_id) DO UPDATE SET height=EXCLUDED.height RETURNING id`, index.Height, encode(index.ID)).Scan(&id) + return +} + +func applyEvents(tx txn, events []wallet.Event) error { + stmt, err := tx.Prepare(`INSERT INTO events (date_created, index_id, event_type, event_data) VALUES ($1, $2, $3, $4) RETURNING id`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer stmt.Close() + + addRelevantAddrStmt, err := tx.Prepare(`INSERT INTO event_addresses (event_id, address_id, block_height) VALUES ($1, $2, $3)`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer addRelevantAddrStmt.Close() + + for _, event := range events { + id, err := insertChainIndex(tx, event.Index) + if err != nil { + return fmt.Errorf("failed to create chain index: %w", err) + } + + buf, err := json.Marshal(event.Val) + if err != nil { + return fmt.Errorf("failed to marshal event: %w", err) + } + + var eventID int64 + err = stmt.QueryRow(sqlTime(event.Timestamp), id, event.Val.EventType(), buf).Scan(&eventID) + if err != nil { + return fmt.Errorf("failed to execute statement: %w", err) + } + + for _, addr := range event.Relevant { + addressID, err := insertAddress(tx, addr) + if err != nil { + return fmt.Errorf("failed to insert address: %w", err) + } else if _, err := addRelevantAddrStmt.Exec(eventID, addressID, event.Index.Height); err != nil { + return fmt.Errorf("failed to add relevant address: %w", err) + } + log.Println("added relevant address", eventID, addr) + } + } + return nil +} + +func deleteSiacoinOutputs(tx txn, spent []types.SiacoinElement) error { + addrStmt, err := tx.Prepare(`SELECT id, siacoin_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) + if err != nil { + return fmt.Errorf("failed to prepare lookup statement: %w", err) + } + defer addrStmt.Close() + + updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1 WHERE id=$2`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateBalanceStmt.Close() + + deleteStmt, err := tx.Prepare(`DELETE FROM siacoin_elements WHERE id=$1 RETURNING id`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer deleteStmt.Close() + + for _, se := range spent { + // query the address database ID and balance + var addressID int64 + var balance types.Currency + err := addrStmt.QueryRow(encode(se.SiacoinOutput.Address)).Scan(&addressID, (*sqlCurrency)(&balance)) + if err != nil { + return fmt.Errorf("failed to lookup address %q: %w", se.SiacoinOutput.Address, err) + } + + // update the balance + balance = balance.Sub(se.SiacoinOutput.Value) + _, err = updateBalanceStmt.Exec((*sqlCurrency)(&balance), addressID) + if err != nil { + return fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) + } + + var dummy types.Hash256 + err = deleteStmt.QueryRow(encode(se.ID)).Scan(decode(&dummy, 32)) + if err != nil { + return fmt.Errorf("failed to delete output %q: %w", se.ID, err) + } + } + return nil +} + +func applySiacoinOutputs(tx txn, added map[types.Hash256]types.SiacoinElement) error { + addrStmt, err := tx.Prepare(`SELECT id, siacoin_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) + if err != nil { + return fmt.Errorf("failed to prepare lookup statement: %w", err) + } + defer addrStmt.Close() + + updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1 WHERE id=$2`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateBalanceStmt.Close() + + addStmt, err := tx.Prepare(`INSERT INTO siacoin_elements (id, address_id, siacoin_value, merkle_proof, leaf_index, maturity_height) VALUES ($1, $2, $3, $4, $5, $6)`) + if err != nil { + return fmt.Errorf("failed to prepare insert statement: %w", err) + } + defer addStmt.Close() + + for _, se := range added { + // query the address database ID and balance + var addressID int64 + var balance types.Currency + err := addrStmt.QueryRow(encode(se.SiacoinOutput.Address)).Scan(&addressID, (*sqlCurrency)(&balance)) + if err != nil { + return fmt.Errorf("failed to lookup address %q: %w", se.SiacoinOutput.Address, err) + } + + // update the balance + balance = balance.Add(se.SiacoinOutput.Value) + _, err = updateBalanceStmt.Exec((*sqlCurrency)(&balance), addressID) + if err != nil { + return fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) + } + + // insert the created utxo + _, err = addStmt.Exec(encode(se.ID), addressID, sqlCurrency(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.MaturityHeight, se.LeafIndex) + if err != nil { + return fmt.Errorf("failed to insert output %q: %w", se.ID, err) + } + } + return nil +} + +func deleteSiafundOutputs(tx txn, spent []types.SiafundElement) error { + addrStmt, err := tx.Prepare(`SELECT id, siafund_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) + if err != nil { + return fmt.Errorf("failed to prepare lookup statement: %w", err) + } + defer addrStmt.Close() + + updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE id=$2`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateBalanceStmt.Close() + + spendStmt, err := tx.Prepare(`DELETE FROM siafund_elements WHERE id=$1 RETURNING id`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer spendStmt.Close() + + for _, se := range spent { + // query the address database ID and balance + var addressID int64 + var balance uint64 + err := addrStmt.QueryRow(encode(se.SiafundOutput.Address)).Scan(&addressID, balance) + if err != nil { + return fmt.Errorf("failed to lookup address %q: %w", se.SiafundOutput.Address, err) + } + + // update the balance + if balance < se.SiafundOutput.Value { + panic("siafund balance is negative") // developer error + } + balance -= se.SiafundOutput.Value + _, err = updateBalanceStmt.Exec(balance, addressID) + if err != nil { + return fmt.Errorf("failed to update address %q balance: %w", se.SiafundOutput.Address, err) + } + + var dummy types.Hash256 + err = spendStmt.QueryRow(encode(se.ID)).Scan(decode(&dummy, 32)) + if err != nil { + return fmt.Errorf("failed to delete output %q: %w", se.ID, err) + } + } + return nil +} + +func applySiafundOutputs(tx txn, added map[types.Hash256]types.SiafundElement) error { + addrStmt, err := tx.Prepare(`SELECT id, siafund_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) + if err != nil { + return fmt.Errorf("failed to prepare lookup statement: %w", err) + } + defer addrStmt.Close() + + updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE id=$2`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateBalanceStmt.Close() + + addStmt, err := tx.Prepare(`INSERT INTO siafund_elements (id, address_id, claim_start, siafund_value, merkle_proof, leaf_index) VALUES ($1, $2, $3, $4, $5, $6)`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer addStmt.Close() + + for _, se := range added { + // query the address database ID and balance + var addressID int64 + var balance uint64 + err := addrStmt.QueryRow(encode(se.SiafundOutput.Address)).Scan(&addressID, balance) + if err != nil { + return fmt.Errorf("failed to lookup address %q: %w", se.SiafundOutput.Address, err) + } + + // update the balance + if balance < se.SiafundOutput.Value { + panic("siafund balance is negative") // developer error + } + balance -= se.SiafundOutput.Value + _, err = updateBalanceStmt.Exec(balance, addressID) + if err != nil { + return fmt.Errorf("failed to update address %q balance: %w", se.SiafundOutput.Address, err) + } + + _, err = addStmt.Exec(encode(se.ID), addressID, sqlCurrency(se.ClaimStart), se.SiafundOutput.Value, encodeSlice(se.MerkleProof), se.LeafIndex) + if err != nil { + return fmt.Errorf("failed to insert output %q: %w", se.ID, err) + } + } + return nil +} + +func updateLastIndexedTip(tx txn, tip types.ChainIndex) error { + _, err := tx.Exec(`UPDATE global_settings SET last_indexed_tip=$1`, encode(tip.ID)) + return err +} + +// how slow is this going to be 😬? +// +// todo: determine if it's feasible for exchange mode to keep everything in +// memory. +func updateElementProofs(tx txn, table string, updater proofUpdater) error { + stmt, err := tx.Prepare(`SELECT id, merkle_proof, leaf_index FROM ` + table + ` LIMIT $1 OFFSET $2`) + if err != nil { + return fmt.Errorf("failed to prepare batch statement: %w", err) + } + defer stmt.Close() + + updateStmt, err := tx.Prepare(`UPDATE ` + table + ` SET merkle_proof=$1, leaf_index=$2 WHERE id=$3`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateStmt.Close() + + var updated []types.StateElement + for offset := 0; ; offset += updateProofBatchSize { + updated = updated[:0] + + more, err := func(n int) (bool, error) { + rows, err := stmt.Query(updateProofBatchSize, n) + if err != nil { + return false, fmt.Errorf("failed to query siacoin elements: %w", err) + } + defer rows.Close() + + var more bool + for rows.Next() { + // if we get here, there may be more rows to process + more = true + + var se types.StateElement + err := rows.Scan(decode(&se.ID, 32), decodeSlice(&se.MerkleProof, 32*1000), &se.LeafIndex) + if err != nil { + return false, fmt.Errorf("failed to scan state element: %w", err) + } + updater.UpdateElementProof(&se) + updated = append(updated, se) + } + return more, nil + }(offset) + if err != nil { + return err + } + + for _, se := range updated { + _, err := updateStmt.Exec(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)) + if err != nil { + return fmt.Errorf("failed to update siacoin element %q: %w", se.ID, err) + } + } + + if !more { + break + } + } + + return nil +} + +func applyChainUpdates(tx txn, updates []*chain.ApplyUpdate) error { + stmt, err := tx.Prepare(`SELECT id FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer stmt.Close() + + // note: this would be more performant for small wallets to load all + // addresses into memory. However, for larger wallets (> 10K addresses), + // this is time consuming. Instead, the database is queried for each + // address. Monitor performance and consider changing this in the + // future. From a memory perspective, it would be fine to lazy load all + // addresses into memory. + ownsAddress := func(address types.Address) bool { + var dbID int64 + err := stmt.QueryRow(encode(address)).Scan(&dbID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + panic(err) // database error + } + return err == nil + } + + for _, update := range updates { + events := wallet.AppliedEvents(update.State, update.Block, update, ownsAddress) + if err := applyEvents(tx, events); err != nil { + return fmt.Errorf("failed to apply events: %w", err) + } + + var spentSiacoinOutputs []types.SiacoinElement + newSiacoinOutputs := make(map[types.Hash256]types.SiacoinElement) + update.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { + if !ownsAddress(se.SiacoinOutput.Address) { + return + } + + if spent { + spentSiacoinOutputs = append(spentSiacoinOutputs, se) + delete(newSiacoinOutputs, se.ID) + } else { + newSiacoinOutputs[se.ID] = se + } + }) + + if err := deleteSiacoinOutputs(tx, spentSiacoinOutputs); err != nil { + return fmt.Errorf("failed to delete siacoin outputs: %w", err) + } else if err := applySiacoinOutputs(tx, newSiacoinOutputs); err != nil { + return fmt.Errorf("failed to apply siacoin outputs: %w", err) + } + + var spentSiafundOutputs []types.SiafundElement + newSiafundOutputs := make(map[types.Hash256]types.SiafundElement) + update.ForEachSiafundElement(func(sf types.SiafundElement, spent bool) { + if !ownsAddress(sf.SiafundOutput.Address) { + return + } + + if spent { + spentSiafundOutputs = append(spentSiafundOutputs, sf) + delete(newSiafundOutputs, sf.ID) + } else { + newSiafundOutputs[sf.ID] = sf + } + }) + + if err := deleteSiafundOutputs(tx, spentSiafundOutputs); err != nil { + return fmt.Errorf("failed to delete siafund outputs: %w", err) + } else if err := applySiafundOutputs(tx, newSiafundOutputs); err != nil { + return fmt.Errorf("failed to apply siafund outputs: %w", err) + } + + // update proofs + if err := updateElementProofs(tx, "siacoin_elements", update); err != nil { + return fmt.Errorf("failed to update siacoin element proofs: %w", err) + } else if err := updateElementProofs(tx, "siafund_elements", update); err != nil { + return fmt.Errorf("failed to update siafund element proofs: %w", err) + } + } + + lastTip := updates[len(updates)-1].State.Index + if err := updateLastIndexedTip(tx, lastTip); err != nil { + return fmt.Errorf("failed to update last indexed tip: %w", err) + } + return nil +} + +func (s *Store) ProcessChainApplyUpdate(cau *chain.ApplyUpdate, mayCommit bool) error { + s.updates = append(s.updates, cau) + + if mayCommit { + return s.transaction(func(tx txn) error { + if err := applyChainUpdates(tx, s.updates); err != nil { + return err + } + s.updates = nil + return nil + }) + } + return nil +} + +func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { + // update hasn't been committed yet + if len(s.updates) > 0 && s.updates[len(s.updates)-1].Block.ID() == cru.Block.ID() { + s.updates = s.updates[:len(s.updates)-1] + return nil + } + + // update has been committed, revert it + return s.transaction(func(tx txn) error { + stmt, err := tx.Prepare(`SELECT sia_address FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer stmt.Close() + + // note: this would be more performant for small wallets to load all + // addresses into memory. However, for larger wallets (> 10K addresses), + // this is time consuming. Instead, the database is queried for each + // address. Monitor performance and consider changing this in the + // future. From a memory perspective, it would be fine to lazy load all + // addresses into memory. + ownsAddress := func(address types.Address) bool { + var dbID int64 + err := stmt.QueryRow(encode(address)).Scan(&dbID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + panic(err) // database error + } + return err == nil + } + + var spentSiacoinOutputs []types.SiacoinElement + var spentSiafundOutputs []types.SiafundElement + addedSiacoinOutputs := make(map[types.Hash256]types.SiacoinElement) + addedSiafundOutputs := make(map[types.Hash256]types.SiafundElement) + + cru.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { + if !ownsAddress(se.SiacoinOutput.Address) { + return + } + + if !spent { + spentSiacoinOutputs = append(spentSiacoinOutputs, se) + } else { + addedSiacoinOutputs[se.ID] = se + } + }) + + cru.ForEachSiafundElement(func(se types.SiafundElement, spent bool) { + if !ownsAddress(se.SiafundOutput.Address) { + return + } + + if !spent { + spentSiafundOutputs = append(spentSiafundOutputs, se) + } else { + addedSiafundOutputs[se.ID] = se + } + }) + + // revert siacoin outputs + if err := deleteSiacoinOutputs(tx, spentSiacoinOutputs); err != nil { + return fmt.Errorf("failed to delete siacoin outputs: %w", err) + } else if err := applySiacoinOutputs(tx, addedSiacoinOutputs); err != nil { + return fmt.Errorf("failed to apply siacoin outputs: %w", err) + } + + // revert siafund outputs + if err := deleteSiafundOutputs(tx, spentSiafundOutputs); err != nil { + return fmt.Errorf("failed to delete siafund outputs: %w", err) + } else if err := applySiafundOutputs(tx, addedSiafundOutputs); err != nil { + return fmt.Errorf("failed to apply siafund outputs: %w", err) + } + + // revert events + _, err = tx.Exec(`DELETE FROM chain_indices WHERE block_id=$1`, cru.Block.ID()) + if err != nil { + return fmt.Errorf("failed to delete chain index: %w", err) + } + + // update proofs + if err := updateElementProofs(tx, "siacoin_elements", cru); err != nil { + return fmt.Errorf("failed to update siacoin element proofs: %w", err) + } else if err := updateElementProofs(tx, "siafund_elements", cru); err != nil { + return fmt.Errorf("failed to update siafund element proofs: %w", err) + } + return nil + }) +} + +// LastCommittedIndex returns the last chain index that was committed. +func (s *Store) LastCommittedIndex() (index types.ChainIndex, err error) { + err = s.db.QueryRow(`SELECT last_indexed_tip FROM global_settings`).Scan(decode(&index, 40)) + return +} diff --git a/persist/sqlite/consts_default.go b/persist/sqlite/consts_default.go new file mode 100644 index 0000000..50b7330 --- /dev/null +++ b/persist/sqlite/consts_default.go @@ -0,0 +1,12 @@ +//go:build !testing + +package sqlite + +import "time" + +const ( + busyTimeout = 10000 // 10 seconds + maxRetryAttempts = 30 // 30 attempts + factor = 1.8 // factor ^ retryAttempts = backoff time in milliseconds + maxBackoff = 15 * time.Second +) diff --git a/persist/sqlite/consts_testing.go b/persist/sqlite/consts_testing.go new file mode 100644 index 0000000..f4911e3 --- /dev/null +++ b/persist/sqlite/consts_testing.go @@ -0,0 +1,12 @@ +//go:build testing + +package sqlite + +import "time" + +const ( + busyTimeout = 100 // 100ms + maxRetryAttempts = 10 // 10 attempts + factor = 2.0 // factor ^ retryAttempts = backoff time in milliseconds + maxBackoff = 15 * time.Second +) diff --git a/persist/sqlite/init.go b/persist/sqlite/init.go new file mode 100644 index 0000000..bc6c1a4 --- /dev/null +++ b/persist/sqlite/init.go @@ -0,0 +1,89 @@ +package sqlite + +import ( + "database/sql" + _ "embed" // for init.sql + "errors" + "time" + + "fmt" + + "go.sia.tech/core/types" + "go.uber.org/zap" +) + +// init queries are run when the database is first created. +// +//go:embed init.sql +var initDatabase string + +func initializeSettings(tx txn, target int64) error { + _, err := tx.Exec(`INSERT INTO global_settings (id, db_version, last_indexed_tip) VALUES (0, ?, ?)`, target, encode(types.ChainIndex{})) + return err +} + +func (s *Store) initNewDatabase(target int64) error { + return s.transaction(func(tx txn) error { + if _, err := tx.Exec(initDatabase); err != nil { + return fmt.Errorf("failed to initialize database: %w", err) + } else if err := initializeSettings(tx, target); err != nil { + return fmt.Errorf("failed to initialize settings: %w", err) + } + return nil + }) +} + +func (s *Store) upgradeDatabase(current, target int64) error { + log := s.log.Named("migrations") + log.Info("migrating database", zap.Int64("current", current), zap.Int64("target", target)) + + // disable foreign key constraints during migration + if _, err := s.db.Exec("PRAGMA foreign_keys = OFF"); err != nil { + return fmt.Errorf("failed to disable foreign key constraints: %w", err) + } + defer func() { + // re-enable foreign key constraints + if _, err := s.db.Exec("PRAGMA foreign_keys = ON"); err != nil { + log.Panic("failed to enable foreign key constraints", zap.Error(err)) + } + }() + + return s.transaction(func(tx txn) error { + for _, fn := range migrations[current-1:] { + current++ + start := time.Now() + if err := fn(tx, log.With(zap.Int64("version", current))); err != nil { + return fmt.Errorf("failed to migrate database to version %v: %w", current, err) + } + // check that no foreign key constraints were violated + if err := tx.QueryRow("PRAGMA foreign_key_check").Scan(); !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("foreign key constraints are not satisfied") + } + log.Debug("migration complete", zap.Int64("current", current), zap.Int64("target", target), zap.Duration("elapsed", time.Since(start))) + } + + // set the final database version + return setDBVersion(tx, target) + }) +} + +func (s *Store) init() error { + // calculate the expected final database version + target := int64(len(migrations) + 1) + // disable foreign key constraints during migration + if _, err := s.db.Exec("PRAGMA foreign_keys = OFF"); err != nil { + return fmt.Errorf("failed to disable foreign key constraints: %w", err) + } + + version := getDBVersion(s.db) + switch { + case version == 0: + return s.initNewDatabase(target) + case version < target: + return s.upgradeDatabase(version, target) + case version > target: + return fmt.Errorf("database version %v is newer than expected %v. database downgrades are not supported", version, target) + } + // nothing to do + return nil +} diff --git a/persist/sqlite/init.sql b/persist/sqlite/init.sql new file mode 100644 index 0000000..d6a8619 --- /dev/null +++ b/persist/sqlite/init.sql @@ -0,0 +1,70 @@ +CREATE TABLE chain_indices ( + id INTEGER PRIMARY KEY, + block_id BLOB UNIQUE NOT NULL, + height INTEGER UNIQUE NOT NULL +); + +CREATE TABLE sia_addresses ( + id INTEGER PRIMARY KEY, + sia_address BLOB UNIQUE NOT NULL, + siacoin_balance BLOB NOT NULL, + siafund_balance INTEGER NOT NULL +); + +CREATE TABLE siacoin_elements ( + id BLOB PRIMARY KEY, + siacoin_value BLOB NOT NULL, + merkle_proof BLOB NOT NULL, + leaf_index INTEGER NOT NULL, + maturity_height INTEGER NOT NULL, /* stored as int64 for easier querying */ + address_id INTEGER NOT NULL REFERENCES sia_addresses (id) +); +CREATE INDEX siacoin_elements_address_id ON siacoin_elements (address_id); + +CREATE TABLE siafund_elements ( + id BLOB PRIMARY KEY, + claim_start BLOB NOT NULL, + merkle_proof BLOB NOT NULL, + leaf_index INTEGER NOT NULL, + siafund_value INTEGER NOT NULL, + address_id INTEGER NOT NULL REFERENCES sia_addresses (id) +); +CREATE INDEX siafund_elements_address_id ON siafund_elements (address_id); + +CREATE TABLE wallets ( + id TEXT PRIMARY KEY NOT NULL, + extra_data BLOB NOT NULL +); + +CREATE TABLE wallet_addresses ( + wallet_id TEXT NOT NULL REFERENCES wallets (id), + address_id INTEGER NOT NULL REFERENCES sia_addresses (id), + extra_data BLOB NOT NULL, + UNIQUE (wallet_id, address_id) +); +CREATE INDEX wallet_addresses_address_id ON wallet_addresses (address_id); + +CREATE TABLE events ( + id INTEGER PRIMARY KEY, + date_created INTEGER NOT NULL, + index_id BLOB NOT NULL REFERENCES chain_indices (id) ON DELETE CASCADE, + event_type TEXT NOT NULL, + event_data TEXT NOT NULL +); + +CREATE TABLE event_addresses ( + id INTEGER PRIMARY KEY, + event_id INTEGER NOT NULL REFERENCES events (id) ON DELETE CASCADE, + address_id INTEGER NOT NULL REFERENCES sia_addresses (id), + block_height INTEGER NOT NULL, /* prevents extra join when querying for events */ + UNIQUE (event_id, address_id) +); +CREATE INDEX event_addresses_event_id_idx ON event_addresses (event_id); +CREATE INDEX event_addresses_address_id_idx ON event_addresses (address_id); +CREATE INDEX event_addresses_event_id_address_id_block_height ON event_addresses(event_id, address_id, block_height DESC); + +CREATE TABLE global_settings ( + id INTEGER PRIMARY KEY NOT NULL DEFAULT 0 CHECK (id = 0), -- enforce a single row + db_version INTEGER NOT NULL, -- used for migrations + last_indexed_tip BLOB -- the last chain index that was processed +); diff --git a/persist/sqlite/migrations.go b/persist/sqlite/migrations.go new file mode 100644 index 0000000..7aa3692 --- /dev/null +++ b/persist/sqlite/migrations.go @@ -0,0 +1,10 @@ +package sqlite + +import ( + "go.uber.org/zap" +) + +// migrations is a list of functions that are run to migrate the database from +// one version to the next. Migrations are used to update existing databases to +// match the schema in init.sql. +var migrations = []func(tx txn, log *zap.Logger) error{} diff --git a/persist/sqlite/sql.go b/persist/sqlite/sql.go new file mode 100644 index 0000000..6ea715f --- /dev/null +++ b/persist/sqlite/sql.go @@ -0,0 +1,232 @@ +package sqlite + +import ( + "context" + "database/sql" + "math/rand" + "strings" + "time" + + _ "github.com/mattn/go-sqlite3" // import sqlite3 driver + "go.uber.org/zap" +) + +const ( + longQueryDuration = 10 * time.Millisecond + longTxnDuration = 10 * time.Millisecond +) + +type ( + // A scanner is an interface that wraps the Scan method of sql.Rows and sql.Row + scanner interface { + Scan(dest ...any) error + } + + // A txn is an interface for executing queries within a transaction. + txn interface { + // Exec executes a query without returning any rows. The args are for + // any placeholder parameters in the query. + Exec(query string, args ...any) (sql.Result, error) + // Prepare creates a prepared statement for later queries or executions. + // Multiple queries or executions may be run concurrently from the + // returned statement. The caller must call the statement's Close method + // when the statement is no longer needed. + Prepare(query string) (*loggedStmt, error) + // Query executes a query that returns rows, typically a SELECT. The + // args are for any placeholder parameters in the query. + Query(query string, args ...any) (*loggedRows, error) + // QueryRow executes a query that is expected to return at most one row. + // QueryRow always returns a non-nil value. Errors are deferred until + // Row's Scan method is called. If the query selects no rows, the *Row's + // Scan will return ErrNoRows. Otherwise, the *Row's Scan scans the + // first selected row and discards the rest. + QueryRow(query string, args ...any) *loggedRow + } + + loggedStmt struct { + *sql.Stmt + query string + log *zap.Logger + } + + loggedTxn struct { + *sql.Tx + log *zap.Logger + } + + loggedRow struct { + *sql.Row + log *zap.Logger + } + + loggedRows struct { + *sql.Rows + log *zap.Logger + } +) + +func (lr *loggedRows) Next() bool { + start := time.Now() + next := lr.Rows.Next() + if dur := time.Since(start); dur > longQueryDuration { + lr.log.Debug("slow next", zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return next +} + +func (lr *loggedRows) Scan(dest ...any) error { + start := time.Now() + err := lr.Rows.Scan(dest...) + if dur := time.Since(start); dur > longQueryDuration { + lr.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return err +} + +func (lr *loggedRow) Scan(dest ...any) error { + start := time.Now() + err := lr.Row.Scan(dest...) + if dur := time.Since(start); dur > longQueryDuration { + lr.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return err +} + +func (ls *loggedStmt) Exec(args ...any) (sql.Result, error) { + return ls.ExecContext(context.Background(), args...) +} + +func (ls *loggedStmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) { + start := time.Now() + result, err := ls.Stmt.ExecContext(ctx, args...) + if dur := time.Since(start); dur > longQueryDuration { + ls.log.Debug("slow exec", zap.String("query", ls.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return result, err +} + +func (ls *loggedStmt) Query(args ...any) (*sql.Rows, error) { + return ls.QueryContext(context.Background(), args...) +} + +func (ls *loggedStmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) { + start := time.Now() + rows, err := ls.Stmt.QueryContext(ctx, args...) + if dur := time.Since(start); dur > longQueryDuration { + ls.log.Debug("slow query", zap.String("query", ls.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return rows, err +} + +func (ls *loggedStmt) QueryRow(args ...any) *loggedRow { + return ls.QueryRowContext(context.Background(), args...) +} + +func (ls *loggedStmt) QueryRowContext(ctx context.Context, args ...any) *loggedRow { + start := time.Now() + row := ls.Stmt.QueryRowContext(ctx, args...) + if dur := time.Since(start); dur > longQueryDuration { + ls.log.Debug("slow query row", zap.String("query", ls.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return &loggedRow{row, ls.log.Named("row")} +} + +// Exec executes a query without returning any rows. The args are for +// any placeholder parameters in the query. +func (lt *loggedTxn) Exec(query string, args ...any) (sql.Result, error) { + start := time.Now() + result, err := lt.Tx.Exec(query, args...) + if dur := time.Since(start); dur > longQueryDuration { + lt.log.Debug("slow exec", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return result, err +} + +// Prepare creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. The caller must call the statement's Close method +// when the statement is no longer needed. +func (lt *loggedTxn) Prepare(query string) (*loggedStmt, error) { + start := time.Now() + stmt, err := lt.Tx.Prepare(query) + if dur := time.Since(start); dur > longQueryDuration { + lt.log.Debug("slow prepare", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } else if err != nil { + return nil, err + } + return &loggedStmt{ + Stmt: stmt, + query: query, + log: lt.log.Named("statement"), + }, nil +} + +// Query executes a query that returns rows, typically a SELECT. The +// args are for any placeholder parameters in the query. +func (lt *loggedTxn) Query(query string, args ...any) (*loggedRows, error) { + start := time.Now() + rows, err := lt.Tx.Query(query, args...) + if dur := time.Since(start); dur > longQueryDuration { + lt.log.Debug("slow query", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return &loggedRows{rows, lt.log.Named("rows")}, err +} + +// QueryRow executes a query that is expected to return at most one row. +// QueryRow always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. If the query selects no rows, the *Row's +// Scan will return ErrNoRows. Otherwise, the *Row's Scan scans the +// first selected row and discards the rest. +func (lt *loggedTxn) QueryRow(query string, args ...any) *loggedRow { + start := time.Now() + row := lt.Tx.QueryRow(query, args...) + if dur := time.Since(start); dur > longQueryDuration { + lt.log.Debug("slow query row", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return &loggedRow{row, lt.log.Named("row")} +} + +func queryPlaceHolders(n int) string { + if n == 0 { + return "" + } else if n == 1 { + return "?" + } + var b strings.Builder + b.Grow(((n - 1) * 2) + 1) // ?,? + for i := 0; i < n-1; i++ { + b.WriteString("?,") + } + b.WriteString("?") + return b.String() +} + +func queryArgs[T any](args []T) []any { + if len(args) == 0 { + return nil + } + out := make([]any, len(args)) + for i, arg := range args { + out[i] = arg + } + return out +} + +// getDBVersion returns the current version of the database. +func getDBVersion(db *sql.DB) (version int64) { + // error is ignored -- the database may not have been initialized yet. + db.QueryRow(`SELECT db_version FROM global_settings;`).Scan(&version) + return +} + +// setDBVersion sets the current version of the database. +func setDBVersion(tx txn, version int64) error { + const query = `UPDATE global_settings SET db_version=$1 RETURNING id;` + var dbID int64 + return tx.QueryRow(query, version).Scan(&dbID) +} + +// jitterSleep sleeps for a random duration between t and t*1.5. +func jitterSleep(t time.Duration) { + time.Sleep(t + time.Duration(rand.Int63n(int64(t/2)))) +} diff --git a/persist/sqlite/store.go b/persist/sqlite/store.go new file mode 100644 index 0000000..0db301e --- /dev/null +++ b/persist/sqlite/store.go @@ -0,0 +1,120 @@ +package sqlite + +import ( + "database/sql" + "encoding/hex" + "fmt" + "math" + "strings" + "time" + + "go.sia.tech/coreutils/chain" + "go.uber.org/zap" + "lukechampine.com/frand" +) + +type ( + // A Store is a persistent store that uses a SQL database as its backend. + Store struct { + db *sql.DB + log *zap.Logger + + updates []*chain.ApplyUpdate + } +) + +// transaction executes a function within a database transaction. If the +// function returns an error, the transaction is rolled back. Otherwise, the +// transaction is committed. If the transaction fails due to a busy error, it is +// retried up to 10 times before returning. +func (s *Store) transaction(fn func(txn) error) error { + var err error + txnID := hex.EncodeToString(frand.Bytes(4)) + log := s.log.Named("transaction").With(zap.String("id", txnID)) + start := time.Now() + attempt := 1 + for ; attempt < maxRetryAttempts; attempt++ { + attemptStart := time.Now() + log := log.With(zap.Int("attempt", attempt)) + err = doTransaction(s.db, log, fn) + if err == nil { + // no error, break out of the loop + return nil + } + + // return immediately if the error is not a busy error + if !strings.Contains(err.Error(), "database is locked") { + break + } + // exponential backoff + sleep := time.Duration(math.Pow(factor, float64(attempt))) * time.Millisecond + if sleep > maxBackoff { + sleep = maxBackoff + } + log.Debug("database locked", zap.Duration("elapsed", time.Since(attemptStart)), zap.Duration("totalElapsed", time.Since(start)), zap.Stack("stack"), zap.Duration("retry", sleep)) + jitterSleep(sleep) + } + return fmt.Errorf("transaction failed (attempt %d): %w", attempt, err) +} + +// Close closes the underlying database. +func (s *Store) Close() error { + return s.db.Close() +} + +func sqliteFilepath(fp string) string { + params := []string{ + fmt.Sprintf("_busy_timeout=%d", busyTimeout), + "_foreign_keys=true", + "_journal_mode=WAL", + "_secure_delete=false", + "_cache_size=-65536", // 64MiB + } + return "file:" + fp + "?" + strings.Join(params, "&") +} + +// doTransaction is a helper function to execute a function within a transaction. If fn returns +// an error, the transaction is rolled back. Otherwise, the transaction is +// committed. +func doTransaction(db *sql.DB, log *zap.Logger, fn func(tx txn) error) error { + start := time.Now() + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() + defer func() { + // log the transaction if it took longer than txn duration + if time.Since(start) > longTxnDuration { + log.Debug("long transaction", zap.Duration("elapsed", time.Since(start)), zap.Stack("stack"), zap.Bool("failed", err != nil)) + } + }() + + ltx := &loggedTxn{ + Tx: tx, + log: log, + } + if err = fn(ltx); err != nil { + return err + } else if err = tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + return nil +} + +// OpenDatabase creates a new SQLite store and initializes the database. If the +// database does not exist, it is created. +func OpenDatabase(fp string, log *zap.Logger) (*Store, error) { + db, err := sql.Open("sqlite3", sqliteFilepath(fp)) + if err != nil { + return nil, err + } + store := &Store{ + db: db, + log: log, + } + if err := store.init(); err != nil { + return nil, fmt.Errorf("failed to initialize database: %w", err) + } + return store, nil +} diff --git a/persist/sqlite/types.go b/persist/sqlite/types.go new file mode 100644 index 0000000..44f1f10 --- /dev/null +++ b/persist/sqlite/types.go @@ -0,0 +1,135 @@ +package sqlite + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "time" + + "go.sia.tech/core/types" +) + +type ( + sqlCurrency types.Currency + sqlTime time.Time +) + +// Scan implements the sql.Scanner interface. +func (sc *sqlCurrency) Scan(src any) error { + buf, ok := src.([]byte) + if !ok { + return fmt.Errorf("cannot scan %T to Currency", src) + } else if len(buf) != 16 { + return fmt.Errorf("cannot scan %d bytes to Currency", len(buf)) + } + + sc.Lo = binary.LittleEndian.Uint64(buf[:8]) + sc.Hi = binary.LittleEndian.Uint64(buf[8:]) + return nil +} + +// Value implements the driver.Valuer interface. +func (sc sqlCurrency) Value() (driver.Value, error) { + buf := make([]byte, 16) + binary.LittleEndian.PutUint64(buf[:8], sc.Lo) + binary.LittleEndian.PutUint64(buf[8:], sc.Hi) + return buf, nil +} + +func (st *sqlTime) Scan(src any) error { + switch src := src.(type) { + case int64: + *st = sqlTime(time.Unix(src, 0)) + return nil + default: + return fmt.Errorf("cannot scan %T to Time", src) + } +} + +func (st sqlTime) Value() (driver.Value, error) { + return time.Time(st).Unix(), nil +} + +func encode[T types.EncoderTo](v T) []byte { + var buf bytes.Buffer + enc := types.NewEncoder(&buf) + v.EncodeTo(enc) + if err := enc.Flush(); err != nil { + panic(err) + } + return buf.Bytes() +} + +func encodeSlice[T types.EncoderTo](v []T) []byte { + var buf bytes.Buffer + enc := types.NewEncoder(&buf) + enc.WritePrefix(len(v)) + for _, e := range v { + e.EncodeTo(enc) + } + if err := enc.Flush(); err != nil { + panic(err) + } + return buf.Bytes() +} + +type decodableSlice[T any] struct { + v *[]T + n int64 +} + +func (d *decodableSlice[T]) Scan(src any) error { + switch src := src.(type) { + case []byte: + dec := types.NewDecoder(io.LimitedReader{ + R: bytes.NewReader(src), + N: d.n, + }) + s := make([]T, dec.ReadPrefix()) + for i := range s { + dv, ok := any(&s[i]).(types.DecoderFrom) + if !ok { + panic(fmt.Errorf("cannot decode %T", s[i])) + } + dv.DecodeFrom(dec) + } + if err := dec.Err(); err != nil { + return err + } + *d.v = s + return nil + default: + return fmt.Errorf("cannot scan %T to []byte", src) + } +} + +func decodeSlice[T any](v *[]T, maxLen int64) sql.Scanner { + return &decodableSlice[T]{v: v, n: maxLen} +} + +type decodable[T types.DecoderFrom] struct { + v T + n int64 +} + +func (d *decodable[T]) Scan(src any) error { + switch src := src.(type) { + case []byte: + dec := types.NewDecoder(io.LimitedReader{ + R: bytes.NewReader(src), + N: d.n, + }) + + d.v.DecodeFrom(dec) + return dec.Err() + default: + return fmt.Errorf("cannot scan %T to []byte", src) + } +} + +func decode[T types.DecoderFrom](v T, maxLen int64) sql.Scanner { + return &decodable[T]{v, maxLen} +} diff --git a/persist/sqlite/wallet.go b/persist/sqlite/wallet.go new file mode 100644 index 0000000..56ba3fd --- /dev/null +++ b/persist/sqlite/wallet.go @@ -0,0 +1,292 @@ +package sqlite + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + + "go.sia.tech/core/types" + "go.sia.tech/walletd/wallet" +) + +func insertAddress(tx txn, addr types.Address) (id int64, err error) { + const query = `INSERT INTO sia_addresses (sia_address, siacoin_balance, siafund_balance) +VALUES ($1, $2, 0) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address +RETURNING id` + + err = tx.QueryRow(query, encode(addr), (*sqlCurrency)(&types.ZeroCurrency)).Scan(&id) + return +} + +func (s *Store) WalletEvents(walletID string, offset, limit int) (events []wallet.Event, err error) { + err = s.transaction(func(tx txn) error { + const query = `SELECT ev.id, ev.date_created, ci.height, ci.block_id, ev.event_type, ev.event_data +FROM events ev +INNER JOIN chain_indices ci ON (ev.index_id = ci.id) +WHERE ev.id IN (SELECT event_id FROM event_addresses WHERE address_id IN (SELECT address_id FROM wallet_addresses WHERE wallet_id=$1)) +ORDER BY ci.height DESC, ev.id ASC +LIMIT $2 OFFSET $3` + + rows, err := tx.Query(query, walletID, limit, offset) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var eventID int64 + var event wallet.Event + var eventType string + var eventBuf []byte + + err := rows.Scan(&eventID, (*sqlTime)(&event.Timestamp), &event.Index.Height, decode(&event.Index.ID, 32), &eventType, &eventBuf) + if err != nil { + return fmt.Errorf("failed to scan event: %w", err) + } + + switch eventType { + case wallet.EventTypeTransaction: + var tx wallet.EventTransaction + if err = json.Unmarshal(eventBuf, &tx); err != nil { + return fmt.Errorf("failed to unmarshal transaction event: %w", err) + } + event.Val = &tx + case wallet.EventTypeMissedFileContract: + var m wallet.EventMissedFileContract + if err = json.Unmarshal(eventBuf, &m); err != nil { + return fmt.Errorf("failed to unmarshal missed file contract event: %w", err) + } + event.Val = &m + case wallet.EventTypeMinerPayout: + var m wallet.EventMinerPayout + if err = json.Unmarshal(eventBuf, &m); err != nil { + return fmt.Errorf("failed to unmarshal payout event: %w", err) + } + event.Val = &m + default: + return fmt.Errorf("unknown event type: %s", eventType) + } + + // event.Relevant = relevantAddresses[eventID] + events = append(events, event) + } + return nil + }) + return +} + +func (s *Store) AddWallet(name string, info json.RawMessage) error { + return s.transaction(func(tx txn) error { + const query = `INSERT INTO wallets (id, extra_data) VALUES ($1, $2)` + + _, err := tx.Exec(query, name, info) + if err != nil { + return fmt.Errorf("failed to insert wallet: %w", err) + } + return nil + }) +} + +func (s *Store) DeleteWallet(name string) error { + return s.transaction(func(tx txn) error { + _, err := tx.Exec(`DELETE FROM wallets WHERE id=$1`, name) + return err + }) +} + +func (s *Store) Wallets() (map[string]json.RawMessage, error) { + wallets := make(map[string]json.RawMessage) + err := s.transaction(func(tx txn) error { + const query = `SELECT id, extra_data FROM wallets` + + rows, err := tx.Query(query) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var friendlyName string + var extraData json.RawMessage + if err := rows.Scan(&friendlyName, &extraData); err != nil { + return fmt.Errorf("failed to scan wallet: %w", err) + } + wallets[friendlyName] = extraData + } + return nil + }) + return wallets, err +} + +func (s *Store) AddAddress(walletID string, address types.Address, info json.RawMessage) error { + return s.transaction(func(tx txn) error { + addressID, err := insertAddress(tx, address) + if err != nil { + return fmt.Errorf("failed to insert address: %w", err) + } + _, err = tx.Exec(`INSERT INTO wallet_addresses (wallet_id, extra_data, address_id) VALUES ($1, $2, $3)`, walletID, info, addressID) + return err + }) +} + +func (s *Store) RemoveAddress(walletID string, address types.Address) error { + return s.transaction(func(tx txn) error { + const query = `DELETE FROM wallet_addresses WHERE wallet_id=$1 AND address_id=(SELECT id FROM sia_addresses WHERE sia_address=$2)` + _, err := tx.Exec(query, walletID, encode(address)) + return err + }) +} + +func (s *Store) Addresses(walletID string) (map[types.Address]json.RawMessage, error) { + addresses := make(map[types.Address]json.RawMessage) + err := s.transaction(func(tx txn) error { + const query = `SELECT sa.sia_address, wa.extra_data +FROM wallet_addresses wa +INNER JOIN sia_addresses sa ON (sa.id = wa.address_id) +WHERE wa.wallet_id=$1` + + rows, err := tx.Query(query, walletID) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var address types.Address + var extraData json.RawMessage + if err := rows.Scan(decode(&address, 32), &extraData); err != nil { + return fmt.Errorf("failed to scan address: %w", err) + } + addresses[address] = extraData + } + return nil + }) + return addresses, err +} + +func (s *Store) UnspentSiacoinOutputs(walletID string) (siacoins []types.SiacoinElement, err error) { + err = s.transaction(func(tx txn) error { + const query = `SELECT se.id, se.leaf_index, se.merkle_proof, se.siacoin_value, sa.sia_address, se.maturity_height + FROM siacoin_elements se + INNER JOIN sia_addresses sa ON (se.address_id = sa.id) + WHERE se.address_id IN (SELECT address_id FROM wallet_addresses WHERE wallet_id=$1)` + + rows, err := tx.Query(query, walletID) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var siacoin types.SiacoinElement + var proof []byte + + err := rows.Scan(decode(&siacoin.ID, 32), &siacoin.LeafIndex, &proof, (*sqlCurrency)(&siacoin.SiacoinOutput.Value), decode(&siacoin.SiacoinOutput.Address, 32), &siacoin.MaturityHeight) + if err != nil { + return fmt.Errorf("failed to scan siacoin element: %w", err) + } + siacoins = append(siacoins, siacoin) + } + return nil + }) + return +} + +func (s *Store) UnspentSiafundOutputs(walletID string) (siafunds []types.SiafundElement, err error) { + err = s.transaction(func(tx txn) error { + const query = `SELECT se.id, se.leaf_index, se.merkle_proof, se.siafund_value, se.claim_start, sa.sia_address + FROM siafund_elements se + INNER JOIN sia_addresses sa ON (se.address_id = sa.id) + WHERE se.address_id IN (SELECT address_id FROM wallet_addresses WHERE wallet_id=$1)` + + rows, err := tx.Query(query, walletID) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var siafund types.SiafundElement + var proof []byte + + err := rows.Scan(decode(&siafund.ID, 32), &siafund.LeafIndex, &proof, &siafund.SiafundOutput.Value, (*sqlCurrency)(&siafund.ClaimStart), decode(&siafund.SiafundOutput.Address, 32)) + if err != nil { + return fmt.Errorf("failed to scan siacoin element: %w", err) + } + siafunds = append(siafunds, siafund) + } + return nil + }) + return +} + +// WalletBalance returns the total balance of a wallet. +func (s *Store) WalletBalance(walletID string) (sc types.Currency, sf uint64, err error) { + err = s.transaction(func(tx txn) error { + const query = `SELECT siacoin_balance, siafund_balance FROM sia_addresses sa + INNER JOIN wallet_addresses wa ON (sa.id = wa.address_id) + WHERE wa.wallet_id=$1` + + rows, err := tx.Query(query, walletID) + if err != nil { + return err + } + + for rows.Next() { + var siacoin types.Currency + var siafund uint64 + + if err := rows.Scan((*sqlCurrency)(&siacoin), &siafund); err != nil { + return fmt.Errorf("failed to scan address balance: %w", err) + } + sc = sc.Add(siacoin) + sf += siafund + } + return nil + }) + return +} + +// AddressBalance returns the balance of a single address. +func (s *Store) AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) { + err = s.transaction(func(tx txn) error { + const query = `SELECT siacoin_balance, siafund_balance FROM address_balance WHERE sia_address=$1` + return tx.QueryRow(query, encode(address)).Scan((*sqlCurrency)(&sc), &sf) + }) + return +} + +func (s *Store) Annotate(walletID string, txns []types.Transaction) (annotated []wallet.PoolTransaction, err error) { + err = s.transaction(func(tx txn) error { + stmt, err := tx.Prepare(`SELECT sia_address FROM wallet_addresses WHERE wallet_id=$1 AND sia_address=$2 LIMIT 1`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer stmt.Close() + + // note: this would be more performant for small wallets to load all + // addresses into memory. However, for larger wallets (> 10K addresses), + // this is time consuming. Instead, the database is queried for each + // address. Monitor performance and consider changing this in the + // future. From a memory perspective, it would be fine to lazy load all + // addresses into memory. + ownsAddress := func(address types.Address) bool { + var dbID int64 + err := stmt.QueryRow(walletID, encode(address)).Scan(dbID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + panic(err) // database error + } + return err == nil + } + + for _, txn := range txns { + ptxn := wallet.Annotate(txn, ownsAddress) + if ptxn.Type != "unrelated" { + annotated = append(annotated, ptxn) + } + } + return nil + }) + return +} diff --git a/wallet/state.go b/wallet/state.go new file mode 100644 index 0000000..219a7f2 --- /dev/null +++ b/wallet/state.go @@ -0,0 +1,83 @@ +package wallet + +import ( + "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" +) + +// A Midstate is a snapshot of unapplied consensus changes. +type Midstate struct { + SpentSiacoinOutputs map[types.Hash256]bool + SpentSiafundOutputs map[types.Hash256]bool + + NewSiacoinOutputs map[types.Hash256]types.SiacoinElement + NewSiafundOutputs map[types.Hash256]types.SiafundElement + + Events []Event +} + +func (ms *Midstate) Apply(cau *chain.ApplyUpdate, ownsAddress func(types.Address) bool) { + events := AppliedEvents(cau.State, cau.Block, cau, ownsAddress) + ms.Events = append(ms.Events, events...) + + cau.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { + if !ownsAddress(se.SiacoinOutput.Address) { + return + } + + if spent { + ms.SpentSiacoinOutputs[se.ID] = true + delete(ms.NewSiacoinOutputs, se.ID) + } else { + ms.NewSiacoinOutputs[se.ID] = se + } + }) + + cau.ForEachSiafundElement(func(sf types.SiafundElement, spent bool) { + if !ownsAddress(sf.SiafundOutput.Address) { + return + } + + if spent { + ms.SpentSiafundOutputs[sf.ID] = true + delete(ms.NewSiafundOutputs, sf.ID) + } else { + ms.NewSiafundOutputs[sf.ID] = sf + } + }) +} + +func (ms *Midstate) Revert(cru *chain.RevertUpdate, ownsAddress func(types.Address) bool) { + revertedBlockID := cru.Block.ID() + for i := len(ms.Events) - 1; i >= 0; i-- { + // working backwards, revert all events until the block ID no longer + // matches. + if ms.Events[i].Index.ID != revertedBlockID { + break + } + ms.Events = ms.Events[:i] + } + + cru.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { + if !ownsAddress(se.SiacoinOutput.Address) { + return + } + + if !spent { + delete(ms.SpentSiacoinOutputs, se.ID) + } + }) + + cru.ForEachSiafundElement(func(sf types.SiafundElement, spent bool) { + if !ownsAddress(sf.SiafundOutput.Address) { + return + } + + if spent { + ms.SpentSiafundOutputs[sf.ID] = true + delete(ms.NewSiafundOutputs, sf.ID) + } else { + ms.NewSiafundOutputs[sf.ID] = sf + } + }) +} diff --git a/wallet/wallet.go b/wallet/wallet.go index ecc9754..d1a458c 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -9,6 +9,13 @@ import ( "go.sia.tech/core/types" ) +const ( + // transactions + EventTypeTransaction = "transaction" + EventTypeMinerPayout = "miner payout" + EventTypeMissedFileContract = "missed file contract" +) + // StandardTransactionSignature is the most common form of TransactionSignature. // It covers the entire transaction, references a sole public key, and has no // timelock. @@ -143,12 +150,12 @@ type Event struct { Index types.ChainIndex Timestamp time.Time Relevant []types.Address - Val interface{ eventType() string } + Val interface{ EventType() string } } -func (*EventTransaction) eventType() string { return "transaction" } -func (*EventMinerPayout) eventType() string { return "miner payout" } -func (*EventMissedFileContract) eventType() string { return "missed file contract" } +func (*EventTransaction) EventType() string { return EventTypeTransaction } +func (*EventMinerPayout) EventType() string { return EventTypeMinerPayout } +func (*EventMissedFileContract) EventType() string { return EventTypeMissedFileContract } // MarshalJSON implements json.Marshaler. func (e Event) MarshalJSON() ([]byte, error) { @@ -163,7 +170,7 @@ func (e Event) MarshalJSON() ([]byte, error) { Timestamp: e.Timestamp, Index: e.Index, Relevant: e.Relevant, - Type: e.Val.eventType(), + Type: e.Val.EventType(), Val: val, }) } @@ -184,11 +191,11 @@ func (e *Event) UnmarshalJSON(data []byte) error { e.Index = s.Index e.Relevant = s.Relevant switch s.Type { - case (*EventTransaction)(nil).eventType(): + case (*EventTransaction)(nil).EventType(): e.Val = new(EventTransaction) - case (*EventMinerPayout)(nil).eventType(): + case (*EventMinerPayout)(nil).EventType(): e.Val = new(EventMinerPayout) - case (*EventMissedFileContract)(nil).eventType(): + case (*EventMissedFileContract)(nil).EventType(): e.Val = new(EventMissedFileContract) } if e.Val == nil { @@ -259,7 +266,7 @@ type ChainUpdate interface { // AppliedEvents extracts a list of relevant events from a chain update. func AppliedEvents(cs consensus.State, b types.Block, cu ChainUpdate, relevant func(types.Address) bool) []Event { var events []Event - addEvent := func(v interface{ eventType() string }, relevant []types.Address) { + addEvent := func(v interface{ EventType() string }, relevant []types.Address) { // dedup relevant addresses seen := make(map[types.Address]bool) unique := relevant[:0] From 46db8cc74b1f858fd1898d32c7d6d364d969696c Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 9 Jan 2024 15:50:09 -0800 Subject: [PATCH 02/24] wallet: add manager --- wallet/manager.go | 173 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 wallet/manager.go diff --git a/wallet/manager.go b/wallet/manager.go new file mode 100644 index 0000000..e4a2380 --- /dev/null +++ b/wallet/manager.go @@ -0,0 +1,173 @@ +package wallet + +import ( + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "go.sia.tech/core/chain" + "go.sia.tech/core/types" + "go.uber.org/zap" +) + +type ( + ChainManager interface { + AddSubscriber(chain.Subscriber, types.ChainIndex) error + RemoveSubscriber(chain.Subscriber) + + BestIndex(height uint64) (types.ChainIndex, bool) + } + + Store interface { + chain.Subscriber + + WalletEvents(name string, offset, limit int) ([]Event, error) + AddWallet(name string, info json.RawMessage) error + DeleteWallet(name string) error + Wallets() (map[string]json.RawMessage, error) + + AddAddress(walletID string, address types.Address, info json.RawMessage) error + RemoveAddress(walletID string, address types.Address) error + Addresses(walletID string) (map[types.Address]json.RawMessage, error) + UnspentSiacoinOutputs(walletID string) ([]types.SiacoinElement, error) + UnspentSiafundOutputs(walletID string) ([]types.SiafundElement, error) + Annotate(walletID string, txns []types.Transaction) ([]PoolTransaction, error) + WalletBalance(walletID string) (sc types.Currency, sf uint64, err error) + + AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) + + LastCommittedIndex() (types.ChainIndex, error) + } + + // A Manager manages wallets. + Manager struct { + chain ChainManager + store Store + log *zap.Logger + + mu sync.Mutex + used map[types.Hash256]bool + } +) + +// AddWallet adds the given wallet. +func (m *Manager) AddWallet(name string, info json.RawMessage) error { + return m.store.AddWallet(name, info) +} + +// DeleteWallet deletes the given wallet. +func (m *Manager) DeleteWallet(name string) error { + return m.store.DeleteWallet(name) +} + +// Wallets returns the wallets of the wallet manager. +func (m *Manager) Wallets() (map[string]json.RawMessage, error) { + return m.store.Wallets() +} + +// AddAddress adds the given address to the given wallet. +func (m *Manager) AddAddress(name string, addr types.Address, info json.RawMessage) error { + return m.store.AddAddress(name, addr, info) +} + +// RemoveAddress removes the given address from the given wallet. +func (m *Manager) RemoveAddress(name string, addr types.Address) error { + return m.store.RemoveAddress(name, addr) +} + +// Addresses returns the addresses of the given wallet. +func (m *Manager) Addresses(name string) (map[types.Address]json.RawMessage, error) { + return m.store.Addresses(name) +} + +// Events returns the events of the given wallet. +func (m *Manager) Events(name string, offset, limit int) ([]Event, error) { + return m.store.WalletEvents(name, offset, limit) +} + +// UnspentSiacoinOutputs returns the unspent siacoin outputs of the given wallet +func (m *Manager) UnspentSiacoinOutputs(name string) ([]types.SiacoinElement, error) { + return m.store.UnspentSiacoinOutputs(name) +} + +// UnspentSiafundOutputs returns the unspent siafund outputs of the given wallet +func (m *Manager) UnspentSiafundOutputs(name string) ([]types.SiafundElement, error) { + return m.store.UnspentSiafundOutputs(name) +} + +// Annotate annotates the given transactions with the wallet they belong to. +func (m *Manager) Annotate(name string, pool []types.Transaction) ([]PoolTransaction, error) { + return m.store.Annotate(name, pool) +} + +// WalletBalance returns the balance of the given wallet. +func (m *Manager) WalletBalance(walletID string) (sc types.Currency, sf uint64, err error) { + return m.store.WalletBalance(walletID) +} + +// AddressBalance returns the balance of the given address. +func (m *Manager) AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) { + return m.store.AddressBalance(address) +} + +// Reserve reserves the given ids for the given duration. +func (m *Manager) Reserve(ids []types.Hash256, duration time.Duration) error { + m.mu.Lock() + defer m.mu.Unlock() + + // check if any of the ids are already reserved + for _, id := range ids { + if m.used[id] { + return fmt.Errorf("output %q already reserved", id) + } + } + + // reserve the ids + for _, id := range ids { + m.used[id] = true + } + + // sleep for the duration and then unreserve the ids + time.AfterFunc(duration, func() { + m.mu.Lock() + defer m.mu.Unlock() + + for _, id := range ids { + delete(m.used, id) + } + }) + return nil +} + +// Subscribe resubscribes the indexer starting at the given height. +func (m *Manager) Subscribe(startHeight uint64) error { + var index types.ChainIndex + if startHeight > 0 { + var ok bool + index, ok = m.chain.BestIndex(startHeight - 1) + if !ok { + return errors.New("invalid height") + } + } + m.chain.RemoveSubscriber(m.store) + return m.chain.AddSubscriber(m.store, index) +} + +// NewManager creates a new wallet manager. +func NewManager(cm ChainManager, store Store, log *zap.Logger) (*Manager, error) { + m := &Manager{ + chain: cm, + store: store, + log: log, + } + + lastTip, err := store.LastCommittedIndex() + if err != nil { + return nil, fmt.Errorf("failed to get last committed index: %w", err) + } else if err := cm.AddSubscriber(store, lastTip); err != nil { + return nil, fmt.Errorf("failed to subscribe to chain manager: %w", err) + } + return m, nil +} From 591f38e8ccc7444d19732de0d8e79913678c8d53 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 9 Jan 2024 15:50:18 -0800 Subject: [PATCH 03/24] api, cmd: use sqlite3 store --- api/api_test.go | 64 ++++++++++++++++++++++++------ api/client.go | 14 +++---- api/server.go | 97 ++++++++++++++++++--------------------------- cmd/walletd/main.go | 27 ++++++++++++- cmd/walletd/node.go | 37 ++++++++++------- wallet/manager.go | 2 +- 6 files changed, 145 insertions(+), 96 deletions(-) diff --git a/api/api_test.go b/api/api_test.go index 37ad08d..9ffa789 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -3,6 +3,7 @@ package api_test import ( "net" "net/http" + "path/filepath" "testing" "time" @@ -14,8 +15,9 @@ import ( "go.sia.tech/jape" "go.sia.tech/walletd/api" "go.sia.tech/walletd/internal/syncerutil" - "go.sia.tech/walletd/internal/walletutil" + "go.sia.tech/walletd/persist/sqlite" "go.sia.tech/walletd/wallet" + "go.uber.org/zap/zaptest" "lukechampine.com/frand" ) @@ -48,6 +50,8 @@ func runServer(cm api.ChainManager, s api.Syncer, wm api.WalletManager) (*api.Cl } func TestWallet(t *testing.T) { + log := zaptest.NewLogger(t) + n, genesisBlock := testNetwork() giftPrivateKey := types.GeneratePrivateKey() giftAddress := types.StandardUnlockHash(giftPrivateKey.PublicKey()) @@ -62,7 +66,17 @@ func TestWallet(t *testing.T) { t.Fatal(err) } cm := chain.NewManager(dbstore, tipState) - wm := walletutil.NewEphemeralWalletManager(cm) + + ws, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer ws.Close() + wm, err := wallet.NewManager(cm, ws, log.Named("wallet")) + if err != nil { + t.Fatal(err) + } + sav := wallet.NewSeedAddressVault(wallet.NewSeed(), 0, 20) c, shutdown := runServer(cm, nil, wm) defer shutdown() @@ -70,7 +84,7 @@ func TestWallet(t *testing.T) { t.Fatal(err) } wc := c.Wallet("primary") - if err := wc.Subscribe(0); err != nil { + if err := c.Resubscribe(0); err != nil { t.Fatal(err) } @@ -153,7 +167,7 @@ func TestWallet(t *testing.T) { } // transaction should appear in history - events, err = wc.Events(0, -1) + events, err = wc.Events(0, 100) if err != nil { t.Fatal(err) } else if len(events) == 0 { @@ -169,6 +183,8 @@ func TestWallet(t *testing.T) { } func TestV2(t *testing.T) { + log := zaptest.NewLogger(t) + n, genesisBlock := testNetwork() // gift primary wallet some coins primaryPrivateKey := types.GeneratePrivateKey() @@ -184,7 +200,15 @@ func TestV2(t *testing.T) { t.Fatal(err) } cm := chain.NewManager(dbstore, tipState) - wm := walletutil.NewEphemeralWalletManager(cm) + ws, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer ws.Close() + wm, err := wallet.NewManager(cm, ws, log.Named("wallet")) + if err != nil { + t.Fatal(err) + } c, shutdown := runServer(cm, nil, wm) defer shutdown() if err := c.AddWallet("primary", nil); err != nil { @@ -194,9 +218,6 @@ func TestV2(t *testing.T) { if err := primary.AddAddress(primaryAddress, nil); err != nil { t.Fatal(err) } - if err := primary.Subscribe(0); err != nil { - t.Fatal(err) - } if err := c.AddWallet("secondary", nil); err != nil { t.Fatal(err) } @@ -204,7 +225,7 @@ func TestV2(t *testing.T) { if err := secondary.AddAddress(secondaryAddress, nil); err != nil { t.Fatal(err) } - if err := secondary.Subscribe(0); err != nil { + if err := c.Resubscribe(0); err != nil { t.Fatal(err) } @@ -373,6 +394,7 @@ func TestV2(t *testing.T) { } func TestP2P(t *testing.T) { + log := zaptest.NewLogger(t) n, genesisBlock := testNetwork() // gift primary wallet some coins primaryPrivateKey := types.GeneratePrivateKey() @@ -388,7 +410,15 @@ func TestP2P(t *testing.T) { t.Fatal(err) } cm1 := chain.NewManager(dbstore1, tipState) - wm1 := walletutil.NewEphemeralWalletManager(cm1) + ws1, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer ws1.Close() + wm1, err := wallet.NewManager(cm1, ws1, log.Named("wallet")) + if err != nil { + t.Fatal(err) + } l1, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) @@ -409,7 +439,7 @@ func TestP2P(t *testing.T) { if err := primary.AddAddress(primaryAddress, nil); err != nil { t.Fatal(err) } - if err := primary.Subscribe(0); err != nil { + if err := c1.Resubscribe(0); err != nil { t.Fatal(err) } @@ -418,7 +448,15 @@ func TestP2P(t *testing.T) { t.Fatal(err) } cm2 := chain.NewManager(dbstore2, tipState) - wm2 := walletutil.NewEphemeralWalletManager(cm2) + ws2, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer ws2.Close() + wm2, err := wallet.NewManager(cm2, ws2, log.Named("wallet")) + if err != nil { + t.Fatal(err) + } l2, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) @@ -439,7 +477,7 @@ func TestP2P(t *testing.T) { if err := secondary.AddAddress(secondaryAddress, nil); err != nil { t.Fatal(err) } - if err := secondary.Subscribe(0); err != nil { + if err := c2.Resubscribe(0); err != nil { t.Fatal(err) } diff --git a/api/client.go b/api/client.go index 973f194..3258d9e 100644 --- a/api/client.go +++ b/api/client.go @@ -105,6 +105,13 @@ func (c *Client) Wallet(name string) *WalletClient { return &WalletClient{c: c.c, name: name} } +// Resubscribe subscribes the wallet to consensus updates, starting at the +// specified height. This can only be done once. +func (c *Client) Resubscribe(height uint64) (err error) { + err = c.c.POST("/resubscribe", height, nil) + return +} + // A WalletClient provides methods for interacting with a particular wallet on a // walletd API server. type WalletClient struct { @@ -112,13 +119,6 @@ type WalletClient struct { name string } -// Subscribe subscribes the wallet to consensus updates, starting at the -// specified height. This can only be done once. -func (c *WalletClient) Subscribe(height uint64) (err error) { - err = c.c.POST(fmt.Sprintf("/wallets/%v/subscribe", c.name), height, nil) - return -} - // AddAddress adds the specified address and associated metadata to the // wallet. func (c *WalletClient) AddAddress(addr types.Address, info json.RawMessage) (err error) { diff --git a/api/server.go b/api/server.go index 898ba29..f7c87d7 100644 --- a/api/server.go +++ b/api/server.go @@ -3,7 +3,6 @@ package api import ( "encoding/json" "errors" - "fmt" "net/http" "reflect" "sync" @@ -46,17 +45,23 @@ type ( // A WalletManager manages wallets, keyed by name. WalletManager interface { + Subscribe(startHeight uint64) error + AddWallet(name string, info json.RawMessage) error DeleteWallet(name string) error - Wallets() map[string]json.RawMessage - SubscribeWallet(name string, startHeight uint64) error + Wallets() (map[string]json.RawMessage, error) AddAddress(name string, addr types.Address, info json.RawMessage) error RemoveAddress(name string, addr types.Address) error Addresses(name string) (map[types.Address]json.RawMessage, error) Events(name string, offset, limit int) ([]wallet.Event, error) - UnspentOutputs(name string) ([]types.SiacoinElement, []types.SiafundElement, error) + UnspentSiacoinOutputs(name string) ([]types.SiacoinElement, error) + UnspentSiafundOutputs(name string) ([]types.SiafundElement, error) + WalletBalance(walletID string) (sc types.Currency, sf uint64, err error) Annotate(name string, pool []types.Transaction) ([]wallet.PoolTransaction, error) + + Reserve(ids []types.Hash256, duration time.Duration) error + AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) } ) @@ -165,7 +170,11 @@ func (s *server) txpoolBroadcastHandler(jc jape.Context) { } func (s *server) walletsHandler(jc jape.Context) { - jc.Encode(s.wm.Wallets()) + wallets, err := s.wm.Wallets() + if jc.Check("couldn't load wallets", err) != nil { + return + } + jc.Encode(wallets) } func (s *server) walletsNameHandlerPUT(jc jape.Context) { @@ -187,12 +196,11 @@ func (s *server) walletsNameHandlerDELETE(jc jape.Context) { } } -func (s *server) walletsSubscribeHandler(jc jape.Context) { - var name string +func (s *server) resubscribeHandler(jc jape.Context) { var height uint64 - if jc.DecodeParam("name", &name) != nil || jc.Decode(&height) != nil { + if jc.Decode(&height) != nil { return - } else if jc.Check("couldn't subscribe wallet", s.wm.SubscribeWallet(name, height)) != nil { + } else if jc.Check("couldn't subscribe wallet", s.wm.Subscribe(height)) != nil { return } } @@ -235,26 +243,14 @@ func (s *server) walletsBalanceHandler(jc jape.Context) { if jc.DecodeParam("name", &name) != nil { return } - scos, sfos, err := s.wm.UnspentOutputs(name) - if jc.Check("couldn't load outputs", err) != nil { + + sc, sf, err := s.wm.WalletBalance(name) + if jc.Check("couldn't load balance", err) != nil { return } - height := s.cm.TipState().Index.Height - var sc, immature types.Currency - var sf uint64 - for _, sco := range scos { - if height >= sco.MaturityHeight { - sc = sc.Add(sco.SiacoinOutput.Value) - } else { - immature = immature.Add(sco.SiacoinOutput.Value) - } - } - for _, sfo := range sfos { - sf += sfo.SiafundOutput.Value - } jc.Encode(WalletBalanceResponse{ Siacoins: sc, - ImmatureSiacoins: immature, + ImmatureSiacoins: types.ZeroCurrency, Siafunds: sf, }) } @@ -289,8 +285,13 @@ func (s *server) walletsOutputsHandler(jc jape.Context) { if jc.DecodeParam("name", &name) != nil { return } - scos, sfos, err := s.wm.UnspentOutputs(name) - if jc.Check("couldn't load outputs", err) != nil { + scos, err := s.wm.UnspentSiacoinOutputs(name) + if jc.Check("couldn't load siacoin outputs", err) != nil { + return + } + + sfos, err := s.wm.UnspentSiafundOutputs(name) + if jc.Check("couldn't load siafund outputs", err) != nil { return } jc.Encode(WalletOutputsResponse{ @@ -300,44 +301,23 @@ func (s *server) walletsOutputsHandler(jc jape.Context) { } func (s *server) walletsReserveHandler(jc jape.Context) { - var name string var wrr WalletReserveRequest - if jc.DecodeParam("name", &name) != nil || jc.Decode(&wrr) != nil { + if jc.Decode(&wrr) != nil { return } - s.mu.Lock() + ids := make([]types.Hash256, 0, len(wrr.SiacoinOutputs)+len(wrr.SiafundOutputs)) for _, id := range wrr.SiacoinOutputs { - if s.used[types.Hash256(id)] { - s.mu.Unlock() - jc.Error(fmt.Errorf("output %v is already reserved", id), http.StatusBadRequest) - return - } - s.used[types.Hash256(id)] = true + ids = append(ids, types.Hash256(id)) } + for _, id := range wrr.SiafundOutputs { - if s.used[types.Hash256(id)] { - s.mu.Unlock() - jc.Error(fmt.Errorf("output %v is already reserved", id), http.StatusBadRequest) - return - } - s.used[types.Hash256(id)] = true + ids = append(ids, types.Hash256(id)) } - s.mu.Unlock() - if wrr.Duration == 0 { - wrr.Duration = 10 * time.Minute + if jc.Check("couldn't reserve outputs", s.wm.Reserve(ids, wrr.Duration)) != nil { + return } - time.AfterFunc(wrr.Duration, func() { - s.mu.Lock() - defer s.mu.Unlock() - for _, id := range wrr.SiacoinOutputs { - delete(s.used, types.Hash256(id)) - } - for _, id := range wrr.SiafundOutputs { - delete(s.used, types.Hash256(id)) - } - }) } func (s *server) walletsReleaseHandler(jc jape.Context) { @@ -412,7 +392,7 @@ func (s *server) walletsFundHandler(jc jape.Context) { if jc.DecodeParam("name", &name) != nil || jc.Decode(&wfr) != nil { return } - utxos, _, err := s.wm.UnspentOutputs(name) + utxos, err := s.wm.UnspentSiacoinOutputs(name) if jc.Check("couldn't get utxos to fund transaction", err) != nil { return } @@ -486,7 +466,7 @@ func (s *server) walletsFundSFHandler(jc jape.Context) { if jc.DecodeParam("name", &name) != nil || jc.Decode(&wfr) != nil { return } - _, utxos, err := s.wm.UnspentOutputs(name) + utxos, err := s.wm.UnspentSiafundOutputs(name) if jc.Check("couldn't get utxos to fund transaction", err) != nil { return } @@ -524,10 +504,11 @@ func NewServer(cm ChainManager, s Syncer, wm WalletManager) http.Handler { "GET /txpool/fee": srv.txpoolFeeHandler, "POST /txpool/broadcast": srv.txpoolBroadcastHandler, + "POST /resubscribe": srv.resubscribeHandler, + "GET /wallets": srv.walletsHandler, "PUT /wallets/:name": srv.walletsNameHandlerPUT, "DELETE /wallets/:name": srv.walletsNameHandlerDELETE, - "POST /wallets/:name/subscribe": srv.walletsSubscribeHandler, "PUT /wallets/:name/addresses/:addr": srv.walletsAddressHandlerPUT, "DELETE /wallets/:name/addresses/:addr": srv.walletsAddressHandlerDELETE, "GET /wallets/:name/addresses": srv.walletsAddressesHandlerGET, diff --git a/cmd/walletd/main.go b/cmd/walletd/main.go index e87d9e8..ae4f3ec 100644 --- a/cmd/walletd/main.go +++ b/cmd/walletd/main.go @@ -11,6 +11,8 @@ import ( "go.sia.tech/core/types" "go.sia.tech/walletd/wallet" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" "golang.org/x/term" "lukechampine.com/flagg" "lukechampine.com/frand" @@ -162,7 +164,27 @@ func main() { if err != nil { log.Fatal(err) } - n, err := newNode(gatewayAddr, dir, network, upnp) + + // configure console logging note: this is configured before anything else + // to have consistent logging. File logging will be added after the cli + // flags and config is parsed + consoleCfg := zap.NewProductionEncoderConfig() + consoleCfg.TimeKey = "" // prevent duplicate timestamps + consoleCfg.EncodeTime = zapcore.RFC3339TimeEncoder + consoleCfg.EncodeDuration = zapcore.StringDurationEncoder + consoleCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder + consoleCfg.StacktraceKey = "" + consoleCfg.CallerKey = "" + consoleEncoder := zapcore.NewConsoleEncoder(consoleCfg) + + // only log info messages to console unless stdout logging is enabled + consoleCore := zapcore.NewCore(consoleEncoder, zapcore.Lock(os.Stdout), zap.NewAtomicLevelAt(zap.InfoLevel)) + logger := zap.New(consoleCore, zap.AddCaller()) + defer logger.Sync() + // redirect stdlib log to zap + zap.RedirectStdLog(logger.Named("stdlib")) + + n, err := newNode(gatewayAddr, dir, network, upnp, logger) if err != nil { log.Fatal(err) } @@ -170,6 +192,8 @@ func main() { stop := n.Start() log.Println("api: Listening on", l.Addr()) go startWeb(l, n, apiPassword) + log.Println("api: Listening on", l.Addr()) + go startWeb(l, n, apiPassword) signalCh := make(chan os.Signal, 1) signal.Notify(signalCh, os.Interrupt) <-signalCh @@ -204,7 +228,6 @@ func main() { seed := loadTestnetSeed(seed) c := initTestnetClient(apiAddr, network, seed) runTestnetMiner(c, seed) - case balanceCmd: if len(cmd.Args()) != 0 { cmd.Usage() diff --git a/cmd/walletd/node.go b/cmd/walletd/node.go index 3249fe6..bd6119b 100644 --- a/cmd/walletd/node.go +++ b/cmd/walletd/node.go @@ -3,7 +3,7 @@ package main import ( "context" "errors" - "log" + "fmt" "net" "path/filepath" "strconv" @@ -16,7 +16,8 @@ import ( "go.sia.tech/coreutils/chain" "go.sia.tech/coreutils/syncer" "go.sia.tech/walletd/internal/syncerutil" - "go.sia.tech/walletd/internal/walletutil" + "go.sia.tech/walletd/persist/sqlite" + "go.sia.tech/walletd/wallet" "go.uber.org/zap" "lukechampine.com/upnp" ) @@ -85,12 +86,12 @@ var anagamiBootstrap = []string{ type node struct { cm *chain.Manager s *syncer.Syncer - wm *walletutil.JSONWalletManager + wm *wallet.Manager Start func() (stop func()) } -func newNode(addr, dir string, chainNetwork string, useUPNP bool) (*node, error) { +func newNode(addr, dir string, chainNetwork string, useUPNP bool, log *zap.Logger) (*node, error) { var network *consensus.Network var genesisBlock types.Block var bootstrapPeers []string @@ -110,11 +111,11 @@ func newNode(addr, dir string, chainNetwork string, useUPNP bool) (*node, error) bdb, err := coreutils.OpenBoltChainDB(filepath.Join(dir, "consensus.db")) if err != nil { - log.Fatal(err) + return nil, fmt.Errorf("failed to open consensus database: %w", err) } dbstore, tipState, err := chain.NewDBStore(bdb, network, genesisBlock) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create chain store: %w", err) } cm := chain.NewManager(dbstore, tipState) @@ -127,21 +128,21 @@ func newNode(addr, dir string, chainNetwork string, useUPNP bool) (*node, error) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if d, err := upnp.Discover(ctx); err != nil { - log.Println("WARN: couldn't discover UPnP device:", err) + log.Debug("couldn't discover UPnP router", zap.Error(err)) } else { _, portStr, _ := net.SplitHostPort(addr) port, _ := strconv.Atoi(portStr) if !d.IsForwarded(uint16(port), "TCP") { if err := d.Forward(uint16(port), "TCP", "walletd"); err != nil { - log.Println("WARN: couldn't forward port:", err) + log.Debug("couldn't forward port", zap.Error(err)) } else { - log.Println("p2p: Forwarded port", port) + log.Debug("upnp: forwarded p2p port", zap.Int("port", port)) } } if ip, err := d.ExternalIP(); err != nil { - log.Println("WARN: couldn't determine external IP:", err) + log.Debug("couldn't determine external IP", zap.Error(err)) } else { - log.Println("p2p: External IP is", ip) + log.Debug("external IP is", zap.String("ip", ip)) syncerAddr = net.JoinHostPort(ip, portStr) } } @@ -154,7 +155,7 @@ func newNode(addr, dir string, chainNetwork string, useUPNP bool) (*node, error) ps, err := syncerutil.NewJSONPeerStore(filepath.Join(dir, "peers.json")) if err != nil { - log.Fatal(err) + return nil, fmt.Errorf("failed to open peer store: %w", err) } for _, peer := range bootstrapPeers { ps.AddPeer(peer) @@ -164,10 +165,16 @@ func newNode(addr, dir string, chainNetwork string, useUPNP bool) (*node, error) UniqueID: gateway.GenerateUniqueID(), NetAddress: syncerAddr, } - s := syncer.New(l, cm, ps, header, syncer.WithLogger(zap.NewNop())) - wm, err := walletutil.NewJSONWalletManager(dir, cm) + s := syncer.New(l, cm, ps, header, syncer.WithLogger(log.Named("syncer"))) + + walletDB, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to open wallet database: %w", err) + } + + wm, err := wallet.NewManager(cm, walletDB, log.Named("wallet")) + if err != nil { + return nil, fmt.Errorf("failed to create wallet manager: %w", err) } return &node{ diff --git a/wallet/manager.go b/wallet/manager.go index e4a2380..3e982e5 100644 --- a/wallet/manager.go +++ b/wallet/manager.go @@ -7,8 +7,8 @@ import ( "sync" "time" - "go.sia.tech/core/chain" "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" "go.uber.org/zap" ) From 1bec4038e955c29e7256cf807687d1641cb62f10 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 9 Jan 2024 15:50:49 -0800 Subject: [PATCH 04/24] internal: remove ephemeral and JSON wallet store --- internal/walletutil/manager.go | 403 --------------------------------- internal/walletutil/store.go | 402 -------------------------------- 2 files changed, 805 deletions(-) delete mode 100644 internal/walletutil/manager.go delete mode 100644 internal/walletutil/store.go diff --git a/internal/walletutil/manager.go b/internal/walletutil/manager.go deleted file mode 100644 index be5e962..0000000 --- a/internal/walletutil/manager.go +++ /dev/null @@ -1,403 +0,0 @@ -package walletutil - -import ( - "encoding/json" - "errors" - "os" - "path/filepath" - "sync" - - "go.sia.tech/coreutils/chain" - "go.sia.tech/core/types" - "go.sia.tech/walletd/wallet" -) - -var errNoWallet = errors.New("wallet does not exist") - -type ChainManager interface { - AddSubscriber(s chain.Subscriber, tip types.ChainIndex) error - RemoveSubscriber(s chain.Subscriber) - BestIndex(height uint64) (types.ChainIndex, bool) -} - -type managedEphemeralWallet struct { - w *EphemeralStore - info json.RawMessage - subscribed bool -} - -// An EphemeralWalletManager manages multiple ephemeral wallet stores. -type EphemeralWalletManager struct { - cm ChainManager - mu sync.Mutex - wallets map[string]*managedEphemeralWallet -} - -// AddWallet implements api.WalletManager. -func (wm *EphemeralWalletManager) AddWallet(name string, info json.RawMessage) error { - wm.mu.Lock() - defer wm.mu.Unlock() - if _, ok := wm.wallets[name]; ok { - return errors.New("wallet already exists") - } - store := NewEphemeralStore() - wm.wallets[name] = &managedEphemeralWallet{store, info, false} - return nil -} - -// DeleteWallet implements api.WalletManager. -func (wm *EphemeralWalletManager) DeleteWallet(name string) error { - wm.mu.Lock() - defer wm.mu.Unlock() - delete(wm.wallets, name) - return nil -} - -// Wallets implements api.WalletManager. -func (wm *EphemeralWalletManager) Wallets() map[string]json.RawMessage { - wm.mu.Lock() - defer wm.mu.Unlock() - ws := make(map[string]json.RawMessage, len(wm.wallets)) - for name, w := range wm.wallets { - ws[name] = w.info - } - return ws -} - -// AddAddress implements api.WalletManager. -func (wm *EphemeralWalletManager) AddAddress(name string, addr types.Address, info json.RawMessage) error { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return errNoWallet - } - return mw.w.AddAddress(addr, info) -} - -// RemoveAddress implements api.WalletManager. -func (wm *EphemeralWalletManager) RemoveAddress(name string, addr types.Address) error { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return errNoWallet - } - return mw.w.RemoveAddress(addr) -} - -// Addresses implements api.WalletManager. -func (wm *EphemeralWalletManager) Addresses(name string) (map[types.Address]json.RawMessage, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, errNoWallet - } - return mw.w.Addresses() -} - -// Events implements api.WalletManager. -func (wm *EphemeralWalletManager) Events(name string, offset, limit int) ([]wallet.Event, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, errNoWallet - } - return mw.w.Events(offset, limit) -} - -// Annotate implements api.WalletManager. -func (wm *EphemeralWalletManager) Annotate(name string, txns []types.Transaction) ([]wallet.PoolTransaction, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, errNoWallet - } - return mw.w.Annotate(txns), nil -} - -// UnspentOutputs implements api.WalletManager. -func (wm *EphemeralWalletManager) UnspentOutputs(name string) ([]types.SiacoinElement, []types.SiafundElement, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, nil, errNoWallet - } - return mw.w.UnspentOutputs() -} - -// SubscribeWallet implements api.WalletManager. -func (wm *EphemeralWalletManager) SubscribeWallet(name string, startHeight uint64) error { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return errNoWallet - } else if mw.subscribed { - return errors.New("already subscribed") - } - // AddSubscriber applies each block *after* index, but we want to *include* - // the block at startHeight, so subtract one. - // - // NOTE: if subscribing from height 0, we must pass an empty index in order - // to receive the genesis block. - var index types.ChainIndex - if startHeight > 0 { - if index, ok = wm.cm.BestIndex(startHeight - 1); !ok { - return errors.New("invalid height") - } - } - if err := wm.cm.AddSubscriber(mw.w, index); err != nil { - return err - } - mw.subscribed = true - return nil -} - -// NewEphemeralWalletManager returns a new EphemeralWalletManager. -func NewEphemeralWalletManager(cm ChainManager) *EphemeralWalletManager { - return &EphemeralWalletManager{ - cm: cm, - wallets: make(map[string]*managedEphemeralWallet), - } -} - -type managedJSONWallet struct { - w *JSONStore - info json.RawMessage - subscribed bool -} - -type managerPersistData struct { - Wallets []managerPersistWallet `json:"wallets"` -} - -type managerPersistWallet struct { - Name string `json:"name"` - Info json.RawMessage `json:"info"` - Subscribed bool `json:"subscribed"` -} - -// A JSONWalletManager manages multiple JSON wallet stores. -type JSONWalletManager struct { - dir string - cm ChainManager - mu sync.Mutex - wallets map[string]*managedJSONWallet -} - -func (wm *JSONWalletManager) save() error { - var p managerPersistData - for name, mw := range wm.wallets { - p.Wallets = append(p.Wallets, managerPersistWallet{name, mw.info, mw.subscribed}) - } - js, err := json.MarshalIndent(p, "", " ") - if err != nil { - return err - } - dst := filepath.Join(wm.dir, "wallets.json") - f, err := os.OpenFile(dst+"_tmp", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0660) - if err != nil { - return err - } - defer f.Close() - if _, err = f.Write(js); err != nil { - return err - } else if f.Sync(); err != nil { - return err - } else if f.Close(); err != nil { - return err - } else if err := os.Rename(dst+"_tmp", dst); err != nil { - return err - } - return nil -} - -func (wm *JSONWalletManager) load() error { - dst := filepath.Join(wm.dir, "wallets.json") - f, err := os.Open(dst) - if os.IsNotExist(err) { - return nil - } else if err != nil { - return err - } - defer f.Close() - var p managerPersistData - if err := json.NewDecoder(f).Decode(&p); err != nil { - return err - } - for _, pw := range p.Wallets { - wm.wallets[pw.Name] = &managedJSONWallet{nil, pw.Info, pw.Subscribed} - } - return nil -} - -// AddWallet implements api.WalletManager. -func (wm *JSONWalletManager) AddWallet(name string, info json.RawMessage) error { - wm.mu.Lock() - defer wm.mu.Unlock() - if mw, ok := wm.wallets[name]; ok { - // update existing wallet - mw.info = info - return wm.save() - } else if _, err := os.Stat(filepath.Join(wm.dir, "wallets", name+".json")); err == nil { - // shouldn't happen in normal conditions - return errors.New("a wallet with that name already exists, but is absent from wallets.json") - } - store, _, err := NewJSONStore(filepath.Join(wm.dir, "wallets", name+".json")) - if err != nil { - return err - } - wm.wallets[name] = &managedJSONWallet{store, info, false} - return wm.save() -} - -// DeleteWallet implements api.WalletManager. -func (wm *JSONWalletManager) DeleteWallet(name string) error { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil - } - wm.cm.RemoveSubscriber(mw.w) - delete(wm.wallets, name) - return os.RemoveAll(filepath.Join(wm.dir, "wallets", name+".json")) -} - -// Wallets implements api.WalletManager. -func (wm *JSONWalletManager) Wallets() map[string]json.RawMessage { - wm.mu.Lock() - defer wm.mu.Unlock() - ws := make(map[string]json.RawMessage, len(wm.wallets)) - for name, w := range wm.wallets { - ws[name] = w.info - } - return ws -} - -// AddAddress implements api.WalletManager. -func (wm *JSONWalletManager) AddAddress(name string, addr types.Address, info json.RawMessage) error { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return errNoWallet - } - return mw.w.AddAddress(addr, info) -} - -// RemoveAddress implements api.WalletManager. -func (wm *JSONWalletManager) RemoveAddress(name string, addr types.Address) error { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return errNoWallet - } - return mw.w.RemoveAddress(addr) -} - -// Addresses implements api.WalletManager. -func (wm *JSONWalletManager) Addresses(name string) (map[types.Address]json.RawMessage, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, errNoWallet - } - return mw.w.Addresses() -} - -// Events implements api.WalletManager. -func (wm *JSONWalletManager) Events(name string, offset, limit int) ([]wallet.Event, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, errNoWallet - } - return mw.w.Events(offset, limit) -} - -// Annotate implements api.WalletManager. -func (wm *JSONWalletManager) Annotate(name string, txns []types.Transaction) ([]wallet.PoolTransaction, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, errNoWallet - } - return mw.w.Annotate(txns), nil -} - -// UnspentOutputs implements api.WalletManager. -func (wm *JSONWalletManager) UnspentOutputs(name string) ([]types.SiacoinElement, []types.SiafundElement, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, nil, errNoWallet - } - return mw.w.UnspentOutputs() -} - -// SubscribeWallet implements api.WalletManager. -func (wm *JSONWalletManager) SubscribeWallet(name string, startHeight uint64) error { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return errNoWallet - } else if mw.subscribed { - return errors.New("already subscribed") - } - // AddSubscriber applies each block *after* index, but we want to *include* - // the block at startHeight, so subtract one. - // - // NOTE: if subscribing from height 0, we must pass an empty index in order - // to receive the genesis block. - var index types.ChainIndex - if startHeight > 0 { - if index, ok = wm.cm.BestIndex(startHeight - 1); !ok { - return errors.New("invalid height") - } - } - if err := wm.cm.AddSubscriber(mw.w, index); err != nil { - return err - } - mw.subscribed = true - return wm.save() -} - -// NewJSONWalletManager returns a wallet manager that stores wallets in the -// specified directory. -func NewJSONWalletManager(dir string, cm ChainManager) (*JSONWalletManager, error) { - wm := &JSONWalletManager{ - dir: dir, - cm: cm, - wallets: make(map[string]*managedJSONWallet), - } - if err := os.MkdirAll(filepath.Join(dir, "wallets"), 0700); err != nil { - return nil, err - } else if err := wm.load(); err != nil { - return nil, err - } - for name, mw := range wm.wallets { - store, tip, err := NewJSONStore(filepath.Join(dir, "wallets", name+".json")) - if err != nil { - return nil, err - } - if mw.subscribed { - if err := cm.AddSubscriber(store, tip); err != nil { - return nil, err - } - } - mw.w = store - } - return wm, nil -} diff --git a/internal/walletutil/store.go b/internal/walletutil/store.go deleted file mode 100644 index 8a2ffef..0000000 --- a/internal/walletutil/store.go +++ /dev/null @@ -1,402 +0,0 @@ -package walletutil - -import ( - "encoding/json" - "fmt" - "os" - "sync" - - "go.sia.tech/coreutils/chain" - "go.sia.tech/core/types" - "go.sia.tech/walletd/wallet" -) - -// An EphemeralStore stores wallet state in memory. -type EphemeralStore struct { - tip types.ChainIndex - addrs map[types.Address]json.RawMessage - sces map[types.SiacoinOutputID]types.SiacoinElement - sfes map[types.SiafundOutputID]types.SiafundElement - events []wallet.Event - mu sync.Mutex -} - -func (s *EphemeralStore) ownsAddress(addr types.Address) bool { - _, ok := s.addrs[addr] - return ok -} - -// Events implements api.Wallet. -func (s *EphemeralStore) Events(offset, limit int) (events []wallet.Event, err error) { - s.mu.Lock() - defer s.mu.Unlock() - if limit == -1 { - limit = len(s.events) - } - if offset > len(s.events) { - offset = len(s.events) - } - if offset+limit > len(s.events) { - limit = len(s.events) - offset - } - // reverse - es := make([]wallet.Event, limit) - for i := range es { - es[i] = s.events[len(s.events)-offset-i-1] - } - return es, nil -} - -// Annotate implements api.Wallet. -func (s *EphemeralStore) Annotate(txns []types.Transaction) (ptxns []wallet.PoolTransaction) { - s.mu.Lock() - defer s.mu.Unlock() - for _, txn := range txns { - ptxn := wallet.Annotate(txn, s.ownsAddress) - if ptxn.Type != "unrelated" { - ptxns = append(ptxns, ptxn) - } - } - return -} - -// UnspentOutputs implements api.Wallet. -func (s *EphemeralStore) UnspentOutputs() (sces []types.SiacoinElement, sfes []types.SiafundElement, err error) { - s.mu.Lock() - defer s.mu.Unlock() - for _, sco := range s.sces { - sces = append(sces, sco) - } - for _, sfo := range s.sfes { - sfes = append(sfes, sfo) - } - return -} - -// Addresses implements api.Wallet. -func (s *EphemeralStore) Addresses() (map[types.Address]json.RawMessage, error) { - s.mu.Lock() - defer s.mu.Unlock() - addrs := make(map[types.Address]json.RawMessage, len(s.addrs)) - for addr, info := range s.addrs { - addrs[addr] = info - } - return addrs, nil -} - -// AddAddress implements api.Wallet. -func (s *EphemeralStore) AddAddress(addr types.Address, info json.RawMessage) error { - s.mu.Lock() - defer s.mu.Unlock() - s.addrs[addr] = info - return nil -} - -// RemoveAddress implements api.Wallet. -func (s *EphemeralStore) RemoveAddress(addr types.Address) error { - s.mu.Lock() - defer s.mu.Unlock() - if _, ok := s.addrs[addr]; !ok { - return nil - } - delete(s.addrs, addr) - - // filter outputs - for scoid, sce := range s.sces { - if sce.SiacoinOutput.Address == addr { - delete(s.sces, scoid) - } - } - for sfoid, sfe := range s.sfes { - if sfe.SiafundOutput.Address == addr { - delete(s.sfes, sfoid) - } - } - - // filter events - relevantContract := func(fc types.FileContract) bool { - for _, sco := range fc.ValidProofOutputs { - if s.ownsAddress(sco.Address) { - return true - } - } - for _, sco := range fc.MissedProofOutputs { - if s.ownsAddress(sco.Address) { - return true - } - } - return false - } - relevantV2Contract := func(fc types.V2FileContract) bool { - return s.ownsAddress(fc.RenterOutput.Address) || s.ownsAddress(fc.HostOutput.Address) - } - relevantEvent := func(e wallet.Event) bool { - switch e := e.Val.(type) { - case *wallet.EventTransaction: - for _, sce := range e.SiacoinInputs { - if s.ownsAddress(sce.SiacoinOutput.Address) { - return true - } - } - for _, sce := range e.SiacoinOutputs { - if s.ownsAddress(sce.SiacoinOutput.Address) { - return true - } - } - for _, sfe := range e.SiafundInputs { - if s.ownsAddress(sfe.SiafundElement.SiafundOutput.Address) || - s.ownsAddress(sfe.ClaimElement.SiacoinOutput.Address) { - return true - } - } - for _, sfe := range e.SiafundOutputs { - if s.ownsAddress(sfe.SiafundOutput.Address) { - return true - } - } - for _, fc := range e.FileContracts { - if relevantContract(fc.FileContract.FileContract) || (fc.Revision != nil && relevantContract(*fc.Revision)) { - return true - } - } - for _, fc := range e.V2FileContracts { - if relevantV2Contract(fc.FileContract.V2FileContract) || (fc.Revision != nil && relevantV2Contract(*fc.Revision)) { - return true - } - if fc.Resolution != nil { - switch r := fc.Resolution.(type) { - case *types.V2FileContractFinalization: - if relevantV2Contract(types.V2FileContract(*r)) { - return true - } - case *types.V2FileContractRenewal: - if relevantV2Contract(r.FinalRevision) || relevantV2Contract(r.InitialRevision) { - return true - } - } - } - } - return false - case *wallet.EventMinerPayout: - return s.ownsAddress(e.SiacoinOutput.SiacoinOutput.Address) - case *wallet.EventMissedFileContract: - for _, sce := range e.MissedOutputs { - if s.ownsAddress(sce.SiacoinOutput.Address) { - return true - } - } - return false - default: - panic(fmt.Sprintf("unhandled event type %T", e)) - } - } - - rem := s.events[:0] - for _, e := range s.events { - if relevantEvent(e) { - rem = append(rem, e) - } - } - s.events = rem - return nil -} - -// ProcessChainApplyUpdate implements chain.Subscriber. -func (s *EphemeralStore) ProcessChainApplyUpdate(cau *chain.ApplyUpdate, _ bool) error { - s.mu.Lock() - defer s.mu.Unlock() - - events := wallet.AppliedEvents(cau.State, cau.Block, cau, s.ownsAddress) - s.events = append(s.events, events...) - - // add/remove outputs - cau.ForEachSiacoinElement(func(sce types.SiacoinElement, spent bool) { - if s.ownsAddress(sce.SiacoinOutput.Address) { - if spent { - delete(s.sces, types.SiacoinOutputID(sce.ID)) - } else { - sce.MerkleProof = append([]types.Hash256(nil), sce.MerkleProof...) - s.sces[types.SiacoinOutputID(sce.ID)] = sce - } - } - }) - cau.ForEachSiafundElement(func(sfe types.SiafundElement, spent bool) { - if s.ownsAddress(sfe.SiafundOutput.Address) { - if spent { - delete(s.sfes, types.SiafundOutputID(sfe.ID)) - } else { - sfe.MerkleProof = append([]types.Hash256(nil), sfe.MerkleProof...) - s.sfes[types.SiafundOutputID(sfe.ID)] = sfe - } - } - }) - - // update proofs - for id, sce := range s.sces { - cau.UpdateElementProof(&sce.StateElement) - s.sces[id] = sce - } - for id, sfe := range s.sfes { - cau.UpdateElementProof(&sfe.StateElement) - s.sfes[id] = sfe - } - - s.tip = cau.State.Index - return nil -} - -// ProcessChainRevertUpdate implements chain.Subscriber. -func (s *EphemeralStore) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { - s.mu.Lock() - defer s.mu.Unlock() - - // terribly inefficient, but not a big deal because reverts are infrequent - numEvents := len(wallet.AppliedEvents(cru.State, cru.Block, cru, s.ownsAddress)) - s.events = s.events[:len(s.events)-numEvents] - - cru.ForEachSiacoinElement(func(sce types.SiacoinElement, spent bool) { - if s.ownsAddress(sce.SiacoinOutput.Address) { - if !spent { - delete(s.sces, types.SiacoinOutputID(sce.ID)) - } else { - sce.MerkleProof = append([]types.Hash256(nil), sce.MerkleProof...) - s.sces[types.SiacoinOutputID(sce.ID)] = sce - } - } - }) - cru.ForEachSiafundElement(func(sfe types.SiafundElement, spent bool) { - if s.ownsAddress(sfe.SiafundOutput.Address) { - if !spent { - delete(s.sfes, types.SiafundOutputID(sfe.ID)) - } else { - sfe.MerkleProof = append([]types.Hash256(nil), sfe.MerkleProof...) - s.sfes[types.SiafundOutputID(sfe.ID)] = sfe - } - } - }) - - // update proofs - for id, sce := range s.sces { - cru.UpdateElementProof(&sce.StateElement) - s.sces[id] = sce - } - for id, sfe := range s.sfes { - cru.UpdateElementProof(&sfe.StateElement) - s.sfes[id] = sfe - } - - s.tip = cru.State.Index - return nil -} - -// NewEphemeralStore returns a new EphemeralStore. -func NewEphemeralStore() *EphemeralStore { - return &EphemeralStore{ - addrs: make(map[types.Address]json.RawMessage), - sces: make(map[types.SiacoinOutputID]types.SiacoinElement), - sfes: make(map[types.SiafundOutputID]types.SiafundElement), - } -} - -// A JSONStore stores wallet state in memory, backed by a JSON file. -type JSONStore struct { - *EphemeralStore - path string -} - -type persistData struct { - Tip types.ChainIndex - Addresses map[types.Address]json.RawMessage - SiacoinElements map[types.SiacoinOutputID]types.SiacoinElement - SiafundElements map[types.SiafundOutputID]types.SiafundElement - Events []wallet.Event -} - -func (s *JSONStore) save() error { - js, err := json.MarshalIndent(persistData{ - Tip: s.tip, - Addresses: s.addrs, - SiacoinElements: s.sces, - SiafundElements: s.sfes, - Events: s.events, - }, "", " ") - if err != nil { - return err - } - - f, err := os.OpenFile(s.path+"_tmp", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0660) - if err != nil { - return err - } - defer f.Close() - if _, err = f.Write(js); err != nil { - return err - } else if f.Sync(); err != nil { - return err - } else if f.Close(); err != nil { - return err - } else if err := os.Rename(s.path+"_tmp", s.path); err != nil { - return err - } - return nil -} - -func (s *JSONStore) load() error { - f, err := os.Open(s.path) - if os.IsNotExist(err) { - return nil - } else if err != nil { - return err - } - defer f.Close() - var p persistData - if err := json.NewDecoder(f).Decode(&p); err != nil { - return err - } - s.tip = p.Tip - s.addrs = p.Addresses - s.sces = p.SiacoinElements - s.sfes = p.SiafundElements - s.events = p.Events - return nil -} - -// ProcessChainApplyUpdate implements chain.Subscriber. -func (s *JSONStore) ProcessChainApplyUpdate(cau *chain.ApplyUpdate, mayCommit bool) error { - err := s.EphemeralStore.ProcessChainApplyUpdate(cau, mayCommit) - if err == nil && mayCommit { - err = s.save() - } - return err -} - -// ProcessChainRevertUpdate implements chain.Subscriber. -func (s *JSONStore) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { - return s.EphemeralStore.ProcessChainRevertUpdate(cru) -} - -// AddAddress implements api.Wallet. -func (s *JSONStore) AddAddress(addr types.Address, info json.RawMessage) error { - if err := s.EphemeralStore.AddAddress(addr, info); err != nil { - return err - } - return s.save() -} - -// RemoveAddress implements api.Wallet. -func (s *JSONStore) RemoveAddress(addr types.Address) error { - if err := s.EphemeralStore.RemoveAddress(addr); err != nil { - return err - } - return s.save() -} - -// NewJSONStore returns a new JSONStore. -func NewJSONStore(path string) (*JSONStore, types.ChainIndex, error) { - s := &JSONStore{ - EphemeralStore: NewEphemeralStore(), - path: path, - } - err := s.load() - return s, s.tip, err -} From ffed07b17970672c9efe4beef798621a73c7274e Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 9 Jan 2024 16:00:47 -0800 Subject: [PATCH 05/24] cmd: create data directory if it doesn't exist --- cmd/walletd/main.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cmd/walletd/main.go b/cmd/walletd/main.go index ae4f3ec..477344c 100644 --- a/cmd/walletd/main.go +++ b/cmd/walletd/main.go @@ -159,6 +159,11 @@ func main() { cmd.Usage() return } + + if err := os.MkdirAll(dir, 0700); err != nil { + log.Fatal(err) + } + apiPassword := getAPIPassword() l, err := net.Listen("tcp", apiAddr) if err != nil { From fb4466f8402f0439bc432b65a65af76d3719035b Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 9 Jan 2024 16:04:27 -0800 Subject: [PATCH 06/24] sqlite: fix tip encoding --- persist/sqlite/consensus.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index 03f8f0a..1d33ab7 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -248,7 +248,7 @@ func applySiafundOutputs(tx txn, added map[types.Hash256]types.SiafundElement) e } func updateLastIndexedTip(tx txn, tip types.ChainIndex) error { - _, err := tx.Exec(`UPDATE global_settings SET last_indexed_tip=$1`, encode(tip.ID)) + _, err := tx.Exec(`UPDATE global_settings SET last_indexed_tip=$1`, encode(tip)) return err } From 6c45c1f0549c0c5f82dcbe2defd8e557cac4e86b Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Wed, 10 Jan 2024 11:07:15 -0800 Subject: [PATCH 07/24] cmd: gracefully close stores on shutdown --- cmd/walletd/node.go | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/cmd/walletd/node.go b/cmd/walletd/node.go index bd6119b..a3e3c34 100644 --- a/cmd/walletd/node.go +++ b/cmd/walletd/node.go @@ -84,13 +84,23 @@ var anagamiBootstrap = []string{ } type node struct { - cm *chain.Manager - s *syncer.Syncer - wm *wallet.Manager + chainStore *boltDB + cm *chain.Manager + + s *syncer.Syncer + + walletStore *sqlite.Store + wm *wallet.Manager Start func() (stop func()) } +// Close shuts down the node and closes its database. +func (n *node) Close() error { + n.chainStore.Close() + return n.walletStore.Close() +} + func newNode(addr, dir string, chainNetwork string, useUPNP bool, log *zap.Logger) (*node, error) { var network *consensus.Network var genesisBlock types.Block @@ -178,9 +188,11 @@ func newNode(addr, dir string, chainNetwork string, useUPNP bool, log *zap.Logge } return &node{ - cm: cm, - s: s, - wm: wm, + chainStore: db, + cm: cm, + s: s, + walletStore: walletDB, + wm: wm, Start: func() func() { ch := make(chan struct{}) go func() { From f78017189e72fd7b475f13fbd177f8036a6c96a4 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Wed, 10 Jan 2024 11:17:07 -0800 Subject: [PATCH 08/24] ci: fix lint errors --- .github/actions/test/action.yml | 8 +- .golangci.yml | 161 ++++++++++++++++++++++++++++++++ persist/sqlite/consensus.go | 3 + persist/sqlite/types.go | 1 + persist/sqlite/wallet.go | 12 +++ wallet/manager.go | 2 + wallet/state.go | 83 ---------------- wallet/wallet.go | 16 +++- 8 files changed, 196 insertions(+), 90 deletions(-) create mode 100644 .golangci.yml delete mode 100644 wallet/state.go diff --git a/.github/actions/test/action.yml b/.github/actions/test/action.yml index de33d2c..5754f0e 100644 --- a/.github/actions/test/action.yml +++ b/.github/actions/test/action.yml @@ -7,10 +7,10 @@ runs: - name: Configure git # required for golangci-lint on Windows shell: bash run: git config --global core.autocrlf false -# - name: Lint -# uses: golangci/golangci-lint-action@v3 -# with: -# skip-cache: true + - name: Lint + uses: golangci/golangci-lint-action@v3 + with: + skip-cache: true # - name: Analyze # uses: SiaFoundation/action-golang-analysis@HEAD # with: diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..ca4188f --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,161 @@ +# Based off of the example file at https://github.com/golangci/golangci-lint + +# options for analysis running +run: + # default concurrency is a available CPU number + concurrency: 4 + + # timeout for analysis, e.g. 30s, 5m, default is 1m + timeout: 600s + + # exit code when at least one issue was found, default is 1 + issues-exit-code: 1 + + # include test files or not, default is true + tests: true + + # list of build tags, all linters use it. Default is empty list. + build-tags: [] + + # which dirs to skip: issues from them won't be reported; + # can use regexp here: generated.*, regexp is applied on full path; + # default value is empty list, but default dirs are skipped independently + # from this option's value (see skip-dirs-use-default). + skip-dirs: + - cover + + # default is true. Enables skipping of directories: + # vendor$, third_party$, testdata$, examples$, Godeps$, builtin$ + skip-dirs-use-default: true + + # which files to skip: they will be analyzed, but issues from them + # won't be reported. Default value is empty list, but there is + # no need to include all autogenerated files, we confidently recognize + # autogenerated files. If it's not please let us know. + skip-files: [] + +# output configuration options +output: + # colored-line-number|line-number|json|tab|checkstyle|code-climate, default is "colored-line-number" + format: colored-line-number + + # print lines of code with issue, default is true + print-issued-lines: true + + # print linter name in the end of issue text, default is true + print-linter-name: true + +# all available settings of specific linters +linters-settings: + ## Enabled linters: + govet: + # report about shadowed variables + check-shadowing: false + disable-all: false + + tagliatelle: + case: + rules: + json: goCamel + yaml: goCamel + + + gocritic: + # Which checks should be enabled; can't be combined with 'disabled-checks'; + # See https://go-critic.github.io/overview#checks-overview + # To check which checks are enabled run `GL_DEBUG=gocritic golangci-lint run` + # By default list of stable checks is used. + enabled-checks: + - argOrder # Diagnostic options + - badCond + - caseOrder + - dupArg + - dupBranchBody + - dupCase + - dupSubExpr + - nilValReturn + - offBy1 + - weakCond + - boolExprSimplify # Style options here and below. + - builtinShadow + - emptyFallthrough + - hexLiteral + - underef + - equalFold + revive: + ignore-generated-header: true + rules: + - name: blank-imports + disabled: false + - name: bool-literal-in-expr + disabled: false + - name: confusing-results + disabled: false + - name: constant-logical-expr + disabled: false + - name: context-as-argument + disabled: false + - name: exported + disabled: false + - name: errorf + disabled: false + - name: if-return + disabled: false + - name: indent-error-flow + disabled: false + - name: increment-decrement + disabled: false + - name: modifies-value-receiver + disabled: false + - name: optimize-operands-order + disabled: false + - name: range-val-in-closure + disabled: false + - name: struct-tag + disabled: false + - name: superfluous-else + disabled: false + - name: time-equal + disabled: false + - name: unexported-naming + disabled: false + - name: unexported-return + disabled: false + - name: unnecessary-stmt + disabled: false + - name: unreachable-code + disabled: false + - name: package-comments + disabled: true + +linters: + disable-all: true + fast: false + enable: + - tagliatelle + - gocritic + - gofmt + - revive + - govet + - misspell + - typecheck + - whitespace + +issues: + # Maximum issues count per one linter. Set to 0 to disable. Default is 50. + max-issues-per-linter: 0 + + # Maximum count of issues with the same text. Set to 0 to disable. Default is 3. + max-same-issues: 0 + + # List of regexps of issue texts to exclude, empty list by default. + # But independently from this option we use default exclude patterns, + # it can be disabled by `exclude-use-default: false`. To list all + # excluded by default patterns execute `golangci-lint run --help` + exclude: [] + + # Independently from option `exclude` we use default exclude patterns, + # it can be disabled by this option. To list all + # excluded by default patterns execute `golangci-lint run --help`. + # Default value for this option is true. + exclude-use-default: false \ No newline at end of file diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index 1d33ab7..5edb89e 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -314,6 +314,7 @@ func updateElementProofs(tx txn, table string, updater proofUpdater) error { return nil } +// applyChainUpdates applies the given chain updates to the database. func applyChainUpdates(tx txn, updates []*chain.ApplyUpdate) error { stmt, err := tx.Prepare(`SELECT id FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { @@ -399,6 +400,7 @@ func applyChainUpdates(tx txn, updates []*chain.ApplyUpdate) error { return nil } +// ProcessChainApplyUpdate implements chain.Subscriber func (s *Store) ProcessChainApplyUpdate(cau *chain.ApplyUpdate, mayCommit bool) error { s.updates = append(s.updates, cau) @@ -414,6 +416,7 @@ func (s *Store) ProcessChainApplyUpdate(cau *chain.ApplyUpdate, mayCommit bool) return nil } +// ProcessChainRevertUpdate implements chain.Subscriber func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { // update hasn't been committed yet if len(s.updates) > 0 && s.updates[len(s.updates)-1].Block.ID() == cru.Block.ID() { diff --git a/persist/sqlite/types.go b/persist/sqlite/types.go index 44f1f10..31083c6 100644 --- a/persist/sqlite/types.go +++ b/persist/sqlite/types.go @@ -115,6 +115,7 @@ type decodable[T types.DecoderFrom] struct { n int64 } +// Scan implements the sql.Scanner interface. func (d *decodable[T]) Scan(src any) error { switch src := src.(type) { case []byte: diff --git a/persist/sqlite/wallet.go b/persist/sqlite/wallet.go index 56ba3fd..8306e1c 100644 --- a/persist/sqlite/wallet.go +++ b/persist/sqlite/wallet.go @@ -19,6 +19,7 @@ RETURNING id` return } +// WalletEvents returns the events relevant to a wallet, sorted by height descending. func (s *Store) WalletEvents(walletID string, offset, limit int) (events []wallet.Event, err error) { err = s.transaction(func(tx txn) error { const query = `SELECT ev.id, ev.date_created, ci.height, ci.block_id, ev.event_type, ev.event_data @@ -76,6 +77,7 @@ LIMIT $2 OFFSET $3` return } +// AddWallet adds a wallet to the database. func (s *Store) AddWallet(name string, info json.RawMessage) error { return s.transaction(func(tx txn) error { const query = `INSERT INTO wallets (id, extra_data) VALUES ($1, $2)` @@ -88,6 +90,8 @@ func (s *Store) AddWallet(name string, info json.RawMessage) error { }) } +// DeleteWallet deletes a wallet from the database. This does not stop tracking +// addresses that were previously associated with the wallet. func (s *Store) DeleteWallet(name string) error { return s.transaction(func(tx txn) error { _, err := tx.Exec(`DELETE FROM wallets WHERE id=$1`, name) @@ -95,6 +99,7 @@ func (s *Store) DeleteWallet(name string) error { }) } +// Wallets returns a map of wallet names to wallet extra data. func (s *Store) Wallets() (map[string]json.RawMessage, error) { wallets := make(map[string]json.RawMessage) err := s.transaction(func(tx txn) error { @@ -119,6 +124,7 @@ func (s *Store) Wallets() (map[string]json.RawMessage, error) { return wallets, err } +// AddAddress adds an address to a wallet. func (s *Store) AddAddress(walletID string, address types.Address, info json.RawMessage) error { return s.transaction(func(tx txn) error { addressID, err := insertAddress(tx, address) @@ -130,6 +136,8 @@ func (s *Store) AddAddress(walletID string, address types.Address, info json.Raw }) } +// RemoveAddress removes an address from a wallet. This does not stop tracking +// the address. func (s *Store) RemoveAddress(walletID string, address types.Address) error { return s.transaction(func(tx txn) error { const query = `DELETE FROM wallet_addresses WHERE wallet_id=$1 AND address_id=(SELECT id FROM sia_addresses WHERE sia_address=$2)` @@ -138,6 +146,7 @@ func (s *Store) RemoveAddress(walletID string, address types.Address) error { }) } +// Addresses returns a map of addresses to their extra data for a wallet. func (s *Store) Addresses(walletID string) (map[types.Address]json.RawMessage, error) { addresses := make(map[types.Address]json.RawMessage) err := s.transaction(func(tx txn) error { @@ -165,6 +174,7 @@ WHERE wa.wallet_id=$1` return addresses, err } +// UnspentSiacoinOutputs returns the unspent siacoin outputs for a wallet. func (s *Store) UnspentSiacoinOutputs(walletID string) (siacoins []types.SiacoinElement, err error) { err = s.transaction(func(tx txn) error { const query = `SELECT se.id, se.leaf_index, se.merkle_proof, se.siacoin_value, sa.sia_address, se.maturity_height @@ -193,6 +203,7 @@ func (s *Store) UnspentSiacoinOutputs(walletID string) (siacoins []types.Siacoin return } +// UnspentSiafundOutputs returns the unspent siafund outputs for a wallet. func (s *Store) UnspentSiafundOutputs(walletID string) (siafunds []types.SiafundElement, err error) { err = s.transaction(func(tx txn) error { const query = `SELECT se.id, se.leaf_index, se.merkle_proof, se.siafund_value, se.claim_start, sa.sia_address @@ -257,6 +268,7 @@ func (s *Store) AddressBalance(address types.Address) (sc types.Currency, sf uin return } +// Annotate annotates a list of transactions using the wallet's addresses. func (s *Store) Annotate(walletID string, txns []types.Transaction) (annotated []wallet.PoolTransaction, err error) { err = s.transaction(func(tx txn) error { stmt, err := tx.Prepare(`SELECT sia_address FROM wallet_addresses WHERE wallet_id=$1 AND sia_address=$2 LIMIT 1`) diff --git a/wallet/manager.go b/wallet/manager.go index 3e982e5..f1b0e62 100644 --- a/wallet/manager.go +++ b/wallet/manager.go @@ -13,6 +13,7 @@ import ( ) type ( + // A ChainManager manages the consensus state ChainManager interface { AddSubscriber(chain.Subscriber, types.ChainIndex) error RemoveSubscriber(chain.Subscriber) @@ -20,6 +21,7 @@ type ( BestIndex(height uint64) (types.ChainIndex, bool) } + // A Store is a persistent store of wallet data. Store interface { chain.Subscriber diff --git a/wallet/state.go b/wallet/state.go deleted file mode 100644 index 219a7f2..0000000 --- a/wallet/state.go +++ /dev/null @@ -1,83 +0,0 @@ -package wallet - -import ( - "go.sia.tech/core/types" - "go.sia.tech/coreutils/chain" -) - -// A Midstate is a snapshot of unapplied consensus changes. -type Midstate struct { - SpentSiacoinOutputs map[types.Hash256]bool - SpentSiafundOutputs map[types.Hash256]bool - - NewSiacoinOutputs map[types.Hash256]types.SiacoinElement - NewSiafundOutputs map[types.Hash256]types.SiafundElement - - Events []Event -} - -func (ms *Midstate) Apply(cau *chain.ApplyUpdate, ownsAddress func(types.Address) bool) { - events := AppliedEvents(cau.State, cau.Block, cau, ownsAddress) - ms.Events = append(ms.Events, events...) - - cau.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { - if !ownsAddress(se.SiacoinOutput.Address) { - return - } - - if spent { - ms.SpentSiacoinOutputs[se.ID] = true - delete(ms.NewSiacoinOutputs, se.ID) - } else { - ms.NewSiacoinOutputs[se.ID] = se - } - }) - - cau.ForEachSiafundElement(func(sf types.SiafundElement, spent bool) { - if !ownsAddress(sf.SiafundOutput.Address) { - return - } - - if spent { - ms.SpentSiafundOutputs[sf.ID] = true - delete(ms.NewSiafundOutputs, sf.ID) - } else { - ms.NewSiafundOutputs[sf.ID] = sf - } - }) -} - -func (ms *Midstate) Revert(cru *chain.RevertUpdate, ownsAddress func(types.Address) bool) { - revertedBlockID := cru.Block.ID() - for i := len(ms.Events) - 1; i >= 0; i-- { - // working backwards, revert all events until the block ID no longer - // matches. - if ms.Events[i].Index.ID != revertedBlockID { - break - } - ms.Events = ms.Events[:i] - } - - cru.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { - if !ownsAddress(se.SiacoinOutput.Address) { - return - } - - if !spent { - delete(ms.SpentSiacoinOutputs, se.ID) - } - }) - - cru.ForEachSiafundElement(func(sf types.SiafundElement, spent bool) { - if !ownsAddress(sf.SiafundOutput.Address) { - return - } - - if spent { - ms.SpentSiafundOutputs[sf.ID] = true - delete(ms.NewSiafundOutputs, sf.ID) - } else { - ms.NewSiafundOutputs[sf.ID] = sf - } - }) -} diff --git a/wallet/wallet.go b/wallet/wallet.go index d1a458c..7a806b1 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -9,8 +9,8 @@ import ( "go.sia.tech/core/types" ) +// event type constants const ( - // transactions EventTypeTransaction = "transaction" EventTypeMinerPayout = "miner payout" EventTypeMissedFileContract = "missed file contract" @@ -153,8 +153,13 @@ type Event struct { Val interface{ EventType() string } } -func (*EventTransaction) EventType() string { return EventTypeTransaction } -func (*EventMinerPayout) EventType() string { return EventTypeMinerPayout } +// EventType implements Event. +func (*EventTransaction) EventType() string { return EventTypeTransaction } + +// EventType implements Event. +func (*EventMinerPayout) EventType() string { return EventTypeMinerPayout } + +// EventType implements Event. func (*EventMissedFileContract) EventType() string { return EventTypeMissedFileContract } // MarshalJSON implements json.Marshaler. @@ -235,6 +240,7 @@ type V2FileContract struct { Outputs []types.SiacoinElement `json:"outputs,omitempty"` } +// An EventTransaction represents a transaction that affects the wallet. type EventTransaction struct { ID types.TransactionID `json:"id"` SiacoinInputs []types.SiacoinElement `json:"siacoinInputs"` @@ -247,15 +253,19 @@ type EventTransaction struct { Fee types.Currency `json:"fee"` } +// An EventMinerPayout represents a miner payout from a block. type EventMinerPayout struct { SiacoinOutput types.SiacoinElement `json:"siacoinOutput"` } +// An EventMissedFileContract represents a file contract that has expired +// without a storage proof type EventMissedFileContract struct { FileContract types.FileContractElement `json:"fileContract"` MissedOutputs []types.SiacoinElement `json:"missedOutputs"` } +// A ChainUpdate is a set of changes to the consensus state. type ChainUpdate interface { ForEachSiacoinElement(func(sce types.SiacoinElement, spent bool)) ForEachSiafundElement(func(sfe types.SiafundElement, spent bool)) From e6f17174008f37f08024458b70d23c54b6bf0b49 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Wed, 10 Jan 2024 11:18:43 -0800 Subject: [PATCH 09/24] ci: disable jape analyzer --- .golangci.yml | 2 -- persist/sqlite/consensus.go | 2 -- 2 files changed, 4 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index ca4188f..041664e 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -101,8 +101,6 @@ linters-settings: disabled: false - name: if-return disabled: false - - name: indent-error-flow - disabled: false - name: increment-decrement disabled: false - name: modifies-value-receiver diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index 5edb89e..dc3d82a 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "log" "go.sia.tech/core/types" "go.sia.tech/coreutils/chain" @@ -60,7 +59,6 @@ func applyEvents(tx txn, events []wallet.Event) error { } else if _, err := addRelevantAddrStmt.Exec(eventID, addressID, event.Index.Height); err != nil { return fmt.Errorf("failed to add relevant address: %w", err) } - log.Println("added relevant address", eventID, addr) } } return nil From 7a3f173f4404d295e7eb927483bd9c21d9d8f1be Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Wed, 10 Jan 2024 14:15:38 -0800 Subject: [PATCH 10/24] sqlite: fix siacoin element arg order --- persist/sqlite/consensus.go | 8 ++++---- persist/sqlite/wallet.go | 9 +++------ 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index dc3d82a..c720675 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -144,7 +144,7 @@ func applySiacoinOutputs(tx txn, added map[types.Hash256]types.SiacoinElement) e } // insert the created utxo - _, err = addStmt.Exec(encode(se.ID), addressID, sqlCurrency(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.MaturityHeight, se.LeafIndex) + _, err = addStmt.Exec(encode(se.ID), addressID, sqlCurrency(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight) if err != nil { return fmt.Errorf("failed to insert output %q: %w", se.ID, err) } @@ -261,7 +261,7 @@ func updateElementProofs(tx txn, table string, updater proofUpdater) error { } defer stmt.Close() - updateStmt, err := tx.Prepare(`UPDATE ` + table + ` SET merkle_proof=$1, leaf_index=$2 WHERE id=$3`) + updateStmt, err := tx.Prepare(`UPDATE ` + table + ` SET merkle_proof=$1, leaf_index=$2 WHERE id=$3 RETURNING id`) if err != nil { return fmt.Errorf("failed to prepare update statement: %w", err) } @@ -298,7 +298,8 @@ func updateElementProofs(tx txn, table string, updater proofUpdater) error { } for _, se := range updated { - _, err := updateStmt.Exec(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)) + var dummy types.Hash256 + err := updateStmt.QueryRow(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)).Scan(decode(&dummy, 32)) if err != nil { return fmt.Errorf("failed to update siacoin element %q: %w", se.ID, err) } @@ -308,7 +309,6 @@ func updateElementProofs(tx txn, table string, updater proofUpdater) error { break } } - return nil } diff --git a/persist/sqlite/wallet.go b/persist/sqlite/wallet.go index 8306e1c..e142a0b 100644 --- a/persist/sqlite/wallet.go +++ b/persist/sqlite/wallet.go @@ -190,12 +190,11 @@ func (s *Store) UnspentSiacoinOutputs(walletID string) (siacoins []types.Siacoin for rows.Next() { var siacoin types.SiacoinElement - var proof []byte - - err := rows.Scan(decode(&siacoin.ID, 32), &siacoin.LeafIndex, &proof, (*sqlCurrency)(&siacoin.SiacoinOutput.Value), decode(&siacoin.SiacoinOutput.Address, 32), &siacoin.MaturityHeight) + err := rows.Scan(decode(&siacoin.ID, 32), &siacoin.LeafIndex, decodeSlice[types.Hash256](&siacoin.MerkleProof, 32*1000), (*sqlCurrency)(&siacoin.SiacoinOutput.Value), decode(&siacoin.SiacoinOutput.Address, 32), &siacoin.MaturityHeight) if err != nil { return fmt.Errorf("failed to scan siacoin element: %w", err) } + siacoins = append(siacoins, siacoin) } return nil @@ -219,9 +218,7 @@ func (s *Store) UnspentSiafundOutputs(walletID string) (siafunds []types.Siafund for rows.Next() { var siafund types.SiafundElement - var proof []byte - - err := rows.Scan(decode(&siafund.ID, 32), &siafund.LeafIndex, &proof, &siafund.SiafundOutput.Value, (*sqlCurrency)(&siafund.ClaimStart), decode(&siafund.SiafundOutput.Address, 32)) + err := rows.Scan(decode(&siafund.ID, 32), &siafund.LeafIndex, decodeSlice(&siafund.MerkleProof, 32*1000), &siafund.SiafundOutput.Value, (*sqlCurrency)(&siafund.ClaimStart), decode(&siafund.SiafundOutput.Address, 32)) if err != nil { return fmt.Errorf("failed to scan siacoin element: %w", err) } From 77ec0611ecb829e7be3c1dcc1a3f438db899c7b8 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Wed, 10 Jan 2024 14:36:54 -0800 Subject: [PATCH 11/24] ci: enable cgo, bump go versions for test --- .github/workflows/main.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8e59abc..863f1a5 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -6,6 +6,9 @@ on: branches: - master +env: + CGO_ENABLED: 1 + jobs: test: runs-on: ${{ matrix.os }} @@ -14,7 +17,7 @@ jobs: strategy: matrix: os: [ ubuntu-latest , macos-latest, windows-latest ] - go-version: [ '1.19', '1.20' ] + go-version: [ '1.20', '1.21' ] steps: - name: Configure git run: git config --global core.autocrlf false # required on Windows From 986347158b4e15adf9804350bff1446c634367ca Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Wed, 10 Jan 2024 16:14:14 -0800 Subject: [PATCH 12/24] sqlite: use NewBufDecoder --- persist/sqlite/consensus.go | 10 +++++----- persist/sqlite/types.go | 22 ++++++---------------- persist/sqlite/wallet.go | 8 ++++---- 3 files changed, 15 insertions(+), 25 deletions(-) diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index c720675..8722547 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -100,7 +100,7 @@ func deleteSiacoinOutputs(tx txn, spent []types.SiacoinElement) error { } var dummy types.Hash256 - err = deleteStmt.QueryRow(encode(se.ID)).Scan(decode(&dummy, 32)) + err = deleteStmt.QueryRow(encode(se.ID)).Scan(decode(&dummy)) if err != nil { return fmt.Errorf("failed to delete output %q: %w", se.ID, err) } @@ -191,7 +191,7 @@ func deleteSiafundOutputs(tx txn, spent []types.SiafundElement) error { } var dummy types.Hash256 - err = spendStmt.QueryRow(encode(se.ID)).Scan(decode(&dummy, 32)) + err = spendStmt.QueryRow(encode(se.ID)).Scan(decode(&dummy)) if err != nil { return fmt.Errorf("failed to delete output %q: %w", se.ID, err) } @@ -284,7 +284,7 @@ func updateElementProofs(tx txn, table string, updater proofUpdater) error { more = true var se types.StateElement - err := rows.Scan(decode(&se.ID, 32), decodeSlice(&se.MerkleProof, 32*1000), &se.LeafIndex) + err := rows.Scan(decode(&se.ID), decodeSlice(&se.MerkleProof), &se.LeafIndex) if err != nil { return false, fmt.Errorf("failed to scan state element: %w", err) } @@ -299,7 +299,7 @@ func updateElementProofs(tx txn, table string, updater proofUpdater) error { for _, se := range updated { var dummy types.Hash256 - err := updateStmt.QueryRow(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)).Scan(decode(&dummy, 32)) + err := updateStmt.QueryRow(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)).Scan(decode(&dummy)) if err != nil { return fmt.Errorf("failed to update siacoin element %q: %w", se.ID, err) } @@ -506,6 +506,6 @@ func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { // LastCommittedIndex returns the last chain index that was committed. func (s *Store) LastCommittedIndex() (index types.ChainIndex, err error) { - err = s.db.QueryRow(`SELECT last_indexed_tip FROM global_settings`).Scan(decode(&index, 40)) + err = s.db.QueryRow(`SELECT last_indexed_tip FROM global_settings`).Scan(decode(&index)) return } diff --git a/persist/sqlite/types.go b/persist/sqlite/types.go index 31083c6..f7f0973 100644 --- a/persist/sqlite/types.go +++ b/persist/sqlite/types.go @@ -6,7 +6,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "time" "go.sia.tech/core/types" @@ -78,16 +77,12 @@ func encodeSlice[T types.EncoderTo](v []T) []byte { type decodableSlice[T any] struct { v *[]T - n int64 } func (d *decodableSlice[T]) Scan(src any) error { switch src := src.(type) { case []byte: - dec := types.NewDecoder(io.LimitedReader{ - R: bytes.NewReader(src), - N: d.n, - }) + dec := types.NewBufDecoder(src) s := make([]T, dec.ReadPrefix()) for i := range s { dv, ok := any(&s[i]).(types.DecoderFrom) @@ -106,24 +101,19 @@ func (d *decodableSlice[T]) Scan(src any) error { } } -func decodeSlice[T any](v *[]T, maxLen int64) sql.Scanner { - return &decodableSlice[T]{v: v, n: maxLen} +func decodeSlice[T any](v *[]T) sql.Scanner { + return &decodableSlice[T]{v: v} } type decodable[T types.DecoderFrom] struct { v T - n int64 } // Scan implements the sql.Scanner interface. func (d *decodable[T]) Scan(src any) error { switch src := src.(type) { case []byte: - dec := types.NewDecoder(io.LimitedReader{ - R: bytes.NewReader(src), - N: d.n, - }) - + dec := types.NewBufDecoder(src) d.v.DecodeFrom(dec) return dec.Err() default: @@ -131,6 +121,6 @@ func (d *decodable[T]) Scan(src any) error { } } -func decode[T types.DecoderFrom](v T, maxLen int64) sql.Scanner { - return &decodable[T]{v, maxLen} +func decode[T types.DecoderFrom](v T) sql.Scanner { + return &decodable[T]{v} } diff --git a/persist/sqlite/wallet.go b/persist/sqlite/wallet.go index e142a0b..5d9f37c 100644 --- a/persist/sqlite/wallet.go +++ b/persist/sqlite/wallet.go @@ -41,7 +41,7 @@ LIMIT $2 OFFSET $3` var eventType string var eventBuf []byte - err := rows.Scan(&eventID, (*sqlTime)(&event.Timestamp), &event.Index.Height, decode(&event.Index.ID, 32), &eventType, &eventBuf) + err := rows.Scan(&eventID, (*sqlTime)(&event.Timestamp), &event.Index.Height, decode(&event.Index.ID), &eventType, &eventBuf) if err != nil { return fmt.Errorf("failed to scan event: %w", err) } @@ -164,7 +164,7 @@ WHERE wa.wallet_id=$1` for rows.Next() { var address types.Address var extraData json.RawMessage - if err := rows.Scan(decode(&address, 32), &extraData); err != nil { + if err := rows.Scan(decode(&address), &extraData); err != nil { return fmt.Errorf("failed to scan address: %w", err) } addresses[address] = extraData @@ -190,7 +190,7 @@ func (s *Store) UnspentSiacoinOutputs(walletID string) (siacoins []types.Siacoin for rows.Next() { var siacoin types.SiacoinElement - err := rows.Scan(decode(&siacoin.ID, 32), &siacoin.LeafIndex, decodeSlice[types.Hash256](&siacoin.MerkleProof, 32*1000), (*sqlCurrency)(&siacoin.SiacoinOutput.Value), decode(&siacoin.SiacoinOutput.Address, 32), &siacoin.MaturityHeight) + err := rows.Scan(decode(&siacoin.ID), &siacoin.LeafIndex, decodeSlice[types.Hash256](&siacoin.MerkleProof), (*sqlCurrency)(&siacoin.SiacoinOutput.Value), decode(&siacoin.SiacoinOutput.Address), &siacoin.MaturityHeight) if err != nil { return fmt.Errorf("failed to scan siacoin element: %w", err) } @@ -218,7 +218,7 @@ func (s *Store) UnspentSiafundOutputs(walletID string) (siafunds []types.Siafund for rows.Next() { var siafund types.SiafundElement - err := rows.Scan(decode(&siafund.ID, 32), &siafund.LeafIndex, decodeSlice(&siafund.MerkleProof, 32*1000), &siafund.SiafundOutput.Value, (*sqlCurrency)(&siafund.ClaimStart), decode(&siafund.SiafundOutput.Address, 32)) + err := rows.Scan(decode(&siafund.ID), &siafund.LeafIndex, decodeSlice(&siafund.MerkleProof), &siafund.SiafundOutput.Value, (*sqlCurrency)(&siafund.ClaimStart), decode(&siafund.SiafundOutput.Address)) if err != nil { return fmt.Errorf("failed to scan siacoin element: %w", err) } From 37c4daaf85bd1c0becd4c08bd82e7f8b1428132b Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Wed, 10 Jan 2024 16:14:21 -0800 Subject: [PATCH 13/24] api: fix client docstring --- api/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/client.go b/api/client.go index 3258d9e..8365495 100644 --- a/api/client.go +++ b/api/client.go @@ -106,7 +106,7 @@ func (c *Client) Wallet(name string) *WalletClient { } // Resubscribe subscribes the wallet to consensus updates, starting at the -// specified height. This can only be done once. +// specified height. func (c *Client) Resubscribe(height uint64) (err error) { err = c.c.POST("/resubscribe", height, nil) return From ef6583049592fea6a2dd47db5b7c3f6971feb9bb Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Wed, 10 Jan 2024 16:22:40 -0800 Subject: [PATCH 14/24] api,cmd,sqlite,syncer: remove ephemeral store, add sqlite peer store --- api/api_test.go | 23 ++-- api/server.go | 3 +- cmd/walletd/node.go | 39 +++---- cmd/walletd/testnet.go | 3 +- internal/syncerutil/store.go | 208 ----------------------------------- persist/sqlite/init.sql | 15 +++ persist/sqlite/peers.go | 188 +++++++++++++++++++++++++++++++ 7 files changed, 235 insertions(+), 244 deletions(-) delete mode 100644 internal/syncerutil/store.go create mode 100644 persist/sqlite/peers.go diff --git a/api/api_test.go b/api/api_test.go index 9ffa789..afcb5a5 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -14,7 +14,6 @@ import ( "go.sia.tech/coreutils/syncer" "go.sia.tech/jape" "go.sia.tech/walletd/api" - "go.sia.tech/walletd/internal/syncerutil" "go.sia.tech/walletd/persist/sqlite" "go.sia.tech/walletd/wallet" "go.uber.org/zap/zaptest" @@ -394,7 +393,7 @@ func TestV2(t *testing.T) { } func TestP2P(t *testing.T) { - log := zaptest.NewLogger(t) + logger := zaptest.NewLogger(t) n, genesisBlock := testNetwork() // gift primary wallet some coins primaryPrivateKey := types.GeneratePrivateKey() @@ -409,13 +408,14 @@ func TestP2P(t *testing.T) { if err != nil { t.Fatal(err) } + log1 := logger.Named("one") cm1 := chain.NewManager(dbstore1, tipState) - ws1, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log.Named("sqlite3")) + store1, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log1.Named("sqlite3")) if err != nil { t.Fatal(err) } - defer ws1.Close() - wm1, err := wallet.NewManager(cm1, ws1, log.Named("wallet")) + defer store1.Close() + wm1, err := wallet.NewManager(cm1, store1, log1.Named("wallet")) if err != nil { t.Fatal(err) } @@ -424,7 +424,7 @@ func TestP2P(t *testing.T) { t.Fatal(err) } defer l1.Close() - s1 := syncer.New(l1, cm1, syncerutil.NewEphemeralPeerStore(), gateway.Header{ + s1 := syncer.New(l1, cm1, store1, gateway.Header{ GenesisID: genesisBlock.ID(), UniqueID: gateway.GenerateUniqueID(), NetAddress: l1.Addr().String(), @@ -447,13 +447,14 @@ func TestP2P(t *testing.T) { if err != nil { t.Fatal(err) } + log2 := logger.Named("two") cm2 := chain.NewManager(dbstore2, tipState) - ws2, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log.Named("sqlite3")) + store2, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log2.Named("sqlite3")) if err != nil { t.Fatal(err) } - defer ws2.Close() - wm2, err := wallet.NewManager(cm2, ws2, log.Named("wallet")) + defer store2.Close() + wm2, err := wallet.NewManager(cm2, store2, log2.Named("wallet")) if err != nil { t.Fatal(err) } @@ -462,11 +463,11 @@ func TestP2P(t *testing.T) { t.Fatal(err) } defer l2.Close() - s2 := syncer.New(l2, cm2, syncerutil.NewEphemeralPeerStore(), gateway.Header{ + s2 := syncer.New(l2, cm2, store2, gateway.Header{ GenesisID: genesisBlock.ID(), UniqueID: gateway.GenerateUniqueID(), NetAddress: l2.Addr().String(), - }) + }, syncer.WithLogger(zaptest.NewLogger(t))) go s2.Run() c2, shutdown2 := runServer(cm2, s2, wm2) defer shutdown2() diff --git a/api/server.go b/api/server.go index f7c87d7..5030592 100644 --- a/api/server.go +++ b/api/server.go @@ -92,7 +92,8 @@ func (s *server) syncerPeersHandler(jc jape.Context) { for _, p := range s.s.Peers() { info, ok := s.s.PeerInfo(p.Addr()) if !ok { - continue + jc.Error(errors.New("peer not found"), http.StatusNotFound) + return } peers = append(peers, GatewayPeer{ Addr: p.Addr(), diff --git a/cmd/walletd/node.go b/cmd/walletd/node.go index a3e3c34..b5a6808 100644 --- a/cmd/walletd/node.go +++ b/cmd/walletd/node.go @@ -15,7 +15,6 @@ import ( "go.sia.tech/coreutils" "go.sia.tech/coreutils/chain" "go.sia.tech/coreutils/syncer" - "go.sia.tech/walletd/internal/syncerutil" "go.sia.tech/walletd/persist/sqlite" "go.sia.tech/walletd/wallet" "go.uber.org/zap" @@ -84,13 +83,12 @@ var anagamiBootstrap = []string{ } type node struct { - chainStore *boltDB + chainStore *coreutils.BoltChainDB cm *chain.Manager - s *syncer.Syncer - - walletStore *sqlite.Store - wm *wallet.Manager + store *sqlite.Store + s *syncer.Syncer + wm *wallet.Manager Start func() (stop func()) } @@ -98,7 +96,7 @@ type node struct { // Close shuts down the node and closes its database. func (n *node) Close() error { n.chainStore.Close() - return n.walletStore.Close() + return n.store.Close() } func newNode(addr, dir string, chainNetwork string, useUPNP bool, log *zap.Logger) (*node, error) { @@ -163,36 +161,31 @@ func newNode(addr, dir string, chainNetwork string, useUPNP bool, log *zap.Logge syncerAddr = net.JoinHostPort("127.0.0.1", port) } - ps, err := syncerutil.NewJSONPeerStore(filepath.Join(dir, "peers.json")) + store, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) if err != nil { - return nil, fmt.Errorf("failed to open peer store: %w", err) + return nil, fmt.Errorf("failed to open wallet database: %w", err) } + for _, peer := range bootstrapPeers { - ps.AddPeer(peer) + store.AddPeer(peer) } header := gateway.Header{ GenesisID: genesisBlock.ID(), UniqueID: gateway.GenerateUniqueID(), NetAddress: syncerAddr, } - s := syncer.New(l, cm, ps, header, syncer.WithLogger(log.Named("syncer"))) - - walletDB, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) - if err != nil { - return nil, fmt.Errorf("failed to open wallet database: %w", err) - } - - wm, err := wallet.NewManager(cm, walletDB, log.Named("wallet")) + s := syncer.New(l, cm, store, header, syncer.WithLogger(log.Named("syncer"))) + wm, err := wallet.NewManager(cm, store, log.Named("wallet")) if err != nil { return nil, fmt.Errorf("failed to create wallet manager: %w", err) } return &node{ - chainStore: db, - cm: cm, - s: s, - walletStore: walletDB, - wm: wm, + chainStore: bdb, + cm: cm, + store: store, + s: s, + wm: wm, Start: func() func() { ch := make(chan struct{}) go func() { diff --git a/cmd/walletd/testnet.go b/cmd/walletd/testnet.go index 254d7a0..b87ad21 100644 --- a/cmd/walletd/testnet.go +++ b/cmd/walletd/testnet.go @@ -115,12 +115,13 @@ func initTestnetClient(addr string, network string, seed wallet.Seed) *api.Clien } else if err := wc.AddAddress(ourAddr, nil); err != nil { fmt.Println() log.Fatal(err) - } else if err := wc.Subscribe(0); err != nil { + } else if err := c.Resubscribe(0); err != nil { fmt.Println() log.Fatal(err) } fmt.Println("done.") } + return c } diff --git a/internal/syncerutil/store.go b/internal/syncerutil/store.go deleted file mode 100644 index 6456c6b..0000000 --- a/internal/syncerutil/store.go +++ /dev/null @@ -1,208 +0,0 @@ -package syncerutil - -import ( - "encoding/json" - "net" - "os" - "sync" - "time" - - "go.sia.tech/coreutils/syncer" -) - -type peerBan struct { - Expiry time.Time `json:"expiry"` - Reason string `json:"reason"` -} - -// EphemeralPeerStore implements PeerStore with an in-memory map. -type EphemeralPeerStore struct { - peers map[string]syncer.PeerInfo - bans map[string]peerBan - mu sync.Mutex -} - -func (eps *EphemeralPeerStore) banned(peer string) bool { - host, _, err := net.SplitHostPort(peer) - if err != nil { - return false // shouldn't happen - } - for _, s := range []string{ - peer, // 1.2.3.4:5678 - syncer.Subnet(host, "/32"), // 1.2.3.4:* - syncer.Subnet(host, "/24"), // 1.2.3.* - syncer.Subnet(host, "/16"), // 1.2.* - syncer.Subnet(host, "/8"), // 1.* - } { - if b, ok := eps.bans[s]; ok { - if time.Until(b.Expiry) <= 0 { - delete(eps.bans, s) - } else { - return true - } - } - } - return false -} - -// AddPeer implements PeerStore. -func (eps *EphemeralPeerStore) AddPeer(peer string) { - eps.mu.Lock() - defer eps.mu.Unlock() - if _, ok := eps.peers[peer]; !ok { - eps.peers[peer] = syncer.PeerInfo{FirstSeen: time.Now()} - } -} - -// Peers implements PeerStore. -func (eps *EphemeralPeerStore) Peers() []string { - eps.mu.Lock() - defer eps.mu.Unlock() - var peers []string - for p := range eps.peers { - if !eps.banned(p) { - peers = append(peers, p) - } - } - return peers -} - -// UpdatePeerInfo implements PeerStore. -func (eps *EphemeralPeerStore) UpdatePeerInfo(peer string, fn func(*syncer.PeerInfo)) { - eps.mu.Lock() - defer eps.mu.Unlock() - info, ok := eps.peers[peer] - if !ok { - return - } - fn(&info) - eps.peers[peer] = info -} - -// PeerInfo implements PeerStore. -func (eps *EphemeralPeerStore) PeerInfo(peer string) (syncer.PeerInfo, bool) { - eps.mu.Lock() - defer eps.mu.Unlock() - info, ok := eps.peers[peer] - return info, ok -} - -// Ban implements PeerStore. -func (eps *EphemeralPeerStore) Ban(peer string, duration time.Duration, reason string) { - eps.mu.Lock() - defer eps.mu.Unlock() - // canonicalize - if _, ipnet, err := net.ParseCIDR(peer); err == nil { - peer = ipnet.String() - } - eps.bans[peer] = peerBan{Expiry: time.Now().Add(duration), Reason: reason} -} - -// Banned implements PeerStore. -func (eps *EphemeralPeerStore) Banned(peer string) bool { - eps.mu.Lock() - defer eps.mu.Unlock() - return eps.banned(peer) -} - -// NewEphemeralPeerStore initializes an EphemeralPeerStore. -func NewEphemeralPeerStore() *EphemeralPeerStore { - return &EphemeralPeerStore{ - peers: make(map[string]syncer.PeerInfo), - bans: make(map[string]peerBan), - } -} - -type jsonPersist struct { - Peers map[string]syncer.PeerInfo `json:"peers"` - Bans map[string]peerBan `json:"bans"` -} - -// JSONPeerStore implements PeerStore with a JSON file on disk. -type JSONPeerStore struct { - *EphemeralPeerStore - path string - lastSave time.Time -} - -func (jps *JSONPeerStore) load() error { - f, err := os.Open(jps.path) - if os.IsNotExist(err) { - return nil - } else if err != nil { - return err - } - defer f.Close() - var p jsonPersist - if err := json.NewDecoder(f).Decode(&p); err != nil { - return err - } - jps.EphemeralPeerStore.peers = p.Peers - jps.EphemeralPeerStore.bans = p.Bans - return nil -} - -func (jps *JSONPeerStore) save() error { - jps.EphemeralPeerStore.mu.Lock() - defer jps.EphemeralPeerStore.mu.Unlock() - if time.Since(jps.lastSave) < 5*time.Second { - return nil - } - defer func() { jps.lastSave = time.Now() }() - // clear out expired bans - for peer, b := range jps.EphemeralPeerStore.bans { - if time.Until(b.Expiry) <= 0 { - delete(jps.EphemeralPeerStore.bans, peer) - } - } - p := jsonPersist{ - Peers: jps.EphemeralPeerStore.peers, - Bans: jps.EphemeralPeerStore.bans, - } - js, err := json.MarshalIndent(p, "", " ") - if err != nil { - return err - } - f, err := os.OpenFile(jps.path+"_tmp", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0660) - if err != nil { - return err - } - defer f.Close() - if _, err = f.Write(js); err != nil { - return err - } else if f.Sync(); err != nil { - return err - } else if f.Close(); err != nil { - return err - } else if err := os.Rename(jps.path+"_tmp", jps.path); err != nil { - return err - } - return nil -} - -// AddPeer implements PeerStore. -func (jps *JSONPeerStore) AddPeer(peer string) { - jps.EphemeralPeerStore.AddPeer(peer) - jps.save() -} - -// UpdatePeerInfo implements PeerStore. -func (jps *JSONPeerStore) UpdatePeerInfo(peer string, fn func(*syncer.PeerInfo)) { - jps.EphemeralPeerStore.UpdatePeerInfo(peer, fn) - jps.save() -} - -// Ban implements PeerStore. -func (jps *JSONPeerStore) Ban(peer string, duration time.Duration, reason string) { - jps.EphemeralPeerStore.Ban(peer, duration, reason) - jps.save() -} - -// NewJSONPeerStore returns a JSONPeerStore backed by the specified file. -func NewJSONPeerStore(path string) (*JSONPeerStore, error) { - jps := &JSONPeerStore{ - EphemeralPeerStore: NewEphemeralPeerStore(), - path: path, - } - return jps, jps.load() -} diff --git a/persist/sqlite/init.sql b/persist/sqlite/init.sql index d6a8619..504e3b5 100644 --- a/persist/sqlite/init.sql +++ b/persist/sqlite/init.sql @@ -63,6 +63,21 @@ CREATE INDEX event_addresses_event_id_idx ON event_addresses (event_id); CREATE INDEX event_addresses_address_id_idx ON event_addresses (address_id); CREATE INDEX event_addresses_event_id_address_id_block_height ON event_addresses(event_id, address_id, block_height DESC); +CREATE TABLE syncer_peers ( + peer_address TEXT PRIMARY KEY NOT NULL, + first_seen INTEGER NOT NULL, + last_connect INTEGER NOT NULL, + synced_blocks INTEGER NOT NULL, + sync_duration INTEGER NOT NULL +); + +CREATE TABLE syncer_bans ( + net_cidr TEXT PRIMARY KEY NOT NULL, + expiration INTEGER NOT NULL, + reason TEXT NOT NULL +); +CREATE INDEX syncer_bans_expiration_index ON syncer_bans (expiration); + CREATE TABLE global_settings ( id INTEGER PRIMARY KEY NOT NULL DEFAULT 0 CHECK (id = 0), -- enforce a single row db_version INTEGER NOT NULL, -- used for migrations diff --git a/persist/sqlite/peers.go b/persist/sqlite/peers.go new file mode 100644 index 0000000..206822b --- /dev/null +++ b/persist/sqlite/peers.go @@ -0,0 +1,188 @@ +package sqlite + +import ( + "database/sql" + "errors" + "fmt" + "net" + "strconv" + "strings" + "time" + + "go.sia.tech/coreutils/syncer" + "go.uber.org/zap" +) + +func getPeerInfo(tx txn, peer string) (syncer.PeerInfo, error) { + const query = `SELECT first_seen, last_connect, synced_blocks, sync_duration FROM syncer_peers WHERE peer_address=$1` + var info syncer.PeerInfo + err := tx.QueryRow(query, peer).Scan((*sqlTime)(&info.FirstSeen), (*sqlTime)(&info.LastConnect), &info.SyncedBlocks, &info.SyncDuration) + return info, err +} + +func (s *Store) updatePeerInfo(tx txn, peer string, info syncer.PeerInfo) error { + const query = `UPDATE syncer_peers SET first_seen=$2, last_connect=$3, synced_blocks=$4, sync_duration=$5 WHERE peer_address=$1` + _, err := tx.Exec(query, peer, (*sqlTime)(&info.FirstSeen), (*sqlTime)(&info.LastConnect), info.SyncedBlocks, info.SyncDuration) + return err +} + +// AddPeer adds the given peer to the store. +func (s *Store) AddPeer(peer string) { + err := s.transaction(func(tx txn) error { + const query = `INSERT INTO syncer_peers (peer_address, first_seen, last_connect, synced_blocks, sync_duration) VALUES ($1, $2, 0, 0, 0) ON CONFLICT (peer_address) DO NOTHING` + _, err := tx.Exec(query, peer, sqlTime(time.Now())) + return err + }) + if err != nil { + s.log.Error("failed to add peer", zap.Error(err)) + } +} + +// Peers returns the addresses of all known peers. +func (s *Store) Peers() (peers []string) { + err := s.transaction(func(tx txn) error { + const query = `SELECT peer_address FROM syncer_peers` + rows, err := tx.Query(query) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var peer string + if err := rows.Scan(&peer); err != nil { + return err + } + peers = append(peers, peer) + } + return nil + }) + if err != nil { + panic(err) // 😔 + } + return +} + +// UpdatePeerInfo updates the info for the given peer. +func (s *Store) UpdatePeerInfo(peer string, fn func(*syncer.PeerInfo)) { + err := s.transaction(func(tx txn) error { + info, err := getPeerInfo(tx, peer) + if err != nil { + return err + } + fn(&info) + return s.updatePeerInfo(tx, peer, info) + }) + if err != nil { + panic(err) // 😔 + } +} + +// PeerInfo returns the info for the given peer. +func (s *Store) PeerInfo(peer string) (syncer.PeerInfo, bool) { + var info syncer.PeerInfo + var err error + err = s.transaction(func(tx txn) error { + info, err = getPeerInfo(tx, peer) + return err + }) + if errors.Is(err, sql.ErrNoRows) { + return info, false + } else if err != nil { + panic(err) // 😔 + } + return info, true +} + +// normalizePeer normalizes a peer address to a CIDR subnet. +func normalizePeer(peer string) (string, error) { + host, _, err := net.SplitHostPort(peer) + if err != nil { + host = peer + } + if strings.IndexByte(host, '/') != -1 { + _, subnet, err := net.ParseCIDR(host) + if err != nil { + return "", fmt.Errorf("failed to parse CIDR: %w", err) + } + return subnet.String(), nil + } + + ip := net.ParseIP(host) + if ip == nil { + return "", errors.New("invalid IP address") + } + + var maskLen int + if ip.To4() != nil { + maskLen = 32 + } else { + maskLen = 128 + } + + _, normalized, err := net.ParseCIDR(fmt.Sprintf("%s/%d", ip.String(), maskLen)) + if err != nil { + panic("failed to parse CIDR") + } + return normalized.String(), nil +} + +// Ban temporarily bans one or more IPs. The addr should either be a single +// IP with port (e.g. 1.2.3.4:5678) or a CIDR subnet (e.g. 1.2.3.4/16). +func (s *Store) Ban(peer string, duration time.Duration, reason string) { + address, err := normalizePeer(peer) + if err != nil { + s.log.Error("failed to normalize peer", zap.Error(err)) + return + } + err = s.transaction(func(tx txn) error { + const query = `INSERT INTO syncer_bans (net_cidr, expiration, reason) VALUES ($1, $2, $3) ON CONFLICT (net_cidr) DO UPDATE SET expiration=EXCLUDED.expiration, reason=EXCLUDED.reason` + _, err := tx.Exec(query, address, sqlTime(time.Now().Add(duration)), reason) + return err + }) + if err != nil { + s.log.Error("failed to ban peer", zap.Error(err)) + } +} + +// Banned returns true if the peer is banned. +func (s *Store) Banned(peer string) (banned bool) { + // normalize the peer into a CIDR subnet + peer, err := normalizePeer(peer) + if err != nil { + s.log.Error("failed to normalize peer", zap.Error(err)) + return false + } + + _, subnet, err := net.ParseCIDR(peer) + if err != nil { + s.log.Error("failed to parse CIDR", zap.Error(err)) + return false + } + + // check all subnets from the given subnet to the max subnet length + var maxMaskLen int + if subnet.IP.To4() != nil { + maxMaskLen = 32 + } else { + maxMaskLen = 128 + } + + checkSubnets := make([]string, 0, maxMaskLen) + for i := maxMaskLen; i > 0; i-- { + check := subnet.IP.String() + "/" + strconv.Itoa(i) + checkSubnets = append(checkSubnets, check) + } + + err = s.transaction(func(tx txn) error { + query := `SELECT net_cidr, expiration FROM syncer_bans WHERE net_cidr IN (` + queryPlaceHolders(len(checkSubnets)) + `) ORDER BY expiration DESC LIMIT 1` + + var expiration time.Time + err := tx.QueryRow(query, queryArgs(checkSubnets)...).Scan((*sqlTime)(&expiration)) + banned = time.Now().Before(expiration) // will return false for any sql errors, including ErrNoRows + return err + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + s.log.Error("failed to check ban status", zap.Error(err)) + } + return +} From 0dcc110b9e1f998262474b122d859cd1242cc0aa Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Wed, 10 Jan 2024 16:43:11 -0800 Subject: [PATCH 15/24] sqlite: add peer tests --- persist/sqlite/peers.go | 17 ++++-- persist/sqlite/peers_test.go | 101 +++++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 5 deletions(-) create mode 100644 persist/sqlite/peers_test.go diff --git a/persist/sqlite/peers.go b/persist/sqlite/peers.go index 206822b..7046f73 100644 --- a/persist/sqlite/peers.go +++ b/persist/sqlite/peers.go @@ -21,8 +21,8 @@ func getPeerInfo(tx txn, peer string) (syncer.PeerInfo, error) { } func (s *Store) updatePeerInfo(tx txn, peer string, info syncer.PeerInfo) error { - const query = `UPDATE syncer_peers SET first_seen=$2, last_connect=$3, synced_blocks=$4, sync_duration=$5 WHERE peer_address=$1` - _, err := tx.Exec(query, peer, (*sqlTime)(&info.FirstSeen), (*sqlTime)(&info.LastConnect), info.SyncedBlocks, info.SyncDuration) + const query = `UPDATE syncer_peers SET first_seen=$1, last_connect=$2, synced_blocks=$3, sync_duration=$4 WHERE peer_address=$5 RETURNING peer_address` + err := tx.QueryRow(query, (*sqlTime)(&info.FirstSeen), (*sqlTime)(&info.LastConnect), info.SyncedBlocks, info.SyncDuration, peer).Scan(&peer) return err } @@ -169,16 +169,23 @@ func (s *Store) Banned(peer string) (banned bool) { checkSubnets := make([]string, 0, maxMaskLen) for i := maxMaskLen; i > 0; i-- { - check := subnet.IP.String() + "/" + strconv.Itoa(i) - checkSubnets = append(checkSubnets, check) + _, subnet, err := net.ParseCIDR(subnet.IP.String() + "/" + strconv.Itoa(i)) + if err != nil { + panic("failed to parse CIDR") + } + checkSubnets = append(checkSubnets, subnet.String()) } err = s.transaction(func(tx txn) error { query := `SELECT net_cidr, expiration FROM syncer_bans WHERE net_cidr IN (` + queryPlaceHolders(len(checkSubnets)) + `) ORDER BY expiration DESC LIMIT 1` + var subnet string var expiration time.Time - err := tx.QueryRow(query, queryArgs(checkSubnets)...).Scan((*sqlTime)(&expiration)) + err := tx.QueryRow(query, queryArgs(checkSubnets)...).Scan(&subnet, (*sqlTime)(&expiration)) banned = time.Now().Before(expiration) // will return false for any sql errors, including ErrNoRows + if err == nil && banned { + s.log.Debug("found ban", zap.String("subnet", subnet), zap.Time("expiration", expiration)) + } return err }) if err != nil && !errors.Is(err, sql.ErrNoRows) { diff --git a/persist/sqlite/peers_test.go b/persist/sqlite/peers_test.go new file mode 100644 index 0000000..2de6d2f --- /dev/null +++ b/persist/sqlite/peers_test.go @@ -0,0 +1,101 @@ +package sqlite + +import ( + "net" + "path/filepath" + "testing" + "time" + + "go.sia.tech/walletd/syncer" + "go.uber.org/zap/zaptest" +) + +func TestAddPeer(t *testing.T) { + log := zaptest.NewLogger(t) + db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + const peer = "1.2.3.4:9981" + + if err := db.AddPeer(peer); err != nil { + t.Fatal(err) + } + + lastConnect := time.Now().Truncate(time.Second) // stored as unix milliseconds + syncedBlocks := uint64(15) + syncDuration := 5 * time.Second + + err = db.UpdatePeerInfo(peer, func(info *syncer.PeerInfo) { + info.LastConnect = lastConnect + info.SyncedBlocks = syncedBlocks + info.SyncDuration = syncDuration + }) + if err != nil { + t.Fatal(err) + } + + info, err := db.PeerInfo(peer) + if err != nil { + t.Fatal(err) + } + + if !info.LastConnect.Equal(lastConnect) { + t.Errorf("expected LastConnect = %v; got %v", lastConnect, info.LastConnect) + } + if info.SyncedBlocks != syncedBlocks { + t.Errorf("expected SyncedBlocks = %d; got %d", syncedBlocks, info.SyncedBlocks) + } + if info.SyncDuration != 5*time.Second { + t.Errorf("expected SyncDuration = %s; got %s", syncDuration, info.SyncDuration) + } +} + +func TestBanPeer(t *testing.T) { + log := zaptest.NewLogger(t) + db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + const peer = "1.2.3.4" + + if db.Banned(peer) { + t.Fatal("expected peer to not be banned") + } + + // ban the peer + if err := db.Ban(peer, time.Second, "test"); err != nil { + t.Fatal(err) + } + + if !db.Banned(peer) { + t.Fatal("expected peer to be banned") + } + + // wait for the ban to expire + time.Sleep(time.Second) + + if db.Banned(peer) { + t.Fatal("expected peer to not be banned") + } + + // ban a subnet + _, subnet, err := net.ParseCIDR(peer + "/24") + if err != nil { + t.Fatal(err) + } + + t.Log("banning", subnet) + + if err := db.Ban(subnet.String(), time.Second, "test"); err != nil { + t.Fatal(err) + } + + if !db.Banned(peer) { + t.Fatal("expected peer to be banned") + } +} From ac351b5510d78a87ebb98af776a62871f1e15f13 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Wed, 10 Jan 2024 16:49:18 -0800 Subject: [PATCH 16/24] sqlite: better update proof logic --- persist/sqlite/consensus.go | 77 +++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 38 deletions(-) diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index 8722547..3424e49 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -250,10 +250,38 @@ func updateLastIndexedTip(tx txn, tip types.ChainIndex) error { return err } +func getStateElementBatch(stmt *loggedStmt, offset, limit int) ([]types.StateElement, error) { + rows, err := stmt.Query(limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to query siacoin elements: %w", err) + } + defer rows.Close() + + var updated []types.StateElement + for rows.Next() { + var se types.StateElement + err := rows.Scan(decode(&se.ID), decodeSlice(&se.MerkleProof), &se.LeafIndex) + if err != nil { + return nil, fmt.Errorf("failed to scan state element: %w", err) + } + updated = append(updated, se) + } + return updated, nil +} + +func updateStateElement(stmt *loggedStmt, se types.StateElement) error { + res, err := stmt.Exec(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)) + if err != nil { + return fmt.Errorf("failed to update siacoin element %q: %w", se.ID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("expected 1 row to be affected, got %d", n) + } + return nil +} + // how slow is this going to be 😬? -// -// todo: determine if it's feasible for exchange mode to keep everything in -// memory. func updateElementProofs(tx txn, table string, updater proofUpdater) error { stmt, err := tx.Prepare(`SELECT id, merkle_proof, leaf_index FROM ` + table + ` LIMIT $1 OFFSET $2`) if err != nil { @@ -267,47 +295,20 @@ func updateElementProofs(tx txn, table string, updater proofUpdater) error { } defer updateStmt.Close() - var updated []types.StateElement for offset := 0; ; offset += updateProofBatchSize { - updated = updated[:0] - - more, err := func(n int) (bool, error) { - rows, err := stmt.Query(updateProofBatchSize, n) - if err != nil { - return false, fmt.Errorf("failed to query siacoin elements: %w", err) - } - defer rows.Close() - - var more bool - for rows.Next() { - // if we get here, there may be more rows to process - more = true - - var se types.StateElement - err := rows.Scan(decode(&se.ID), decodeSlice(&se.MerkleProof), &se.LeafIndex) - if err != nil { - return false, fmt.Errorf("failed to scan state element: %w", err) - } - updater.UpdateElementProof(&se) - updated = append(updated, se) - } - return more, nil - }(offset) + elements, err := getStateElementBatch(stmt, offset, updateProofBatchSize) if err != nil { - return err + return fmt.Errorf("failed to get state element batch: %w", err) + } else if len(elements) == 0 { + break } - for _, se := range updated { - var dummy types.Hash256 - err := updateStmt.QueryRow(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)).Scan(decode(&dummy)) - if err != nil { - return fmt.Errorf("failed to update siacoin element %q: %w", se.ID, err) + for _, se := range elements { + updater.UpdateElementProof(&se) + if err := updateStateElement(updateStmt, se); err != nil { + return fmt.Errorf("failed to update state element: %w", err) } } - - if !more { - break - } } return nil } From d470946aa9b4a0687f0410e58802fc849fd4b2d4 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Thu, 11 Jan 2024 07:24:28 -0800 Subject: [PATCH 17/24] sqlite: fix ownsAddress --- persist/sqlite/consensus.go | 2 +- persist/sqlite/wallet.go | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index 3424e49..525cce4 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -425,7 +425,7 @@ func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { // update has been committed, revert it return s.transaction(func(tx txn) error { - stmt, err := tx.Prepare(`SELECT sia_address FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) + stmt, err := tx.Prepare(`SELECT id FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } diff --git a/persist/sqlite/wallet.go b/persist/sqlite/wallet.go index 5d9f37c..cbd3de7 100644 --- a/persist/sqlite/wallet.go +++ b/persist/sqlite/wallet.go @@ -268,7 +268,10 @@ func (s *Store) AddressBalance(address types.Address) (sc types.Currency, sf uin // Annotate annotates a list of transactions using the wallet's addresses. func (s *Store) Annotate(walletID string, txns []types.Transaction) (annotated []wallet.PoolTransaction, err error) { err = s.transaction(func(tx txn) error { - stmt, err := tx.Prepare(`SELECT sia_address FROM wallet_addresses WHERE wallet_id=$1 AND sia_address=$2 LIMIT 1`) + const query = `SELECT sa.id FROM sia_addresses sa +INNER JOIN wallet_addresses wa ON (sa.id = wa.address_id) +WHERE wa.wallet_id=$1 AND sa.sia_address=$2 LIMIT 1` + stmt, err := tx.Prepare(query) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } From d066588ad82a9248416b95c5a7f0f0c0079b95e8 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Fri, 12 Jan 2024 11:10:22 -0800 Subject: [PATCH 18/24] sqlite: better encoding --- persist/sqlite/consensus.go | 14 ++-- persist/sqlite/encoding.go | 116 +++++++++++++++++++++++++++++++++ persist/sqlite/peers.go | 12 ++-- persist/sqlite/types.go | 126 ------------------------------------ persist/sqlite/wallet.go | 12 ++-- 5 files changed, 135 insertions(+), 145 deletions(-) create mode 100644 persist/sqlite/encoding.go delete mode 100644 persist/sqlite/types.go diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index 525cce4..ff9c1b4 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -47,7 +47,7 @@ func applyEvents(tx txn, events []wallet.Event) error { } var eventID int64 - err = stmt.QueryRow(sqlTime(event.Timestamp), id, event.Val.EventType(), buf).Scan(&eventID) + err = stmt.QueryRow(encode(event.Timestamp), id, event.Val.EventType(), buf).Scan(&eventID) if err != nil { return fmt.Errorf("failed to execute statement: %w", err) } @@ -87,14 +87,14 @@ func deleteSiacoinOutputs(tx txn, spent []types.SiacoinElement) error { // query the address database ID and balance var addressID int64 var balance types.Currency - err := addrStmt.QueryRow(encode(se.SiacoinOutput.Address)).Scan(&addressID, (*sqlCurrency)(&balance)) + err := addrStmt.QueryRow(encode(se.SiacoinOutput.Address)).Scan(&addressID, decode(&balance)) if err != nil { return fmt.Errorf("failed to lookup address %q: %w", se.SiacoinOutput.Address, err) } // update the balance balance = balance.Sub(se.SiacoinOutput.Value) - _, err = updateBalanceStmt.Exec((*sqlCurrency)(&balance), addressID) + _, err = updateBalanceStmt.Exec(encode(balance), addressID) if err != nil { return fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) } @@ -131,20 +131,20 @@ func applySiacoinOutputs(tx txn, added map[types.Hash256]types.SiacoinElement) e // query the address database ID and balance var addressID int64 var balance types.Currency - err := addrStmt.QueryRow(encode(se.SiacoinOutput.Address)).Scan(&addressID, (*sqlCurrency)(&balance)) + err := addrStmt.QueryRow(encode(se.SiacoinOutput.Address)).Scan(&addressID, decode(&balance)) if err != nil { return fmt.Errorf("failed to lookup address %q: %w", se.SiacoinOutput.Address, err) } // update the balance balance = balance.Add(se.SiacoinOutput.Value) - _, err = updateBalanceStmt.Exec((*sqlCurrency)(&balance), addressID) + _, err = updateBalanceStmt.Exec(encode(balance), addressID) if err != nil { return fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) } // insert the created utxo - _, err = addStmt.Exec(encode(se.ID), addressID, sqlCurrency(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight) + _, err = addStmt.Exec(encode(se.ID), addressID, encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight) if err != nil { return fmt.Errorf("failed to insert output %q: %w", se.ID, err) } @@ -237,7 +237,7 @@ func applySiafundOutputs(tx txn, added map[types.Hash256]types.SiafundElement) e return fmt.Errorf("failed to update address %q balance: %w", se.SiafundOutput.Address, err) } - _, err = addStmt.Exec(encode(se.ID), addressID, sqlCurrency(se.ClaimStart), se.SiafundOutput.Value, encodeSlice(se.MerkleProof), se.LeafIndex) + _, err = addStmt.Exec(encode(se.ID), addressID, encode(se.ClaimStart), se.SiafundOutput.Value, encodeSlice(se.MerkleProof), se.LeafIndex) if err != nil { return fmt.Errorf("failed to insert output %q: %w", se.ID, err) } diff --git a/persist/sqlite/encoding.go b/persist/sqlite/encoding.go new file mode 100644 index 0000000..f011b23 --- /dev/null +++ b/persist/sqlite/encoding.go @@ -0,0 +1,116 @@ +package sqlite + +import ( + "bytes" + "database/sql" + "encoding/binary" + "errors" + "fmt" + "time" + + "go.sia.tech/core/types" +) + +func encode(obj any) any { + switch obj := obj.(type) { + case types.EncoderTo: + var buf bytes.Buffer + e := types.NewEncoder(&buf) + obj.EncodeTo(e) + e.Flush() + return buf.Bytes() + case uint64: + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, obj) + return b + case time.Time: + return obj.Unix() + default: + panic(fmt.Sprintf("dbEncode: unsupported type %T", obj)) + } +} + +type decodable struct { + v any +} + +// Scan implements the sql.Scanner interface. +func (d *decodable) Scan(src any) error { + if src == nil { + return errors.New("cannot scan nil into decodable") + } + + switch src := src.(type) { + case []byte: + switch v := d.v.(type) { + case types.DecoderFrom: + dec := types.NewBufDecoder(src) + v.DecodeFrom(dec) + return dec.Err() + case *uint64: + *v = binary.LittleEndian.Uint64(src) + default: + return fmt.Errorf("cannot scan %T to %T", src, d.v) + } + return nil + case int64: + switch v := d.v.(type) { + case *uint64: + *v = uint64(src) + case *time.Time: + *v = time.Unix(src, 0).UTC() + default: + return fmt.Errorf("cannot scan %T to %T", src, d.v) + } + return nil + default: + return fmt.Errorf("cannot scan %T to %T", src, d.v) + } +} + +func decode(obj any) sql.Scanner { + return &decodable{obj} +} + +type decodableSlice[T any] struct { + v *[]T +} + +func (d *decodableSlice[T]) Scan(src any) error { + switch src := src.(type) { + case []byte: + dec := types.NewBufDecoder(src) + s := make([]T, dec.ReadPrefix()) + for i := range s { + dv, ok := any(&s[i]).(types.DecoderFrom) + if !ok { + panic(fmt.Errorf("cannot decode %T", s[i])) + } + dv.DecodeFrom(dec) + } + if err := dec.Err(); err != nil { + return err + } + *d.v = s + return nil + default: + return fmt.Errorf("cannot scan %T to []byte", src) + } +} + +func decodeSlice[T any](v *[]T) sql.Scanner { + return &decodableSlice[T]{v: v} +} + +func encodeSlice[T types.EncoderTo](v []T) []byte { + var buf bytes.Buffer + enc := types.NewEncoder(&buf) + enc.WritePrefix(len(v)) + for _, e := range v { + e.EncodeTo(enc) + } + if err := enc.Flush(); err != nil { + panic(err) + } + return buf.Bytes() +} diff --git a/persist/sqlite/peers.go b/persist/sqlite/peers.go index 7046f73..3b59936 100644 --- a/persist/sqlite/peers.go +++ b/persist/sqlite/peers.go @@ -16,13 +16,13 @@ import ( func getPeerInfo(tx txn, peer string) (syncer.PeerInfo, error) { const query = `SELECT first_seen, last_connect, synced_blocks, sync_duration FROM syncer_peers WHERE peer_address=$1` var info syncer.PeerInfo - err := tx.QueryRow(query, peer).Scan((*sqlTime)(&info.FirstSeen), (*sqlTime)(&info.LastConnect), &info.SyncedBlocks, &info.SyncDuration) + err := tx.QueryRow(query, peer).Scan(decode(&info.FirstSeen), decode(&info.LastConnect), &info.SyncedBlocks, &info.SyncDuration) return info, err } func (s *Store) updatePeerInfo(tx txn, peer string, info syncer.PeerInfo) error { const query = `UPDATE syncer_peers SET first_seen=$1, last_connect=$2, synced_blocks=$3, sync_duration=$4 WHERE peer_address=$5 RETURNING peer_address` - err := tx.QueryRow(query, (*sqlTime)(&info.FirstSeen), (*sqlTime)(&info.LastConnect), info.SyncedBlocks, info.SyncDuration, peer).Scan(&peer) + err := tx.QueryRow(query, encode(info.FirstSeen), encode(info.LastConnect), info.SyncedBlocks, info.SyncDuration, peer).Scan(&peer) return err } @@ -30,7 +30,7 @@ func (s *Store) updatePeerInfo(tx txn, peer string, info syncer.PeerInfo) error func (s *Store) AddPeer(peer string) { err := s.transaction(func(tx txn) error { const query = `INSERT INTO syncer_peers (peer_address, first_seen, last_connect, synced_blocks, sync_duration) VALUES ($1, $2, 0, 0, 0) ON CONFLICT (peer_address) DO NOTHING` - _, err := tx.Exec(query, peer, sqlTime(time.Now())) + _, err := tx.Exec(query, peer, encode(time.Now())) return err }) if err != nil { @@ -67,7 +67,7 @@ func (s *Store) UpdatePeerInfo(peer string, fn func(*syncer.PeerInfo)) { err := s.transaction(func(tx txn) error { info, err := getPeerInfo(tx, peer) if err != nil { - return err + return fmt.Errorf("failed to get peer info: %w", err) } fn(&info) return s.updatePeerInfo(tx, peer, info) @@ -136,7 +136,7 @@ func (s *Store) Ban(peer string, duration time.Duration, reason string) { } err = s.transaction(func(tx txn) error { const query = `INSERT INTO syncer_bans (net_cidr, expiration, reason) VALUES ($1, $2, $3) ON CONFLICT (net_cidr) DO UPDATE SET expiration=EXCLUDED.expiration, reason=EXCLUDED.reason` - _, err := tx.Exec(query, address, sqlTime(time.Now().Add(duration)), reason) + _, err := tx.Exec(query, address, encode(time.Now().Add(duration)), reason) return err }) if err != nil { @@ -181,7 +181,7 @@ func (s *Store) Banned(peer string) (banned bool) { var subnet string var expiration time.Time - err := tx.QueryRow(query, queryArgs(checkSubnets)...).Scan(&subnet, (*sqlTime)(&expiration)) + err := tx.QueryRow(query, queryArgs(checkSubnets)...).Scan(&subnet, decode(&expiration)) banned = time.Now().Before(expiration) // will return false for any sql errors, including ErrNoRows if err == nil && banned { s.log.Debug("found ban", zap.String("subnet", subnet), zap.Time("expiration", expiration)) diff --git a/persist/sqlite/types.go b/persist/sqlite/types.go deleted file mode 100644 index f7f0973..0000000 --- a/persist/sqlite/types.go +++ /dev/null @@ -1,126 +0,0 @@ -package sqlite - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "encoding/binary" - "fmt" - "time" - - "go.sia.tech/core/types" -) - -type ( - sqlCurrency types.Currency - sqlTime time.Time -) - -// Scan implements the sql.Scanner interface. -func (sc *sqlCurrency) Scan(src any) error { - buf, ok := src.([]byte) - if !ok { - return fmt.Errorf("cannot scan %T to Currency", src) - } else if len(buf) != 16 { - return fmt.Errorf("cannot scan %d bytes to Currency", len(buf)) - } - - sc.Lo = binary.LittleEndian.Uint64(buf[:8]) - sc.Hi = binary.LittleEndian.Uint64(buf[8:]) - return nil -} - -// Value implements the driver.Valuer interface. -func (sc sqlCurrency) Value() (driver.Value, error) { - buf := make([]byte, 16) - binary.LittleEndian.PutUint64(buf[:8], sc.Lo) - binary.LittleEndian.PutUint64(buf[8:], sc.Hi) - return buf, nil -} - -func (st *sqlTime) Scan(src any) error { - switch src := src.(type) { - case int64: - *st = sqlTime(time.Unix(src, 0)) - return nil - default: - return fmt.Errorf("cannot scan %T to Time", src) - } -} - -func (st sqlTime) Value() (driver.Value, error) { - return time.Time(st).Unix(), nil -} - -func encode[T types.EncoderTo](v T) []byte { - var buf bytes.Buffer - enc := types.NewEncoder(&buf) - v.EncodeTo(enc) - if err := enc.Flush(); err != nil { - panic(err) - } - return buf.Bytes() -} - -func encodeSlice[T types.EncoderTo](v []T) []byte { - var buf bytes.Buffer - enc := types.NewEncoder(&buf) - enc.WritePrefix(len(v)) - for _, e := range v { - e.EncodeTo(enc) - } - if err := enc.Flush(); err != nil { - panic(err) - } - return buf.Bytes() -} - -type decodableSlice[T any] struct { - v *[]T -} - -func (d *decodableSlice[T]) Scan(src any) error { - switch src := src.(type) { - case []byte: - dec := types.NewBufDecoder(src) - s := make([]T, dec.ReadPrefix()) - for i := range s { - dv, ok := any(&s[i]).(types.DecoderFrom) - if !ok { - panic(fmt.Errorf("cannot decode %T", s[i])) - } - dv.DecodeFrom(dec) - } - if err := dec.Err(); err != nil { - return err - } - *d.v = s - return nil - default: - return fmt.Errorf("cannot scan %T to []byte", src) - } -} - -func decodeSlice[T any](v *[]T) sql.Scanner { - return &decodableSlice[T]{v: v} -} - -type decodable[T types.DecoderFrom] struct { - v T -} - -// Scan implements the sql.Scanner interface. -func (d *decodable[T]) Scan(src any) error { - switch src := src.(type) { - case []byte: - dec := types.NewBufDecoder(src) - d.v.DecodeFrom(dec) - return dec.Err() - default: - return fmt.Errorf("cannot scan %T to []byte", src) - } -} - -func decode[T types.DecoderFrom](v T) sql.Scanner { - return &decodable[T]{v} -} diff --git a/persist/sqlite/wallet.go b/persist/sqlite/wallet.go index cbd3de7..0fc0a05 100644 --- a/persist/sqlite/wallet.go +++ b/persist/sqlite/wallet.go @@ -15,7 +15,7 @@ func insertAddress(tx txn, addr types.Address) (id int64, err error) { VALUES ($1, $2, 0) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id` - err = tx.QueryRow(query, encode(addr), (*sqlCurrency)(&types.ZeroCurrency)).Scan(&id) + err = tx.QueryRow(query, encode(addr), encode(types.ZeroCurrency)).Scan(&id) return } @@ -41,7 +41,7 @@ LIMIT $2 OFFSET $3` var eventType string var eventBuf []byte - err := rows.Scan(&eventID, (*sqlTime)(&event.Timestamp), &event.Index.Height, decode(&event.Index.ID), &eventType, &eventBuf) + err := rows.Scan(&eventID, decode(&event.Timestamp), &event.Index.Height, decode(&event.Index.ID), &eventType, &eventBuf) if err != nil { return fmt.Errorf("failed to scan event: %w", err) } @@ -190,7 +190,7 @@ func (s *Store) UnspentSiacoinOutputs(walletID string) (siacoins []types.Siacoin for rows.Next() { var siacoin types.SiacoinElement - err := rows.Scan(decode(&siacoin.ID), &siacoin.LeafIndex, decodeSlice[types.Hash256](&siacoin.MerkleProof), (*sqlCurrency)(&siacoin.SiacoinOutput.Value), decode(&siacoin.SiacoinOutput.Address), &siacoin.MaturityHeight) + err := rows.Scan(decode(&siacoin.ID), &siacoin.LeafIndex, decodeSlice[types.Hash256](&siacoin.MerkleProof), decode(&siacoin.SiacoinOutput.Value), decode(&siacoin.SiacoinOutput.Address), &siacoin.MaturityHeight) if err != nil { return fmt.Errorf("failed to scan siacoin element: %w", err) } @@ -218,7 +218,7 @@ func (s *Store) UnspentSiafundOutputs(walletID string) (siafunds []types.Siafund for rows.Next() { var siafund types.SiafundElement - err := rows.Scan(decode(&siafund.ID), &siafund.LeafIndex, decodeSlice(&siafund.MerkleProof), &siafund.SiafundOutput.Value, (*sqlCurrency)(&siafund.ClaimStart), decode(&siafund.SiafundOutput.Address)) + err := rows.Scan(decode(&siafund.ID), &siafund.LeafIndex, decodeSlice(&siafund.MerkleProof), &siafund.SiafundOutput.Value, decode(&siafund.ClaimStart), decode(&siafund.SiafundOutput.Address)) if err != nil { return fmt.Errorf("failed to scan siacoin element: %w", err) } @@ -245,7 +245,7 @@ func (s *Store) WalletBalance(walletID string) (sc types.Currency, sf uint64, er var siacoin types.Currency var siafund uint64 - if err := rows.Scan((*sqlCurrency)(&siacoin), &siafund); err != nil { + if err := rows.Scan(decode(&siacoin), &siafund); err != nil { return fmt.Errorf("failed to scan address balance: %w", err) } sc = sc.Add(siacoin) @@ -260,7 +260,7 @@ func (s *Store) WalletBalance(walletID string) (sc types.Currency, sf uint64, er func (s *Store) AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) { err = s.transaction(func(tx txn) error { const query = `SELECT siacoin_balance, siafund_balance FROM address_balance WHERE sia_address=$1` - return tx.QueryRow(query, encode(address)).Scan((*sqlCurrency)(&sc), &sf) + return tx.QueryRow(query, encode(address)).Scan(decode(&sc), &sf) }) return } From db194563a3871e9a4b83f4c2f3597be31c589ef1 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 23 Jan 2024 15:11:41 -0800 Subject: [PATCH 19/24] sqlite: fix tests --- persist/sqlite/peers_test.go | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/persist/sqlite/peers_test.go b/persist/sqlite/peers_test.go index 2de6d2f..4f3e26e 100644 --- a/persist/sqlite/peers_test.go +++ b/persist/sqlite/peers_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "go.sia.tech/walletd/syncer" + "go.sia.tech/coreutils/syncer" "go.uber.org/zap/zaptest" ) @@ -20,15 +20,13 @@ func TestAddPeer(t *testing.T) { const peer = "1.2.3.4:9981" - if err := db.AddPeer(peer); err != nil { - t.Fatal(err) - } + db.AddPeer(peer) lastConnect := time.Now().Truncate(time.Second) // stored as unix milliseconds syncedBlocks := uint64(15) syncDuration := 5 * time.Second - err = db.UpdatePeerInfo(peer, func(info *syncer.PeerInfo) { + db.UpdatePeerInfo(peer, func(info *syncer.PeerInfo) { info.LastConnect = lastConnect info.SyncedBlocks = syncedBlocks info.SyncDuration = syncDuration @@ -37,9 +35,9 @@ func TestAddPeer(t *testing.T) { t.Fatal(err) } - info, err := db.PeerInfo(peer) - if err != nil { - t.Fatal(err) + info, ok := db.PeerInfo(peer) + if !ok { + t.Fatal("expected peer to be in database") } if !info.LastConnect.Equal(lastConnect) { @@ -68,9 +66,7 @@ func TestBanPeer(t *testing.T) { } // ban the peer - if err := db.Ban(peer, time.Second, "test"); err != nil { - t.Fatal(err) - } + db.Ban(peer, time.Second, "test") if !db.Banned(peer) { t.Fatal("expected peer to be banned") @@ -90,11 +86,7 @@ func TestBanPeer(t *testing.T) { } t.Log("banning", subnet) - - if err := db.Ban(subnet.String(), time.Second, "test"); err != nil { - t.Fatal(err) - } - + db.Ban(subnet.String(), time.Second, "test") if !db.Banned(peer) { t.Fatal("expected peer to be banned") } From 21dd5141508bd024b37460b832ab7e74ab89999a Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 23 Jan 2024 18:04:54 -0800 Subject: [PATCH 20/24] sqlite: remove txn interface --- persist/sqlite/consensus.go | 30 ++++----- persist/sqlite/init.go | 6 +- persist/sqlite/migrations.go | 2 +- persist/sqlite/peers.go | 16 ++--- persist/sqlite/sql.go | 123 +++++++++++++++-------------------- persist/sqlite/store.go | 17 +++-- persist/sqlite/wallet.go | 26 ++++---- 7 files changed, 104 insertions(+), 116 deletions(-) diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index ff9c1b4..9079fc8 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -17,12 +17,12 @@ type proofUpdater interface { UpdateElementProof(*types.StateElement) } -func insertChainIndex(tx txn, index types.ChainIndex) (id int64, err error) { +func insertChainIndex(tx *txn, index types.ChainIndex) (id int64, err error) { err = tx.QueryRow(`INSERT INTO chain_indices (height, block_id) VALUES ($1, $2) ON CONFLICT (block_id) DO UPDATE SET height=EXCLUDED.height RETURNING id`, index.Height, encode(index.ID)).Scan(&id) return } -func applyEvents(tx txn, events []wallet.Event) error { +func applyEvents(tx *txn, events []wallet.Event) error { stmt, err := tx.Prepare(`INSERT INTO events (date_created, index_id, event_type, event_data) VALUES ($1, $2, $3, $4) RETURNING id`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) @@ -64,7 +64,7 @@ func applyEvents(tx txn, events []wallet.Event) error { return nil } -func deleteSiacoinOutputs(tx txn, spent []types.SiacoinElement) error { +func deleteSiacoinOutputs(tx *txn, spent []types.SiacoinElement) error { addrStmt, err := tx.Prepare(`SELECT id, siacoin_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { return fmt.Errorf("failed to prepare lookup statement: %w", err) @@ -108,7 +108,7 @@ func deleteSiacoinOutputs(tx txn, spent []types.SiacoinElement) error { return nil } -func applySiacoinOutputs(tx txn, added map[types.Hash256]types.SiacoinElement) error { +func applySiacoinOutputs(tx *txn, added map[types.Hash256]types.SiacoinElement) error { addrStmt, err := tx.Prepare(`SELECT id, siacoin_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { return fmt.Errorf("failed to prepare lookup statement: %w", err) @@ -152,7 +152,7 @@ func applySiacoinOutputs(tx txn, added map[types.Hash256]types.SiacoinElement) e return nil } -func deleteSiafundOutputs(tx txn, spent []types.SiafundElement) error { +func deleteSiafundOutputs(tx *txn, spent []types.SiafundElement) error { addrStmt, err := tx.Prepare(`SELECT id, siafund_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { return fmt.Errorf("failed to prepare lookup statement: %w", err) @@ -199,7 +199,7 @@ func deleteSiafundOutputs(tx txn, spent []types.SiafundElement) error { return nil } -func applySiafundOutputs(tx txn, added map[types.Hash256]types.SiafundElement) error { +func applySiafundOutputs(tx *txn, added map[types.Hash256]types.SiafundElement) error { addrStmt, err := tx.Prepare(`SELECT id, siafund_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { return fmt.Errorf("failed to prepare lookup statement: %w", err) @@ -245,13 +245,13 @@ func applySiafundOutputs(tx txn, added map[types.Hash256]types.SiafundElement) e return nil } -func updateLastIndexedTip(tx txn, tip types.ChainIndex) error { +func updateLastIndexedTip(tx *txn, tip types.ChainIndex) error { _, err := tx.Exec(`UPDATE global_settings SET last_indexed_tip=$1`, encode(tip)) return err } -func getStateElementBatch(stmt *loggedStmt, offset, limit int) ([]types.StateElement, error) { - rows, err := stmt.Query(limit, offset) +func getStateElementBatch(s *stmt, offset, limit int) ([]types.StateElement, error) { + rows, err := s.Query(limit, offset) if err != nil { return nil, fmt.Errorf("failed to query siacoin elements: %w", err) } @@ -269,8 +269,8 @@ func getStateElementBatch(stmt *loggedStmt, offset, limit int) ([]types.StateEle return updated, nil } -func updateStateElement(stmt *loggedStmt, se types.StateElement) error { - res, err := stmt.Exec(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)) +func updateStateElement(s *stmt, se types.StateElement) error { + res, err := s.Exec(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)) if err != nil { return fmt.Errorf("failed to update siacoin element %q: %w", se.ID, err) } else if n, err := res.RowsAffected(); err != nil { @@ -282,7 +282,7 @@ func updateStateElement(stmt *loggedStmt, se types.StateElement) error { } // how slow is this going to be 😬? -func updateElementProofs(tx txn, table string, updater proofUpdater) error { +func updateElementProofs(tx *txn, table string, updater proofUpdater) error { stmt, err := tx.Prepare(`SELECT id, merkle_proof, leaf_index FROM ` + table + ` LIMIT $1 OFFSET $2`) if err != nil { return fmt.Errorf("failed to prepare batch statement: %w", err) @@ -314,7 +314,7 @@ func updateElementProofs(tx txn, table string, updater proofUpdater) error { } // applyChainUpdates applies the given chain updates to the database. -func applyChainUpdates(tx txn, updates []*chain.ApplyUpdate) error { +func applyChainUpdates(tx *txn, updates []*chain.ApplyUpdate) error { stmt, err := tx.Prepare(`SELECT id FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) @@ -404,7 +404,7 @@ func (s *Store) ProcessChainApplyUpdate(cau *chain.ApplyUpdate, mayCommit bool) s.updates = append(s.updates, cau) if mayCommit { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { if err := applyChainUpdates(tx, s.updates); err != nil { return err } @@ -424,7 +424,7 @@ func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { } // update has been committed, revert it - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { stmt, err := tx.Prepare(`SELECT id FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) diff --git a/persist/sqlite/init.go b/persist/sqlite/init.go index bc6c1a4..24ae378 100644 --- a/persist/sqlite/init.go +++ b/persist/sqlite/init.go @@ -17,13 +17,13 @@ import ( //go:embed init.sql var initDatabase string -func initializeSettings(tx txn, target int64) error { +func initializeSettings(tx *txn, target int64) error { _, err := tx.Exec(`INSERT INTO global_settings (id, db_version, last_indexed_tip) VALUES (0, ?, ?)`, target, encode(types.ChainIndex{})) return err } func (s *Store) initNewDatabase(target int64) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { if _, err := tx.Exec(initDatabase); err != nil { return fmt.Errorf("failed to initialize database: %w", err) } else if err := initializeSettings(tx, target); err != nil { @@ -48,7 +48,7 @@ func (s *Store) upgradeDatabase(current, target int64) error { } }() - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { for _, fn := range migrations[current-1:] { current++ start := time.Now() diff --git a/persist/sqlite/migrations.go b/persist/sqlite/migrations.go index 7aa3692..99d01e1 100644 --- a/persist/sqlite/migrations.go +++ b/persist/sqlite/migrations.go @@ -7,4 +7,4 @@ import ( // migrations is a list of functions that are run to migrate the database from // one version to the next. Migrations are used to update existing databases to // match the schema in init.sql. -var migrations = []func(tx txn, log *zap.Logger) error{} +var migrations = []func(tx *txn, log *zap.Logger) error{} diff --git a/persist/sqlite/peers.go b/persist/sqlite/peers.go index 3b59936..4d8de8d 100644 --- a/persist/sqlite/peers.go +++ b/persist/sqlite/peers.go @@ -13,14 +13,14 @@ import ( "go.uber.org/zap" ) -func getPeerInfo(tx txn, peer string) (syncer.PeerInfo, error) { +func getPeerInfo(tx *txn, peer string) (syncer.PeerInfo, error) { const query = `SELECT first_seen, last_connect, synced_blocks, sync_duration FROM syncer_peers WHERE peer_address=$1` var info syncer.PeerInfo err := tx.QueryRow(query, peer).Scan(decode(&info.FirstSeen), decode(&info.LastConnect), &info.SyncedBlocks, &info.SyncDuration) return info, err } -func (s *Store) updatePeerInfo(tx txn, peer string, info syncer.PeerInfo) error { +func (s *Store) updatePeerInfo(tx *txn, peer string, info syncer.PeerInfo) error { const query = `UPDATE syncer_peers SET first_seen=$1, last_connect=$2, synced_blocks=$3, sync_duration=$4 WHERE peer_address=$5 RETURNING peer_address` err := tx.QueryRow(query, encode(info.FirstSeen), encode(info.LastConnect), info.SyncedBlocks, info.SyncDuration, peer).Scan(&peer) return err @@ -28,7 +28,7 @@ func (s *Store) updatePeerInfo(tx txn, peer string, info syncer.PeerInfo) error // AddPeer adds the given peer to the store. func (s *Store) AddPeer(peer string) { - err := s.transaction(func(tx txn) error { + err := s.transaction(func(tx *txn) error { const query = `INSERT INTO syncer_peers (peer_address, first_seen, last_connect, synced_blocks, sync_duration) VALUES ($1, $2, 0, 0, 0) ON CONFLICT (peer_address) DO NOTHING` _, err := tx.Exec(query, peer, encode(time.Now())) return err @@ -40,7 +40,7 @@ func (s *Store) AddPeer(peer string) { // Peers returns the addresses of all known peers. func (s *Store) Peers() (peers []string) { - err := s.transaction(func(tx txn) error { + err := s.transaction(func(tx *txn) error { const query = `SELECT peer_address FROM syncer_peers` rows, err := tx.Query(query) if err != nil { @@ -64,7 +64,7 @@ func (s *Store) Peers() (peers []string) { // UpdatePeerInfo updates the info for the given peer. func (s *Store) UpdatePeerInfo(peer string, fn func(*syncer.PeerInfo)) { - err := s.transaction(func(tx txn) error { + err := s.transaction(func(tx *txn) error { info, err := getPeerInfo(tx, peer) if err != nil { return fmt.Errorf("failed to get peer info: %w", err) @@ -81,7 +81,7 @@ func (s *Store) UpdatePeerInfo(peer string, fn func(*syncer.PeerInfo)) { func (s *Store) PeerInfo(peer string) (syncer.PeerInfo, bool) { var info syncer.PeerInfo var err error - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { info, err = getPeerInfo(tx, peer) return err }) @@ -134,7 +134,7 @@ func (s *Store) Ban(peer string, duration time.Duration, reason string) { s.log.Error("failed to normalize peer", zap.Error(err)) return } - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `INSERT INTO syncer_bans (net_cidr, expiration, reason) VALUES ($1, $2, $3) ON CONFLICT (net_cidr) DO UPDATE SET expiration=EXCLUDED.expiration, reason=EXCLUDED.reason` _, err := tx.Exec(query, address, encode(time.Now().Add(duration)), reason) return err @@ -176,7 +176,7 @@ func (s *Store) Banned(peer string) (banned bool) { checkSubnets = append(checkSubnets, subnet.String()) } - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { query := `SELECT net_cidr, expiration FROM syncer_bans WHERE net_cidr IN (` + queryPlaceHolders(len(checkSubnets)) + `) ORDER BY expiration DESC LIMIT 1` var subnet string diff --git a/persist/sqlite/sql.go b/persist/sqlite/sql.go index 6ea715f..fb253d1 100644 --- a/persist/sqlite/sql.go +++ b/persist/sqlite/sql.go @@ -22,122 +22,107 @@ type ( Scan(dest ...any) error } - // A txn is an interface for executing queries within a transaction. - txn interface { - // Exec executes a query without returning any rows. The args are for - // any placeholder parameters in the query. - Exec(query string, args ...any) (sql.Result, error) - // Prepare creates a prepared statement for later queries or executions. - // Multiple queries or executions may be run concurrently from the - // returned statement. The caller must call the statement's Close method - // when the statement is no longer needed. - Prepare(query string) (*loggedStmt, error) - // Query executes a query that returns rows, typically a SELECT. The - // args are for any placeholder parameters in the query. - Query(query string, args ...any) (*loggedRows, error) - // QueryRow executes a query that is expected to return at most one row. - // QueryRow always returns a non-nil value. Errors are deferred until - // Row's Scan method is called. If the query selects no rows, the *Row's - // Scan will return ErrNoRows. Otherwise, the *Row's Scan scans the - // first selected row and discards the rest. - QueryRow(query string, args ...any) *loggedRow - } - - loggedStmt struct { + // A stmt wraps a *sql.Stmt, logging slow queries. + stmt struct { *sql.Stmt query string - log *zap.Logger + + log *zap.Logger } - loggedTxn struct { + // A txn wraps a *sql.Tx, logging slow queries. + txn struct { *sql.Tx log *zap.Logger } - loggedRow struct { + // A row wraps a *sql.Row, logging slow queries. + row struct { *sql.Row log *zap.Logger } - loggedRows struct { + // rows wraps a *sql.Rows, logging slow queries. + rows struct { *sql.Rows + log *zap.Logger } ) -func (lr *loggedRows) Next() bool { +func (r *rows) Next() bool { start := time.Now() - next := lr.Rows.Next() + next := r.Rows.Next() if dur := time.Since(start); dur > longQueryDuration { - lr.log.Debug("slow next", zap.Duration("elapsed", dur), zap.Stack("stack")) + r.log.Debug("slow next", zap.Duration("elapsed", dur), zap.Stack("stack")) } return next } -func (lr *loggedRows) Scan(dest ...any) error { +func (r *rows) Scan(dest ...any) error { start := time.Now() - err := lr.Rows.Scan(dest...) + err := r.Rows.Scan(dest...) if dur := time.Since(start); dur > longQueryDuration { - lr.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) + r.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) } return err } -func (lr *loggedRow) Scan(dest ...any) error { +func (r *row) Scan(dest ...any) error { start := time.Now() - err := lr.Row.Scan(dest...) + err := r.Row.Scan(dest...) if dur := time.Since(start); dur > longQueryDuration { - lr.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) + r.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) } return err } -func (ls *loggedStmt) Exec(args ...any) (sql.Result, error) { - return ls.ExecContext(context.Background(), args...) +func (s *stmt) Exec(args ...any) (sql.Result, error) { + return s.ExecContext(context.Background(), args...) } -func (ls *loggedStmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) { +func (s *stmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) { start := time.Now() - result, err := ls.Stmt.ExecContext(ctx, args...) + result, err := s.Stmt.ExecContext(ctx, args...) if dur := time.Since(start); dur > longQueryDuration { - ls.log.Debug("slow exec", zap.String("query", ls.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + s.log.Debug("slow exec", zap.String("query", s.query), zap.Duration("elapsed", dur), zap.Stack("stack")) } return result, err } -func (ls *loggedStmt) Query(args ...any) (*sql.Rows, error) { - return ls.QueryContext(context.Background(), args...) +func (s *stmt) Query(args ...any) (*sql.Rows, error) { + return s.QueryContext(context.Background(), args...) } -func (ls *loggedStmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) { +func (s *stmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) { start := time.Now() - rows, err := ls.Stmt.QueryContext(ctx, args...) + rows, err := s.Stmt.QueryContext(ctx, args...) if dur := time.Since(start); dur > longQueryDuration { - ls.log.Debug("slow query", zap.String("query", ls.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + s.log.Debug("slow query", zap.String("query", s.query), zap.Duration("elapsed", dur), zap.Stack("stack")) } return rows, err } -func (ls *loggedStmt) QueryRow(args ...any) *loggedRow { - return ls.QueryRowContext(context.Background(), args...) +func (s *stmt) QueryRow(args ...any) *row { + return s.QueryRowContext(context.Background(), args...) } -func (ls *loggedStmt) QueryRowContext(ctx context.Context, args ...any) *loggedRow { +func (s *stmt) QueryRowContext(ctx context.Context, args ...any) *row { start := time.Now() - row := ls.Stmt.QueryRowContext(ctx, args...) + r := s.Stmt.QueryRowContext(ctx, args...) if dur := time.Since(start); dur > longQueryDuration { - ls.log.Debug("slow query row", zap.String("query", ls.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + s.log.Debug("slow query row", zap.String("query", s.query), zap.Duration("elapsed", dur), zap.Stack("stack")) } - return &loggedRow{row, ls.log.Named("row")} + return &row{r, s.log.Named("row")} } // Exec executes a query without returning any rows. The args are for // any placeholder parameters in the query. -func (lt *loggedTxn) Exec(query string, args ...any) (sql.Result, error) { +func (tx *txn) Exec(query string, args ...any) (sql.Result, error) { start := time.Now() - result, err := lt.Tx.Exec(query, args...) + result, err := tx.Tx.Exec(query, args...) if dur := time.Since(start); dur > longQueryDuration { - lt.log.Debug("slow exec", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + tx.log.Debug("slow exec", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) } return result, err } @@ -146,30 +131,30 @@ func (lt *loggedTxn) Exec(query string, args ...any) (sql.Result, error) { // Multiple queries or executions may be run concurrently from the // returned statement. The caller must call the statement's Close method // when the statement is no longer needed. -func (lt *loggedTxn) Prepare(query string) (*loggedStmt, error) { +func (tx *txn) Prepare(query string) (*stmt, error) { start := time.Now() - stmt, err := lt.Tx.Prepare(query) + s, err := tx.Tx.Prepare(query) if dur := time.Since(start); dur > longQueryDuration { - lt.log.Debug("slow prepare", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + tx.log.Debug("slow prepare", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) } else if err != nil { return nil, err } - return &loggedStmt{ - Stmt: stmt, + return &stmt{ + Stmt: s, query: query, - log: lt.log.Named("statement"), + log: tx.log.Named("statement"), }, nil } // Query executes a query that returns rows, typically a SELECT. The // args are for any placeholder parameters in the query. -func (lt *loggedTxn) Query(query string, args ...any) (*loggedRows, error) { +func (tx *txn) Query(query string, args ...any) (*rows, error) { start := time.Now() - rows, err := lt.Tx.Query(query, args...) + r, err := tx.Tx.Query(query, args...) if dur := time.Since(start); dur > longQueryDuration { - lt.log.Debug("slow query", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + tx.log.Debug("slow query", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) } - return &loggedRows{rows, lt.log.Named("rows")}, err + return &rows{r, tx.log.Named("rows")}, err } // QueryRow executes a query that is expected to return at most one row. @@ -177,13 +162,13 @@ func (lt *loggedTxn) Query(query string, args ...any) (*loggedRows, error) { // Row's Scan method is called. If the query selects no rows, the *Row's // Scan will return ErrNoRows. Otherwise, the *Row's Scan scans the // first selected row and discards the rest. -func (lt *loggedTxn) QueryRow(query string, args ...any) *loggedRow { +func (tx *txn) QueryRow(query string, args ...any) *row { start := time.Now() - row := lt.Tx.QueryRow(query, args...) + r := tx.Tx.QueryRow(query, args...) if dur := time.Since(start); dur > longQueryDuration { - lt.log.Debug("slow query row", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + tx.log.Debug("slow query row", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) } - return &loggedRow{row, lt.log.Named("row")} + return &row{r, tx.log.Named("row")} } func queryPlaceHolders(n int) string { @@ -220,7 +205,7 @@ func getDBVersion(db *sql.DB) (version int64) { } // setDBVersion sets the current version of the database. -func setDBVersion(tx txn, version int64) error { +func setDBVersion(tx *txn, version int64) error { const query = `UPDATE global_settings SET db_version=$1 RETURNING id;` var dbID int64 return tx.QueryRow(query, version).Scan(&dbID) diff --git a/persist/sqlite/store.go b/persist/sqlite/store.go index 0db301e..e50fda5 100644 --- a/persist/sqlite/store.go +++ b/persist/sqlite/store.go @@ -3,6 +3,7 @@ package sqlite import ( "database/sql" "encoding/hex" + "errors" "fmt" "math" "strings" @@ -27,7 +28,7 @@ type ( // function returns an error, the transaction is rolled back. Otherwise, the // transaction is committed. If the transaction fails due to a busy error, it is // retried up to 10 times before returning. -func (s *Store) transaction(fn func(txn) error) error { +func (s *Store) transaction(fn func(*txn) error) error { var err error txnID := hex.EncodeToString(frand.Bytes(4)) log := s.log.Named("transaction").With(zap.String("id", txnID)) @@ -76,25 +77,27 @@ func sqliteFilepath(fp string) string { // doTransaction is a helper function to execute a function within a transaction. If fn returns // an error, the transaction is rolled back. Otherwise, the transaction is // committed. -func doTransaction(db *sql.DB, log *zap.Logger, fn func(tx txn) error) error { +func doTransaction(db *sql.DB, log *zap.Logger, fn func(tx *txn) error) error { start := time.Now() - tx, err := db.Begin() + dbtx, err := db.Begin() if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } - defer tx.Rollback() defer func() { + if err := dbtx.Rollback(); err != nil && !errors.Is(err, sql.ErrTxDone) { + log.Error("failed to rollback transaction", zap.Error(err)) + } // log the transaction if it took longer than txn duration if time.Since(start) > longTxnDuration { log.Debug("long transaction", zap.Duration("elapsed", time.Since(start)), zap.Stack("stack"), zap.Bool("failed", err != nil)) } }() - ltx := &loggedTxn{ - Tx: tx, + tx := &txn{ + Tx: dbtx, log: log, } - if err = fn(ltx); err != nil { + if err = fn(tx); err != nil { return err } else if err = tx.Commit(); err != nil { return fmt.Errorf("failed to commit transaction: %w", err) diff --git a/persist/sqlite/wallet.go b/persist/sqlite/wallet.go index 0fc0a05..f5f081d 100644 --- a/persist/sqlite/wallet.go +++ b/persist/sqlite/wallet.go @@ -10,7 +10,7 @@ import ( "go.sia.tech/walletd/wallet" ) -func insertAddress(tx txn, addr types.Address) (id int64, err error) { +func insertAddress(tx *txn, addr types.Address) (id int64, err error) { const query = `INSERT INTO sia_addresses (sia_address, siacoin_balance, siafund_balance) VALUES ($1, $2, 0) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id` @@ -21,7 +21,7 @@ RETURNING id` // WalletEvents returns the events relevant to a wallet, sorted by height descending. func (s *Store) WalletEvents(walletID string, offset, limit int) (events []wallet.Event, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `SELECT ev.id, ev.date_created, ci.height, ci.block_id, ev.event_type, ev.event_data FROM events ev INNER JOIN chain_indices ci ON (ev.index_id = ci.id) @@ -79,7 +79,7 @@ LIMIT $2 OFFSET $3` // AddWallet adds a wallet to the database. func (s *Store) AddWallet(name string, info json.RawMessage) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { const query = `INSERT INTO wallets (id, extra_data) VALUES ($1, $2)` _, err := tx.Exec(query, name, info) @@ -93,7 +93,7 @@ func (s *Store) AddWallet(name string, info json.RawMessage) error { // DeleteWallet deletes a wallet from the database. This does not stop tracking // addresses that were previously associated with the wallet. func (s *Store) DeleteWallet(name string) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { _, err := tx.Exec(`DELETE FROM wallets WHERE id=$1`, name) return err }) @@ -102,7 +102,7 @@ func (s *Store) DeleteWallet(name string) error { // Wallets returns a map of wallet names to wallet extra data. func (s *Store) Wallets() (map[string]json.RawMessage, error) { wallets := make(map[string]json.RawMessage) - err := s.transaction(func(tx txn) error { + err := s.transaction(func(tx *txn) error { const query = `SELECT id, extra_data FROM wallets` rows, err := tx.Query(query) @@ -126,7 +126,7 @@ func (s *Store) Wallets() (map[string]json.RawMessage, error) { // AddAddress adds an address to a wallet. func (s *Store) AddAddress(walletID string, address types.Address, info json.RawMessage) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { addressID, err := insertAddress(tx, address) if err != nil { return fmt.Errorf("failed to insert address: %w", err) @@ -139,7 +139,7 @@ func (s *Store) AddAddress(walletID string, address types.Address, info json.Raw // RemoveAddress removes an address from a wallet. This does not stop tracking // the address. func (s *Store) RemoveAddress(walletID string, address types.Address) error { - return s.transaction(func(tx txn) error { + return s.transaction(func(tx *txn) error { const query = `DELETE FROM wallet_addresses WHERE wallet_id=$1 AND address_id=(SELECT id FROM sia_addresses WHERE sia_address=$2)` _, err := tx.Exec(query, walletID, encode(address)) return err @@ -149,7 +149,7 @@ func (s *Store) RemoveAddress(walletID string, address types.Address) error { // Addresses returns a map of addresses to their extra data for a wallet. func (s *Store) Addresses(walletID string) (map[types.Address]json.RawMessage, error) { addresses := make(map[types.Address]json.RawMessage) - err := s.transaction(func(tx txn) error { + err := s.transaction(func(tx *txn) error { const query = `SELECT sa.sia_address, wa.extra_data FROM wallet_addresses wa INNER JOIN sia_addresses sa ON (sa.id = wa.address_id) @@ -176,7 +176,7 @@ WHERE wa.wallet_id=$1` // UnspentSiacoinOutputs returns the unspent siacoin outputs for a wallet. func (s *Store) UnspentSiacoinOutputs(walletID string) (siacoins []types.SiacoinElement, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `SELECT se.id, se.leaf_index, se.merkle_proof, se.siacoin_value, sa.sia_address, se.maturity_height FROM siacoin_elements se INNER JOIN sia_addresses sa ON (se.address_id = sa.id) @@ -204,7 +204,7 @@ func (s *Store) UnspentSiacoinOutputs(walletID string) (siacoins []types.Siacoin // UnspentSiafundOutputs returns the unspent siafund outputs for a wallet. func (s *Store) UnspentSiafundOutputs(walletID string) (siafunds []types.SiafundElement, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `SELECT se.id, se.leaf_index, se.merkle_proof, se.siafund_value, se.claim_start, sa.sia_address FROM siafund_elements se INNER JOIN sia_addresses sa ON (se.address_id = sa.id) @@ -231,7 +231,7 @@ func (s *Store) UnspentSiafundOutputs(walletID string) (siafunds []types.Siafund // WalletBalance returns the total balance of a wallet. func (s *Store) WalletBalance(walletID string) (sc types.Currency, sf uint64, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `SELECT siacoin_balance, siafund_balance FROM sia_addresses sa INNER JOIN wallet_addresses wa ON (sa.id = wa.address_id) WHERE wa.wallet_id=$1` @@ -258,7 +258,7 @@ func (s *Store) WalletBalance(walletID string) (sc types.Currency, sf uint64, er // AddressBalance returns the balance of a single address. func (s *Store) AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `SELECT siacoin_balance, siafund_balance FROM address_balance WHERE sia_address=$1` return tx.QueryRow(query, encode(address)).Scan(decode(&sc), &sf) }) @@ -267,7 +267,7 @@ func (s *Store) AddressBalance(address types.Address) (sc types.Currency, sf uin // Annotate annotates a list of transactions using the wallet's addresses. func (s *Store) Annotate(walletID string, txns []types.Transaction) (annotated []wallet.PoolTransaction, err error) { - err = s.transaction(func(tx txn) error { + err = s.transaction(func(tx *txn) error { const query = `SELECT sa.id FROM sia_addresses sa INNER JOIN wallet_addresses wa ON (sa.id = wa.address_id) WHERE wa.wallet_id=$1 AND sa.sia_address=$2 LIMIT 1` From 44e471d74f541e6f6b8573964d872d49ccefabdf Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Fri, 26 Jan 2024 11:00:24 -0800 Subject: [PATCH 21/24] sqlite: fix address balance tracking --- persist/sqlite/consensus.go | 376 +++++++++++++++--------------------- 1 file changed, 161 insertions(+), 215 deletions(-) diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index 9079fc8..889da40 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -9,12 +9,16 @@ import ( "go.sia.tech/core/types" "go.sia.tech/coreutils/chain" "go.sia.tech/walletd/wallet" + "go.uber.org/zap" ) const updateProofBatchSize = 1000 -type proofUpdater interface { +type chainUpdate interface { UpdateElementProof(*types.StateElement) + ForEachTreeNode(func(row, col uint64, h types.Hash256)) + ForEachSiacoinElement(func(types.SiacoinElement, bool)) + ForEachSiafundElement(func(types.SiafundElement, bool)) } func insertChainIndex(tx *txn, index types.ChainIndex) (id int64, err error) { @@ -64,12 +68,14 @@ func applyEvents(tx *txn, events []wallet.Event) error { return nil } -func deleteSiacoinOutputs(tx *txn, spent []types.SiacoinElement) error { - addrStmt, err := tx.Prepare(`SELECT id, siacoin_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) +func applySiacoinElements(tx *txn, cu chainUpdate, relevantAddress func(types.Address) bool, log *zap.Logger) error { + addrStatement, err := tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, siafund_balance) VALUES ($1, $2, 0) +ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address +RETURNING id, siacoin_balance`) if err != nil { - return fmt.Errorf("failed to prepare lookup statement: %w", err) + return fmt.Errorf("failed to prepare statement: %w", err) } - defer addrStmt.Close() + defer addrStatement.Close() updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1 WHERE id=$2`) if err != nil { @@ -77,172 +83,183 @@ func deleteSiacoinOutputs(tx *txn, spent []types.SiacoinElement) error { } defer updateBalanceStmt.Close() - deleteStmt, err := tx.Prepare(`DELETE FROM siacoin_elements WHERE id=$1 RETURNING id`) + addStmt, err := tx.Prepare(`INSERT INTO siacoin_elements (id, address_id, siacoin_value, merkle_proof, leaf_index, maturity_height) VALUES ($1, $2, $3, $4, $5, $6)`) + if err != nil { + return fmt.Errorf("failed to prepare insert statement: %w", err) + } + defer addStmt.Close() + + spendStmt, err := tx.Prepare(`DELETE FROM siacoin_elements WHERE id=$1 RETURNING id`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } - defer deleteStmt.Close() + defer spendStmt.Close() + + // using ForEachSiacoinElement creates an interesting problem. The + // ForEachSiacoinElement function is only called once for each element. So + // if a siacoin element is spent and created in the same block, the element + // will not exist in the database. + // + // This creates a problem with balance tracking since it subtracts the + // element value from the balance. However, since the element value was + // never added to the balance in the first place, the balance will be + // incorrect. The solution is to check if the UTXO is in the database before + // decrementing the balance. + // + // This is an important implementation detail since the store must assume + // the chain manager is correct and can't check the integrity of the database + // without reimplementing some of the consensus logic. + cu.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { + // sticky error + if err != nil { + return + } else if !relevantAddress(se.SiacoinOutput.Address) { + return + } - for _, se := range spent { // query the address database ID and balance var addressID int64 var balance types.Currency - err := addrStmt.QueryRow(encode(se.SiacoinOutput.Address)).Scan(&addressID, decode(&balance)) + err = addrStatement.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency)).Scan(&addressID, decode(&balance)) if err != nil { - return fmt.Errorf("failed to lookup address %q: %w", se.SiacoinOutput.Address, err) + err = fmt.Errorf("failed to query address %q: %w", se.SiacoinOutput.Address, err) + return } - // update the balance - balance = balance.Sub(se.SiacoinOutput.Value) - _, err = updateBalanceStmt.Exec(encode(balance), addressID) - if err != nil { - return fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) - } + if spent { + var dummy types.Hash256 + err = spendStmt.QueryRow(encode(se.ID)).Scan(decode(&dummy)) + if errors.Is(err, sql.ErrNoRows) { + // spent output not found, most likely an ephemeral output. ignore + err = nil + return + } else if err != nil { + err = fmt.Errorf("failed to delete output %q: %w", se.ID, err) + return + } - var dummy types.Hash256 - err = deleteStmt.QueryRow(encode(se.ID)).Scan(decode(&dummy)) - if err != nil { - return fmt.Errorf("failed to delete output %q: %w", se.ID, err) + // update the balance after making sure the utxo was in the database + // and not an ephemeral output + updated, underflow := balance.SubWithUnderflow(se.SiacoinOutput.Value) + if underflow { + log.Panic("balance is negative", zap.Stringer("address", se.SiacoinOutput.Address), zap.String("balance", balance.ExactString()), zap.Stringer("outputID", se.ID), zap.String("value", se.SiacoinOutput.Value.ExactString())) + } + _, err = updateBalanceStmt.Exec(encode(updated), addressID) + if err != nil { + err = fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) + return + } + + log.Debug("removed utxo", zap.Stringer("address", se.SiacoinOutput.Address), zap.Stringer("outputID", se.ID), zap.String("value", se.SiacoinOutput.Value.ExactString()), zap.Int64("addressID", addressID)) + } else { + balance = balance.Add(se.SiacoinOutput.Value) + + // update the balance + _, err = updateBalanceStmt.Exec(encode(balance), addressID) + if err != nil { + err = fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) + return + } + + // insert the created utxo + _, err = addStmt.Exec(encode(se.ID), addressID, encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight) + if err != nil { + err = fmt.Errorf("failed to insert output %q: %w", se.ID, err) + return + } + log.Debug("added utxo", zap.Stringer("address", se.SiacoinOutput.Address), zap.Stringer("outputID", se.ID), zap.String("value", se.SiacoinOutput.Value.ExactString()), zap.Int64("addressID", addressID)) } - } - return nil + }) + return err } -func applySiacoinOutputs(tx *txn, added map[types.Hash256]types.SiacoinElement) error { - addrStmt, err := tx.Prepare(`SELECT id, siacoin_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) +func applySiafundElements(tx *txn, cu chainUpdate, relevantAddress func(types.Address) bool, log *zap.Logger) error { + addrStatement, err := tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, siafund_balance) VALUES ($1, $2, 0) +ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address +RETURNING id, siafund_balance`) if err != nil { - return fmt.Errorf("failed to prepare lookup statement: %w", err) + return fmt.Errorf("failed to prepare statement: %w", err) } - defer addrStmt.Close() + defer addrStatement.Close() - updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1 WHERE id=$2`) + updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE id=$2`) if err != nil { return fmt.Errorf("failed to prepare update statement: %w", err) } defer updateBalanceStmt.Close() - addStmt, err := tx.Prepare(`INSERT INTO siacoin_elements (id, address_id, siacoin_value, merkle_proof, leaf_index, maturity_height) VALUES ($1, $2, $3, $4, $5, $6)`) + addStmt, err := tx.Prepare(`INSERT INTO siafund_elements (id, address_id, claim_start, merkle_proof, leaf_index, siafund_value) VALUES ($1, $2, $3, $4, $5, $6)`) if err != nil { return fmt.Errorf("failed to prepare insert statement: %w", err) } defer addStmt.Close() - for _, se := range added { - // query the address database ID and balance - var addressID int64 - var balance types.Currency - err := addrStmt.QueryRow(encode(se.SiacoinOutput.Address)).Scan(&addressID, decode(&balance)) - if err != nil { - return fmt.Errorf("failed to lookup address %q: %w", se.SiacoinOutput.Address, err) - } - - // update the balance - balance = balance.Add(se.SiacoinOutput.Value) - _, err = updateBalanceStmt.Exec(encode(balance), addressID) - if err != nil { - return fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) - } - - // insert the created utxo - _, err = addStmt.Exec(encode(se.ID), addressID, encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight) - if err != nil { - return fmt.Errorf("failed to insert output %q: %w", se.ID, err) - } - } - return nil -} - -func deleteSiafundOutputs(tx *txn, spent []types.SiafundElement) error { - addrStmt, err := tx.Prepare(`SELECT id, siafund_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) - if err != nil { - return fmt.Errorf("failed to prepare lookup statement: %w", err) - } - defer addrStmt.Close() - - updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE id=$2`) - if err != nil { - return fmt.Errorf("failed to prepare update statement: %w", err) - } - defer updateBalanceStmt.Close() - - spendStmt, err := tx.Prepare(`DELETE FROM siafund_elements WHERE id=$1 RETURNING id`) + spendStmt, err := tx.Prepare(`DELETE FROM siacoin_elements WHERE id=$1 RETURNING id`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } defer spendStmt.Close() - for _, se := range spent { - // query the address database ID and balance - var addressID int64 - var balance uint64 - err := addrStmt.QueryRow(encode(se.SiafundOutput.Address)).Scan(&addressID, balance) + cu.ForEachSiafundElement(func(se types.SiafundElement, spent bool) { + // sticky error if err != nil { - return fmt.Errorf("failed to lookup address %q: %w", se.SiafundOutput.Address, err) + return + } else if !relevantAddress(se.SiafundOutput.Address) { + return } - // update the balance - if balance < se.SiafundOutput.Value { - panic("siafund balance is negative") // developer error - } - balance -= se.SiafundOutput.Value - _, err = updateBalanceStmt.Exec(balance, addressID) - if err != nil { - return fmt.Errorf("failed to update address %q balance: %w", se.SiafundOutput.Address, err) - } - - var dummy types.Hash256 - err = spendStmt.QueryRow(encode(se.ID)).Scan(decode(&dummy)) - if err != nil { - return fmt.Errorf("failed to delete output %q: %w", se.ID, err) - } - } - return nil -} - -func applySiafundOutputs(tx *txn, added map[types.Hash256]types.SiafundElement) error { - addrStmt, err := tx.Prepare(`SELECT id, siafund_balance FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) - if err != nil { - return fmt.Errorf("failed to prepare lookup statement: %w", err) - } - defer addrStmt.Close() - - updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE id=$2`) - if err != nil { - return fmt.Errorf("failed to prepare update statement: %w", err) - } - defer updateBalanceStmt.Close() - - addStmt, err := tx.Prepare(`INSERT INTO siafund_elements (id, address_id, claim_start, siafund_value, merkle_proof, leaf_index) VALUES ($1, $2, $3, $4, $5, $6)`) - if err != nil { - return fmt.Errorf("failed to prepare statement: %w", err) - } - defer addStmt.Close() - - for _, se := range added { // query the address database ID and balance var addressID int64 var balance uint64 - err := addrStmt.QueryRow(encode(se.SiafundOutput.Address)).Scan(&addressID, balance) + err = addrStatement.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency)).Scan(&addressID, &balance) if err != nil { - return fmt.Errorf("failed to lookup address %q: %w", se.SiafundOutput.Address, err) + err = fmt.Errorf("failed to query address %q: %w", se.SiafundOutput.Address, err) + return } // update the balance - if balance < se.SiafundOutput.Value { - panic("siafund balance is negative") // developer error - } - balance -= se.SiafundOutput.Value - _, err = updateBalanceStmt.Exec(balance, addressID) - if err != nil { - return fmt.Errorf("failed to update address %q balance: %w", se.SiafundOutput.Address, err) - } + if spent { + var dummy types.Hash256 + err = spendStmt.QueryRow(encode(se.ID)).Scan(decode(&dummy)) + if errors.Is(err, sql.ErrNoRows) { + // spent output not found, most likely an ephemeral output. + // ignore + err = nil + return + } else if err != nil { + err = fmt.Errorf("failed to delete output %q: %w", se.ID, err) + return + } - _, err = addStmt.Exec(encode(se.ID), addressID, encode(se.ClaimStart), se.SiafundOutput.Value, encodeSlice(se.MerkleProof), se.LeafIndex) - if err != nil { - return fmt.Errorf("failed to insert output %q: %w", se.ID, err) + // update the balance only if the utxo was successfully deleted + if se.SiafundOutput.Value > balance { + log.Panic("balance is negative", zap.Stringer("address", se.SiafundOutput.Address), zap.Uint64("balance", se.SiafundOutput.Value), zap.Stringer("outputID", se.ID), zap.Uint64("value", se.SiafundOutput.Value)) + } + + balance -= se.SiafundOutput.Value + _, err = updateBalanceStmt.Exec(encode(balance), addressID) + if err != nil { + err = fmt.Errorf("failed to update address %q balance: %w", se.SiafundOutput.Address, err) + return + } + } else { + balance += se.SiafundOutput.Value + // update the balance + _, err = updateBalanceStmt.Exec(balance, addressID) + if err != nil { + err = fmt.Errorf("failed to update address %q balance: %w", se.SiafundOutput.Address, err) + return + } + + // insert the created utxo + _, err = addStmt.Exec(encode(se.ID), addressID, encode(se.ClaimStart), encodeSlice(se.MerkleProof), se.LeafIndex, se.SiafundOutput.Value) + if err != nil { + err = fmt.Errorf("failed to insert output %q: %w", se.ID, err) + return + } } - } - return nil + }) + return err } func updateLastIndexedTip(tx *txn, tip types.ChainIndex) error { @@ -282,7 +299,7 @@ func updateStateElement(s *stmt, se types.StateElement) error { } // how slow is this going to be 😬? -func updateElementProofs(tx *txn, table string, updater proofUpdater) error { +func updateElementProofs(tx *txn, table string, cu chainUpdate) error { stmt, err := tx.Prepare(`SELECT id, merkle_proof, leaf_index FROM ` + table + ` LIMIT $1 OFFSET $2`) if err != nil { return fmt.Errorf("failed to prepare batch statement: %w", err) @@ -304,7 +321,7 @@ func updateElementProofs(tx *txn, table string, updater proofUpdater) error { } for _, se := range elements { - updater.UpdateElementProof(&se) + cu.UpdateElementProof(&se) if err := updateStateElement(updateStmt, se); err != nil { return fmt.Errorf("failed to update state element: %w", err) } @@ -314,7 +331,7 @@ func updateElementProofs(tx *txn, table string, updater proofUpdater) error { } // applyChainUpdates applies the given chain updates to the database. -func applyChainUpdates(tx *txn, updates []*chain.ApplyUpdate) error { +func applyChainUpdates(tx *txn, updates []*chain.ApplyUpdate, log *zap.Logger) error { stmt, err := tx.Prepare(`SELECT id FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) @@ -327,7 +344,7 @@ func applyChainUpdates(tx *txn, updates []*chain.ApplyUpdate) error { // address. Monitor performance and consider changing this in the // future. From a memory perspective, it would be fine to lazy load all // addresses into memory. - ownsAddress := func(address types.Address) bool { + relevantAddress := func(address types.Address) bool { var dbID int64 err := stmt.QueryRow(encode(address)).Scan(&dbID) if err != nil && !errors.Is(err, sql.ErrNoRows) { @@ -337,51 +354,15 @@ func applyChainUpdates(tx *txn, updates []*chain.ApplyUpdate) error { } for _, update := range updates { - events := wallet.AppliedEvents(update.State, update.Block, update, ownsAddress) + events := wallet.AppliedEvents(update.State, update.Block, update, relevantAddress) if err := applyEvents(tx, events); err != nil { return fmt.Errorf("failed to apply events: %w", err) } - var spentSiacoinOutputs []types.SiacoinElement - newSiacoinOutputs := make(map[types.Hash256]types.SiacoinElement) - update.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { - if !ownsAddress(se.SiacoinOutput.Address) { - return - } - - if spent { - spentSiacoinOutputs = append(spentSiacoinOutputs, se) - delete(newSiacoinOutputs, se.ID) - } else { - newSiacoinOutputs[se.ID] = se - } - }) - - if err := deleteSiacoinOutputs(tx, spentSiacoinOutputs); err != nil { - return fmt.Errorf("failed to delete siacoin outputs: %w", err) - } else if err := applySiacoinOutputs(tx, newSiacoinOutputs); err != nil { - return fmt.Errorf("failed to apply siacoin outputs: %w", err) - } - - var spentSiafundOutputs []types.SiafundElement - newSiafundOutputs := make(map[types.Hash256]types.SiafundElement) - update.ForEachSiafundElement(func(sf types.SiafundElement, spent bool) { - if !ownsAddress(sf.SiafundOutput.Address) { - return - } - - if spent { - spentSiafundOutputs = append(spentSiafundOutputs, sf) - delete(newSiafundOutputs, sf.ID) - } else { - newSiafundOutputs[sf.ID] = sf - } - }) - - if err := deleteSiafundOutputs(tx, spentSiafundOutputs); err != nil { - return fmt.Errorf("failed to delete siafund outputs: %w", err) - } else if err := applySiafundOutputs(tx, newSiafundOutputs); err != nil { - return fmt.Errorf("failed to apply siafund outputs: %w", err) + if err := applySiacoinElements(tx, update, relevantAddress, log.Named("siacoins")); err != nil { + return fmt.Errorf("failed to apply siacoin elements: %w", err) + } else if err := applySiafundElements(tx, update, relevantAddress, log.Named("siafunds")); err != nil { + return fmt.Errorf("failed to apply siafund elements: %w", err) } // update proofs @@ -405,7 +386,7 @@ func (s *Store) ProcessChainApplyUpdate(cau *chain.ApplyUpdate, mayCommit bool) if mayCommit { return s.transaction(func(tx *txn) error { - if err := applyChainUpdates(tx, s.updates); err != nil { + if err := applyChainUpdates(tx, s.updates, s.log.Named("apply")); err != nil { return err } s.updates = nil @@ -417,6 +398,8 @@ func (s *Store) ProcessChainApplyUpdate(cau *chain.ApplyUpdate, mayCommit bool) // ProcessChainRevertUpdate implements chain.Subscriber func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { + log := s.log.Named("revert") + // update hasn't been committed yet if len(s.updates) > 0 && s.updates[len(s.updates)-1].Block.ID() == cru.Block.ID() { s.updates = s.updates[:len(s.updates)-1] @@ -437,7 +420,7 @@ func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { // address. Monitor performance and consider changing this in the // future. From a memory perspective, it would be fine to lazy load all // addresses into memory. - ownsAddress := func(address types.Address) bool { + relevantAddress := func(address types.Address) bool { var dbID int64 err := stmt.QueryRow(encode(address)).Scan(&dbID) if err != nil && !errors.Is(err, sql.ErrNoRows) { @@ -446,47 +429,10 @@ func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { return err == nil } - var spentSiacoinOutputs []types.SiacoinElement - var spentSiafundOutputs []types.SiafundElement - addedSiacoinOutputs := make(map[types.Hash256]types.SiacoinElement) - addedSiafundOutputs := make(map[types.Hash256]types.SiafundElement) - - cru.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { - if !ownsAddress(se.SiacoinOutput.Address) { - return - } - - if !spent { - spentSiacoinOutputs = append(spentSiacoinOutputs, se) - } else { - addedSiacoinOutputs[se.ID] = se - } - }) - - cru.ForEachSiafundElement(func(se types.SiafundElement, spent bool) { - if !ownsAddress(se.SiafundOutput.Address) { - return - } - - if !spent { - spentSiafundOutputs = append(spentSiafundOutputs, se) - } else { - addedSiafundOutputs[se.ID] = se - } - }) - - // revert siacoin outputs - if err := deleteSiacoinOutputs(tx, spentSiacoinOutputs); err != nil { - return fmt.Errorf("failed to delete siacoin outputs: %w", err) - } else if err := applySiacoinOutputs(tx, addedSiacoinOutputs); err != nil { - return fmt.Errorf("failed to apply siacoin outputs: %w", err) - } - - // revert siafund outputs - if err := deleteSiafundOutputs(tx, spentSiafundOutputs); err != nil { - return fmt.Errorf("failed to delete siafund outputs: %w", err) - } else if err := applySiafundOutputs(tx, addedSiafundOutputs); err != nil { - return fmt.Errorf("failed to apply siafund outputs: %w", err) + if err := applySiacoinElements(tx, cru, relevantAddress, log.Named("siacoins")); err != nil { + return fmt.Errorf("failed to apply siacoin elements: %w", err) + } else if err := applySiafundElements(tx, cru, relevantAddress, log.Named("siafunds")); err != nil { + return fmt.Errorf("failed to apply siafund elements: %w", err) } // revert events From b7ef75436140b0a952342065e6763fb2d6d5d661 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Fri, 26 Jan 2024 11:39:59 -0800 Subject: [PATCH 22/24] api,sqlite,wallet: support immature siacoin balance --- api/api_test.go | 56 +++++++++++++++- api/server.go | 6 +- persist/sqlite/consensus.go | 128 +++++++++++++++++++++++++++++------- persist/sqlite/init.sql | 1 + persist/sqlite/wallet.go | 20 +++--- wallet/manager.go | 4 +- 6 files changed, 176 insertions(+), 39 deletions(-) diff --git a/api/api_test.go b/api/api_test.go index afcb5a5..71f79cf 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -90,7 +90,7 @@ func TestWallet(t *testing.T) { balance, err := wc.Balance() if err != nil { t.Fatal(err) - } else if !balance.Siacoins.IsZero() || balance.Siafunds != 0 { + } else if !balance.Siacoins.IsZero() || !balance.ImmatureSiacoins.IsZero() || balance.Siafunds != 0 { t.Fatal("balance should be 0") } @@ -163,6 +163,8 @@ func TestWallet(t *testing.T) { t.Fatal(err) } else if !balance.Siacoins.Equals(types.Siacoins(1)) { t.Error("balance should be 1 SC, got", balance.Siacoins) + } else if !balance.ImmatureSiacoins.IsZero() { + t.Error("immature balance should be 0 SC, got", balance.ImmatureSiacoins) } // transaction should appear in history @@ -179,6 +181,58 @@ func TestWallet(t *testing.T) { } else if len(outputs) != 2 { t.Error("should have two UTXOs, got", len(outputs)) } + + // mine a block to add an immature balance + cs = cm.TipState() + b = types.Block{ + ParentID: cs.Index.ID, + Timestamp: types.CurrentTimestamp(), + MinerPayouts: []types.SiacoinOutput{{Address: addr, Value: cs.BlockReward()}}, + } + for b.ID().CmpWork(cs.ChildTarget) < 0 { + b.Nonce += cs.NonceFactor() + } + if err := cm.AddBlocks([]types.Block{b}); err != nil { + t.Fatal(err) + } + + // get new balance + balance, err = wc.Balance() + if err != nil { + t.Fatal(err) + } else if !balance.Siacoins.Equals(types.Siacoins(1)) { + t.Error("balance should be 1 SC, got", balance.Siacoins) + } else if !balance.ImmatureSiacoins.Equals(b.MinerPayouts[0].Value) { + t.Errorf("immature balance should be %d SC, got %d SC", b.MinerPayouts[0].Value, balance.ImmatureSiacoins) + } + + // mine enough blocks for the miner payout to mature + expectedBalance := types.Siacoins(1).Add(b.MinerPayouts[0].Value) + target := cs.MaturityHeight() + for cs.Index.Height < target { + cs = cm.TipState() + b := types.Block{ + ParentID: cs.Index.ID, + Timestamp: types.CurrentTimestamp(), + MinerPayouts: []types.SiacoinOutput{{Address: types.VoidAddress, Value: cs.BlockReward()}}, + } + for b.ID().CmpWork(cs.ChildTarget) < 0 { + b.Nonce += cs.NonceFactor() + } + if err := cm.AddBlocks([]types.Block{b}); err != nil { + t.Fatal(err) + } + } + + // get new balance + balance, err = wc.Balance() + if err != nil { + t.Fatal(err) + } else if !balance.Siacoins.Equals(expectedBalance) { + t.Errorf("balance should be %d, got %d", expectedBalance, balance.Siacoins) + } else if !balance.ImmatureSiacoins.IsZero() { + t.Error("immature balance should be 0 SC, got", balance.ImmatureSiacoins) + } } func TestV2(t *testing.T) { diff --git a/api/server.go b/api/server.go index 5030592..6949fba 100644 --- a/api/server.go +++ b/api/server.go @@ -57,7 +57,7 @@ type ( Events(name string, offset, limit int) ([]wallet.Event, error) UnspentSiacoinOutputs(name string) ([]types.SiacoinElement, error) UnspentSiafundOutputs(name string) ([]types.SiafundElement, error) - WalletBalance(walletID string) (sc types.Currency, sf uint64, err error) + WalletBalance(walletID string) (sc, immatureSC types.Currency, sf uint64, err error) Annotate(name string, pool []types.Transaction) ([]wallet.PoolTransaction, error) Reserve(ids []types.Hash256, duration time.Duration) error @@ -245,13 +245,13 @@ func (s *server) walletsBalanceHandler(jc jape.Context) { return } - sc, sf, err := s.wm.WalletBalance(name) + sc, isc, sf, err := s.wm.WalletBalance(name) if jc.Check("couldn't load balance", err) != nil { return } jc.Encode(WalletBalanceResponse{ Siacoins: sc, - ImmatureSiacoins: types.ZeroCurrency, + ImmatureSiacoins: isc, Siafunds: sf, }) } diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index 889da40..33fa61d 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -68,16 +68,16 @@ func applyEvents(tx *txn, events []wallet.Event) error { return nil } -func applySiacoinElements(tx *txn, cu chainUpdate, relevantAddress func(types.Address) bool, log *zap.Logger) error { - addrStatement, err := tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, siafund_balance) VALUES ($1, $2, 0) +func applySiacoinElements(tx *txn, index types.ChainIndex, cu chainUpdate, relevantAddress func(types.Address) bool, log *zap.Logger) error { + addrStatement, err := tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $2, 0) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address -RETURNING id, siacoin_balance`) +RETURNING id, siacoin_balance, immature_siacoin_balance`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } defer addrStatement.Close() - updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1 WHERE id=$2`) + updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1, immature_siacoin_balance=$2 WHERE id=$3`) if err != nil { return fmt.Errorf("failed to prepare update statement: %w", err) } @@ -119,8 +119,8 @@ RETURNING id, siacoin_balance`) // query the address database ID and balance var addressID int64 - var balance types.Currency - err = addrStatement.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency)).Scan(&addressID, decode(&balance)) + var balance, immatureBalance types.Currency + err = addrStatement.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency)).Scan(&addressID, decode(&balance), decode(&immatureBalance)) if err != nil { err = fmt.Errorf("failed to query address %q: %w", se.SiacoinOutput.Address, err) return @@ -138,13 +138,13 @@ RETURNING id, siacoin_balance`) return } - // update the balance after making sure the utxo was in the database - // and not an ephemeral output - updated, underflow := balance.SubWithUnderflow(se.SiacoinOutput.Value) - if underflow { - log.Panic("balance is negative", zap.Stringer("address", se.SiacoinOutput.Address), zap.String("balance", balance.ExactString()), zap.Stringer("outputID", se.ID), zap.String("value", se.SiacoinOutput.Value.ExactString())) + if se.MaturityHeight > index.Height { + immatureBalance = immatureBalance.Sub(se.SiacoinOutput.Value) + } else { + balance = balance.Sub(se.SiacoinOutput.Value) } - _, err = updateBalanceStmt.Exec(encode(updated), addressID) + + _, err = updateBalanceStmt.Exec(encode(balance), encode(immatureBalance), addressID) if err != nil { err = fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) return @@ -152,29 +152,36 @@ RETURNING id, siacoin_balance`) log.Debug("removed utxo", zap.Stringer("address", se.SiacoinOutput.Address), zap.Stringer("outputID", se.ID), zap.String("value", se.SiacoinOutput.Value.ExactString()), zap.Int64("addressID", addressID)) } else { - balance = balance.Add(se.SiacoinOutput.Value) - - // update the balance - _, err = updateBalanceStmt.Exec(encode(balance), addressID) + // insert the created utxo + _, err = addStmt.Exec(encode(se.ID), addressID, encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight) if err != nil { - err = fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) + err = fmt.Errorf("failed to insert output %q: %w", se.ID, err) return } - // insert the created utxo - _, err = addStmt.Exec(encode(se.ID), addressID, encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight) + if se.MaturityHeight > index.Height { + immatureBalance = immatureBalance.Add(se.SiacoinOutput.Value) + log.Debug("adding immature balance") + } else { + balance = balance.Add(se.SiacoinOutput.Value) + log.Debug("adding balance") + } + + // update the balance + _, err = updateBalanceStmt.Exec(encode(balance), encode(immatureBalance), addressID) if err != nil { - err = fmt.Errorf("failed to insert output %q: %w", se.ID, err) + err = fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) return } - log.Debug("added utxo", zap.Stringer("address", se.SiacoinOutput.Address), zap.Stringer("outputID", se.ID), zap.String("value", se.SiacoinOutput.Value.ExactString()), zap.Int64("addressID", addressID)) + log.Debug("added utxo", zap.Uint64("maturityHeight", se.MaturityHeight), zap.Stringer("address", se.SiacoinOutput.Address), zap.Stringer("outputID", se.ID), zap.String("value", se.SiacoinOutput.Value.ExactString()), zap.Int64("addressID", addressID)) } }) return err } func applySiafundElements(tx *txn, cu chainUpdate, relevantAddress func(types.Address) bool, log *zap.Logger) error { - addrStatement, err := tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, siafund_balance) VALUES ($1, $2, 0) + // create the address if it doesn't exist + addrStatement, err := tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $2, 0) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id, siafund_balance`) if err != nil { @@ -211,6 +218,7 @@ RETURNING id, siafund_balance`) // query the address database ID and balance var addressID int64 var balance uint64 + // get the address ID err = addrStatement.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency)).Scan(&addressID, &balance) if err != nil { err = fmt.Errorf("failed to query address %q: %w", se.SiafundOutput.Address, err) @@ -330,6 +338,67 @@ func updateElementProofs(tx *txn, table string, cu chainUpdate) error { return nil } +func getMaturedValue(tx *txn, index types.ChainIndex) (matured map[int64]types.Currency, err error) { + rows, err := tx.Query(`SELECT address_id, siacoin_value FROM siacoin_elements WHERE maturity_height=$1`, index.Height) + if err != nil { + return nil, fmt.Errorf("failed to query siacoin elements: %w", err) + } + defer rows.Close() + + matured = make(map[int64]types.Currency) + for rows.Next() { + var addressID int64 + var value types.Currency + err := rows.Scan(&addressID, decode(&value)) + if err != nil { + return nil, fmt.Errorf("failed to scan matured balance: %w", err) + } + matured[addressID] = matured[addressID].Add(value) + } + return +} + +func updateImmatureBalance(tx *txn, index types.ChainIndex, revert bool) error { + balanceStmt, err := tx.Prepare(`SELECT siacoin_balance, immature_siacoin_balance FROM sia_addresses WHERE id=$1`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer balanceStmt.Close() + + updateStmt, err := tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1, immature_siacoin_balance=$2 WHERE id=$3`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer updateStmt.Close() + + delta, err := getMaturedValue(tx, index) + if err != nil { + return fmt.Errorf("failed to get matured utxos: %w", err) + } + + for addressID, value := range delta { + var balance, immatureBalance types.Currency + err := balanceStmt.QueryRow(addressID).Scan(decode(&balance), decode(&immatureBalance)) + if err != nil { + return fmt.Errorf("failed to query address %d: %w", addressID, err) + } + + if revert { + balance = balance.Sub(value) + immatureBalance = immatureBalance.Add(value) + } else { + balance = balance.Add(value) + immatureBalance = immatureBalance.Sub(value) + } + + _, err = updateStmt.Exec(encode(balance), encode(immatureBalance), addressID) + if err != nil { + return fmt.Errorf("failed to update address %d: %w", addressID, err) + } + } + return nil +} + // applyChainUpdates applies the given chain updates to the database. func applyChainUpdates(tx *txn, updates []*chain.ApplyUpdate, log *zap.Logger) error { stmt, err := tx.Prepare(`SELECT id FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) @@ -354,12 +423,18 @@ func applyChainUpdates(tx *txn, updates []*chain.ApplyUpdate, log *zap.Logger) e } for _, update := range updates { + // mature the immature balance first + if err := updateImmatureBalance(tx, update.State.Index, false); err != nil { + return fmt.Errorf("failed to update immature balance: %w", err) + } + // apply new events events := wallet.AppliedEvents(update.State, update.Block, update, relevantAddress) if err := applyEvents(tx, events); err != nil { return fmt.Errorf("failed to apply events: %w", err) } - if err := applySiacoinElements(tx, update, relevantAddress, log.Named("siacoins")); err != nil { + // apply new elements + if err := applySiacoinElements(tx, update.State.Index, update, relevantAddress, log.Named("siacoins")); err != nil { return fmt.Errorf("failed to apply siacoin elements: %w", err) } else if err := applySiafundElements(tx, update, relevantAddress, log.Named("siafunds")); err != nil { return fmt.Errorf("failed to apply siafund elements: %w", err) @@ -429,7 +504,7 @@ func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { return err == nil } - if err := applySiacoinElements(tx, cru, relevantAddress, log.Named("siacoins")); err != nil { + if err := applySiacoinElements(tx, cru.State.Index, cru, relevantAddress, log.Named("siacoins")); err != nil { return fmt.Errorf("failed to apply siacoin elements: %w", err) } else if err := applySiafundElements(tx, cru, relevantAddress, log.Named("siafunds")); err != nil { return fmt.Errorf("failed to apply siafund elements: %w", err) @@ -441,6 +516,11 @@ func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { return fmt.Errorf("failed to delete chain index: %w", err) } + // revert immature balance + if err := updateImmatureBalance(tx, cru.State.Index, true); err != nil { + return fmt.Errorf("failed to update immature balance: %w", err) + } + // update proofs if err := updateElementProofs(tx, "siacoin_elements", cru); err != nil { return fmt.Errorf("failed to update siacoin element proofs: %w", err) diff --git a/persist/sqlite/init.sql b/persist/sqlite/init.sql index 504e3b5..d9d4cff 100644 --- a/persist/sqlite/init.sql +++ b/persist/sqlite/init.sql @@ -8,6 +8,7 @@ CREATE TABLE sia_addresses ( id INTEGER PRIMARY KEY, sia_address BLOB UNIQUE NOT NULL, siacoin_balance BLOB NOT NULL, + immature_siacoin_balance BLOB NOT NULL, siafund_balance INTEGER NOT NULL ); diff --git a/persist/sqlite/wallet.go b/persist/sqlite/wallet.go index f5f081d..88f1248 100644 --- a/persist/sqlite/wallet.go +++ b/persist/sqlite/wallet.go @@ -11,8 +11,8 @@ import ( ) func insertAddress(tx *txn, addr types.Address) (id int64, err error) { - const query = `INSERT INTO sia_addresses (sia_address, siacoin_balance, siafund_balance) -VALUES ($1, $2, 0) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address + const query = `INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) +VALUES ($1, $2, $2, 0) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id` err = tx.QueryRow(query, encode(addr), encode(types.ZeroCurrency)).Scan(&id) @@ -230,9 +230,9 @@ func (s *Store) UnspentSiafundOutputs(walletID string) (siafunds []types.Siafund } // WalletBalance returns the total balance of a wallet. -func (s *Store) WalletBalance(walletID string) (sc types.Currency, sf uint64, err error) { +func (s *Store) WalletBalance(walletID string) (sc, immatureSC types.Currency, sf uint64, err error) { err = s.transaction(func(tx *txn) error { - const query = `SELECT siacoin_balance, siafund_balance FROM sia_addresses sa + const query = `SELECT siacoin_balance, immature_siacoin_balance, siafund_balance FROM sia_addresses sa INNER JOIN wallet_addresses wa ON (sa.id = wa.address_id) WHERE wa.wallet_id=$1` @@ -242,14 +242,16 @@ func (s *Store) WalletBalance(walletID string) (sc types.Currency, sf uint64, er } for rows.Next() { - var siacoin types.Currency - var siafund uint64 + var addressSC types.Currency + var addressISC types.Currency + var addressSF uint64 - if err := rows.Scan(decode(&siacoin), &siafund); err != nil { + if err := rows.Scan(decode(&addressSC), decode(&addressISC), decode(&addressSF)); err != nil { return fmt.Errorf("failed to scan address balance: %w", err) } - sc = sc.Add(siacoin) - sf += siafund + sc = sc.Add(addressSC) + immatureSC = immatureSC.Add(addressISC) + sf += addressSF } return nil }) diff --git a/wallet/manager.go b/wallet/manager.go index f1b0e62..74039ed 100644 --- a/wallet/manager.go +++ b/wallet/manager.go @@ -36,7 +36,7 @@ type ( UnspentSiacoinOutputs(walletID string) ([]types.SiacoinElement, error) UnspentSiafundOutputs(walletID string) ([]types.SiafundElement, error) Annotate(walletID string, txns []types.Transaction) ([]PoolTransaction, error) - WalletBalance(walletID string) (sc types.Currency, sf uint64, err error) + WalletBalance(walletID string) (sc, immature types.Currency, sf uint64, err error) AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) @@ -105,7 +105,7 @@ func (m *Manager) Annotate(name string, pool []types.Transaction) ([]PoolTransac } // WalletBalance returns the balance of the given wallet. -func (m *Manager) WalletBalance(walletID string) (sc types.Currency, sf uint64, err error) { +func (m *Manager) WalletBalance(walletID string) (sc, immature types.Currency, sf uint64, err error) { return m.store.WalletBalance(walletID) } From 75812b1caf2dd6498666a81c3bf566538bb5b356 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Thu, 1 Feb 2024 10:56:02 -0400 Subject: [PATCH 23/24] remove duplicate api startup --- cmd/walletd/main.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/cmd/walletd/main.go b/cmd/walletd/main.go index 477344c..2b69192 100644 --- a/cmd/walletd/main.go +++ b/cmd/walletd/main.go @@ -197,8 +197,6 @@ func main() { stop := n.Start() log.Println("api: Listening on", l.Addr()) go startWeb(l, n, apiPassword) - log.Println("api: Listening on", l.Addr()) - go startWeb(l, n, apiPassword) signalCh := make(chan os.Signal, 1) signal.Notify(signalCh, os.Interrupt) <-signalCh From 0df580e02db066eb5f1f19d605c1839402084b18 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Thu, 1 Feb 2024 11:12:53 -0400 Subject: [PATCH 24/24] sqlite: fix currency encoding --- go.mod | 2 +- go.sum | 2 ++ persist/sqlite/encoding.go | 11 +++++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index b50ae10..084ce0b 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.21 require ( github.com/mattn/go-sqlite3 v1.14.21 - go.sia.tech/core v0.2.1-0.20240130145801-8067f34b2ecc + go.sia.tech/core v0.2.1 go.sia.tech/coreutils v0.0.0-20240130201319-8303550528d7 go.sia.tech/jape v0.9.0 go.sia.tech/web/walletd v0.16.0 diff --git a/go.sum b/go.sum index e25fb79..f29db2e 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ go.etcd.io/bbolt v1.3.8 h1:xs88BrvEv273UsB79e0hcVrlUWmS0a8upikMFhSyAtA= go.etcd.io/bbolt v1.3.8/go.mod h1:N9Mkw9X8x5fupy0IKsmuqVtoGDyxsaDlbk4Rd05IAQw= go.sia.tech/core v0.2.1-0.20240130145801-8067f34b2ecc h1:oUCCTOatQIwYkJ2FUWRvJtgU+i/BwlzmzCxoSvmmJVQ= go.sia.tech/core v0.2.1-0.20240130145801-8067f34b2ecc/go.mod h1:3EoY+rR78w1/uGoXXVqcYdwSjSJKuEMI5bL7WROA27Q= +go.sia.tech/core v0.2.1 h1:CqmMd+T5rAhC+Py3NxfvGtvsj/GgwIqQHHVrdts/LqY= +go.sia.tech/core v0.2.1/go.mod h1:3EoY+rR78w1/uGoXXVqcYdwSjSJKuEMI5bL7WROA27Q= go.sia.tech/coreutils v0.0.0-20240130201319-8303550528d7 h1:G2l6fRzAdNZy2z7+FhoG2y8ARtFpR6PkXXTB5tkdfZ8= go.sia.tech/coreutils v0.0.0-20240130201319-8303550528d7/go.mod h1:3Mb206QDd3NtRiaHZ2kN87/HKXhcBF6lHVatS7PkViY= go.sia.tech/jape v0.9.0 h1:kWgMFqALYhLMJYOwWBgJda5ko/fi4iZzRxHRP7pp8NY= diff --git a/persist/sqlite/encoding.go b/persist/sqlite/encoding.go index f011b23..8f98230 100644 --- a/persist/sqlite/encoding.go +++ b/persist/sqlite/encoding.go @@ -13,6 +13,11 @@ import ( func encode(obj any) any { switch obj := obj.(type) { + case types.Currency: + buf := make([]byte, 16) + binary.LittleEndian.PutUint64(buf, obj.Lo) + binary.LittleEndian.PutUint64(buf[8:], obj.Hi) + return buf case types.EncoderTo: var buf bytes.Buffer e := types.NewEncoder(&buf) @@ -43,6 +48,12 @@ func (d *decodable) Scan(src any) error { switch src := src.(type) { case []byte: switch v := d.v.(type) { + case *types.Currency: + if len(src) != 16 { + return fmt.Errorf("cannot scan %d bytes into Currency", len(src)) + } + v.Lo = binary.LittleEndian.Uint64(src) + v.Hi = binary.LittleEndian.Uint64(src[8:]) case types.DecoderFrom: dec := types.NewBufDecoder(src) v.DecodeFrom(dec)