diff --git a/api/api.go b/api/api.go index f5cb51b..28c20d4 100644 --- a/api/api.go +++ b/api/api.go @@ -4,6 +4,7 @@ import ( "time" "go.sia.tech/core/types" + "go.sia.tech/walletd/wallet" ) // A GatewayPeer is a currently-connected peer. @@ -30,12 +31,8 @@ type TxpoolTransactionsResponse struct { V2Transactions []types.V2Transaction `json:"v2transactions"` } -// WalletBalanceResponse is the response type for /wallets/:name/balance. -type WalletBalanceResponse struct { - Siacoins types.Currency `json:"siacoins"` - ImmatureSiacoins types.Currency `json:"immatureSiacoins"` - Siafunds uint64 `json:"siafunds"` -} +// BalanceResponse is the response type for /wallets/:name/balance. +type BalanceResponse wallet.Balance // WalletOutputsResponse is the response type for /wallets/:name/outputs. type WalletOutputsResponse struct { diff --git a/api/api_test.go b/api/api_test.go index 71f79cf..59c40fa 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -71,6 +71,7 @@ func TestWallet(t *testing.T) { t.Fatal(err) } defer ws.Close() + wm, err := wallet.NewManager(cm, ws, log.Named("wallet")) if err != nil { t.Fatal(err) diff --git a/api/client.go b/api/client.go index 8365495..8729863 100644 --- a/api/client.go +++ b/api/client.go @@ -139,7 +139,7 @@ func (c *WalletClient) Addresses() (resp map[types.Address]json.RawMessage, err } // Balance returns the current wallet balance. -func (c *WalletClient) Balance() (resp WalletBalanceResponse, err error) { +func (c *WalletClient) Balance() (resp BalanceResponse, err error) { err = c.c.GET(fmt.Sprintf("/wallets/%v/balance", c.name), &resp) return } diff --git a/api/server.go b/api/server.go index 6949fba..81f63db 100644 --- a/api/server.go +++ b/api/server.go @@ -57,11 +57,11 @@ type ( Events(name string, offset, limit int) ([]wallet.Event, error) UnspentSiacoinOutputs(name string) ([]types.SiacoinElement, error) UnspentSiafundOutputs(name string) ([]types.SiafundElement, error) - WalletBalance(walletID string) (sc, immatureSC types.Currency, sf uint64, err error) + WalletBalance(walletID string) (wallet.Balance, error) Annotate(name string, pool []types.Transaction) ([]wallet.PoolTransaction, error) Reserve(ids []types.Hash256, duration time.Duration) error - AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) + AddressBalance(address types.Address) (wallet.Balance, error) } ) @@ -245,15 +245,11 @@ func (s *server) walletsBalanceHandler(jc jape.Context) { return } - sc, isc, sf, err := s.wm.WalletBalance(name) + b, err := s.wm.WalletBalance(name) if jc.Check("couldn't load balance", err) != nil { return } - jc.Encode(WalletBalanceResponse{ - Siacoins: sc, - ImmatureSiacoins: isc, - Siafunds: sf, - }) + jc.Encode(BalanceResponse(b)) } func (s *server) walletsEventsHandler(jc jape.Context) { diff --git a/cmd/walletd/testnet.go b/cmd/walletd/testnet.go index b87ad21..9a66bd0 100644 --- a/cmd/walletd/testnet.go +++ b/cmd/walletd/testnet.go @@ -324,7 +324,7 @@ func printTestnetEvents(c *api.Client, seed wallet.Seed) { check("Couldn't get events:", err) for i := range events { e := events[len(events)-1-i] - switch t := e.Val.(type) { + switch t := e.Data.(type) { case *wallet.EventTransaction: if len(t.SiacoinInputs) == 0 || len(t.SiacoinOutputs) == 0 { continue diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index f194652..a16246b 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -1,6 +1,7 @@ package sqlite import ( + "bytes" "database/sql" "encoding/json" "errors" @@ -12,520 +13,373 @@ import ( "go.uber.org/zap" ) -const updateProofBatchSize = 1000 +type updateTx struct { + tx *txn -type chainUpdate interface { - UpdateElementProof(*types.StateElement) - ForEachTreeNode(func(row, col uint64, h types.Hash256)) - ForEachSiacoinElement(func(types.SiacoinElement, bool)) - ForEachSiafundElement(func(types.SiafundElement, bool)) + relevantAddresses map[types.Address]bool } -func insertChainIndex(tx *txn, index types.ChainIndex) (id int64, err error) { - err = tx.QueryRow(`INSERT INTO chain_indices (height, block_id) VALUES ($1, $2) ON CONFLICT (block_id) DO UPDATE SET height=EXCLUDED.height RETURNING id`, index.Height, encode(index.ID)).Scan(&id) +func scanStateElement(s scanner) (se types.StateElement, err error) { + err = s.Scan(decode(&se.ID), &se.LeafIndex, decodeSlice(&se.MerkleProof)) return } -func applyEvents(tx *txn, events []wallet.Event) error { - stmt, err := tx.Prepare(`INSERT INTO events (date_created, index_id, event_type, event_data) VALUES ($1, $2, $3, $4) RETURNING id`) - if err != nil { - return fmt.Errorf("failed to prepare statement: %w", err) - } - defer stmt.Close() +func scanSiacoinElement(s scanner) (se types.SiacoinElement, err error) { + err = s.Scan(decode(&se.ID), decode(&se.SiacoinOutput.Value), decodeSlice(&se.MerkleProof), &se.LeafIndex, &se.MaturityHeight, decode(&se.SiacoinOutput.Address)) + return +} - addRelevantAddrStmt, err := tx.Prepare(`INSERT INTO event_addresses (event_id, address_id, block_height) VALUES ($1, $2, $3)`) +func (ut *updateTx) SiacoinStateElements() ([]types.StateElement, error) { + const query = `SELECT id, leaf_index, merkle_proof FROM siacoin_elements` + rows, err := ut.tx.Query(query) if err != nil { - return fmt.Errorf("failed to prepare statement: %w", err) + return nil, fmt.Errorf("failed to query siacoin elements: %w", err) } - defer addRelevantAddrStmt.Close() + defer rows.Close() - for _, event := range events { - id, err := insertChainIndex(tx, event.Index) + var elements []types.StateElement + for rows.Next() { + se, err := scanStateElement(rows) if err != nil { - return fmt.Errorf("failed to create chain index: %w", err) + return nil, fmt.Errorf("failed to scan state element: %w", err) } + elements = append(elements, se) + } + return elements, nil +} - buf, err := json.Marshal(event.Val) - if err != nil { - return fmt.Errorf("failed to marshal event: %w", err) - } +func (ut *updateTx) UpdateSiacoinStateElements(elements []types.StateElement) error { + const query = `UPDATE siacoin_elements SET merkle_proof=$1, leaf_index=$2 WHERE id=$3 RETURNING id` + stmt, err := ut.tx.Prepare(query) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer stmt.Close() - var eventID int64 - err = stmt.QueryRow(encode(event.Timestamp), id, event.Val.EventType(), buf).Scan(&eventID) + for _, se := range elements { + var dummy types.Hash256 + err := stmt.QueryRow(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)).Scan(decode(&dummy)) if err != nil { return fmt.Errorf("failed to execute statement: %w", err) } - - for _, addr := range event.Relevant { - addressID, err := insertAddress(tx, addr) - if err != nil { - return fmt.Errorf("failed to insert address: %w", err) - } else if _, err := addRelevantAddrStmt.Exec(eventID, addressID, event.Index.Height); err != nil { - return fmt.Errorf("failed to add relevant address: %w", err) - } - } } return nil } -func applySiacoinElements(tx *txn, index types.ChainIndex, cu chainUpdate, relevantAddress func(types.Address) bool, log *zap.Logger) error { - addrStatement, err := tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $2, 0) -ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address -RETURNING id, siacoin_balance, immature_siacoin_balance`) - if err != nil { - return fmt.Errorf("failed to prepare statement: %w", err) - } - defer addrStatement.Close() - - updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1, immature_siacoin_balance=$2 WHERE id=$3`) +func (ut *updateTx) SiafundStateElements() ([]types.StateElement, error) { + const query = `SELECT id, leaf_index, merkle_proof FROM siafund_elements` + rows, err := ut.tx.Query(query) if err != nil { - return fmt.Errorf("failed to prepare update statement: %w", err) + return nil, fmt.Errorf("failed to query siacoin elements: %w", err) } - defer updateBalanceStmt.Close() + defer rows.Close() - addStmt, err := tx.Prepare(`INSERT INTO siacoin_elements (id, address_id, siacoin_value, merkle_proof, leaf_index, maturity_height) VALUES ($1, $2, $3, $4, $5, $6)`) - if err != nil { - return fmt.Errorf("failed to prepare insert statement: %w", err) + var elements []types.StateElement + for rows.Next() { + se, err := scanStateElement(rows) + if err != nil { + return nil, fmt.Errorf("failed to scan state element: %w", err) + } + elements = append(elements, se) } - defer addStmt.Close() + return elements, nil +} - spendStmt, err := tx.Prepare(`DELETE FROM siacoin_elements WHERE id=$1 RETURNING id`) +func (ut *updateTx) UpdateSiafundStateElements(elements []types.StateElement) error { + const query = `UPDATE siafund_elements SET merkle_proof=$1, leaf_index=$2 WHERE id=$3 RETURNING id` + stmt, err := ut.tx.Prepare(query) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } - defer spendStmt.Close() - - // using ForEachSiacoinElement creates an interesting problem. The - // ForEachSiacoinElement function is only called once for each element. So - // if a siacoin element is spent and created in the same block, the element - // will not exist in the database. - // - // This creates a problem with balance tracking since it subtracts the - // element value from the balance. However, since the element value was - // never added to the balance in the first place, the balance will be - // incorrect. The solution is to check if the UTXO is in the database before - // decrementing the balance. - // - // This is an important implementation detail since the store must assume - // the chain manager is correct and can't check the integrity of the database - // without reimplementing some of the consensus logic. - cu.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { - // sticky error - if err != nil { - return - } else if !relevantAddress(se.SiacoinOutput.Address) { - return - } + defer stmt.Close() - // query the address database ID and balance - var addressID int64 - var balance, immatureBalance types.Currency - err = addrStatement.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency)).Scan(&addressID, decode(&balance), decode(&immatureBalance)) + for _, se := range elements { + var dummy types.Hash256 + err := stmt.QueryRow(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)).Scan(decode(&dummy)) if err != nil { - err = fmt.Errorf("failed to query address %q: %w", se.SiacoinOutput.Address, err) - return - } - - if spent { - var dummy types.Hash256 - err = spendStmt.QueryRow(encode(se.ID)).Scan(decode(&dummy)) - if errors.Is(err, sql.ErrNoRows) { - // spent output not found, most likely an ephemeral output. ignore - err = nil - return - } else if err != nil { - err = fmt.Errorf("failed to delete output %q: %w", se.ID, err) - return - } - - if se.MaturityHeight > index.Height { - immatureBalance = immatureBalance.Sub(se.SiacoinOutput.Value) - } else { - balance = balance.Sub(se.SiacoinOutput.Value) - } - - _, err = updateBalanceStmt.Exec(encode(balance), encode(immatureBalance), addressID) - if err != nil { - err = fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) - return - } - - log.Debug("removed utxo", zap.Stringer("address", se.SiacoinOutput.Address), zap.Stringer("outputID", se.ID), zap.String("value", se.SiacoinOutput.Value.ExactString()), zap.Int64("addressID", addressID)) - } else { - // insert the created utxo - _, err = addStmt.Exec(encode(se.ID), addressID, encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight) - if err != nil { - err = fmt.Errorf("failed to insert output %q: %w", se.ID, err) - return - } - - if se.MaturityHeight > index.Height { - immatureBalance = immatureBalance.Add(se.SiacoinOutput.Value) - log.Debug("adding immature balance") - } else { - balance = balance.Add(se.SiacoinOutput.Value) - log.Debug("adding balance") - } - - // update the balance - _, err = updateBalanceStmt.Exec(encode(balance), encode(immatureBalance), addressID) - if err != nil { - err = fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) - return - } - log.Debug("added utxo", zap.Uint64("maturityHeight", se.MaturityHeight), zap.Stringer("address", se.SiacoinOutput.Address), zap.Stringer("outputID", se.ID), zap.String("value", se.SiacoinOutput.Value.ExactString()), zap.Int64("addressID", addressID)) + return fmt.Errorf("failed to execute statement: %w", err) } - }) - return err + } + return nil } -func applySiafundElements(tx *txn, cu chainUpdate, relevantAddress func(types.Address) bool, log *zap.Logger) error { - // create the address if it doesn't exist - addrStatement, err := tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $2, 0) -ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address -RETURNING id, siafund_balance`) - if err != nil { - return fmt.Errorf("failed to prepare statement: %w", err) +func (ut *updateTx) AddressRelevant(addr types.Address) (bool, error) { + if relevant, ok := ut.relevantAddresses[addr]; ok { + return relevant, nil } - defer addrStatement.Close() - updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE id=$2`) - if err != nil { - return fmt.Errorf("failed to prepare update statement: %w", err) + var id int64 + err := ut.tx.QueryRow(`SELECT id FROM sia_addresses WHERE sia_address=$1`, encode(addr)).Scan(&id) + if errors.Is(err, sql.ErrNoRows) { + ut.relevantAddresses[addr] = false + return false, nil + } else if err != nil { + return false, fmt.Errorf("failed to query address: %w", err) } - defer updateBalanceStmt.Close() + ut.relevantAddresses[addr] = true + return ut.relevantAddresses[addr], nil +} - addStmt, err := tx.Prepare(`INSERT INTO siafund_elements (id, address_id, claim_start, merkle_proof, leaf_index, siafund_value) VALUES ($1, $2, $3, $4, $5, $6)`) - if err != nil { - return fmt.Errorf("failed to prepare insert statement: %w", err) - } - defer addStmt.Close() +func (ut *updateTx) AddressBalance(addr types.Address) (balance wallet.Balance, err error) { + err = ut.tx.QueryRow(`SELECT siacoin_balance, immature_siacoin_balance, siafund_balance FROM sia_addresses WHERE sia_address=$1`, encode(addr)).Scan(decode(&balance.Siacoins), decode(&balance.ImmatureSiacoins), &balance.Siafunds) + return +} - spendStmt, err := tx.Prepare(`DELETE FROM siacoin_elements WHERE id=$1 RETURNING id`) +func (ut *updateTx) UpdateBalances(balances []wallet.AddressBalance) error { + const query = `UPDATE sia_addresses SET siacoin_balance=$1, immature_siacoin_balance=$2, siafund_balance=$3 WHERE sia_address=$4` + stmt, err := ut.tx.Prepare(query) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } - defer spendStmt.Close() - - cu.ForEachSiafundElement(func(se types.SiafundElement, spent bool) { - // sticky error - if err != nil { - return - } else if !relevantAddress(se.SiafundOutput.Address) { - return - } + defer stmt.Close() - // query the address database ID and balance - var addressID int64 - var balance uint64 - // get the address ID - err = addrStatement.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency)).Scan(&addressID, &balance) + for _, ab := range balances { + _, err := stmt.Exec(encode(ab.Balance.Siacoins), encode(ab.Balance.ImmatureSiacoins), ab.Balance.Siafunds, encode(ab.Address)) if err != nil { - err = fmt.Errorf("failed to query address %q: %w", se.SiafundOutput.Address, err) - return - } - - // update the balance - if spent { - var dummy types.Hash256 - err = spendStmt.QueryRow(encode(se.ID)).Scan(decode(&dummy)) - if errors.Is(err, sql.ErrNoRows) { - // spent output not found, most likely an ephemeral output. - // ignore - err = nil - return - } else if err != nil { - err = fmt.Errorf("failed to delete output %q: %w", se.ID, err) - return - } - - // update the balance only if the utxo was successfully deleted - if se.SiafundOutput.Value > balance { - log.Panic("balance is negative", zap.Stringer("address", se.SiafundOutput.Address), zap.Uint64("balance", balance), zap.Stringer("outputID", se.ID), zap.Uint64("value", se.SiafundOutput.Value)) - } - - balance -= se.SiafundOutput.Value - _, err = updateBalanceStmt.Exec(encode(balance), addressID) - if err != nil { - err = fmt.Errorf("failed to update address %q balance: %w", se.SiafundOutput.Address, err) - return - } - } else { - balance += se.SiafundOutput.Value - // update the balance - _, err = updateBalanceStmt.Exec(balance, addressID) - if err != nil { - err = fmt.Errorf("failed to update address %q balance: %w", se.SiafundOutput.Address, err) - return - } - - // insert the created utxo - _, err = addStmt.Exec(encode(se.ID), addressID, encode(se.ClaimStart), encodeSlice(se.MerkleProof), se.LeafIndex, se.SiafundOutput.Value) - if err != nil { - err = fmt.Errorf("failed to insert output %q: %w", se.ID, err) - return - } + return fmt.Errorf("failed to execute statement: %w", err) } - }) - return err -} - -func updateLastIndexedTip(tx *txn, tip types.ChainIndex) error { - _, err := tx.Exec(`UPDATE global_settings SET last_indexed_tip=$1`, encode(tip)) - return err + } + return nil } -func getStateElementBatch(s *stmt, offset, limit int) ([]types.StateElement, error) { - rows, err := s.Query(limit, offset) +func (ut *updateTx) MaturedSiacoinElements(index types.ChainIndex) (elements []types.SiacoinElement, err error) { + const query = `SELECT se.id, se.siacoin_value, se.merkle_proof, se.leaf_index, se.maturity_height, a.sia_address +FROM siacoin_elements se +INNER JOIN sia_addresses a ON (se.address_id=a.id) +WHERE maturity_height=$1` + rows, err := ut.tx.Query(query, index.Height) if err != nil { return nil, fmt.Errorf("failed to query siacoin elements: %w", err) } defer rows.Close() - var updated []types.StateElement for rows.Next() { - var se types.StateElement - err := rows.Scan(decode(&se.ID), decodeSlice(&se.MerkleProof), &se.LeafIndex) + element, err := scanSiacoinElement(rows) if err != nil { - return nil, fmt.Errorf("failed to scan state element: %w", err) + return nil, fmt.Errorf("failed to scan siacoin element: %w", err) } - updated = append(updated, se) + elements = append(elements, element) } - return updated, nil -} - -func updateStateElement(s *stmt, se types.StateElement) error { - res, err := s.Exec(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)) - if err != nil { - return fmt.Errorf("failed to update siacoin element %q: %w", se.ID, err) - } else if n, err := res.RowsAffected(); err != nil { - return fmt.Errorf("failed to get rows affected: %w", err) - } else if n != 1 { - return fmt.Errorf("expected 1 row to be affected, got %d", n) - } - return nil + return } -// how slow is this going to be 😬? -func updateElementProofs(tx *txn, table string, cu chainUpdate) error { - stmt, err := tx.Prepare(`SELECT id, merkle_proof, leaf_index FROM ` + table + ` LIMIT $1 OFFSET $2`) +func (ut *updateTx) AddSiacoinElements(elements []types.SiacoinElement) error { + addrStmt, err := insertAddressStatement(ut.tx) if err != nil { - return fmt.Errorf("failed to prepare batch statement: %w", err) + return fmt.Errorf("failed to prepare address statement: %w", err) } - defer stmt.Close() + defer addrStmt.Close() - updateStmt, err := tx.Prepare(`UPDATE ` + table + ` SET merkle_proof=$1, leaf_index=$2 WHERE id=$3 RETURNING id`) + inserStmt, err := ut.tx.Prepare(`INSERT INTO siacoin_elements (id, siacoin_value, merkle_proof, leaf_index, maturity_height, address_id) VALUES ($1, $2, $3, $4, $5, $6)`) if err != nil { - return fmt.Errorf("failed to prepare update statement: %w", err) + return fmt.Errorf("failed to prepare insert statement: %w", err) } - defer updateStmt.Close() + defer inserStmt.Close() - for offset := 0; ; offset += updateProofBatchSize { - elements, err := getStateElementBatch(stmt, offset, updateProofBatchSize) + for _, se := range elements { + var addressID int64 + err := addrStmt.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID) if err != nil { - return fmt.Errorf("failed to get state element batch: %w", err) - } else if len(elements) == 0 { - break + return fmt.Errorf("failed to query address: %w", err) } - for _, se := range elements { - cu.UpdateElementProof(&se) - if err := updateStateElement(updateStmt, se); err != nil { - return fmt.Errorf("failed to update state element: %w", err) - } + _, err = inserStmt.Exec(encode(se.ID), encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight, addressID) + if err != nil { + return fmt.Errorf("failed to execute statement: %w", err) } } return nil } -func getMaturedValue(tx *txn, index types.ChainIndex) (matured map[int64]types.Currency, err error) { - rows, err := tx.Query(`SELECT address_id, siacoin_value FROM siacoin_elements WHERE maturity_height=$1`, index.Height) +func (ut *updateTx) RemoveSiacoinElements(elements []types.SiacoinOutputID) error { + stmt, err := ut.tx.Prepare(`DELETE FROM siacoin_elements WHERE id=$1 RETURNING id`) if err != nil { - return nil, fmt.Errorf("failed to query siacoin elements: %w", err) + return fmt.Errorf("failed to prepare statement: %w", err) } - defer rows.Close() + defer stmt.Close() - matured = make(map[int64]types.Currency) - for rows.Next() { - var addressID int64 - var value types.Currency - err := rows.Scan(&addressID, decode(&value)) + for _, id := range elements { + var dummy types.Hash256 + err := stmt.QueryRow(encode(id)).Scan(decode(&dummy)) if err != nil { - return nil, fmt.Errorf("failed to scan matured balance: %w", err) + return fmt.Errorf("failed to delete element %q: %w", id, err) } - matured[addressID] = matured[addressID].Add(value) } - return + return nil } -func updateImmatureBalance(tx *txn, index types.ChainIndex, revert bool) error { - balanceStmt, err := tx.Prepare(`SELECT siacoin_balance, immature_siacoin_balance FROM sia_addresses WHERE id=$1`) +func (ut *updateTx) AddSiafundElements(elements []types.SiafundElement) error { + addrStmt, err := insertAddressStatement(ut.tx) if err != nil { - return fmt.Errorf("failed to prepare statement: %w", err) + return fmt.Errorf("failed to prepare address statement: %w", err) } - defer balanceStmt.Close() + defer addrStmt.Close() - updateStmt, err := tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1, immature_siacoin_balance=$2 WHERE id=$3`) + inserStmt, err := ut.tx.Prepare(`INSERT INTO siafund_elements (id, siafund_value, merkle_proof, leaf_index, claim_start, address_id) VALUES ($1, $2, $3, $4, $5, $6)`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } - defer updateStmt.Close() - - delta, err := getMaturedValue(tx, index) - if err != nil { - return fmt.Errorf("failed to get matured utxos: %w", err) - } + defer inserStmt.Close() - for addressID, value := range delta { - var balance, immatureBalance types.Currency - err := balanceStmt.QueryRow(addressID).Scan(decode(&balance), decode(&immatureBalance)) + for _, se := range elements { + var addressID int64 + err := addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID) if err != nil { - return fmt.Errorf("failed to query address %d: %w", addressID, err) + return fmt.Errorf("failed to query address: %w", err) } - if revert { - balance = balance.Sub(value) - immatureBalance = immatureBalance.Add(value) - } else { - balance = balance.Add(value) - immatureBalance = immatureBalance.Sub(value) - } - - _, err = updateStmt.Exec(encode(balance), encode(immatureBalance), addressID) + _, err = inserStmt.Exec(encode(se.ID), se.SiafundOutput.Value, encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ClaimStart), addressID) if err != nil { - return fmt.Errorf("failed to update address %d: %w", addressID, err) + return fmt.Errorf("failed to execute statement: %w", err) } } return nil } -// applyChainUpdates applies the given chain updates to the database. -func applyChainUpdates(tx *txn, updates []*chain.ApplyUpdate, log *zap.Logger) error { - stmt, err := tx.Prepare(`SELECT id FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) +func (ut *updateTx) RemoveSiafundElements(elements []types.SiafundOutputID) error { + stmt, err := ut.tx.Prepare(`DELETE FROM siacoin_elements WHERE id=$1 RETURNING id`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } defer stmt.Close() - // note: this would be more performant for small wallets to load all - // addresses into memory. However, for larger wallets (> 10K addresses), - // this is time consuming. Instead, the database is queried for each - // address. Monitor performance and consider changing this in the - // future. From a memory perspective, it would be fine to lazy load all - // addresses into memory. - relevantAddress := func(address types.Address) bool { - var dbID int64 - err := stmt.QueryRow(encode(address)).Scan(&dbID) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - panic(err) // database error + for _, id := range elements { + var dummy types.Hash256 + err := stmt.QueryRow(encode(id)).Scan(decode(&dummy)) + if err != nil { + return fmt.Errorf("failed to delete element %q: %w", id, err) } - return err == nil } + return nil +} + +func (ut *updateTx) AddEvents(events []wallet.Event) error { + indexStmt, err := ut.tx.Prepare(`INSERT INTO chain_indices (height, block_id) VALUES ($1, $2) ON CONFLICT (block_id) DO UPDATE SET height=EXCLUDED.height RETURNING id`) + if err != nil { + return fmt.Errorf("failed to prepare index statement: %w", err) + } + defer indexStmt.Close() + + eventStmt, err := ut.tx.Prepare(`INSERT INTO events (event_id, maturity_height, date_created, index_id, event_type, event_data) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id`) + if err != nil { + return fmt.Errorf("failed to prepare event statement: %w", err) + } + defer eventStmt.Close() - for _, update := range updates { - // mature the immature balance first - if err := updateImmatureBalance(tx, update.State.Index, false); err != nil { - return fmt.Errorf("failed to update immature balance: %w", err) + addrStmt, err := insertAddressStatement(ut.tx) + if err != nil { + return fmt.Errorf("failed to prepare address statement: %w", err) + } + defer addrStmt.Close() + + relevantAddrStmt, err := ut.tx.Prepare(`INSERT INTO event_addresses (event_id, address_id) VALUES ($1, $2) ON CONFLICT (event_id, address_id) DO NOTHING`) + if err != nil { + return fmt.Errorf("failed to prepare relevant address statement: %w", err) + } + defer addrStmt.Close() + + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + for _, event := range events { + var chainIndexID int64 + err := indexStmt.QueryRow(event.Index.Height, encode(event.Index.ID)).Scan(&chainIndexID) + if err != nil { + return fmt.Errorf("failed to execute statement: %w", err) } - // apply new events - events := wallet.AppliedEvents(update.State, update.Block, update, relevantAddress) - if err := applyEvents(tx, events); err != nil { - return fmt.Errorf("failed to apply events: %w", err) + + buf.Reset() + if err := enc.Encode(event.Data); err != nil { + return fmt.Errorf("failed to encode event: %w", err) } - // apply new elements - if err := applySiacoinElements(tx, update.State.Index, update, relevantAddress, log.Named("siacoins")); err != nil { - return fmt.Errorf("failed to apply siacoin elements: %w", err) - } else if err := applySiafundElements(tx, update, relevantAddress, log.Named("siafunds")); err != nil { - return fmt.Errorf("failed to apply siafund elements: %w", err) + var eventID int64 + err = eventStmt.QueryRow(encode(event.ID), event.MaturityHeight, encode(event.Timestamp), chainIndexID, event.Data.EventType(), buf.String()).Scan(&eventID) + if err != nil { + return fmt.Errorf("failed to add event: %w", err) } - // update proofs - if err := updateElementProofs(tx, "siacoin_elements", update); err != nil { - return fmt.Errorf("failed to update siacoin element proofs: %w", err) - } else if err := updateElementProofs(tx, "siafund_elements", update); err != nil { - return fmt.Errorf("failed to update siafund element proofs: %w", err) + used := make(map[types.Address]bool) + for _, addr := range event.Relevant { + if used[addr] { + continue + } + + var addressID int64 + err = addrStmt.QueryRow(encode(addr), encode(types.ZeroCurrency), 0).Scan(&addressID) + if err != nil { + return fmt.Errorf("failed to get address: %w", err) + } + + _, err = relevantAddrStmt.Exec(eventID, addressID) + if err != nil { + return fmt.Errorf("failed to add relevant address: %w", err) + } + + used[addr] = true } } + return nil +} - lastTip := updates[len(updates)-1].State.Index - if err := updateLastIndexedTip(tx, lastTip); err != nil { - return fmt.Errorf("failed to update last indexed tip: %w", err) +// RevertEvents reverts the events that were added in the given block. +func (ut *updateTx) RevertEvents(blockID types.BlockID) error { + var id int64 + err := ut.tx.QueryRow(`DELETE FROM chain_indices WHERE block_id=$1 RETURNING id`, encode(blockID)).Scan(&id) + if errors.Is(err, sql.ErrNoRows) { + return nil } - return nil + return err } // ProcessChainApplyUpdate implements chain.Subscriber func (s *Store) ProcessChainApplyUpdate(cau *chain.ApplyUpdate, mayCommit bool) error { s.updates = append(s.updates, cau) - + log := s.log.Named("ProcessChainApplyUpdate").With(zap.Stringer("index", cau.State.Index)) + log.Debug("received update") if mayCommit { + log.Debug("committing updates", zap.Int("n", len(s.updates))) return s.transaction(func(tx *txn) error { - if err := applyChainUpdates(tx, s.updates, s.log.Named("apply")); err != nil { - return err + utx := &updateTx{ + tx: tx, + relevantAddresses: make(map[types.Address]bool), + } + + if err := wallet.ApplyChainUpdates(utx, s.updates); err != nil { + return fmt.Errorf("failed to apply updates: %w", err) + } else if err := setLastCommittedIndex(tx, cau.State.Index); err != nil { + return fmt.Errorf("failed to set last committed index: %w", err) } s.updates = nil return nil }) } + return nil } // ProcessChainRevertUpdate implements chain.Subscriber func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { - log := s.log.Named("revert") + log := s.log.Named("ProcessChainRevertUpdate").With(zap.Stringer("index", cru.State.Index)) // update hasn't been committed yet if len(s.updates) > 0 && s.updates[len(s.updates)-1].Block.ID() == cru.Block.ID() { + log.Debug("removed uncommitted update") s.updates = s.updates[:len(s.updates)-1] return nil } + log.Debug("reverting update") // update has been committed, revert it return s.transaction(func(tx *txn) error { - stmt, err := tx.Prepare(`SELECT id FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) - if err != nil { - return fmt.Errorf("failed to prepare statement: %w", err) - } - defer stmt.Close() - - // note: this would be more performant for small wallets to load all - // addresses into memory. However, for larger wallets (> 10K addresses), - // this is time consuming. Instead, the database is queried for each - // address. Monitor performance and consider changing this in the - // future. From a memory perspective, it would be fine to lazy load all - // addresses into memory. - relevantAddress := func(address types.Address) bool { - var dbID int64 - err := stmt.QueryRow(encode(address)).Scan(&dbID) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - panic(err) // database error - } - return err == nil + utx := &updateTx{ + tx: tx, + relevantAddresses: make(map[types.Address]bool), } - if err := applySiacoinElements(tx, cru.State.Index, cru, relevantAddress, log.Named("siacoins")); err != nil { - return fmt.Errorf("failed to apply siacoin elements: %w", err) - } else if err := applySiafundElements(tx, cru, relevantAddress, log.Named("siafunds")); err != nil { - return fmt.Errorf("failed to apply siafund elements: %w", err) - } - - // revert events - _, err = tx.Exec(`DELETE FROM chain_indices WHERE block_id=$1`, cru.Block.ID()) - if err != nil { - return fmt.Errorf("failed to delete chain index: %w", err) - } - - // revert immature balance - if err := updateImmatureBalance(tx, cru.State.Index, true); err != nil { - return fmt.Errorf("failed to update immature balance: %w", err) - } - - // update proofs - if err := updateElementProofs(tx, "siacoin_elements", cru); err != nil { - return fmt.Errorf("failed to update siacoin element proofs: %w", err) - } else if err := updateElementProofs(tx, "siafund_elements", cru); err != nil { - return fmt.Errorf("failed to update siafund element proofs: %w", err) + if err := wallet.RevertChainUpdate(utx, cru); err != nil { + return fmt.Errorf("failed to revert update: %w", err) + } else if err := setLastCommittedIndex(tx, cru.State.Index); err != nil { + return fmt.Errorf("failed to set last committed index: %w", err) } return nil }) @@ -536,3 +390,12 @@ func (s *Store) LastCommittedIndex() (index types.ChainIndex, err error) { err = s.db.QueryRow(`SELECT last_indexed_tip FROM global_settings`).Scan(decode(&index)) return } + +func setLastCommittedIndex(tx *txn, index types.ChainIndex) error { + _, err := tx.Exec(`UPDATE global_settings SET last_indexed_tip=$1`, encode(index)) + return err +} + +func insertAddressStatement(tx *txn) (*stmt, error) { + return tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $2, $3) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id`) +} diff --git a/persist/sqlite/consensus_test.go b/persist/sqlite/consensus_test.go new file mode 100644 index 0000000..16ff48e --- /dev/null +++ b/persist/sqlite/consensus_test.go @@ -0,0 +1,455 @@ +package sqlite_test + +import ( + "path/filepath" + "testing" + + "go.sia.tech/core/consensus" + "go.sia.tech/core/types" + "go.sia.tech/coreutils" + "go.sia.tech/coreutils/chain" + "go.sia.tech/walletd/persist/sqlite" + "go.sia.tech/walletd/wallet" + "go.uber.org/zap/zaptest" +) + +func testV1Network() (*consensus.Network, types.Block) { + // use a modified version of Zen + n, genesisBlock := chain.TestnetZen() + n.InitialTarget = types.BlockID{0xFF} + n.HardforkDevAddr.Height = 1 + n.HardforkTax.Height = 1 + n.HardforkStorageProof.Height = 1 + n.HardforkOak.Height = 1 + n.HardforkASIC.Height = 1 + n.HardforkFoundation.Height = 1 + n.HardforkV2.AllowHeight = 1000 + n.HardforkV2.RequireHeight = 1000 + return n, genesisBlock +} + +func testV2Network() (*consensus.Network, types.Block) { + // use a modified version of Zen + n, genesisBlock := chain.TestnetZen() + n.InitialTarget = types.BlockID{0xFF} + n.HardforkDevAddr.Height = 1 + n.HardforkTax.Height = 1 + n.HardforkStorageProof.Height = 1 + n.HardforkOak.Height = 1 + n.HardforkASIC.Height = 1 + n.HardforkFoundation.Height = 1 + n.HardforkV2.AllowHeight = 100 + n.HardforkV2.RequireHeight = 110 + return n, genesisBlock +} + +func mineBlock(state consensus.State, txns []types.Transaction, minerAddr types.Address) types.Block { + b := types.Block{ + ParentID: state.Index.ID, + Timestamp: types.CurrentTimestamp(), + Transactions: txns, + MinerPayouts: []types.SiacoinOutput{{Address: minerAddr, Value: state.BlockReward()}}, + } + for b.ID().CmpWork(state.ChildTarget) < 0 { + b.Nonce += state.NonceFactor() + } + return b +} + +func mineV2Block(state consensus.State, txns []types.V2Transaction, minerAddr types.Address) types.Block { + b := types.Block{ + ParentID: state.Index.ID, + Timestamp: types.CurrentTimestamp(), + MinerPayouts: []types.SiacoinOutput{{Address: minerAddr, Value: state.BlockReward()}}, + + V2: &types.V2BlockData{ + Transactions: txns, + Height: state.Index.Height + 1, + }, + } + b.V2.Commitment = state.Commitment(state.TransactionsCommitment(b.Transactions, b.V2Transactions()), b.MinerPayouts[0].Address) + for b.ID().CmpWork(state.ChildTarget) < 0 { + b.Nonce += state.NonceFactor() + } + return b +} + +func TestReorg(t *testing.T) { + log := zaptest.NewLogger(t) + dir := t.TempDir() + db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + bdb, err := coreutils.OpenBoltChainDB(filepath.Join(dir, "consensus.db")) + if err != nil { + t.Fatal(err) + } + defer bdb.Close() + + network, genesisBlock := testV1Network() + + store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + cm := chain.NewManager(store, genesisState) + + if err := cm.AddSubscriber(db, types.ChainIndex{}); err != nil { + t.Fatal(err) + } + + pk := types.GeneratePrivateKey() + addr := types.StandardUnlockHash(pk.PublicKey()) + + if err := db.AddWallet("test", nil); err != nil { + t.Fatal(err) + } else if err := db.AddAddress("test", addr, nil); err != nil { + t.Fatal(err) + } + + expectedPayout := cm.TipState().BlockReward() + // mine a block sending the payout to the wallet + if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil { + t.Fatal(err) + } + + // check that the payout was received + balance, err := db.AddressBalance(addr) + if err != nil { + t.Fatal(err) + } else if !balance.ImmatureSiacoins.Equals(expectedPayout) { + t.Fatalf("expected %v, got %v", expectedPayout, balance.ImmatureSiacoins) + } + + // check that a payout event was recorded + events, err := db.WalletEvents("test", 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 event, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { + t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) + } + + // mine to trigger a reorg + var blocks []types.Block + state := genesisState + for i := 0; i < 5; i++ { + blocks = append(blocks, mineBlock(state, nil, types.VoidAddress)) + state.Index.ID = blocks[len(blocks)-1].ID() + state.Index.Height = state.Index.Height + 1 + } + if err := cm.AddBlocks(blocks); err != nil { + t.Fatal(err) + } + + // check that the payout was reverted + balance, err = db.AddressBalance(addr) + if err != nil { + t.Fatal(err) + } else if !balance.ImmatureSiacoins.IsZero() { + t.Fatalf("expected 0, got %v", balance.ImmatureSiacoins) + } + + // check that the payout event was reverted + events, err = db.WalletEvents("test", 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 0 { + t.Fatalf("expected 0 events, got %v", len(events)) + } +} + +func TestEphemeralBalance(t *testing.T) { + log := zaptest.NewLogger(t) + dir := t.TempDir() + db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + bdb, err := coreutils.OpenBoltChainDB(filepath.Join(dir, "consensus.db")) + if err != nil { + t.Fatal(err) + } + defer bdb.Close() + + network, genesisBlock := testV1Network() + + store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + cm := chain.NewManager(store, genesisState) + + if err := cm.AddSubscriber(db, types.ChainIndex{}); err != nil { + t.Fatal(err) + } + + pk := types.GeneratePrivateKey() + addr := types.StandardUnlockHash(pk.PublicKey()) + + if err := db.AddWallet("test", nil); err != nil { + t.Fatal(err) + } else if err := db.AddAddress("test", addr, nil); err != nil { + t.Fatal(err) + } + + expectedPayout := cm.TipState().BlockReward() + maturityHeight := cm.TipState().MaturityHeight() + 1 + // mine a block sending the payout to the wallet + if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil { + t.Fatal(err) + } + + // check that the payout was received + balance, err := db.AddressBalance(addr) + if err != nil { + t.Fatal(err) + } else if !balance.ImmatureSiacoins.Equals(expectedPayout) { + t.Fatalf("expected %v, got %v", expectedPayout, balance.ImmatureSiacoins) + } + + // check that a payout event was recorded + events, err := db.WalletEvents("test", 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 event, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { + t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) + } + + // mine until the payout matures + for i := cm.TipState().Index.Height; i < maturityHeight; i++ { + if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, types.VoidAddress)}); err != nil { + t.Fatal(err) + } + } + + // create a transaction that spends the matured payout + utxos, err := db.UnspentSiacoinOutputs("test") + if err != nil { + t.Fatal(err) + } else if len(utxos) != 1 { + t.Fatalf("expected 1 output, got %v", len(utxos)) + } + + unlockConditions := types.StandardUnlockConditions(pk.PublicKey()) + parentTxn := types.Transaction{ + SiacoinInputs: []types.SiacoinInput{ + { + ParentID: types.SiacoinOutputID(utxos[0].ID), + UnlockConditions: unlockConditions, + }, + }, + SiacoinOutputs: []types.SiacoinOutput{ + {Address: addr, Value: types.Siacoins(100)}, + {Address: types.VoidAddress, Value: utxos[0].SiacoinOutput.Value.Sub(types.Siacoins(100))}, + }, + Signatures: []types.TransactionSignature{ + { + ParentID: utxos[0].ID, + PublicKeyIndex: 0, + CoveredFields: types.CoveredFields{WholeTransaction: true}, + }, + }, + } + parentSigHash := cm.TipState().WholeSigHash(parentTxn, utxos[0].ID, 0, 0, nil) + parentSig := pk.SignHash(parentSigHash) + parentTxn.Signatures[0].Signature = parentSig[:] + + outputID := parentTxn.SiacoinOutputID(0) + txn := types.Transaction{ + SiacoinInputs: []types.SiacoinInput{ + { + ParentID: outputID, + UnlockConditions: unlockConditions, + }, + }, + SiacoinOutputs: []types.SiacoinOutput{ + {Address: types.VoidAddress, Value: types.Siacoins(100)}, + }, + Signatures: []types.TransactionSignature{ + { + ParentID: types.Hash256(outputID), + PublicKeyIndex: 0, + CoveredFields: types.CoveredFields{WholeTransaction: true}, + }, + }, + } + sigHash := cm.TipState().WholeSigHash(txn, types.Hash256(outputID), 0, 0, nil) + sig := pk.SignHash(sigHash) + txn.Signatures[0].Signature = sig[:] + + txnset := []types.Transaction{parentTxn, txn} + + // broadcast the transactions + revertState := cm.TipState() + if err := cm.AddBlocks([]types.Block{mineBlock(revertState, txnset, types.VoidAddress)}); err != nil { + t.Fatal(err) + } + + // check that the payout was spent + balance, err = db.AddressBalance(addr) + if err != nil { + t.Fatal(err) + } else if !balance.Siacoins.IsZero() { + t.Fatalf("expected 0, got %v", balance.Siacoins) + } + + // trigger a reorg + var blocks []types.Block + state := revertState + for i := 0; i < 2; i++ { + blocks = append(blocks, mineBlock(state, nil, types.VoidAddress)) + state.Index.ID = blocks[len(blocks)-1].ID() + state.Index.Height = state.Index.Height + 1 + } + if err := cm.AddBlocks(blocks); err != nil { + t.Fatal(err) + } + + // check that the transaction was reverted + balance, err = db.AddressBalance(addr) + if err != nil { + t.Fatal(err) + } else if !balance.Siacoins.Equals(expectedPayout) { + t.Fatalf("expected %v, got %v", expectedPayout, balance.Siacoins) + } + + // check that only the payout event remains + events, err = db.WalletEvents("test", 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 events, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { + t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) + } +} + +func TestV2(t *testing.T) { + log := zaptest.NewLogger(t) + dir := t.TempDir() + db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + bdb, err := coreutils.OpenBoltChainDB(filepath.Join(dir, "consensus.db")) + if err != nil { + t.Fatal(err) + } + defer bdb.Close() + + network, genesisBlock := testV2Network() + + store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + cm := chain.NewManager(store, genesisState) + + if err := cm.AddSubscriber(db, types.ChainIndex{}); err != nil { + t.Fatal(err) + } + + pk := types.GeneratePrivateKey() + addr := types.StandardUnlockHash(pk.PublicKey()) + + if err := db.AddWallet("test", nil); err != nil { + t.Fatal(err) + } else if err := db.AddAddress("test", addr, nil); err != nil { + t.Fatal(err) + } + + expectedPayout := cm.TipState().BlockReward() + // mine a block sending the payout to the wallet + if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil { + t.Fatal(err) + } + + // check that the payout was received + balance, err := db.AddressBalance(addr) + if err != nil { + t.Fatal(err) + } else if !balance.ImmatureSiacoins.Equals(expectedPayout) { + t.Fatalf("expected %v, got %v", expectedPayout, balance.ImmatureSiacoins) + } + + // check that a payout event was recorded + events, err := db.WalletEvents("test", 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 event, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { + t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) + } + + // mine until the payout matures + maturityHeight := cm.TipState().MaturityHeight() + 1 + for i := cm.TipState().Index.Height; i < maturityHeight; i++ { + if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, types.VoidAddress)}); err != nil { + t.Fatal(err) + } + } + + // create a v2 transaction that spends the matured payout + utxos, err := db.UnspentSiacoinOutputs("test") + if err != nil { + t.Fatal(err) + } + + sce := utxos[0] + policy := types.PolicyTypeUnlockConditions(types.StandardUnlockConditions(pk.PublicKey())) + txn := types.V2Transaction{ + SiacoinInputs: []types.V2SiacoinInput{{ + Parent: sce, + SatisfiedPolicy: types.SatisfiedPolicy{ + Policy: types.SpendPolicy{Type: policy}, + }, + }}, + SiacoinOutputs: []types.SiacoinOutput{ + {Address: types.VoidAddress, Value: sce.SiacoinOutput.Value.Sub(types.Siacoins(100))}, + {Address: addr, Value: types.Siacoins(100)}, + }, + } + txn.SiacoinInputs[0].SatisfiedPolicy.Signatures = []types.Signature{pk.SignHash(cm.TipState().InputSigHash(txn))} + + if err := cm.AddBlocks([]types.Block{mineV2Block(cm.TipState(), []types.V2Transaction{txn}, types.VoidAddress)}); err != nil { + t.Fatal(err) + } + + // check that the change was received + balance, err = db.AddressBalance(addr) + if err != nil { + t.Fatal(err) + } else if !balance.Siacoins.Equals(types.Siacoins(100)) { + t.Fatalf("expected %v, got %v", expectedPayout, balance.ImmatureSiacoins) + } + + // check that a transaction event was recorded + events, err = db.WalletEvents("test", 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 2 { + t.Fatalf("expected 2 events, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeTransaction { + t.Fatalf("expected transaction event, got %v", events[0].Data.EventType()) + } else if events[0].Relevant[0] != addr { + t.Fatalf("expected address %v, got %v", addr, events[0].Relevant[0]) + } +} diff --git a/persist/sqlite/encoding.go b/persist/sqlite/encoding.go index 8f98230..966c6b0 100644 --- a/persist/sqlite/encoding.go +++ b/persist/sqlite/encoding.go @@ -14,9 +14,10 @@ import ( func encode(obj any) any { switch obj := obj.(type) { case types.Currency: + // Currency is encoded as two 64-bit big-endian integers for sorting buf := make([]byte, 16) - binary.LittleEndian.PutUint64(buf, obj.Lo) - binary.LittleEndian.PutUint64(buf[8:], obj.Hi) + binary.BigEndian.PutUint64(buf, obj.Hi) + binary.BigEndian.PutUint64(buf[8:], obj.Lo) return buf case types.EncoderTo: var buf bytes.Buffer @@ -52,8 +53,8 @@ func (d *decodable) Scan(src any) error { if len(src) != 16 { return fmt.Errorf("cannot scan %d bytes into Currency", len(src)) } - v.Lo = binary.LittleEndian.Uint64(src) - v.Hi = binary.LittleEndian.Uint64(src[8:]) + v.Hi = binary.BigEndian.Uint64(src) + v.Lo = binary.BigEndian.Uint64(src[8:]) case types.DecoderFrom: dec := types.NewBufDecoder(src) v.DecodeFrom(dec) diff --git a/persist/sqlite/init.go b/persist/sqlite/init.go index 24ae378..a29a00c 100644 --- a/persist/sqlite/init.go +++ b/persist/sqlite/init.go @@ -70,10 +70,6 @@ func (s *Store) upgradeDatabase(current, target int64) error { func (s *Store) init() error { // calculate the expected final database version target := int64(len(migrations) + 1) - // disable foreign key constraints during migration - if _, err := s.db.Exec("PRAGMA foreign_keys = OFF"); err != nil { - return fmt.Errorf("failed to disable foreign key constraints: %w", err) - } version := getDBVersion(s.db) switch { diff --git a/persist/sqlite/init.sql b/persist/sqlite/init.sql index d9d4cff..2366b64 100644 --- a/persist/sqlite/init.sql +++ b/persist/sqlite/init.sql @@ -47,6 +47,8 @@ CREATE INDEX wallet_addresses_address_id ON wallet_addresses (address_id); CREATE TABLE events ( id INTEGER PRIMARY KEY, + event_id BLOB NOT NULL, + maturity_height INTEGER NOT NULL, date_created INTEGER NOT NULL, index_id BLOB NOT NULL REFERENCES chain_indices (id) ON DELETE CASCADE, event_type TEXT NOT NULL, @@ -54,15 +56,12 @@ CREATE TABLE events ( ); CREATE TABLE event_addresses ( - id INTEGER PRIMARY KEY, event_id INTEGER NOT NULL REFERENCES events (id) ON DELETE CASCADE, address_id INTEGER NOT NULL REFERENCES sia_addresses (id), - block_height INTEGER NOT NULL, /* prevents extra join when querying for events */ - UNIQUE (event_id, address_id) + PRIMARY KEY (event_id, address_id) ); CREATE INDEX event_addresses_event_id_idx ON event_addresses (event_id); CREATE INDEX event_addresses_address_id_idx ON event_addresses (address_id); -CREATE INDEX event_addresses_event_id_address_id_block_height ON event_addresses(event_id, address_id, block_height DESC); CREATE TABLE syncer_peers ( peer_address TEXT PRIMARY KEY NOT NULL, diff --git a/persist/sqlite/wallet.go b/persist/sqlite/wallet.go index 88f1248..ace68de 100644 --- a/persist/sqlite/wallet.go +++ b/persist/sqlite/wallet.go @@ -19,58 +19,106 @@ RETURNING id` return } +func getWalletEvents(tx *txn, walletID string, offset, limit int) (events []wallet.Event, eventIDs []int64, err error) { + const query = `SELECT ev.id, ev.event_id, ev.maturity_height, ev.date_created, ci.height, ci.block_id, ev.event_type, ev.event_data + FROM events ev + INNER JOIN chain_indices ci ON (ev.index_id = ci.id) + WHERE ev.id IN (SELECT event_id FROM event_addresses WHERE address_id IN (SELECT address_id FROM wallet_addresses WHERE wallet_id=$1)) + ORDER BY ev.maturity_height DESC + LIMIT $2 OFFSET $3` + + rows, err := tx.Query(query, walletID, limit, offset) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + for rows.Next() { + var eventID int64 + var event wallet.Event + var eventType string + var eventBuf []byte + + err := rows.Scan(&eventID, decode(&event.ID), &event.MaturityHeight, decode(&event.Timestamp), &event.Index.Height, decode(&event.Index.ID), &eventType, &eventBuf) + if err != nil { + return nil, nil, fmt.Errorf("failed to scan event: %w", err) + } + + switch eventType { + case wallet.EventTypeTransaction: + var tx wallet.EventTransaction + if err = json.Unmarshal(eventBuf, &tx); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal transaction event: %w", err) + } + event.Data = &tx + case wallet.EventTypeContractPayout: + var m wallet.EventContractPayout + if err = json.Unmarshal(eventBuf, &m); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal missed file contract event: %w", err) + } + event.Data = &m + case wallet.EventTypeMinerPayout: + var m wallet.EventMinerPayout + if err = json.Unmarshal(eventBuf, &m); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal payout event: %w", err) + } + event.Data = &m + case wallet.EventTypeFoundationSubsidy: + var m wallet.EventFoundationSubsidy + if err = json.Unmarshal(eventBuf, &m); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal foundation subsidy event: %w", err) + } + event.Data = &m + default: + return nil, nil, fmt.Errorf("unknown event type: %s", eventType) + } + + events = append(events, event) + eventIDs = append(eventIDs, eventID) + } + return +} + +func (s *Store) getWalletEventRelevantAddresses(tx *txn, walletID string, eventIDs []int64) (map[int64][]types.Address, error) { + query := `SELECT ea.event_id, sa.sia_address +FROM event_addresses ea +INNER JOIN sia_addresses sa ON (ea.address_id = sa.id) +WHERE event_id IN (` + queryPlaceHolders(len(eventIDs)) + `) AND address_id IN (SELECT address_id FROM wallet_addresses WHERE wallet_id=?)` + + rows, err := tx.Query(query, append(queryArgs(eventIDs), walletID)...) + if err != nil { + return nil, err + } + defer rows.Close() + + relevantAddresses := make(map[int64][]types.Address) + for rows.Next() { + var eventID int64 + var address types.Address + if err := rows.Scan(&eventID, decode(&address)); err != nil { + return nil, fmt.Errorf("failed to scan relevant address: %w", err) + } + relevantAddresses[eventID] = append(relevantAddresses[eventID], address) + } + return relevantAddresses, nil +} + // WalletEvents returns the events relevant to a wallet, sorted by height descending. func (s *Store) WalletEvents(walletID string, offset, limit int) (events []wallet.Event, err error) { err = s.transaction(func(tx *txn) error { - const query = `SELECT ev.id, ev.date_created, ci.height, ci.block_id, ev.event_type, ev.event_data -FROM events ev -INNER JOIN chain_indices ci ON (ev.index_id = ci.id) -WHERE ev.id IN (SELECT event_id FROM event_addresses WHERE address_id IN (SELECT address_id FROM wallet_addresses WHERE wallet_id=$1)) -ORDER BY ci.height DESC, ev.id ASC -LIMIT $2 OFFSET $3` - - rows, err := tx.Query(query, walletID, limit, offset) + var dbIDs []int64 + events, dbIDs, err = getWalletEvents(tx, walletID, offset, limit) if err != nil { - return err + return fmt.Errorf("failed to get wallet events: %w", err) } - defer rows.Close() - - for rows.Next() { - var eventID int64 - var event wallet.Event - var eventType string - var eventBuf []byte - err := rows.Scan(&eventID, decode(&event.Timestamp), &event.Index.Height, decode(&event.Index.ID), &eventType, &eventBuf) - if err != nil { - return fmt.Errorf("failed to scan event: %w", err) - } - - switch eventType { - case wallet.EventTypeTransaction: - var tx wallet.EventTransaction - if err = json.Unmarshal(eventBuf, &tx); err != nil { - return fmt.Errorf("failed to unmarshal transaction event: %w", err) - } - event.Val = &tx - case wallet.EventTypeMissedFileContract: - var m wallet.EventMissedFileContract - if err = json.Unmarshal(eventBuf, &m); err != nil { - return fmt.Errorf("failed to unmarshal missed file contract event: %w", err) - } - event.Val = &m - case wallet.EventTypeMinerPayout: - var m wallet.EventMinerPayout - if err = json.Unmarshal(eventBuf, &m); err != nil { - return fmt.Errorf("failed to unmarshal payout event: %w", err) - } - event.Val = &m - default: - return fmt.Errorf("unknown event type: %s", eventType) - } + eventRelevantAddresses, err := s.getWalletEventRelevantAddresses(tx, walletID, dbIDs) + if err != nil { + return fmt.Errorf("failed to get relevant addresses: %w", err) + } - // event.Relevant = relevantAddresses[eventID] - events = append(events, event) + for i := range events { + events[i].Relevant = eventRelevantAddresses[dbIDs[i]] } return nil }) @@ -79,6 +127,9 @@ LIMIT $2 OFFSET $3` // AddWallet adds a wallet to the database. func (s *Store) AddWallet(name string, info json.RawMessage) error { + if info == nil { + info = json.RawMessage("{}") + } return s.transaction(func(tx *txn) error { const query = `INSERT INTO wallets (id, extra_data) VALUES ($1, $2)` @@ -126,6 +177,9 @@ func (s *Store) Wallets() (map[string]json.RawMessage, error) { // AddAddress adds an address to a wallet. func (s *Store) AddAddress(walletID string, address types.Address, info json.RawMessage) error { + if info == nil { + info = json.RawMessage("{}") + } return s.transaction(func(tx *txn) error { addressID, err := insertAddress(tx, address) if err != nil { @@ -230,7 +284,7 @@ func (s *Store) UnspentSiafundOutputs(walletID string) (siafunds []types.Siafund } // WalletBalance returns the total balance of a wallet. -func (s *Store) WalletBalance(walletID string) (sc, immatureSC types.Currency, sf uint64, err error) { +func (s *Store) WalletBalance(walletID string) (balance wallet.Balance, err error) { err = s.transaction(func(tx *txn) error { const query = `SELECT siacoin_balance, immature_siacoin_balance, siafund_balance FROM sia_addresses sa INNER JOIN wallet_addresses wa ON (sa.id = wa.address_id) @@ -246,12 +300,12 @@ func (s *Store) WalletBalance(walletID string) (sc, immatureSC types.Currency, s var addressISC types.Currency var addressSF uint64 - if err := rows.Scan(decode(&addressSC), decode(&addressISC), decode(&addressSF)); err != nil { + if err := rows.Scan(decode(&addressSC), decode(&addressISC), &addressSF); err != nil { return fmt.Errorf("failed to scan address balance: %w", err) } - sc = sc.Add(addressSC) - immatureSC = immatureSC.Add(addressISC) - sf += addressSF + balance.Siacoins = balance.Siacoins.Add(addressSC) + balance.ImmatureSiacoins = balance.ImmatureSiacoins.Add(addressISC) + balance.Siafunds += addressSF } return nil }) @@ -259,10 +313,10 @@ func (s *Store) WalletBalance(walletID string) (sc, immatureSC types.Currency, s } // AddressBalance returns the balance of a single address. -func (s *Store) AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) { +func (s *Store) AddressBalance(address types.Address) (balance wallet.Balance, err error) { err = s.transaction(func(tx *txn) error { - const query = `SELECT siacoin_balance, siafund_balance FROM address_balance WHERE sia_address=$1` - return tx.QueryRow(query, encode(address)).Scan(decode(&sc), &sf) + const query = `SELECT siacoin_balance, immature_siacoin_balance, siafund_balance FROM sia_addresses WHERE sia_address=$1` + return tx.QueryRow(query, encode(address)).Scan(decode(&balance.Siacoins), decode(&balance.ImmatureSiacoins), &balance.Siafunds) }) return } diff --git a/wallet/manager.go b/wallet/manager.go index 74039ed..8ecb5b3 100644 --- a/wallet/manager.go +++ b/wallet/manager.go @@ -36,9 +36,9 @@ type ( UnspentSiacoinOutputs(walletID string) ([]types.SiacoinElement, error) UnspentSiafundOutputs(walletID string) ([]types.SiafundElement, error) Annotate(walletID string, txns []types.Transaction) ([]PoolTransaction, error) - WalletBalance(walletID string) (sc, immature types.Currency, sf uint64, err error) + WalletBalance(walletID string) (Balance, error) - AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) + AddressBalance(address types.Address) (Balance, error) LastCommittedIndex() (types.ChainIndex, error) } @@ -105,12 +105,12 @@ func (m *Manager) Annotate(name string, pool []types.Transaction) ([]PoolTransac } // WalletBalance returns the balance of the given wallet. -func (m *Manager) WalletBalance(walletID string) (sc, immature types.Currency, sf uint64, err error) { +func (m *Manager) WalletBalance(walletID string) (Balance, error) { return m.store.WalletBalance(walletID) } // AddressBalance returns the balance of the given address. -func (m *Manager) AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) { +func (m *Manager) AddressBalance(address types.Address) (Balance, error) { return m.store.AddressBalance(address) } diff --git a/wallet/update.go b/wallet/update.go new file mode 100644 index 0000000..4784c79 --- /dev/null +++ b/wallet/update.go @@ -0,0 +1,452 @@ +package wallet + +import ( + "fmt" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" +) + +type ( + // AddressBalance pairs an address with its balance. + AddressBalance struct { + Address types.Address `json:"address"` + Balance + } + + // An ApplyTx atomically applies a set of updates to a store. + ApplyTx interface { + SiacoinStateElements() ([]types.StateElement, error) + UpdateSiacoinStateElements([]types.StateElement) error + + SiafundStateElements() ([]types.StateElement, error) + UpdateSiafundStateElements([]types.StateElement) error + + AddressRelevant(types.Address) (bool, error) + AddressBalance(types.Address) (Balance, error) + UpdateBalances([]AddressBalance) error + + MaturedSiacoinElements(types.ChainIndex) ([]types.SiacoinElement, error) + AddSiacoinElements([]types.SiacoinElement) error + RemoveSiacoinElements([]types.SiacoinOutputID) error + + AddSiafundElements([]types.SiafundElement) error + RemoveSiafundElements([]types.SiafundOutputID) error + + AddEvents([]Event) error + } + + // RevertTx atomically reverts an update from a store. + RevertTx interface { + RevertEvents(types.BlockID) error + + SiacoinStateElements() ([]types.StateElement, error) + UpdateSiacoinStateElements([]types.StateElement) error + + SiafundStateElements() ([]types.StateElement, error) + UpdateSiafundStateElements([]types.StateElement) error + + AddressRelevant(types.Address) (bool, error) + AddressBalance(types.Address) (Balance, error) + UpdateBalances([]AddressBalance) error + + MaturedSiacoinElements(types.ChainIndex) ([]types.SiacoinElement, error) + AddSiacoinElements([]types.SiacoinElement) error + RemoveSiacoinElements([]types.SiacoinOutputID) error + } +) + +// ApplyChainUpdates atomically applies a set of chain updates to a store +func ApplyChainUpdates(tx ApplyTx, updates []*chain.ApplyUpdate) error { + var events []Event + balances := make(map[types.Address]Balance) + newSiacoinElements := make(map[types.SiacoinOutputID]types.SiacoinElement) + newSiafundElements := make(map[types.SiafundOutputID]types.SiafundElement) + spentSiacoinElements := make(map[types.SiacoinOutputID]bool) + spentSiafundElements := make(map[types.SiafundOutputID]bool) + + updateBalance := func(addr types.Address, fn func(b *Balance)) error { + balance, ok := balances[addr] + if !ok { + var err error + balance, err = tx.AddressBalance(addr) + if err != nil { + return fmt.Errorf("failed to get address balance: %w", err) + } + } + + fn(&balance) + balances[addr] = balance + return nil + } + + // fetch all siacoin and siafund state elements + siacoinStateElements, err := tx.SiacoinStateElements() + if err != nil { + return fmt.Errorf("failed to get siacoin state elements: %w", err) + } + siafundStateElements, err := tx.SiafundStateElements() + if err != nil { + return fmt.Errorf("failed to get siafund state elements: %w", err) + } + + for _, cau := range updates { + // update the immature balance of each relevant address + matured, err := tx.MaturedSiacoinElements(cau.State.Index) + if err != nil { + return fmt.Errorf("failed to get matured siacoin elements: %w", err) + } + for _, se := range matured { + err := updateBalance(se.SiacoinOutput.Address, func(b *Balance) { + b.ImmatureSiacoins = b.ImmatureSiacoins.Sub(se.SiacoinOutput.Value) + b.Siacoins = b.Siacoins.Add(se.SiacoinOutput.Value) + }) + if err != nil { + return fmt.Errorf("failed to update address balance: %w", err) + } + } + + // determine which siacoin and siafund elements are ephemeral + // + // note: I thought we could use LeafIndex == EphemeralLeafIndex, but + // it seems to be set before the subscriber is called. + created := make(map[types.Hash256]bool) + ephemeral := make(map[types.Hash256]bool) + for _, txn := range cau.Block.Transactions { + for i := range txn.SiacoinOutputs { + created[types.Hash256(txn.SiacoinOutputID(i))] = true + } + for _, input := range txn.SiacoinInputs { + ephemeral[types.Hash256(input.ParentID)] = created[types.Hash256(input.ParentID)] + } + for i := range txn.SiafundOutputs { + created[types.Hash256(txn.SiafundOutputID(i))] = true + } + for _, input := range txn.SiafundInputs { + ephemeral[types.Hash256(input.ParentID)] = created[types.Hash256(input.ParentID)] + } + } + + // add new siacoin elements to the store + var siacoinElementErr error + cau.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { + if siacoinElementErr != nil { + return + } else if ephemeral[se.ID] { + return + } + + relevant, err := tx.AddressRelevant(se.SiacoinOutput.Address) + if err != nil { + siacoinElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) + return + } else if !relevant { + return + } + + if spent { + delete(newSiacoinElements, types.SiacoinOutputID(se.ID)) + spentSiacoinElements[types.SiacoinOutputID(se.ID)] = true + } else { + newSiacoinElements[types.SiacoinOutputID(se.ID)] = se + } + + err = updateBalance(se.SiacoinOutput.Address, func(b *Balance) { + switch { + case se.MaturityHeight > cau.State.Index.Height: + b.ImmatureSiacoins = b.ImmatureSiacoins.Add(se.SiacoinOutput.Value) + case spent: + b.Siacoins = b.Siacoins.Sub(se.SiacoinOutput.Value) + default: + b.Siacoins = b.Siacoins.Add(se.SiacoinOutput.Value) + } + }) + if err != nil { + siacoinElementErr = fmt.Errorf("failed to update address balance: %w", err) + return + } + }) + if siacoinElementErr != nil { + return fmt.Errorf("failed to add siacoin elements: %w", siacoinElementErr) + } + + var siafundElementErr error + cau.ForEachSiafundElement(func(se types.SiafundElement, spent bool) { + if siafundElementErr != nil { + return + } else if ephemeral[se.ID] { + return + } + + relevant, err := tx.AddressRelevant(se.SiafundOutput.Address) + if err != nil { + siafundElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) + return + } else if !relevant { + return + } + + if spent { + delete(newSiafundElements, types.SiafundOutputID(se.ID)) + spentSiafundElements[types.SiafundOutputID(se.ID)] = true + } else { + newSiafundElements[types.SiafundOutputID(se.ID)] = se + } + + err = updateBalance(se.SiafundOutput.Address, func(b *Balance) { + if spent { + if b.Siafunds < se.SiafundOutput.Value { + panic(fmt.Errorf("negative siafund balance")) + } + b.Siafunds -= se.SiafundOutput.Value + } else { + b.Siafunds += se.SiafundOutput.Value + } + }) + if err != nil { + siafundElementErr = fmt.Errorf("failed to update address balance: %w", err) + return + } + }) + + // add events + relevant := func(addr types.Address) bool { + relevant, err := tx.AddressRelevant(addr) + if err != nil { + panic(fmt.Errorf("failed to check if address is relevant: %w", err)) + } + return relevant + } + if err != nil { + return fmt.Errorf("failed to get applied events: %w", err) + } + events = append(events, AppliedEvents(cau.State, cau.Block, cau, relevant)...) + + // update siacoin element proofs + for id := range newSiacoinElements { + ele := newSiacoinElements[id] + cau.UpdateElementProof(&ele.StateElement) + newSiacoinElements[id] = ele + } + for i := range siacoinStateElements { + cau.UpdateElementProof(&siacoinStateElements[i]) + } + + // update siafund element proofs + for id := range newSiafundElements { + ele := newSiafundElements[id] + cau.UpdateElementProof(&ele.StateElement) + newSiafundElements[id] = ele + } + for i := range siafundStateElements { + cau.UpdateElementProof(&siafundStateElements[i]) + } + } + + // update the address balances + balanceChanges := make([]AddressBalance, 0, len(balances)) + for addr, balance := range balances { + balanceChanges = append(balanceChanges, AddressBalance{ + Address: addr, + Balance: balance, + }) + } + if err = tx.UpdateBalances(balanceChanges); err != nil { + return fmt.Errorf("failed to update address balance: %w", err) + } + + // add the new siacoin elements + siacoinElements := make([]types.SiacoinElement, 0, len(newSiacoinElements)) + for _, ele := range newSiacoinElements { + siacoinElements = append(siacoinElements, ele) + } + if err = tx.AddSiacoinElements(siacoinElements); err != nil { + return fmt.Errorf("failed to add siacoin elements: %w", err) + } + + // remove the spent siacoin elements + siacoinOutputIDs := make([]types.SiacoinOutputID, 0, len(spentSiacoinElements)) + for id := range spentSiacoinElements { + siacoinOutputIDs = append(siacoinOutputIDs, id) + } + if err = tx.RemoveSiacoinElements(siacoinOutputIDs); err != nil { + return fmt.Errorf("failed to remove siacoin elements: %w", err) + } + + // add the new siafund elements + siafundElements := make([]types.SiafundElement, 0, len(newSiafundElements)) + for _, ele := range newSiafundElements { + siafundElements = append(siafundElements, ele) + } + if err = tx.AddSiafundElements(siafundElements); err != nil { + return fmt.Errorf("failed to add siafund elements: %w", err) + } + + // remove the spent siafund elements + siafundOutputIDs := make([]types.SiafundOutputID, 0, len(spentSiafundElements)) + for id := range spentSiafundElements { + siafundOutputIDs = append(siafundOutputIDs, id) + } + if err = tx.RemoveSiafundElements(siafundOutputIDs); err != nil { + return fmt.Errorf("failed to remove siafund elements: %w", err) + } + + // add new events + if err = tx.AddEvents(events); err != nil { + return fmt.Errorf("failed to add events: %w", err) + } + + // update the siacoin state elements + filteredStateElements := siacoinStateElements[:0] + for _, se := range siacoinStateElements { + if _, ok := spentSiacoinElements[types.SiacoinOutputID(se.ID)]; !ok { + filteredStateElements = append(filteredStateElements, se) + } + } + err = tx.UpdateSiacoinStateElements(filteredStateElements) + if err != nil { + return fmt.Errorf("failed to update siacoin state elements: %w", err) + } + + // update the siafund state elements + filteredStateElements = siafundStateElements[:0] + for _, se := range siafundStateElements { + if _, ok := spentSiafundElements[types.SiafundOutputID(se.ID)]; !ok { + filteredStateElements = append(filteredStateElements, se) + } + } + if err = tx.UpdateSiafundStateElements(filteredStateElements); err != nil { + return fmt.Errorf("failed to update siafund state elements: %w", err) + } + + return nil +} + +// RevertChainUpdate atomically reverts a chain update from a store +func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { + balances := make(map[types.Address]Balance) + newSiacoinElements := make(map[types.SiacoinOutputID]types.SiacoinElement) + newSiafundElements := make(map[types.SiafundOutputID]types.SiafundElement) + spentSiacoinElements := make(map[types.SiacoinOutputID]bool) + spentSiafundElements := make(map[types.SiafundOutputID]bool) + + updateBalance := func(addr types.Address, fn func(b *Balance)) error { + balance, ok := balances[addr] + if !ok { + var err error + balance, err = tx.AddressBalance(addr) + if err != nil { + return fmt.Errorf("failed to get address balance: %w", err) + } + } + + fn(&balance) + balances[addr] = balance + return nil + } + + // determine which siacoin and siafund elements are ephemeral + // + // note: I thought we could use LeafIndex == EphemeralLeafIndex, but + // it seems to be set before the subscriber is called. + created := make(map[types.Hash256]bool) + ephemeral := make(map[types.Hash256]bool) + for _, txn := range cru.Block.Transactions { + for i := range txn.SiacoinOutputs { + created[types.Hash256(txn.SiacoinOutputID(i))] = true + } + for _, input := range txn.SiacoinInputs { + ephemeral[types.Hash256(input.ParentID)] = created[types.Hash256(input.ParentID)] + } + for i := range txn.SiafundOutputs { + created[types.Hash256(txn.SiafundOutputID(i))] = true + } + for _, input := range txn.SiafundInputs { + ephemeral[types.Hash256(input.ParentID)] = created[types.Hash256(input.ParentID)] + } + } + + var siacoinElementErr error + cru.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { + if siacoinElementErr != nil { + return + } + + relevant, err := tx.AddressRelevant(se.SiacoinOutput.Address) + if err != nil { + siacoinElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) + return + } else if !relevant { + return + } else if ephemeral[se.ID] { + return + } + + if !spent { + newSiacoinElements[types.SiacoinOutputID(se.ID)] = se + } else { + spentSiacoinElements[types.SiacoinOutputID(se.ID)] = true + } + + siacoinElementErr = updateBalance(se.SiacoinOutput.Address, func(b *Balance) { + switch { + case se.MaturityHeight > cru.State.Index.Height: + b.ImmatureSiacoins = b.ImmatureSiacoins.Sub(se.SiacoinOutput.Value) + case spent: + b.Siacoins = b.Siacoins.Add(se.SiacoinOutput.Value) + default: + b.Siacoins = b.Siacoins.Sub(se.SiacoinOutput.Value) + } + }) + }) + if siacoinElementErr != nil { + return fmt.Errorf("failed to update address balance: %w", siacoinElementErr) + } + + var siafundElementErr error + cru.ForEachSiafundElement(func(se types.SiafundElement, spent bool) { + if siafundElementErr != nil { + return + } + + relevant, err := tx.AddressRelevant(se.SiafundOutput.Address) + if err != nil { + siacoinElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) + return + } else if !relevant { + return + } else if ephemeral[se.ID] { + return + } + + if !spent { + newSiafundElements[types.SiafundOutputID(se.ID)] = se + } else { + spentSiafundElements[types.SiafundOutputID(se.ID)] = true + } + + siafundElementErr = updateBalance(se.SiafundOutput.Address, func(b *Balance) { + if spent { + b.Siafunds -= se.SiafundOutput.Value + } else { + b.Siafunds += se.SiafundOutput.Value + } + }) + }) + if siafundElementErr != nil { + return fmt.Errorf("failed to update address balance: %w", siafundElementErr) + } + + balanceChanges := make([]AddressBalance, 0, len(balances)) + for addr, balance := range balances { + balanceChanges = append(balanceChanges, AddressBalance{ + Address: addr, + Balance: balance, + }) + } + if err := tx.UpdateBalances(balanceChanges); err != nil { + return fmt.Errorf("failed to update address balance: %w", err) + } + + return tx.RevertEvents(cru.Block.ID()) +} diff --git a/wallet/wallet.go b/wallet/wallet.go index 7a806b1..2e1dc95 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -11,9 +11,20 @@ import ( // event type constants const ( - EventTypeTransaction = "transaction" - EventTypeMinerPayout = "miner payout" - EventTypeMissedFileContract = "missed file contract" + EventTypeTransaction = "transaction" + EventTypeMinerPayout = "miner payout" + EventTypeContractPayout = "contract payout" + EventTypeSiafundClaim = "siafund claim" + EventTypeFoundationSubsidy = "foundation subsidy" +) + +type ( + // Balance is a summary of a siacoin and siafund balance + Balance struct { + Siacoins types.Currency `json:"siacoins"` + ImmatureSiacoins types.Currency `json:"immatureSiacoins"` + Siafunds uint64 `json:"siafunds"` + } ) // StandardTransactionSignature is the most common form of TransactionSignature. @@ -145,12 +156,18 @@ func Annotate(txn types.Transaction, ownsAddress func(types.Address) bool) PoolT return ptxn } +type eventData interface { + EventType() string +} + // An Event is something interesting that happened on the Sia blockchain. type Event struct { - Index types.ChainIndex - Timestamp time.Time - Relevant []types.Address - Val interface{ EventType() string } + ID types.Hash256 `json:"id"` + Index types.ChainIndex `json:"index"` + Timestamp time.Time `json:"timestamp"` + MaturityHeight uint64 `json:"maturityHeight"` + Relevant []types.Address `json:"relevant"` + Data eventData `json:"data"` } // EventType implements Event. @@ -160,11 +177,14 @@ func (*EventTransaction) EventType() string { return EventTypeTransaction } func (*EventMinerPayout) EventType() string { return EventTypeMinerPayout } // EventType implements Event. -func (*EventMissedFileContract) EventType() string { return EventTypeMissedFileContract } +func (*EventFoundationSubsidy) EventType() string { return EventTypeFoundationSubsidy } + +// EventType implements Event. +func (*EventContractPayout) EventType() string { return EventTypeContractPayout } // MarshalJSON implements json.Marshaler. func (e Event) MarshalJSON() ([]byte, error) { - val, _ := json.Marshal(e.Val) + val, _ := json.Marshal(e.Data) return json.Marshal(struct { Timestamp time.Time `json:"timestamp"` Index types.ChainIndex `json:"index"` @@ -175,7 +195,7 @@ func (e Event) MarshalJSON() ([]byte, error) { Timestamp: e.Timestamp, Index: e.Index, Relevant: e.Relevant, - Type: e.Val.EventType(), + Type: e.Data.EventType(), Val: val, }) } @@ -197,16 +217,16 @@ func (e *Event) UnmarshalJSON(data []byte) error { e.Relevant = s.Relevant switch s.Type { case (*EventTransaction)(nil).EventType(): - e.Val = new(EventTransaction) + e.Data = new(EventTransaction) case (*EventMinerPayout)(nil).EventType(): - e.Val = new(EventMinerPayout) - case (*EventMissedFileContract)(nil).EventType(): - e.Val = new(EventMissedFileContract) + e.Data = new(EventMinerPayout) + case (*EventContractPayout)(nil).EventType(): + e.Data = new(EventContractPayout) } - if e.Val == nil { + if e.Data == nil { return fmt.Errorf("unknown event type %q", s.Type) } - return json.Unmarshal(s.Val, e.Val) + return json.Unmarshal(s.Val, e.Data) } // A HostAnnouncement represents a host announcement within an EventTransaction. @@ -242,7 +262,6 @@ type V2FileContract struct { // An EventTransaction represents a transaction that affects the wallet. type EventTransaction struct { - ID types.TransactionID `json:"id"` SiacoinInputs []types.SiacoinElement `json:"siacoinInputs"` SiacoinOutputs []types.SiacoinElement `json:"siacoinOutputs"` SiafundInputs []SiafundInput `json:"siafundInputs"` @@ -258,11 +277,16 @@ type EventMinerPayout struct { SiacoinOutput types.SiacoinElement `json:"siacoinOutput"` } -// An EventMissedFileContract represents a file contract that has expired -// without a storage proof -type EventMissedFileContract struct { +// EventFoundationSubsidy represents a foundation subsidy from a block. +type EventFoundationSubsidy struct { + SiacoinOutput types.SiacoinElement `json:"siacoinOutput"` +} + +// An EventContractPayout represents a file contract payout +type EventContractPayout struct { FileContract types.FileContractElement `json:"fileContract"` - MissedOutputs []types.SiacoinElement `json:"missedOutputs"` + SiacoinOutput types.SiacoinElement `json:"siacoinOutput"` + Missed bool `json:"missed"` } // A ChainUpdate is a set of changes to the consensus state. @@ -276,7 +300,7 @@ type ChainUpdate interface { // AppliedEvents extracts a list of relevant events from a chain update. func AppliedEvents(cs consensus.State, b types.Block, cu ChainUpdate, relevant func(types.Address) bool) []Event { var events []Event - addEvent := func(v interface{ EventType() string }, relevant []types.Address) { + addEvent := func(id types.Hash256, maturityHeight uint64, v eventData, relevant []types.Address) { // dedup relevant addresses seen := make(map[types.Address]bool) unique := relevant[:0] @@ -288,10 +312,11 @@ func AppliedEvents(cs consensus.State, b types.Block, cu ChainUpdate, relevant f } events = append(events, Event{ - Timestamp: b.Timestamp, - Index: cs.Index, - Relevant: unique, - Val: v, + Timestamp: b.Timestamp, + Index: cs.Index, + MaturityHeight: maturityHeight, + Relevant: unique, + Data: v, }) } @@ -461,7 +486,6 @@ func AppliedEvents(cs consensus.State, b types.Block, cu ChainUpdate, relevant f } e := &EventTransaction{ - ID: txn.ID(), SiacoinInputs: make([]types.SiacoinElement, len(txn.SiacoinInputs)), SiacoinOutputs: make([]types.SiacoinElement, len(txn.SiacoinOutputs)), SiafundInputs: make([]SiafundInput, len(txn.SiafundInputs)), @@ -526,7 +550,7 @@ func AppliedEvents(cs consensus.State, b types.Block, cu ChainUpdate, relevant f e.Fee = e.Fee.Add(txn.MinerFees[i]) } - addEvent(e, relevant) + addEvent(types.Hash256(txn.ID()), cs.Index.Height, e, relevant) // transaction maturity height is the current block height } // handle v2 transactions @@ -538,7 +562,6 @@ func AppliedEvents(cs consensus.State, b types.Block, cu ChainUpdate, relevant f txid := txn.ID() e := &EventTransaction{ - ID: txid, SiacoinInputs: make([]types.SiacoinElement, len(txn.SiacoinInputs)), SiacoinOutputs: make([]types.SiacoinElement, len(txn.SiacoinOutputs)), SiafundInputs: make([]SiafundInput, len(txn.SiafundInputs)), @@ -597,35 +620,61 @@ func AppliedEvents(cs consensus.State, b types.Block, cu ChainUpdate, relevant f } e.Fee = txn.MinerFee - addEvent(e, relevant) + addEvent(types.Hash256(txid), cs.Index.Height, e, relevant) // transaction maturity height is the current block height } // handle missed contracts cu.ForEachFileContractElement(func(fce types.FileContractElement, rev *types.FileContractElement, resolved, valid bool) { - if resolved && !valid { - relevant := relevantContract(fce.FileContract) - if len(relevant) == 0 { - return + if !resolved { + return + } + + relevant := relevantContract(fce.FileContract) + if len(relevant) == 0 { + return + } + + if valid { + for i := range fce.FileContract.ValidProofOutputs { + outputID := types.FileContractID(fce.ID).ValidOutputID(i) + addEvent(types.Hash256(outputID), cs.MaturityHeight(), &EventContractPayout{ + FileContract: fce, + SiacoinOutput: sces[outputID], + Missed: false, + }, relevant) } - missedOutputs := make([]types.SiacoinElement, len(fce.FileContract.MissedProofOutputs)) - for i := range missedOutputs { - missedOutputs[i] = sces[types.FileContractID(fce.ID).MissedOutputID(i)] + } else { + for i := range fce.FileContract.MissedProofOutputs { + outputID := types.FileContractID(fce.ID).MissedOutputID(i) + addEvent(types.Hash256(outputID), cs.MaturityHeight(), &EventContractPayout{ + FileContract: fce, + SiacoinOutput: sces[outputID], + Missed: true, + }, relevant) } - addEvent(&EventMissedFileContract{ - FileContract: fce, - MissedOutputs: missedOutputs, - }, relevant) } }) // handle block rewards for i := range b.MinerPayouts { if relevant(b.MinerPayouts[i].Address) { - addEvent(&EventMinerPayout{ - SiacoinOutput: sces[cs.Index.ID.MinerOutputID(i)], + outputID := cs.Index.ID.MinerOutputID(i) + addEvent(types.Hash256(outputID), cs.MaturityHeight(), &EventMinerPayout{ + SiacoinOutput: sces[outputID], }, []types.Address{b.MinerPayouts[i].Address}) } } + // handle foundation subsidy + if relevant(cs.FoundationPrimaryAddress) { + outputID := cs.Index.ID.FoundationOutputID() + sce, ok := sces[outputID] + if ok { + addEvent(types.Hash256(outputID), cs.MaturityHeight(), &EventFoundationSubsidy{ + SiacoinOutput: sce, + }, []types.Address{cs.FoundationPrimaryAddress}) + } + } + return events }