diff --git a/api/api_test.go b/api/api_test.go index 59c40fa..642357e 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -66,7 +66,7 @@ func TestWallet(t *testing.T) { } cm := chain.NewManager(dbstore, tipState) - ws, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log.Named("sqlite3")) + ws, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), sqlite.WithLogger(log.Named("sqlite3"))) if err != nil { t.Fatal(err) } @@ -254,7 +254,7 @@ func TestV2(t *testing.T) { t.Fatal(err) } cm := chain.NewManager(dbstore, tipState) - ws, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log.Named("sqlite3")) + ws, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), sqlite.WithLogger(log.Named("sqlite3"))) if err != nil { t.Fatal(err) } @@ -465,7 +465,7 @@ func TestP2P(t *testing.T) { } log1 := logger.Named("one") cm1 := chain.NewManager(dbstore1, tipState) - store1, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log1.Named("sqlite3")) + store1, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), sqlite.WithLogger(log1.Named("sqlite3"))) if err != nil { t.Fatal(err) } @@ -504,7 +504,7 @@ func TestP2P(t *testing.T) { } log2 := logger.Named("two") cm2 := chain.NewManager(dbstore2, tipState) - store2, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log2.Named("sqlite3")) + store2, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), sqlite.WithLogger(log2.Named("sqlite3"))) if err != nil { t.Fatal(err) } diff --git a/cmd/walletd/node.go b/cmd/walletd/node.go index b5a6808..d4f8086 100644 --- a/cmd/walletd/node.go +++ b/cmd/walletd/node.go @@ -161,7 +161,7 @@ func newNode(addr, dir string, chainNetwork string, useUPNP bool, log *zap.Logge syncerAddr = net.JoinHostPort("127.0.0.1", port) } - store, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) + store, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), sqlite.WithLogger(log.Named("sqlite3"))) if err != nil { return nil, fmt.Errorf("failed to open wallet database: %w", err) } diff --git a/persist/sqlite/addresses.go b/persist/sqlite/addresses.go new file mode 100644 index 0000000..1ab38cb --- /dev/null +++ b/persist/sqlite/addresses.go @@ -0,0 +1,51 @@ +package sqlite + +import ( + "database/sql" + "errors" + "fmt" + + "go.sia.tech/core/types" + "go.sia.tech/walletd/wallet" +) + +// AddressBalance returns the balance of a single address. +func (s *Store) AddressBalance(address types.Address) (balance wallet.Balance, err error) { + err = s.transaction(func(tx *txn) error { + const query = `SELECT siacoin_balance, immature_siacoin_balance, siafund_balance FROM sia_addresses WHERE sia_address=$1` + return tx.QueryRow(query, encode(address)).Scan(decode(&balance.Siacoins), decode(&balance.ImmatureSiacoins), &balance.Siafunds) + }) + if errors.Is(err, sql.ErrNoRows) { + return wallet.Balance{}, wallet.ErrNotFound + } + return +} + +// AddressEvents returns the events related to a single address. +func (s *Store) AddressEvents(address types.Address, offset, limit int) (events []wallet.Event, err error) { + err = s.transaction(func(tx *txn) error { + const query = `SELECT ev.id, ev.event_id, ev.maturity_height, ev.date_created, ci.height, ci.block_id, ev.event_type, ev.event_data +FROM events ev +INNER JOIN chain_indices ci ON (ev.index_id = ci.id) +INNER JOIN event_addresses ea ON (ev.id = ea.event_id) +INNER JOIN sia_addresses sa ON (ea.address_id = sa.id) +WHERE sa.sia_address = $1 +ORDER BY ev.maturity_height DESC +LIMIT $2 OFFSET $3` + rows, err := tx.Query(query, encode(address), limit, offset) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + _, event, err := scanEvent(rows) + if err != nil { + return fmt.Errorf("failed to scan event: %w", err) + } + event.Relevant = []types.Address{address} // only the address is relevant + events = append(events, event) + } + return rows.Err() + }) + return +} diff --git a/persist/sqlite/config.go b/persist/sqlite/config.go new file mode 100644 index 0000000..3e1963a --- /dev/null +++ b/persist/sqlite/config.go @@ -0,0 +1,23 @@ +package sqlite + +import ( + "go.uber.org/zap" +) + +// An Option is a functional option for configuring a Store. +type Option func(*Store) + +// WithLogger sets the logger used by the Store. +func WithLogger(log *zap.Logger) Option { + return func(s *Store) { + s.log = log + } +} + +// WithFullIndex sets the store to index all transactions and outputs, rather +// than just those relevant to the wallet. +func WithFullIndex() Option { + return func(s *Store) { + s.fullIndex = true + } +} diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index a16246b..acc465e 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -16,6 +16,7 @@ import ( type updateTx struct { tx *txn + fullIndex bool relevantAddresses map[types.Address]bool } @@ -49,6 +50,10 @@ func (ut *updateTx) SiacoinStateElements() ([]types.StateElement, error) { } func (ut *updateTx) UpdateSiacoinStateElements(elements []types.StateElement) error { + if ut.fullIndex { + panic("UpdateSiafundStateElements should not be called with full index enabled") + } + const query = `UPDATE siacoin_elements SET merkle_proof=$1, leaf_index=$2 WHERE id=$3 RETURNING id` stmt, err := ut.tx.Prepare(query) if err != nil { @@ -86,6 +91,10 @@ func (ut *updateTx) SiafundStateElements() ([]types.StateElement, error) { } func (ut *updateTx) UpdateSiafundStateElements(elements []types.StateElement) error { + if ut.fullIndex { + panic("UpdateSiafundStateElements should not be called with full index enabled") + } + const query = `UPDATE siafund_elements SET merkle_proof=$1, leaf_index=$2 WHERE id=$3 RETURNING id` stmt, err := ut.tx.Prepare(query) if err != nil { @@ -103,7 +112,32 @@ func (ut *updateTx) UpdateSiafundStateElements(elements []types.StateElement) er return nil } +// UpdateStateTree updates the state tree with the given changes. +func (ut *updateTx) UpdateStateTree(changes []wallet.TreeNodeUpdate) error { + if !ut.fullIndex { + panic("UpdateStateTree should not be called with full index disabled") + } + + stmt, err := ut.tx.Prepare(`INSERT INTO state_tree (row, column, value) VALUES($1, $2, $3) ON CONFLICT (row, column) DO UPDATE SET value=EXCLUDED.value;`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer stmt.Close() + + for _, change := range changes { + _, err := stmt.Exec(change.Row, change.Column, encode(change.Hash)) + if err != nil { + return fmt.Errorf("failed to execute statement: %w", err) + } + } + return nil +} + func (ut *updateTx) AddressRelevant(addr types.Address) (bool, error) { + if ut.fullIndex { + panic("AddressRelevant should not be called with full index enabled") + } + if relevant, ok := ut.relevantAddresses[addr]; ok { return relevant, nil } @@ -122,11 +156,17 @@ func (ut *updateTx) AddressRelevant(addr types.Address) (bool, error) { func (ut *updateTx) AddressBalance(addr types.Address) (balance wallet.Balance, err error) { err = ut.tx.QueryRow(`SELECT siacoin_balance, immature_siacoin_balance, siafund_balance FROM sia_addresses WHERE sia_address=$1`, encode(addr)).Scan(decode(&balance.Siacoins), decode(&balance.ImmatureSiacoins), &balance.Siafunds) + if errors.Is(err, sql.ErrNoRows) { + if ut.fullIndex { + return wallet.Balance{}, nil + } + return wallet.Balance{}, wallet.ErrNotFound + } return } func (ut *updateTx) UpdateBalances(balances []wallet.AddressBalance) error { - const query = `UPDATE sia_addresses SET siacoin_balance=$1, immature_siacoin_balance=$2, siafund_balance=$3 WHERE sia_address=$4` + const query = `INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $3, $4) ON CONFLICT (sia_address) DO UPDATE SET siacoin_balance=EXCLUDED.siacoin_balance, immature_siacoin_balance=EXCLUDED.immature_siacoin_balance, siafund_balance=EXCLUDED.siafund_balance;` stmt, err := ut.tx.Prepare(query) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) @@ -134,7 +174,7 @@ func (ut *updateTx) UpdateBalances(balances []wallet.AddressBalance) error { defer stmt.Close() for _, ab := range balances { - _, err := stmt.Exec(encode(ab.Balance.Siacoins), encode(ab.Balance.ImmatureSiacoins), ab.Balance.Siafunds, encode(ab.Address)) + _, err := stmt.Exec(encode(ab.Address), encode(ab.Balance.Siacoins), encode(ab.Balance.ImmatureSiacoins), ab.Balance.Siafunds) if err != nil { return fmt.Errorf("failed to execute statement: %w", err) } @@ -170,20 +210,20 @@ func (ut *updateTx) AddSiacoinElements(elements []types.SiacoinElement) error { } defer addrStmt.Close() - inserStmt, err := ut.tx.Prepare(`INSERT INTO siacoin_elements (id, siacoin_value, merkle_proof, leaf_index, maturity_height, address_id) VALUES ($1, $2, $3, $4, $5, $6)`) + insertStmt, err := ut.tx.Prepare(`INSERT INTO siacoin_elements (id, siacoin_value, merkle_proof, leaf_index, maturity_height, address_id) VALUES ($1, $2, $3, $4, $5, $6)`) if err != nil { return fmt.Errorf("failed to prepare insert statement: %w", err) } - defer inserStmt.Close() + defer insertStmt.Close() for _, se := range elements { var addressID int64 - err := addrStmt.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID) + err = addrStmt.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID) if err != nil { return fmt.Errorf("failed to query address: %w", err) } - _, err = inserStmt.Exec(encode(se.ID), encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight, addressID) + _, err = insertStmt.Exec(encode(se.ID), encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight, addressID) if err != nil { return fmt.Errorf("failed to execute statement: %w", err) } @@ -215,20 +255,20 @@ func (ut *updateTx) AddSiafundElements(elements []types.SiafundElement) error { } defer addrStmt.Close() - inserStmt, err := ut.tx.Prepare(`INSERT INTO siafund_elements (id, siafund_value, merkle_proof, leaf_index, claim_start, address_id) VALUES ($1, $2, $3, $4, $5, $6)`) + insertStmt, err := ut.tx.Prepare(`INSERT INTO siafund_elements (id, siafund_value, merkle_proof, leaf_index, claim_start, address_id) VALUES ($1, $2, $3, $4, $5, $6)`) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } - defer inserStmt.Close() + defer insertStmt.Close() for _, se := range elements { var addressID int64 - err := addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID) + err = addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0).Scan(&addressID) if err != nil { return fmt.Errorf("failed to query address: %w", err) } - _, err = inserStmt.Exec(encode(se.ID), se.SiafundOutput.Value, encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ClaimStart), addressID) + _, err = insertStmt.Exec(encode(se.ID), se.SiafundOutput.Value, encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ClaimStart), addressID) if err != nil { return fmt.Errorf("failed to execute statement: %w", err) } @@ -254,7 +294,7 @@ func (ut *updateTx) RemoveSiafundElements(elements []types.SiafundOutputID) erro } func (ut *updateTx) AddEvents(events []wallet.Event) error { - indexStmt, err := ut.tx.Prepare(`INSERT INTO chain_indices (height, block_id) VALUES ($1, $2) ON CONFLICT (block_id) DO UPDATE SET height=EXCLUDED.height RETURNING id`) + indexStmt, err := insertIndexStmt(ut.tx) if err != nil { return fmt.Errorf("failed to prepare index statement: %w", err) } @@ -321,10 +361,10 @@ func (ut *updateTx) AddEvents(events []wallet.Event) error { return nil } -// RevertEvents reverts the events that were added in the given block. -func (ut *updateTx) RevertEvents(blockID types.BlockID) error { +// RevertEvents reverts any events that were added by the index +func (ut *updateTx) RevertEvents(index types.ChainIndex) error { var id int64 - err := ut.tx.QueryRow(`DELETE FROM chain_indices WHERE block_id=$1 RETURNING id`, encode(blockID)).Scan(&id) + err := ut.tx.QueryRow(`DELETE FROM chain_indices WHERE block_id=$1 AND height=$2 RETURNING id`, encode(index.ID), index.Height).Scan(&id) if errors.Is(err, sql.ErrNoRows) { return nil } @@ -341,10 +381,11 @@ func (s *Store) ProcessChainApplyUpdate(cau *chain.ApplyUpdate, mayCommit bool) return s.transaction(func(tx *txn) error { utx := &updateTx{ tx: tx, + fullIndex: s.fullIndex, relevantAddresses: make(map[types.Address]bool), } - if err := wallet.ApplyChainUpdates(utx, s.updates); err != nil { + if err := wallet.ApplyChainUpdates(utx, s.updates, s.fullIndex); err != nil { return fmt.Errorf("failed to apply updates: %w", err) } else if err := setLastCommittedIndex(tx, cau.State.Index); err != nil { return fmt.Errorf("failed to set last committed index: %w", err) @@ -373,10 +414,11 @@ func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { return s.transaction(func(tx *txn) error { utx := &updateTx{ tx: tx, + fullIndex: s.fullIndex, relevantAddresses: make(map[types.Address]bool), } - if err := wallet.RevertChainUpdate(utx, cru); err != nil { + if err := wallet.RevertChainUpdate(utx, cru, s.fullIndex); err != nil { return fmt.Errorf("failed to revert update: %w", err) } else if err := setLastCommittedIndex(tx, cru.State.Index); err != nil { return fmt.Errorf("failed to set last committed index: %w", err) @@ -399,3 +441,7 @@ func setLastCommittedIndex(tx *txn, index types.ChainIndex) error { func insertAddressStatement(tx *txn) (*stmt, error) { return tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $2, $3) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id`) } + +func insertIndexStmt(tx *txn) (*stmt, error) { + return tx.Prepare(`INSERT INTO chain_indices (height, block_id) VALUES ($1, $2) ON CONFLICT (block_id) DO UPDATE SET height=EXCLUDED.height RETURNING id`) +} diff --git a/persist/sqlite/consensus_test.go b/persist/sqlite/consensus_test.go index 16ff48e..aa401f5 100644 --- a/persist/sqlite/consensus_test.go +++ b/persist/sqlite/consensus_test.go @@ -77,7 +77,7 @@ func mineV2Block(state consensus.State, txns []types.V2Transaction, minerAddr ty func TestReorg(t *testing.T) { log := zaptest.NewLogger(t) dir := t.TempDir() - db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) + db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), sqlite.WithLogger(log.Named("sqlite3"))) if err != nil { t.Fatal(err) } @@ -113,6 +113,7 @@ func TestReorg(t *testing.T) { } expectedPayout := cm.TipState().BlockReward() + maturityHeight := cm.TipState().MaturityHeight() // mine a block sending the payout to the wallet if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil { t.Fatal(err) @@ -136,6 +137,18 @@ func TestReorg(t *testing.T) { t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) } + // check that the utxo was created + utxos, err := db.UnspentSiacoinOutputs("test") + if err != nil { + t.Fatal(err) + } else if len(utxos) != 1 { + t.Fatalf("expected 1 output, got %v", len(utxos)) + } else if utxos[0].SiacoinOutput.Value.Cmp(expectedPayout) != 0 { + t.Fatalf("expected %v, got %v", expectedPayout, utxos[0].SiacoinOutput.Value) + } else if utxos[0].MaturityHeight != maturityHeight { + t.Fatalf("expected %v, got %v", maturityHeight, utxos[0].MaturityHeight) + } + // mine to trigger a reorg var blocks []types.Block state := genesisState @@ -163,12 +176,112 @@ func TestReorg(t *testing.T) { } else if len(events) != 0 { t.Fatalf("expected 0 events, got %v", len(events)) } + + // check that the utxo was removed + utxos, err = db.UnspentSiacoinOutputs("test") + if err != nil { + t.Fatal(err) + } else if len(utxos) != 0 { + t.Fatalf("expected 0 outputs, got %v", len(utxos)) + } + + // mine a new payout + expectedPayout = cm.TipState().BlockReward() + maturityHeight = cm.TipState().MaturityHeight() + if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil { + t.Fatal(err) + } + + // check that the payout was received + balance, err = db.AddressBalance(addr) + if err != nil { + t.Fatal(err) + } else if !balance.ImmatureSiacoins.Equals(expectedPayout) { + t.Fatalf("expected %v, got %v", expectedPayout, balance.ImmatureSiacoins) + } + + // check that a payout event was recorded + events, err = db.WalletEvents("test", 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 event, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { + t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) + } + + // check that the utxo was created + utxos, err = db.UnspentSiacoinOutputs("test") + if err != nil { + t.Fatal(err) + } else if len(utxos) != 1 { + t.Fatalf("expected 1 output, got %v", len(utxos)) + } else if utxos[0].SiacoinOutput.Value.Cmp(expectedPayout) != 0 { + t.Fatalf("expected %v, got %v", expectedPayout, utxos[0].SiacoinOutput.Value) + } else if utxos[0].MaturityHeight != maturityHeight { + t.Fatalf("expected %v, got %v", maturityHeight, utxos[0].MaturityHeight) + } + + // mine until the payout matures + var prevState consensus.State + for i := cm.TipState().Index.Height; i < maturityHeight+1; i++ { + if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, types.VoidAddress)}); err != nil { + t.Fatal(err) + } + if i == maturityHeight-5 { + prevState = cm.TipState() + } + } + + // check that the balance was updated + balance, err = db.AddressBalance(addr) + if err != nil { + t.Fatal(err) + } else if !balance.ImmatureSiacoins.IsZero() { + t.Fatalf("expected %v, got %v", types.ZeroCurrency, balance.ImmatureSiacoins) + } else if !balance.Siacoins.Equals(expectedPayout) { + t.Fatalf("expected %v, got %v", expectedPayout, balance.Siacoins) + } + + // reorg the last few blocks to re-mature the payout + blocks = nil + state = prevState + for i := 0; i < 10; i++ { + blocks = append(blocks, mineBlock(state, nil, types.VoidAddress)) + state.Index.ID = blocks[len(blocks)-1].ID() + state.Index.Height = state.Index.Height + 1 + } + if err := cm.AddBlocks(blocks); err != nil { + t.Fatal(err) + } + + // check that the balance is correct + balance, err = db.AddressBalance(addr) + if err != nil { + t.Fatal(err) + } else if !balance.ImmatureSiacoins.IsZero() { + t.Fatalf("expected %v, got %v", types.ZeroCurrency, balance.ImmatureSiacoins) + } else if !balance.Siacoins.Equals(expectedPayout) { + t.Fatalf("expected %v, got %v", expectedPayout, balance.Siacoins) + } + + // check that only the single utxo still exists + utxos, err = db.UnspentSiacoinOutputs("test") + if err != nil { + t.Fatal(err) + } else if len(utxos) != 1 { + t.Fatalf("expected 1 output, got %v", len(utxos)) + } else if utxos[0].SiacoinOutput.Value.Cmp(expectedPayout) != 0 { + t.Fatalf("expected %v, got %v", expectedPayout, utxos[0].SiacoinOutput.Value) + } else if utxos[0].MaturityHeight != maturityHeight { + t.Fatalf("expected %v, got %v", maturityHeight, utxos[0].MaturityHeight) + } } func TestEphemeralBalance(t *testing.T) { log := zaptest.NewLogger(t) dir := t.TempDir() - db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) + db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), sqlite.WithLogger(log.Named("sqlite3"))) if err != nil { t.Fatal(err) } @@ -205,8 +318,10 @@ func TestEphemeralBalance(t *testing.T) { expectedPayout := cm.TipState().BlockReward() maturityHeight := cm.TipState().MaturityHeight() + 1 + block := mineBlock(cm.TipState(), nil, addr) + minerPayoutID := block.ID().MinerOutputID(0) // mine a block sending the payout to the wallet - if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil { + if err := cm.AddBlocks([]types.Block{block}); err != nil { t.Fatal(err) } @@ -226,6 +341,8 @@ func TestEphemeralBalance(t *testing.T) { t.Fatalf("expected 1 event, got %v", len(events)) } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) + } else if events[0].ID != types.Hash256(minerPayoutID) { + t.Fatalf("expected %v, got %v", minerPayoutID, events[0].ID) } // mine until the payout matures @@ -306,6 +423,24 @@ func TestEphemeralBalance(t *testing.T) { t.Fatalf("expected 0, got %v", balance.Siacoins) } + // check that both transactions were added + events, err = db.WalletEvents("test", 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 3 { // 1 payout, 2 transactions + t.Fatalf("expected 3 events, got %v", len(events)) + } else if events[2].Data.EventType() != wallet.EventTypeMinerPayout { + t.Fatalf("expected miner payout event, got %v", events[2].Data.EventType()) + } else if events[1].Data.EventType() != wallet.EventTypeTransaction { + t.Fatalf("expected transaction event, got %v", events[1].Data.EventType()) + } else if events[0].Data.EventType() != wallet.EventTypeTransaction { + t.Fatalf("expected transaction event, got %v", events[0].Data.EventType()) + } else if events[1].ID != types.Hash256(parentTxn.ID()) { // parent txn first + t.Fatalf("expected %v, got %v", parentTxn.ID(), events[1].ID) + } else if events[0].ID != types.Hash256(txn.ID()) { // child txn second + t.Fatalf("expected %v, got %v", txn.ID(), events[0].ID) + } + // trigger a reorg var blocks []types.Block state := revertState @@ -340,7 +475,7 @@ func TestEphemeralBalance(t *testing.T) { func TestV2(t *testing.T) { log := zaptest.NewLogger(t) dir := t.TempDir() - db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) + db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), sqlite.WithLogger(log.Named("sqlite3"))) if err != nil { t.Fatal(err) } @@ -453,3 +588,157 @@ func TestV2(t *testing.T) { t.Fatalf("expected address %v, got %v", addr, events[0].Relevant[0]) } } + +func TestFullIndex(t *testing.T) { + log := zaptest.NewLogger(t) + dir := t.TempDir() + db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), + sqlite.WithLogger(log.Named("sqlite3")), + sqlite.WithFullIndex()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + bdb, err := coreutils.OpenBoltChainDB(filepath.Join(dir, "consensus.db")) + if err != nil { + t.Fatal(err) + } + defer bdb.Close() + + network, genesisBlock := testV1Network() + + store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + cm := chain.NewManager(store, genesisState) + + if err := cm.AddSubscriber(db, types.ChainIndex{}); err != nil { + t.Fatal(err) + } + + siacoinAirdropAddr, err := types.ParseAddress("addr:3d7f707d05f2e0ec7ccc9220ed7c8af3bc560fbee84d068c2cc28151d617899e1ee8bc069946") + if err != nil { + t.Fatal(err) + } + siacoinAirdropValue := types.Siacoins(1).Mul64(1e12) + + siafundAirdropAddr, err := types.ParseAddress("addr:053b2def3cbdd078c19d62ce2b4f0b1a3c5e0ffbeeff01280efb1f8969b2f5bb4fdc680f0807") + if err != nil { + t.Fatal(err) + } + siafundAirdropValue := uint64(10000) + + balance, err := db.AddressBalance(siacoinAirdropAddr) + if err != nil { + t.Fatal(err) + } else if !balance.Siacoins.Equals(siacoinAirdropValue) { + t.Fatalf("expected %v, got %v", siacoinAirdropValue, balance.Siacoins) + } + + events, err := db.AddressEvents(siacoinAirdropAddr, 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 event, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeTransaction { + t.Fatalf("expected transaction event, got %v", events[0].Data.EventType()) + } else if events[0].ID != types.Hash256(genesisBlock.Transactions[0].ID()) { + t.Fatalf("expected transaction ID %q got %q", genesisBlock.Transactions[0].ID(), events[0].ID) + } + + tx, ok := events[0].Data.(*wallet.EventTransaction) + if !ok { + t.Fatalf("expected transaction event, got %v", events[0].Data.EventType()) + } else if tx.SiacoinOutputs[0].SiacoinOutput.Address != siacoinAirdropAddr { + t.Fatalf("expected address %v, got %v", siacoinAirdropAddr, tx.SiacoinOutputs[0].SiacoinOutput.Address) + } else if !tx.SiacoinOutputs[0].SiacoinOutput.Value.Equals(siacoinAirdropValue) { + t.Fatalf("expected %v, got %v", siacoinAirdropValue, tx.SiacoinOutputs[0].SiacoinOutput.Value) + } + + balance, err = db.AddressBalance(siafundAirdropAddr) + if err != nil { + t.Fatal(err) + } else if balance.Siafunds != siafundAirdropValue { + t.Fatalf("expected %v, got %v", siafundAirdropValue, balance.Siafunds) + } + + pk := types.GeneratePrivateKey() + addr := types.StandardUnlockHash(pk.PublicKey()) + + expectedPayout := cm.TipState().BlockReward() + block := mineBlock(cm.TipState(), nil, addr) + minerPayoutID := block.ID().MinerOutputID(0) + // mine a block sending the payout to the wallet + if err := cm.AddBlocks([]types.Block{block}); err != nil { + t.Fatal(err) + } + + // check that the payout was received + balance, err = db.AddressBalance(addr) + if err != nil { + t.Fatal(err) + } else if !balance.ImmatureSiacoins.Equals(expectedPayout) { + t.Fatalf("expected %v, got %v", expectedPayout, balance.ImmatureSiacoins) + } + + // check that a payout event was recorded + events, err = db.AddressEvents(addr, 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 event, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { + t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) + } else if events[0].ID != types.Hash256(minerPayoutID) { + t.Fatalf("expected %v, got %v", minerPayoutID, events[0].ID) + } + + // mine until the payout matures + maturityHeight := cm.TipState().MaturityHeight() + 1 + for i := cm.TipState().Index.Height; i < maturityHeight; i++ { + if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, types.VoidAddress)}); err != nil { + t.Fatal(err) + } + } + + // check that the balance matured + balance, err = db.AddressBalance(addr) + if err != nil { + t.Fatal(err) + } else if !balance.ImmatureSiacoins.IsZero() { + t.Fatalf("expected 0, got %v", balance.ImmatureSiacoins) + } else if !balance.Siacoins.Equals(expectedPayout) { + t.Fatalf("expected %v, got %v", expectedPayout, balance.Siacoins) + } + + // add the address to a wallet + if err := db.AddWallet("test", nil); err != nil { + t.Fatal(err) + } else if err := db.AddAddress("test", addr, nil); err != nil { + t.Fatal(err) + } + + // check that the wallet balance is correct + walletBalance, err := db.WalletBalance("test") + if err != nil { + t.Fatal(err) + } else if !walletBalance.Siacoins.Equals(expectedPayout) { + t.Fatalf("expected %v, got %v", expectedPayout, walletBalance.Siacoins) + } + + // check that the payout event was associated with the wallet + events, err = db.WalletEvents("test", 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 event, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { + t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) + } else if events[0].ID != types.Hash256(minerPayoutID) { + t.Fatalf("expected %v, got %v", minerPayoutID, events[0].ID) + } +} diff --git a/persist/sqlite/init.sql b/persist/sqlite/init.sql index 2366b64..414f63c 100644 --- a/persist/sqlite/init.sql +++ b/persist/sqlite/init.sql @@ -21,6 +21,7 @@ CREATE TABLE siacoin_elements ( address_id INTEGER NOT NULL REFERENCES sia_addresses (id) ); CREATE INDEX siacoin_elements_address_id ON siacoin_elements (address_id); +CREATE INDEX siacoin_elements_maturity_height ON siacoin_elements (maturity_height); CREATE TABLE siafund_elements ( id BLOB PRIMARY KEY, @@ -32,6 +33,13 @@ CREATE TABLE siafund_elements ( ); CREATE INDEX siafund_elements_address_id ON siafund_elements (address_id); +CREATE TABLE state_tree ( + row INTEGER, + column INTEGER, + value BLOB NOT NULL, + PRIMARY KEY (row, column) +); + CREATE TABLE wallets ( id TEXT PRIMARY KEY NOT NULL, extra_data BLOB NOT NULL @@ -48,9 +56,9 @@ CREATE INDEX wallet_addresses_address_id ON wallet_addresses (address_id); CREATE TABLE events ( id INTEGER PRIMARY KEY, event_id BLOB NOT NULL, + index_id BLOB NOT NULL REFERENCES chain_indices (id) ON DELETE CASCADE, maturity_height INTEGER NOT NULL, date_created INTEGER NOT NULL, - index_id BLOB NOT NULL REFERENCES chain_indices (id) ON DELETE CASCADE, event_type TEXT NOT NULL, event_data TEXT NOT NULL ); diff --git a/persist/sqlite/peers_test.go b/persist/sqlite/peers_test.go index 4f3e26e..4f4474d 100644 --- a/persist/sqlite/peers_test.go +++ b/persist/sqlite/peers_test.go @@ -12,7 +12,7 @@ import ( func TestAddPeer(t *testing.T) { log := zaptest.NewLogger(t) - db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log) + db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), WithLogger(log)) if err != nil { t.Fatal(err) } @@ -53,7 +53,7 @@ func TestAddPeer(t *testing.T) { func TestBanPeer(t *testing.T) { log := zaptest.NewLogger(t) - db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log) + db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), WithLogger(log)) if err != nil { t.Fatal(err) } diff --git a/persist/sqlite/store.go b/persist/sqlite/store.go index e50fda5..8b032e9 100644 --- a/persist/sqlite/store.go +++ b/persist/sqlite/store.go @@ -17,8 +17,10 @@ import ( type ( // A Store is a persistent store that uses a SQL database as its backend. Store struct { - db *sql.DB - log *zap.Logger + db *sql.DB + + log *zap.Logger + fullIndex bool updates []*chain.ApplyUpdate } @@ -107,14 +109,17 @@ func doTransaction(db *sql.DB, log *zap.Logger, fn func(tx *txn) error) error { // OpenDatabase creates a new SQLite store and initializes the database. If the // database does not exist, it is created. -func OpenDatabase(fp string, log *zap.Logger) (*Store, error) { +func OpenDatabase(fp string, opts ...Option) (*Store, error) { db, err := sql.Open("sqlite3", sqliteFilepath(fp)) if err != nil { return nil, err } store := &Store{ db: db, - log: log, + log: zap.NewNop(), + } + for _, opt := range opts { + opt(store) } if err := store.init(); err != nil { return nil, fmt.Errorf("failed to initialize database: %w", err) diff --git a/persist/sqlite/wallet.go b/persist/sqlite/wallet.go index ace68de..112d75f 100644 --- a/persist/sqlite/wallet.go +++ b/persist/sqlite/wallet.go @@ -19,12 +19,53 @@ RETURNING id` return } +func scanEvent(s scanner) (id int64, event wallet.Event, err error) { + var eventType string + var eventBuf []byte + + err = s.Scan(&id, decode(&event.ID), &event.MaturityHeight, decode(&event.Timestamp), &event.Index.Height, decode(&event.Index.ID), &eventType, &eventBuf) + if err != nil { + return 0, wallet.Event{}, fmt.Errorf("failed to scan event: %w", err) + } + + switch eventType { + case wallet.EventTypeTransaction: + var tx wallet.EventTransaction + if err = json.Unmarshal(eventBuf, &tx); err != nil { + return 0, wallet.Event{}, fmt.Errorf("failed to unmarshal transaction event: %w", err) + } + event.Data = &tx + case wallet.EventTypeContractPayout: + var m wallet.EventContractPayout + if err = json.Unmarshal(eventBuf, &m); err != nil { + return 0, wallet.Event{}, fmt.Errorf("failed to unmarshal missed file contract event: %w", err) + } + event.Data = &m + case wallet.EventTypeMinerPayout: + var m wallet.EventMinerPayout + if err = json.Unmarshal(eventBuf, &m); err != nil { + return 0, wallet.Event{}, fmt.Errorf("failed to unmarshal payout event: %w", err) + } + event.Data = &m + case wallet.EventTypeFoundationSubsidy: + var m wallet.EventFoundationSubsidy + if err = json.Unmarshal(eventBuf, &m); err != nil { + return 0, wallet.Event{}, fmt.Errorf("failed to unmarshal foundation subsidy event: %w", err) + } + event.Data = &m + default: + return 0, wallet.Event{}, fmt.Errorf("unknown event type: %s", eventType) + } + + return +} + func getWalletEvents(tx *txn, walletID string, offset, limit int) (events []wallet.Event, eventIDs []int64, err error) { const query = `SELECT ev.id, ev.event_id, ev.maturity_height, ev.date_created, ci.height, ci.block_id, ev.event_type, ev.event_data FROM events ev INNER JOIN chain_indices ci ON (ev.index_id = ci.id) WHERE ev.id IN (SELECT event_id FROM event_addresses WHERE address_id IN (SELECT address_id FROM wallet_addresses WHERE wallet_id=$1)) - ORDER BY ev.maturity_height DESC + ORDER BY ev.maturity_height DESC, ev.id DESC LIMIT $2 OFFSET $3` rows, err := tx.Query(query, walletID, limit, offset) @@ -34,48 +75,17 @@ func getWalletEvents(tx *txn, walletID string, offset, limit int) (events []wall defer rows.Close() for rows.Next() { - var eventID int64 - var event wallet.Event - var eventType string - var eventBuf []byte - - err := rows.Scan(&eventID, decode(&event.ID), &event.MaturityHeight, decode(&event.Timestamp), &event.Index.Height, decode(&event.Index.ID), &eventType, &eventBuf) + eventID, event, err := scanEvent(rows) if err != nil { return nil, nil, fmt.Errorf("failed to scan event: %w", err) } - switch eventType { - case wallet.EventTypeTransaction: - var tx wallet.EventTransaction - if err = json.Unmarshal(eventBuf, &tx); err != nil { - return nil, nil, fmt.Errorf("failed to unmarshal transaction event: %w", err) - } - event.Data = &tx - case wallet.EventTypeContractPayout: - var m wallet.EventContractPayout - if err = json.Unmarshal(eventBuf, &m); err != nil { - return nil, nil, fmt.Errorf("failed to unmarshal missed file contract event: %w", err) - } - event.Data = &m - case wallet.EventTypeMinerPayout: - var m wallet.EventMinerPayout - if err = json.Unmarshal(eventBuf, &m); err != nil { - return nil, nil, fmt.Errorf("failed to unmarshal payout event: %w", err) - } - event.Data = &m - case wallet.EventTypeFoundationSubsidy: - var m wallet.EventFoundationSubsidy - if err = json.Unmarshal(eventBuf, &m); err != nil { - return nil, nil, fmt.Errorf("failed to unmarshal foundation subsidy event: %w", err) - } - event.Data = &m - default: - return nil, nil, fmt.Errorf("unknown event type: %s", eventType) - } - events = append(events, event) eventIDs = append(eventIDs, eventID) } + if err := rows.Err(); err != nil { + return nil, nil, fmt.Errorf("failed to scan events: %w", err) + } return } @@ -312,15 +322,6 @@ func (s *Store) WalletBalance(walletID string) (balance wallet.Balance, err erro return } -// AddressBalance returns the balance of a single address. -func (s *Store) AddressBalance(address types.Address) (balance wallet.Balance, err error) { - err = s.transaction(func(tx *txn) error { - const query = `SELECT siacoin_balance, immature_siacoin_balance, siafund_balance FROM sia_addresses WHERE sia_address=$1` - return tx.QueryRow(query, encode(address)).Scan(decode(&balance.Siacoins), decode(&balance.ImmatureSiacoins), &balance.Siafunds) - }) - return -} - // Annotate annotates a list of transactions using the wallet's addresses. func (s *Store) Annotate(walletID string, txns []types.Transaction) (annotated []wallet.PoolTransaction, err error) { err = s.transaction(func(tx *txn) error { diff --git a/wallet/manager.go b/wallet/manager.go index 8ecb5b3..2940e93 100644 --- a/wallet/manager.go +++ b/wallet/manager.go @@ -54,6 +54,10 @@ type ( } ) +// ErrNotFound should be returned by the store when an address or wallet is +// not found. +var ErrNotFound = errors.New("not found") + // AddWallet adds the given wallet. func (m *Manager) AddWallet(name string, info json.RawMessage) error { return m.store.AddWallet(name, info) diff --git a/wallet/update.go b/wallet/update.go index 4784c79..dbe3f41 100644 --- a/wallet/update.go +++ b/wallet/update.go @@ -14,50 +14,53 @@ type ( Balance } - // An ApplyTx atomically applies a set of updates to a store. - ApplyTx interface { + // A TreeNodeUpdate is a change to a Merkle tree node. + TreeNodeUpdate struct { + Row uint64 + Column uint64 + Hash types.Hash256 + } + + // An UpdateTx atomically updates the state of a store. + UpdateTx interface { SiacoinStateElements() ([]types.StateElement, error) UpdateSiacoinStateElements([]types.StateElement) error SiafundStateElements() ([]types.StateElement, error) UpdateSiafundStateElements([]types.StateElement) error - AddressRelevant(types.Address) (bool, error) - AddressBalance(types.Address) (Balance, error) - UpdateBalances([]AddressBalance) error - - MaturedSiacoinElements(types.ChainIndex) ([]types.SiacoinElement, error) AddSiacoinElements([]types.SiacoinElement) error RemoveSiacoinElements([]types.SiacoinOutputID) error AddSiafundElements([]types.SiafundElement) error RemoveSiafundElements([]types.SiafundOutputID) error + MaturedSiacoinElements(types.ChainIndex) ([]types.SiacoinElement, error) + + AddressRelevant(types.Address) (bool, error) + AddressBalance(types.Address) (Balance, error) + UpdateBalances([]AddressBalance) error + + UpdateStateTree(changes []TreeNodeUpdate) error + } + + // An ApplyTx atomically applies a set of updates to a store. + ApplyTx interface { + UpdateTx + AddEvents([]Event) error } // RevertTx atomically reverts an update from a store. RevertTx interface { - RevertEvents(types.BlockID) error - - SiacoinStateElements() ([]types.StateElement, error) - UpdateSiacoinStateElements([]types.StateElement) error - - SiafundStateElements() ([]types.StateElement, error) - UpdateSiafundStateElements([]types.StateElement) error - - AddressRelevant(types.Address) (bool, error) - AddressBalance(types.Address) (Balance, error) - UpdateBalances([]AddressBalance) error + UpdateTx - MaturedSiacoinElements(types.ChainIndex) ([]types.SiacoinElement, error) - AddSiacoinElements([]types.SiacoinElement) error - RemoveSiacoinElements([]types.SiacoinOutputID) error + RevertEvents(index types.ChainIndex) error } ) // ApplyChainUpdates atomically applies a set of chain updates to a store -func ApplyChainUpdates(tx ApplyTx, updates []*chain.ApplyUpdate) error { +func ApplyChainUpdates(tx ApplyTx, updates []*chain.ApplyUpdate, fullIndex bool) error { var events []Event balances := make(map[types.Address]Balance) newSiacoinElements := make(map[types.SiacoinOutputID]types.SiacoinElement) @@ -65,6 +68,8 @@ func ApplyChainUpdates(tx ApplyTx, updates []*chain.ApplyUpdate) error { spentSiacoinElements := make(map[types.SiacoinOutputID]bool) spentSiafundElements := make(map[types.SiafundOutputID]bool) + var treeUpdates []TreeNodeUpdate + updateBalance := func(addr types.Address, fn func(b *Balance)) error { balance, ok := balances[addr] if !ok { @@ -80,14 +85,18 @@ func ApplyChainUpdates(tx ApplyTx, updates []*chain.ApplyUpdate) error { return nil } - // fetch all siacoin and siafund state elements - siacoinStateElements, err := tx.SiacoinStateElements() - if err != nil { - return fmt.Errorf("failed to get siacoin state elements: %w", err) - } - siafundStateElements, err := tx.SiafundStateElements() - if err != nil { - return fmt.Errorf("failed to get siafund state elements: %w", err) + var siacoinStateElements, siafundStateElements []types.StateElement + var err error + if !fullIndex { + // fetch all siacoin and siafund state elements + siacoinStateElements, err = tx.SiacoinStateElements() + if err != nil { + return fmt.Errorf("failed to get siacoin state elements: %w", err) + } + siafundStateElements, err = tx.SiafundStateElements() + if err != nil { + return fmt.Errorf("failed to get siafund state elements: %w", err) + } } for _, cau := range updates { @@ -136,12 +145,14 @@ func ApplyChainUpdates(tx ApplyTx, updates []*chain.ApplyUpdate) error { return } - relevant, err := tx.AddressRelevant(se.SiacoinOutput.Address) - if err != nil { - siacoinElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) - return - } else if !relevant { - return + if !fullIndex { + relevant, err := tx.AddressRelevant(se.SiacoinOutput.Address) + if err != nil { + siacoinElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) + return + } else if !relevant { + return + } } if spent { @@ -178,12 +189,14 @@ func ApplyChainUpdates(tx ApplyTx, updates []*chain.ApplyUpdate) error { return } - relevant, err := tx.AddressRelevant(se.SiafundOutput.Address) - if err != nil { - siafundElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) - return - } else if !relevant { - return + if !fullIndex { + relevant, err := tx.AddressRelevant(se.SiafundOutput.Address) + if err != nil { + siafundElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) + return + } else if !relevant { + return + } } if spent { @@ -211,6 +224,10 @@ func ApplyChainUpdates(tx ApplyTx, updates []*chain.ApplyUpdate) error { // add events relevant := func(addr types.Address) bool { + if fullIndex { + return true + } + relevant, err := tx.AddressRelevant(addr) if err != nil { panic(fmt.Errorf("failed to check if address is relevant: %w", err)) @@ -222,24 +239,34 @@ func ApplyChainUpdates(tx ApplyTx, updates []*chain.ApplyUpdate) error { } events = append(events, AppliedEvents(cau.State, cau.Block, cau, relevant)...) - // update siacoin element proofs - for id := range newSiacoinElements { - ele := newSiacoinElements[id] - cau.UpdateElementProof(&ele.StateElement) - newSiacoinElements[id] = ele - } - for i := range siacoinStateElements { - cau.UpdateElementProof(&siacoinStateElements[i]) - } + if !fullIndex { + // update siacoin element proofs + for id := range newSiacoinElements { + ele := newSiacoinElements[id] + cau.UpdateElementProof(&ele.StateElement) + newSiacoinElements[id] = ele + } + for i := range siacoinStateElements { + cau.UpdateElementProof(&siacoinStateElements[i]) + } - // update siafund element proofs - for id := range newSiafundElements { - ele := newSiafundElements[id] - cau.UpdateElementProof(&ele.StateElement) - newSiafundElements[id] = ele - } - for i := range siafundStateElements { - cau.UpdateElementProof(&siafundStateElements[i]) + // update siafund element proofs + for id := range newSiafundElements { + ele := newSiafundElements[id] + cau.UpdateElementProof(&ele.StateElement) + newSiafundElements[id] = ele + } + for i := range siafundStateElements { + cau.UpdateElementProof(&siafundStateElements[i]) + } + } else { + cau.ForEachTreeNode(func(row, column uint64, hash types.Hash256) { + treeUpdates = append(treeUpdates, TreeNodeUpdate{ + Row: row, + Column: column, + Hash: hash, + }) + }) } } @@ -296,39 +323,46 @@ func ApplyChainUpdates(tx ApplyTx, updates []*chain.ApplyUpdate) error { return fmt.Errorf("failed to add events: %w", err) } - // update the siacoin state elements - filteredStateElements := siacoinStateElements[:0] - for _, se := range siacoinStateElements { - if _, ok := spentSiacoinElements[types.SiacoinOutputID(se.ID)]; !ok { - filteredStateElements = append(filteredStateElements, se) + if !fullIndex { + // update the siacoin state elements + filteredStateElements := siacoinStateElements[:0] + for _, se := range siacoinStateElements { + if _, ok := spentSiacoinElements[types.SiacoinOutputID(se.ID)]; !ok { + filteredStateElements = append(filteredStateElements, se) + } + } + err = tx.UpdateSiacoinStateElements(filteredStateElements) + if err != nil { + return fmt.Errorf("failed to update siacoin state elements: %w", err) } - } - err = tx.UpdateSiacoinStateElements(filteredStateElements) - if err != nil { - return fmt.Errorf("failed to update siacoin state elements: %w", err) - } - // update the siafund state elements - filteredStateElements = siafundStateElements[:0] - for _, se := range siafundStateElements { - if _, ok := spentSiafundElements[types.SiafundOutputID(se.ID)]; !ok { - filteredStateElements = append(filteredStateElements, se) + // update the siafund state elements + filteredStateElements = siafundStateElements[:0] + for _, se := range siafundStateElements { + if _, ok := spentSiafundElements[types.SiafundOutputID(se.ID)]; !ok { + filteredStateElements = append(filteredStateElements, se) + } + } + if err = tx.UpdateSiafundStateElements(filteredStateElements); err != nil { + return fmt.Errorf("failed to update siafund state elements: %w", err) + } + } else { + if err := tx.UpdateStateTree(treeUpdates); err != nil { + return fmt.Errorf("failed to update state tree: %w", err) } - } - if err = tx.UpdateSiafundStateElements(filteredStateElements); err != nil { - return fmt.Errorf("failed to update siafund state elements: %w", err) } return nil } // RevertChainUpdate atomically reverts a chain update from a store -func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { +func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate, fullIndex bool) error { balances := make(map[types.Address]Balance) - newSiacoinElements := make(map[types.SiacoinOutputID]types.SiacoinElement) - newSiafundElements := make(map[types.SiafundOutputID]types.SiafundElement) - spentSiacoinElements := make(map[types.SiacoinOutputID]bool) - spentSiafundElements := make(map[types.SiafundOutputID]bool) + + var deletedSiacoinElements []types.SiacoinOutputID + var addedSiacoinElements []types.SiacoinElement + var deletedSiafundElements []types.SiafundOutputID + var addedSiafundElements []types.SiafundElement updateBalance := func(addr types.Address, fn func(b *Balance)) error { balance, ok := balances[addr] @@ -366,26 +400,50 @@ func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { } } + // revert the immature balance of each relevant address + revertedIndex := types.ChainIndex{ + Height: cru.State.Index.Height + 1, + ID: cru.Block.ID(), + } + + matured, err := tx.MaturedSiacoinElements(revertedIndex) + if err != nil { + return fmt.Errorf("failed to get matured siacoin elements: %w", err) + } + for _, se := range matured { + err := updateBalance(se.SiacoinOutput.Address, func(b *Balance) { + b.ImmatureSiacoins = b.ImmatureSiacoins.Add(se.SiacoinOutput.Value) + b.Siacoins = b.Siacoins.Sub(se.SiacoinOutput.Value) + }) + if err != nil { + return fmt.Errorf("failed to update address balance: %w", err) + } + } + var siacoinElementErr error cru.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { if siacoinElementErr != nil { return } - relevant, err := tx.AddressRelevant(se.SiacoinOutput.Address) - if err != nil { - siacoinElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) - return - } else if !relevant { - return - } else if ephemeral[se.ID] { - return + if !fullIndex { + relevant, err := tx.AddressRelevant(se.SiacoinOutput.Address) + if err != nil { + siacoinElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) + return + } else if !relevant { + return + } else if ephemeral[se.ID] { + return + } } - if !spent { - newSiacoinElements[types.SiacoinOutputID(se.ID)] = se + if spent { + // re-add any spent siacoin elements + addedSiacoinElements = append(addedSiacoinElements, se) } else { - spentSiacoinElements[types.SiacoinOutputID(se.ID)] = true + // delete any created siacoin elements + deletedSiacoinElements = append(deletedSiacoinElements, types.SiacoinOutputID(se.ID)) } siacoinElementErr = updateBalance(se.SiacoinOutput.Address, func(b *Balance) { @@ -409,27 +467,31 @@ func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { return } - relevant, err := tx.AddressRelevant(se.SiafundOutput.Address) - if err != nil { - siacoinElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) - return - } else if !relevant { - return - } else if ephemeral[se.ID] { - return + if !fullIndex { + relevant, err := tx.AddressRelevant(se.SiafundOutput.Address) + if err != nil { + siacoinElementErr = fmt.Errorf("failed to check if address is relevant: %w", err) + return + } else if !relevant { + return + } else if ephemeral[se.ID] { + return + } } - if !spent { - newSiafundElements[types.SiafundOutputID(se.ID)] = se + if spent { + // re-add any spent siafund elements + addedSiafundElements = append(addedSiafundElements, se) } else { - spentSiafundElements[types.SiafundOutputID(se.ID)] = true + // delete any created siafund elements + deletedSiafundElements = append(deletedSiafundElements, types.SiafundOutputID(se.ID)) } siafundElementErr = updateBalance(se.SiafundOutput.Address, func(b *Balance) { if spent { - b.Siafunds -= se.SiafundOutput.Value - } else { b.Siafunds += se.SiafundOutput.Value + } else { + b.Siafunds -= se.SiafundOutput.Value } }) }) @@ -448,5 +510,52 @@ func RevertChainUpdate(tx RevertTx, cru *chain.RevertUpdate) error { return fmt.Errorf("failed to update address balance: %w", err) } - return tx.RevertEvents(cru.Block.ID()) + // revert siacoin element changes + if err := tx.AddSiacoinElements(addedSiacoinElements); err != nil { + return fmt.Errorf("failed to add siacoin elements: %w", err) + } else if err := tx.RemoveSiacoinElements(deletedSiacoinElements); err != nil { + return fmt.Errorf("failed to remove siacoin elements: %w", err) + } + + // revert siafund element changes + if err := tx.AddSiafundElements(addedSiafundElements); err != nil { + return fmt.Errorf("failed to add siafund elements: %w", err) + } else if err := tx.RemoveSiafundElements(deletedSiafundElements); err != nil { + return fmt.Errorf("failed to remove siafund elements: %w", err) + } + + if !fullIndex { + // update siacoin element proofs + siacoinElements, err := tx.SiacoinStateElements() + if err != nil { + return fmt.Errorf("failed to get siacoin state elements: %w", err) + } + for i := range siacoinElements { + cru.UpdateElementProof(&siacoinElements[i]) + } + + // update siafund element proofs + siafundElements, err := tx.SiafundStateElements() + if err != nil { + return fmt.Errorf("failed to get siafund state elements: %w", err) + } + for i := range siafundElements { + cru.UpdateElementProof(&siafundElements[i]) + } + } else { + var treeUpdates []TreeNodeUpdate + cru.ForEachTreeNode(func(row, column uint64, hash types.Hash256) { + treeUpdates = append(treeUpdates, TreeNodeUpdate{ + Row: row, + Column: column, + Hash: hash, + }) + }) + + if err := tx.UpdateStateTree(treeUpdates); err != nil { + return fmt.Errorf("failed to update state tree: %w", err) + } + } + + return tx.RevertEvents(revertedIndex) } diff --git a/wallet/wallet.go b/wallet/wallet.go index 2e1dc95..510b239 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -186,12 +186,14 @@ func (*EventContractPayout) EventType() string { return EventTypeContractPayout func (e Event) MarshalJSON() ([]byte, error) { val, _ := json.Marshal(e.Data) return json.Marshal(struct { + ID types.Hash256 `json:"id"` Timestamp time.Time `json:"timestamp"` Index types.ChainIndex `json:"index"` Relevant []types.Address `json:"relevant"` Type string `json:"type"` Val json.RawMessage `json:"val"` }{ + ID: e.ID, Timestamp: e.Timestamp, Index: e.Index, Relevant: e.Relevant, @@ -203,15 +205,17 @@ func (e Event) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements json.Unarshaler. func (e *Event) UnmarshalJSON(data []byte) error { var s struct { - Timestamp time.Time - Index types.ChainIndex - Relevant []types.Address - Type string - Val json.RawMessage + ID types.Hash256 `json:"id"` + Timestamp time.Time `json:"timestamp"` + Index types.ChainIndex `json:"index"` + Relevant []types.Address `json:"relevant"` + Type string `json:"type"` + Val json.RawMessage `json:"val"` } if err := json.Unmarshal(data, &s); err != nil { return err } + e.ID = s.ID e.Timestamp = s.Timestamp e.Index = s.Index e.Relevant = s.Relevant @@ -312,6 +316,7 @@ func AppliedEvents(cs consensus.State, b types.Block, cu ChainUpdate, relevant f } events = append(events, Event{ + ID: id, Timestamp: b.Timestamp, Index: cs.Index, MaturityHeight: maturityHeight,