From 3a7e79f16ddf2034e4b500b021b736e63a9a042d Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Wed, 21 Feb 2024 11:11:07 -0800 Subject: [PATCH] sqlite,wallet: consolidate update tx, fix reorg balance, fix element revert --- persist/sqlite/consensus.go | 28 +++++---- persist/sqlite/init.sql | 3 +- wallet/update.go | 118 +++++++++++++++++++++++++----------- 3 files changed, 102 insertions(+), 47 deletions(-) diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index a16246b..0f3e67b 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -170,20 +170,20 @@ func (ut *updateTx) AddSiacoinElements(elements []types.SiacoinElement) error { } defer addrStmt.Close() - inserStmt, err := ut.tx.Prepare(`INSERT INTO siacoin_elements (id, siacoin_value, merkle_proof, leaf_index, maturity_height, address_id) VALUES ($1, $2, $3, $4, $5, $6)`) + insertStmt, err := ut.tx.Prepare(`INSERT INTO siacoin_elements (id, siacoin_value, merkle_proof, leaf_index, maturity_height, address_id) VALUES ($1, $2, $3, $4, $5, $6)`) if err != nil { return fmt.Errorf("failed to prepare insert statement: %w", err) } - defer inserStmt.Close() + defer insertStmt.Close() for _, se := range elements { var addressID int64 - err := addrStmt.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID) + err = addrStmt.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID) if err != nil { return fmt.Errorf("failed to query address: %w", err) } - _, err = inserStmt.Exec(encode(se.ID), encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight, addressID) + _, err = insertStmt.Exec(encode(se.ID), encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight, addressID) if err != nil { return fmt.Errorf("failed to execute statement: %w", err) } @@ -215,20 +215,20 @@ func (ut *updateTx) AddSiafundElements(elements []types.SiafundElement) error { } defer addrStmt.Close() - inserStmt, err := ut.tx.Prepare(`INSERT INTO siafund_elements (id, siafund_value, merkle_proof, leaf_index, claim_start, address_id) VALUES ($1, $2, $3, $4, $5, $6)`) + insertStmt, err := ut.tx.Prepare(`INSERT INTO siafund_elements (id, siafund_value, merkle_proof, leaf_index, claim_start, address_id) VALUES ($1, $2, $3, $4, $5, $6)`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } - defer inserStmt.Close() + defer insertStmt.Close() for _, se := range elements { var addressID int64 - err := addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID) + err = addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID) if err != nil { return fmt.Errorf("failed to query address: %w", err) } - _, err = inserStmt.Exec(encode(se.ID), se.SiafundOutput.Value, encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ClaimStart), addressID) + _, err = insertStmt.Exec(encode(se.ID), se.SiafundOutput.Value, encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ClaimStart), addressID) if err != nil { return fmt.Errorf("failed to execute statement: %w", err) } @@ -254,7 +254,7 @@ func (ut *updateTx) RemoveSiafundElements(elements []types.SiafundOutputID) erro } func (ut *updateTx) AddEvents(events []wallet.Event) error { - indexStmt, err := ut.tx.Prepare(`INSERT INTO chain_indices (height, block_id) VALUES ($1, $2) ON CONFLICT (block_id) DO UPDATE SET height=EXCLUDED.height RETURNING id`) + indexStmt, err := insertIndexStmt(ut.tx) if err != nil { return fmt.Errorf("failed to prepare index statement: %w", err) } @@ -321,10 +321,10 @@ func (ut *updateTx) AddEvents(events []wallet.Event) error { return nil } -// RevertEvents reverts the events that were added in the given block. -func (ut *updateTx) RevertEvents(blockID types.BlockID) error { +// RevertEvents reverts any events that were added by the index +func (ut *updateTx) RevertEvents(index types.ChainIndex) error { var id int64 - err := ut.tx.QueryRow(`DELETE FROM chain_indices WHERE block_id=$1 RETURNING id`, encode(blockID)).Scan(&id) + err := ut.tx.QueryRow(`DELETE FROM chain_indices WHERE block_id=$1 AND height=$2 RETURNING id`, encode(index.ID), index.Height).Scan(&id) if errors.Is(err, sql.ErrNoRows) { return nil } @@ -399,3 +399,7 @@ func setLastCommittedIndex(tx *txn, index types.ChainIndex) error { func insertAddressStatement(tx *txn) (*stmt, error) { return tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $2, $3) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id`) } + +func insertIndexStmt(tx *txn) (*stmt, error) { + return tx.Prepare(`INSERT INTO chain_indices (height, block_id) VALUES ($1, $2) ON CONFLICT (block_id) DO UPDATE SET height=EXCLUDED.height RETURNING id`) +} diff --git a/persist/sqlite/init.sql b/persist/sqlite/init.sql index 2366b64..b7c1767 100644 --- a/persist/sqlite/init.sql +++ b/persist/sqlite/init.sql @@ -21,6 +21,7 @@ CREATE TABLE siacoin_elements ( address_id INTEGER NOT NULL REFERENCES sia_addresses (id) ); CREATE INDEX siacoin_elements_address_id ON siacoin_elements (address_id); +CREATE INDEX siacoin_elements_maturity_height ON siacoin_elements (maturity_height); CREATE TABLE siafund_elements ( id BLOB PRIMARY KEY, @@ -48,9 +49,9 @@ CREATE INDEX wallet_addresses_address_id ON wallet_addresses (address_id); CREATE TABLE events ( id INTEGER PRIMARY KEY, event_id BLOB NOT NULL, + index_id BLOB NOT NULL REFERENCES chain_indices (id) ON DELETE CASCADE, maturity_height INTEGER NOT NULL, date_created INTEGER NOT NULL, - index_id BLOB NOT NULL REFERENCES chain_indices (id) ON DELETE CASCADE, event_type TEXT NOT NULL, event_data TEXT NOT NULL ); diff --git a/wallet/update.go b/wallet/update.go index 4784c79..a48e984 100644 --- a/wallet/update.go +++ b/wallet/update.go @@ -14,45 +14,38 @@ type ( Balance } - // An ApplyTx atomically applies a set of updates to a store. - ApplyTx interface { + UpdateTx interface { SiacoinStateElements() ([]types.StateElement, error) UpdateSiacoinStateElements([]types.StateElement) error SiafundStateElements() ([]types.StateElement, error) UpdateSiafundStateElements([]types.StateElement) error - AddressRelevant(types.Address) (bool, error) - AddressBalance(types.Address) (Balance, error) - UpdateBalances([]AddressBalance) error - - MaturedSiacoinElements(types.ChainIndex) ([]types.SiacoinElement, error) AddSiacoinElements([]types.SiacoinElement) error RemoveSiacoinElements([]types.SiacoinOutputID) error AddSiafundElements([]types.SiafundElement) error RemoveSiafundElements([]types.SiafundOutputID) error + MaturedSiacoinElements(types.ChainIndex) ([]types.SiacoinElement, error) + + AddressRelevant(types.Address) (bool, error) + AddressBalance(types.Address) (Balance, error) + UpdateBalances([]AddressBalance) error + } + + // An ApplyTx atomically applies a set of updates to a store. + ApplyTx interface { + UpdateTx + AddEvents([]Event) error } // RevertTx atomically reverts an update from a store. RevertTx interface { - RevertEvents(types.BlockID) error + UpdateTx - SiacoinStateElements() ([]types.StateElement, error) - UpdateSiacoinStateElements([]types.StateElement) error - - SiafundStateElements() ([]types.StateElement, error) - UpdateSiafundStateElements([]types.StateElement) error - - AddressRelevant(types.Address) (bool, error) - AddressBalance(types.Address) (Balance, error) - UpdateBalances([]AddressBalance) error - - MaturedSiacoinElements(types.ChainIndex) ([]types.SiacoinElement, error) - AddSiacoinElements([]types.SiacoinElement) error - RemoveSiacoinElements([]types.SiacoinOutputID) error + RevertEvents(index types.ChainIndex) error } ) @@ -325,10 +318,11 @@ func ApplyChainUpdates(tx ApplyTx, updates []*chain.ApplyUpdate) error { // RevertChainUpdate atomically reverts a chain update from a store func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { balances := make(map[types.Address]Balance) - newSiacoinElements := make(map[types.SiacoinOutputID]types.SiacoinElement) - newSiafundElements := make(map[types.SiafundOutputID]types.SiafundElement) - spentSiacoinElements := make(map[types.SiacoinOutputID]bool) - spentSiafundElements := make(map[types.SiafundOutputID]bool) + + var deletedSiacoinElements []types.SiacoinOutputID + var addedSiacoinElements []types.SiacoinElement + var deletedSiafundElements []types.SiafundOutputID + var addedSiafundElements []types.SiafundElement updateBalance := func(addr types.Address, fn func(b *Balance)) error { balance, ok := balances[addr] @@ -366,6 +360,26 @@ func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { } } + // revert the immature balance of each relevant address + revertedIndex := types.ChainIndex{ + Height: cru.State.Index.Height + 1, + ID: cru.Block.ID(), + } + + matured, err := tx.MaturedSiacoinElements(revertedIndex) + if err != nil { + return fmt.Errorf("failed to get matured siacoin elements: %w", err) + } + for _, se := range matured { + err := updateBalance(se.SiacoinOutput.Address, func(b *Balance) { + b.ImmatureSiacoins = b.ImmatureSiacoins.Add(se.SiacoinOutput.Value) + b.Siacoins = b.Siacoins.Sub(se.SiacoinOutput.Value) + }) + if err != nil { + return fmt.Errorf("failed to update address balance: %w", err) + } + } + var siacoinElementErr error cru.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { if siacoinElementErr != nil { @@ -382,10 +396,12 @@ func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { return } - if !spent { - newSiacoinElements[types.SiacoinOutputID(se.ID)] = se + if spent { + // re-add any spent siacoin elements + addedSiacoinElements = append(addedSiacoinElements, se) } else { - spentSiacoinElements[types.SiacoinOutputID(se.ID)] = true + // delete any created siacoin elements + deletedSiacoinElements = append(deletedSiacoinElements, types.SiacoinOutputID(se.ID)) } siacoinElementErr = updateBalance(se.SiacoinOutput.Address, func(b *Balance) { @@ -419,17 +435,19 @@ func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { return } - if !spent { - newSiafundElements[types.SiafundOutputID(se.ID)] = se + if spent { + // re-add any spent siafund elements + addedSiafundElements = append(addedSiafundElements, se) } else { - spentSiafundElements[types.SiafundOutputID(se.ID)] = true + // delete any created siafund elements + deletedSiafundElements = append(deletedSiafundElements, types.SiafundOutputID(se.ID)) } siafundElementErr = updateBalance(se.SiafundOutput.Address, func(b *Balance) { if spent { - b.Siafunds -= se.SiafundOutput.Value - } else { b.Siafunds += se.SiafundOutput.Value + } else { + b.Siafunds -= se.SiafundOutput.Value } }) }) @@ -448,5 +466,37 @@ func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { return fmt.Errorf("failed to update address balance: %w", err) } - return tx.RevertEvents(cru.Block.ID()) + // revert siacoin element changes + if err := tx.AddSiacoinElements(addedSiacoinElements); err != nil { + return fmt.Errorf("failed to add siacoin elements: %w", err) + } else if err := tx.RemoveSiacoinElements(deletedSiacoinElements); err != nil { + return fmt.Errorf("failed to remove siacoin elements: %w", err) + } + + // update siacoin element proofs + siacoinElements, err := tx.SiacoinStateElements() + if err != nil { + return fmt.Errorf("failed to get siacoin state elements: %w", err) + } + for i := range siacoinElements { + cru.UpdateElementProof(&siacoinElements[i]) + } + + // revert siafund element changes + if err := tx.AddSiafundElements(addedSiafundElements); err != nil { + return fmt.Errorf("failed to add siafund elements: %w", err) + } else if err := tx.RemoveSiafundElements(deletedSiafundElements); err != nil { + return fmt.Errorf("failed to remove siafund elements: %w", err) + } + + // update siafund element proofs + siafundElements, err := tx.SiafundStateElements() + if err != nil { + return fmt.Errorf("failed to get siafund state elements: %w", err) + } + for i := range siafundElements { + cru.UpdateElementProof(&siafundElements[i]) + } + + return tx.RevertEvents(revertedIndex) }