diff --git a/.github/actions/test/action.yml b/.github/actions/test/action.yml index de33d2c..5754f0e 100644 --- a/.github/actions/test/action.yml +++ b/.github/actions/test/action.yml @@ -7,10 +7,10 @@ runs: - name: Configure git # required for golangci-lint on Windows shell: bash run: git config --global core.autocrlf false -# - name: Lint -# uses: golangci/golangci-lint-action@v3 -# with: -# skip-cache: true + - name: Lint + uses: golangci/golangci-lint-action@v3 + with: + skip-cache: true # - name: Analyze # uses: SiaFoundation/action-golang-analysis@HEAD # with: diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8e59abc..863f1a5 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -6,6 +6,9 @@ on: branches: - master +env: + CGO_ENABLED: 1 + jobs: test: runs-on: ${{ matrix.os }} @@ -14,7 +17,7 @@ jobs: strategy: matrix: os: [ ubuntu-latest , macos-latest, windows-latest ] - go-version: [ '1.19', '1.20' ] + go-version: [ '1.20', '1.21' ] steps: - name: Configure git run: git config --global core.autocrlf false # required on Windows diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..041664e --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,159 @@ +# Based off of the example file at https://github.com/golangci/golangci-lint + +# options for analysis running +run: + # default concurrency is a available CPU number + concurrency: 4 + + # timeout for analysis, e.g. 30s, 5m, default is 1m + timeout: 600s + + # exit code when at least one issue was found, default is 1 + issues-exit-code: 1 + + # include test files or not, default is true + tests: true + + # list of build tags, all linters use it. Default is empty list. + build-tags: [] + + # which dirs to skip: issues from them won't be reported; + # can use regexp here: generated.*, regexp is applied on full path; + # default value is empty list, but default dirs are skipped independently + # from this option's value (see skip-dirs-use-default). + skip-dirs: + - cover + + # default is true. Enables skipping of directories: + # vendor$, third_party$, testdata$, examples$, Godeps$, builtin$ + skip-dirs-use-default: true + + # which files to skip: they will be analyzed, but issues from them + # won't be reported. Default value is empty list, but there is + # no need to include all autogenerated files, we confidently recognize + # autogenerated files. If it's not please let us know. + skip-files: [] + +# output configuration options +output: + # colored-line-number|line-number|json|tab|checkstyle|code-climate, default is "colored-line-number" + format: colored-line-number + + # print lines of code with issue, default is true + print-issued-lines: true + + # print linter name in the end of issue text, default is true + print-linter-name: true + +# all available settings of specific linters +linters-settings: + ## Enabled linters: + govet: + # report about shadowed variables + check-shadowing: false + disable-all: false + + tagliatelle: + case: + rules: + json: goCamel + yaml: goCamel + + + gocritic: + # Which checks should be enabled; can't be combined with 'disabled-checks'; + # See https://go-critic.github.io/overview#checks-overview + # To check which checks are enabled run `GL_DEBUG=gocritic golangci-lint run` + # By default list of stable checks is used. + enabled-checks: + - argOrder # Diagnostic options + - badCond + - caseOrder + - dupArg + - dupBranchBody + - dupCase + - dupSubExpr + - nilValReturn + - offBy1 + - weakCond + - boolExprSimplify # Style options here and below. + - builtinShadow + - emptyFallthrough + - hexLiteral + - underef + - equalFold + revive: + ignore-generated-header: true + rules: + - name: blank-imports + disabled: false + - name: bool-literal-in-expr + disabled: false + - name: confusing-results + disabled: false + - name: constant-logical-expr + disabled: false + - name: context-as-argument + disabled: false + - name: exported + disabled: false + - name: errorf + disabled: false + - name: if-return + disabled: false + - name: increment-decrement + disabled: false + - name: modifies-value-receiver + disabled: false + - name: optimize-operands-order + disabled: false + - name: range-val-in-closure + disabled: false + - name: struct-tag + disabled: false + - name: superfluous-else + disabled: false + - name: time-equal + disabled: false + - name: unexported-naming + disabled: false + - name: unexported-return + disabled: false + - name: unnecessary-stmt + disabled: false + - name: unreachable-code + disabled: false + - name: package-comments + disabled: true + +linters: + disable-all: true + fast: false + enable: + - tagliatelle + - gocritic + - gofmt + - revive + - govet + - misspell + - typecheck + - whitespace + +issues: + # Maximum issues count per one linter. Set to 0 to disable. Default is 50. + max-issues-per-linter: 0 + + # Maximum count of issues with the same text. Set to 0 to disable. Default is 3. + max-same-issues: 0 + + # List of regexps of issue texts to exclude, empty list by default. + # But independently from this option we use default exclude patterns, + # it can be disabled by `exclude-use-default: false`. To list all + # excluded by default patterns execute `golangci-lint run --help` + exclude: [] + + # Independently from option `exclude` we use default exclude patterns, + # it can be disabled by this option. To list all + # excluded by default patterns execute `golangci-lint run --help`. + # Default value for this option is true. + exclude-use-default: false \ No newline at end of file diff --git a/api/api_test.go b/api/api_test.go index 37ad08d..71f79cf 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -3,6 +3,7 @@ package api_test import ( "net" "net/http" + "path/filepath" "testing" "time" @@ -13,9 +14,9 @@ import ( "go.sia.tech/coreutils/syncer" "go.sia.tech/jape" "go.sia.tech/walletd/api" - "go.sia.tech/walletd/internal/syncerutil" - "go.sia.tech/walletd/internal/walletutil" + "go.sia.tech/walletd/persist/sqlite" "go.sia.tech/walletd/wallet" + "go.uber.org/zap/zaptest" "lukechampine.com/frand" ) @@ -48,6 +49,8 @@ func runServer(cm api.ChainManager, s api.Syncer, wm api.WalletManager) (*api.Cl } func TestWallet(t *testing.T) { + log := zaptest.NewLogger(t) + n, genesisBlock := testNetwork() giftPrivateKey := types.GeneratePrivateKey() giftAddress := types.StandardUnlockHash(giftPrivateKey.PublicKey()) @@ -62,7 +65,17 @@ func TestWallet(t *testing.T) { t.Fatal(err) } cm := chain.NewManager(dbstore, tipState) - wm := walletutil.NewEphemeralWalletManager(cm) + + ws, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer ws.Close() + wm, err := wallet.NewManager(cm, ws, log.Named("wallet")) + if err != nil { + t.Fatal(err) + } + sav := wallet.NewSeedAddressVault(wallet.NewSeed(), 0, 20) c, shutdown := runServer(cm, nil, wm) defer shutdown() @@ -70,14 +83,14 @@ func TestWallet(t *testing.T) { t.Fatal(err) } wc := c.Wallet("primary") - if err := wc.Subscribe(0); err != nil { + if err := c.Resubscribe(0); err != nil { t.Fatal(err) } balance, err := wc.Balance() if err != nil { t.Fatal(err) - } else if !balance.Siacoins.IsZero() || balance.Siafunds != 0 { + } else if !balance.Siacoins.IsZero() || !balance.ImmatureSiacoins.IsZero() || balance.Siafunds != 0 { t.Fatal("balance should be 0") } @@ -150,10 +163,12 @@ func TestWallet(t *testing.T) { t.Fatal(err) } else if !balance.Siacoins.Equals(types.Siacoins(1)) { t.Error("balance should be 1 SC, got", balance.Siacoins) + } else if !balance.ImmatureSiacoins.IsZero() { + t.Error("immature balance should be 0 SC, got", balance.ImmatureSiacoins) } // transaction should appear in history - events, err = wc.Events(0, -1) + events, err = wc.Events(0, 100) if err != nil { t.Fatal(err) } else if len(events) == 0 { @@ -166,9 +181,63 @@ func TestWallet(t *testing.T) { } else if len(outputs) != 2 { t.Error("should have two UTXOs, got", len(outputs)) } + + // mine a block to add an immature balance + cs = cm.TipState() + b = types.Block{ + ParentID: cs.Index.ID, + Timestamp: types.CurrentTimestamp(), + MinerPayouts: []types.SiacoinOutput{{Address: addr, Value: cs.BlockReward()}}, + } + for b.ID().CmpWork(cs.ChildTarget) < 0 { + b.Nonce += cs.NonceFactor() + } + if err := cm.AddBlocks([]types.Block{b}); err != nil { + t.Fatal(err) + } + + // get new balance + balance, err = wc.Balance() + if err != nil { + t.Fatal(err) + } else if !balance.Siacoins.Equals(types.Siacoins(1)) { + t.Error("balance should be 1 SC, got", balance.Siacoins) + } else if !balance.ImmatureSiacoins.Equals(b.MinerPayouts[0].Value) { + t.Errorf("immature balance should be %d SC, got %d SC", b.MinerPayouts[0].Value, balance.ImmatureSiacoins) + } + + // mine enough blocks for the miner payout to mature + expectedBalance := types.Siacoins(1).Add(b.MinerPayouts[0].Value) + target := cs.MaturityHeight() + for cs.Index.Height < target { + cs = cm.TipState() + b := types.Block{ + ParentID: cs.Index.ID, + Timestamp: types.CurrentTimestamp(), + MinerPayouts: []types.SiacoinOutput{{Address: types.VoidAddress, Value: cs.BlockReward()}}, + } + for b.ID().CmpWork(cs.ChildTarget) < 0 { + b.Nonce += cs.NonceFactor() + } + if err := cm.AddBlocks([]types.Block{b}); err != nil { + t.Fatal(err) + } + } + + // get new balance + balance, err = wc.Balance() + if err != nil { + t.Fatal(err) + } else if !balance.Siacoins.Equals(expectedBalance) { + t.Errorf("balance should be %d, got %d", expectedBalance, balance.Siacoins) + } else if !balance.ImmatureSiacoins.IsZero() { + t.Error("immature balance should be 0 SC, got", balance.ImmatureSiacoins) + } } func TestV2(t *testing.T) { + log := zaptest.NewLogger(t) + n, genesisBlock := testNetwork() // gift primary wallet some coins primaryPrivateKey := types.GeneratePrivateKey() @@ -184,7 +253,15 @@ func TestV2(t *testing.T) { t.Fatal(err) } cm := chain.NewManager(dbstore, tipState) - wm := walletutil.NewEphemeralWalletManager(cm) + ws, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer ws.Close() + wm, err := wallet.NewManager(cm, ws, log.Named("wallet")) + if err != nil { + t.Fatal(err) + } c, shutdown := runServer(cm, nil, wm) defer shutdown() if err := c.AddWallet("primary", nil); err != nil { @@ -194,9 +271,6 @@ func TestV2(t *testing.T) { if err := primary.AddAddress(primaryAddress, nil); err != nil { t.Fatal(err) } - if err := primary.Subscribe(0); err != nil { - t.Fatal(err) - } if err := c.AddWallet("secondary", nil); err != nil { t.Fatal(err) } @@ -204,7 +278,7 @@ func TestV2(t *testing.T) { if err := secondary.AddAddress(secondaryAddress, nil); err != nil { t.Fatal(err) } - if err := secondary.Subscribe(0); err != nil { + if err := c.Resubscribe(0); err != nil { t.Fatal(err) } @@ -373,6 +447,7 @@ func TestV2(t *testing.T) { } func TestP2P(t *testing.T) { + logger := zaptest.NewLogger(t) n, genesisBlock := testNetwork() // gift primary wallet some coins primaryPrivateKey := types.GeneratePrivateKey() @@ -387,14 +462,23 @@ func TestP2P(t *testing.T) { if err != nil { t.Fatal(err) } + log1 := logger.Named("one") cm1 := chain.NewManager(dbstore1, tipState) - wm1 := walletutil.NewEphemeralWalletManager(cm1) + store1, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log1.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer store1.Close() + wm1, err := wallet.NewManager(cm1, store1, log1.Named("wallet")) + if err != nil { + t.Fatal(err) + } l1, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) } defer l1.Close() - s1 := syncer.New(l1, cm1, syncerutil.NewEphemeralPeerStore(), gateway.Header{ + s1 := syncer.New(l1, cm1, store1, gateway.Header{ GenesisID: genesisBlock.ID(), UniqueID: gateway.GenerateUniqueID(), NetAddress: l1.Addr().String(), @@ -409,7 +493,7 @@ func TestP2P(t *testing.T) { if err := primary.AddAddress(primaryAddress, nil); err != nil { t.Fatal(err) } - if err := primary.Subscribe(0); err != nil { + if err := c1.Resubscribe(0); err != nil { t.Fatal(err) } @@ -417,18 +501,27 @@ func TestP2P(t *testing.T) { if err != nil { t.Fatal(err) } + log2 := logger.Named("two") cm2 := chain.NewManager(dbstore2, tipState) - wm2 := walletutil.NewEphemeralWalletManager(cm2) + store2, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "wallets.db"), log2.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer store2.Close() + wm2, err := wallet.NewManager(cm2, store2, log2.Named("wallet")) + if err != nil { + t.Fatal(err) + } l2, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) } defer l2.Close() - s2 := syncer.New(l2, cm2, syncerutil.NewEphemeralPeerStore(), gateway.Header{ + s2 := syncer.New(l2, cm2, store2, gateway.Header{ GenesisID: genesisBlock.ID(), UniqueID: gateway.GenerateUniqueID(), NetAddress: l2.Addr().String(), - }) + }, syncer.WithLogger(zaptest.NewLogger(t))) go s2.Run() c2, shutdown2 := runServer(cm2, s2, wm2) defer shutdown2() @@ -439,7 +532,7 @@ func TestP2P(t *testing.T) { if err := secondary.AddAddress(secondaryAddress, nil); err != nil { t.Fatal(err) } - if err := secondary.Subscribe(0); err != nil { + if err := c2.Resubscribe(0); err != nil { t.Fatal(err) } diff --git a/api/client.go b/api/client.go index 973f194..8365495 100644 --- a/api/client.go +++ b/api/client.go @@ -105,6 +105,13 @@ func (c *Client) Wallet(name string) *WalletClient { return &WalletClient{c: c.c, name: name} } +// Resubscribe subscribes the wallet to consensus updates, starting at the +// specified height. +func (c *Client) Resubscribe(height uint64) (err error) { + err = c.c.POST("/resubscribe", height, nil) + return +} + // A WalletClient provides methods for interacting with a particular wallet on a // walletd API server. type WalletClient struct { @@ -112,13 +119,6 @@ type WalletClient struct { name string } -// Subscribe subscribes the wallet to consensus updates, starting at the -// specified height. This can only be done once. -func (c *WalletClient) Subscribe(height uint64) (err error) { - err = c.c.POST(fmt.Sprintf("/wallets/%v/subscribe", c.name), height, nil) - return -} - // AddAddress adds the specified address and associated metadata to the // wallet. func (c *WalletClient) AddAddress(addr types.Address, info json.RawMessage) (err error) { diff --git a/api/server.go b/api/server.go index 898ba29..6949fba 100644 --- a/api/server.go +++ b/api/server.go @@ -3,7 +3,6 @@ package api import ( "encoding/json" "errors" - "fmt" "net/http" "reflect" "sync" @@ -46,17 +45,23 @@ type ( // A WalletManager manages wallets, keyed by name. WalletManager interface { + Subscribe(startHeight uint64) error + AddWallet(name string, info json.RawMessage) error DeleteWallet(name string) error - Wallets() map[string]json.RawMessage - SubscribeWallet(name string, startHeight uint64) error + Wallets() (map[string]json.RawMessage, error) AddAddress(name string, addr types.Address, info json.RawMessage) error RemoveAddress(name string, addr types.Address) error Addresses(name string) (map[types.Address]json.RawMessage, error) Events(name string, offset, limit int) ([]wallet.Event, error) - UnspentOutputs(name string) ([]types.SiacoinElement, []types.SiafundElement, error) + UnspentSiacoinOutputs(name string) ([]types.SiacoinElement, error) + UnspentSiafundOutputs(name string) ([]types.SiafundElement, error) + WalletBalance(walletID string) (sc, immatureSC types.Currency, sf uint64, err error) Annotate(name string, pool []types.Transaction) ([]wallet.PoolTransaction, error) + + Reserve(ids []types.Hash256, duration time.Duration) error + AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) } ) @@ -87,7 +92,8 @@ func (s *server) syncerPeersHandler(jc jape.Context) { for _, p := range s.s.Peers() { info, ok := s.s.PeerInfo(p.Addr()) if !ok { - continue + jc.Error(errors.New("peer not found"), http.StatusNotFound) + return } peers = append(peers, GatewayPeer{ Addr: p.Addr(), @@ -165,7 +171,11 @@ func (s *server) txpoolBroadcastHandler(jc jape.Context) { } func (s *server) walletsHandler(jc jape.Context) { - jc.Encode(s.wm.Wallets()) + wallets, err := s.wm.Wallets() + if jc.Check("couldn't load wallets", err) != nil { + return + } + jc.Encode(wallets) } func (s *server) walletsNameHandlerPUT(jc jape.Context) { @@ -187,12 +197,11 @@ func (s *server) walletsNameHandlerDELETE(jc jape.Context) { } } -func (s *server) walletsSubscribeHandler(jc jape.Context) { - var name string +func (s *server) resubscribeHandler(jc jape.Context) { var height uint64 - if jc.DecodeParam("name", &name) != nil || jc.Decode(&height) != nil { + if jc.Decode(&height) != nil { return - } else if jc.Check("couldn't subscribe wallet", s.wm.SubscribeWallet(name, height)) != nil { + } else if jc.Check("couldn't subscribe wallet", s.wm.Subscribe(height)) != nil { return } } @@ -235,26 +244,14 @@ func (s *server) walletsBalanceHandler(jc jape.Context) { if jc.DecodeParam("name", &name) != nil { return } - scos, sfos, err := s.wm.UnspentOutputs(name) - if jc.Check("couldn't load outputs", err) != nil { + + sc, isc, sf, err := s.wm.WalletBalance(name) + if jc.Check("couldn't load balance", err) != nil { return } - height := s.cm.TipState().Index.Height - var sc, immature types.Currency - var sf uint64 - for _, sco := range scos { - if height >= sco.MaturityHeight { - sc = sc.Add(sco.SiacoinOutput.Value) - } else { - immature = immature.Add(sco.SiacoinOutput.Value) - } - } - for _, sfo := range sfos { - sf += sfo.SiafundOutput.Value - } jc.Encode(WalletBalanceResponse{ Siacoins: sc, - ImmatureSiacoins: immature, + ImmatureSiacoins: isc, Siafunds: sf, }) } @@ -289,8 +286,13 @@ func (s *server) walletsOutputsHandler(jc jape.Context) { if jc.DecodeParam("name", &name) != nil { return } - scos, sfos, err := s.wm.UnspentOutputs(name) - if jc.Check("couldn't load outputs", err) != nil { + scos, err := s.wm.UnspentSiacoinOutputs(name) + if jc.Check("couldn't load siacoin outputs", err) != nil { + return + } + + sfos, err := s.wm.UnspentSiafundOutputs(name) + if jc.Check("couldn't load siafund outputs", err) != nil { return } jc.Encode(WalletOutputsResponse{ @@ -300,44 +302,23 @@ func (s *server) walletsOutputsHandler(jc jape.Context) { } func (s *server) walletsReserveHandler(jc jape.Context) { - var name string var wrr WalletReserveRequest - if jc.DecodeParam("name", &name) != nil || jc.Decode(&wrr) != nil { + if jc.Decode(&wrr) != nil { return } - s.mu.Lock() + ids := make([]types.Hash256, 0, len(wrr.SiacoinOutputs)+len(wrr.SiafundOutputs)) for _, id := range wrr.SiacoinOutputs { - if s.used[types.Hash256(id)] { - s.mu.Unlock() - jc.Error(fmt.Errorf("output %v is already reserved", id), http.StatusBadRequest) - return - } - s.used[types.Hash256(id)] = true + ids = append(ids, types.Hash256(id)) } + for _, id := range wrr.SiafundOutputs { - if s.used[types.Hash256(id)] { - s.mu.Unlock() - jc.Error(fmt.Errorf("output %v is already reserved", id), http.StatusBadRequest) - return - } - s.used[types.Hash256(id)] = true + ids = append(ids, types.Hash256(id)) } - s.mu.Unlock() - if wrr.Duration == 0 { - wrr.Duration = 10 * time.Minute + if jc.Check("couldn't reserve outputs", s.wm.Reserve(ids, wrr.Duration)) != nil { + return } - time.AfterFunc(wrr.Duration, func() { - s.mu.Lock() - defer s.mu.Unlock() - for _, id := range wrr.SiacoinOutputs { - delete(s.used, types.Hash256(id)) - } - for _, id := range wrr.SiafundOutputs { - delete(s.used, types.Hash256(id)) - } - }) } func (s *server) walletsReleaseHandler(jc jape.Context) { @@ -412,7 +393,7 @@ func (s *server) walletsFundHandler(jc jape.Context) { if jc.DecodeParam("name", &name) != nil || jc.Decode(&wfr) != nil { return } - utxos, _, err := s.wm.UnspentOutputs(name) + utxos, err := s.wm.UnspentSiacoinOutputs(name) if jc.Check("couldn't get utxos to fund transaction", err) != nil { return } @@ -486,7 +467,7 @@ func (s *server) walletsFundSFHandler(jc jape.Context) { if jc.DecodeParam("name", &name) != nil || jc.Decode(&wfr) != nil { return } - _, utxos, err := s.wm.UnspentOutputs(name) + utxos, err := s.wm.UnspentSiafundOutputs(name) if jc.Check("couldn't get utxos to fund transaction", err) != nil { return } @@ -524,10 +505,11 @@ func NewServer(cm ChainManager, s Syncer, wm WalletManager) http.Handler { "GET /txpool/fee": srv.txpoolFeeHandler, "POST /txpool/broadcast": srv.txpoolBroadcastHandler, + "POST /resubscribe": srv.resubscribeHandler, + "GET /wallets": srv.walletsHandler, "PUT /wallets/:name": srv.walletsNameHandlerPUT, "DELETE /wallets/:name": srv.walletsNameHandlerDELETE, - "POST /wallets/:name/subscribe": srv.walletsSubscribeHandler, "PUT /wallets/:name/addresses/:addr": srv.walletsAddressHandlerPUT, "DELETE /wallets/:name/addresses/:addr": srv.walletsAddressHandlerDELETE, "GET /wallets/:name/addresses": srv.walletsAddressesHandlerGET, diff --git a/cmd/walletd/main.go b/cmd/walletd/main.go index e87d9e8..2b69192 100644 --- a/cmd/walletd/main.go +++ b/cmd/walletd/main.go @@ -11,6 +11,8 @@ import ( "go.sia.tech/core/types" "go.sia.tech/walletd/wallet" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" "golang.org/x/term" "lukechampine.com/flagg" "lukechampine.com/frand" @@ -157,12 +159,37 @@ func main() { cmd.Usage() return } + + if err := os.MkdirAll(dir, 0700); err != nil { + log.Fatal(err) + } + apiPassword := getAPIPassword() l, err := net.Listen("tcp", apiAddr) if err != nil { log.Fatal(err) } - n, err := newNode(gatewayAddr, dir, network, upnp) + + // configure console logging note: this is configured before anything else + // to have consistent logging. File logging will be added after the cli + // flags and config is parsed + consoleCfg := zap.NewProductionEncoderConfig() + consoleCfg.TimeKey = "" // prevent duplicate timestamps + consoleCfg.EncodeTime = zapcore.RFC3339TimeEncoder + consoleCfg.EncodeDuration = zapcore.StringDurationEncoder + consoleCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder + consoleCfg.StacktraceKey = "" + consoleCfg.CallerKey = "" + consoleEncoder := zapcore.NewConsoleEncoder(consoleCfg) + + // only log info messages to console unless stdout logging is enabled + consoleCore := zapcore.NewCore(consoleEncoder, zapcore.Lock(os.Stdout), zap.NewAtomicLevelAt(zap.InfoLevel)) + logger := zap.New(consoleCore, zap.AddCaller()) + defer logger.Sync() + // redirect stdlib log to zap + zap.RedirectStdLog(logger.Named("stdlib")) + + n, err := newNode(gatewayAddr, dir, network, upnp, logger) if err != nil { log.Fatal(err) } @@ -204,7 +231,6 @@ func main() { seed := loadTestnetSeed(seed) c := initTestnetClient(apiAddr, network, seed) runTestnetMiner(c, seed) - case balanceCmd: if len(cmd.Args()) != 0 { cmd.Usage() diff --git a/cmd/walletd/node.go b/cmd/walletd/node.go index eaf0d5d..b5a6808 100644 --- a/cmd/walletd/node.go +++ b/cmd/walletd/node.go @@ -3,7 +3,7 @@ package main import ( "context" "errors" - "log" + "fmt" "net" "path/filepath" "strconv" @@ -15,8 +15,9 @@ import ( "go.sia.tech/coreutils" "go.sia.tech/coreutils/chain" "go.sia.tech/coreutils/syncer" - "go.sia.tech/walletd/internal/syncerutil" - "go.sia.tech/walletd/internal/walletutil" + "go.sia.tech/walletd/persist/sqlite" + "go.sia.tech/walletd/wallet" + "go.uber.org/zap" "lukechampine.com/upnp" ) @@ -82,14 +83,23 @@ var anagamiBootstrap = []string{ } type node struct { - cm *chain.Manager - s *syncer.Syncer - wm *walletutil.JSONWalletManager + chainStore *coreutils.BoltChainDB + cm *chain.Manager + + store *sqlite.Store + s *syncer.Syncer + wm *wallet.Manager Start func() (stop func()) } -func newNode(addr, dir string, chainNetwork string, useUPNP bool) (*node, error) { +// Close shuts down the node and closes its database. +func (n *node) Close() error { + n.chainStore.Close() + return n.store.Close() +} + +func newNode(addr, dir string, chainNetwork string, useUPNP bool, log *zap.Logger) (*node, error) { var network *consensus.Network var genesisBlock types.Block var bootstrapPeers []string @@ -109,11 +119,11 @@ func newNode(addr, dir string, chainNetwork string, useUPNP bool) (*node, error) bdb, err := coreutils.OpenBoltChainDB(filepath.Join(dir, "consensus.db")) if err != nil { - log.Fatal(err) + return nil, fmt.Errorf("failed to open consensus database: %w", err) } dbstore, tipState, err := chain.NewDBStore(bdb, network, genesisBlock) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create chain store: %w", err) } cm := chain.NewManager(dbstore, tipState) @@ -126,21 +136,21 @@ func newNode(addr, dir string, chainNetwork string, useUPNP bool) (*node, error) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if d, err := upnp.Discover(ctx); err != nil { - log.Println("WARN: couldn't discover UPnP device:", err) + log.Debug("couldn't discover UPnP router", zap.Error(err)) } else { _, portStr, _ := net.SplitHostPort(addr) port, _ := strconv.Atoi(portStr) if !d.IsForwarded(uint16(port), "TCP") { if err := d.Forward(uint16(port), "TCP", "walletd"); err != nil { - log.Println("WARN: couldn't forward port:", err) + log.Debug("couldn't forward port", zap.Error(err)) } else { - log.Println("p2p: Forwarded port", port) + log.Debug("upnp: forwarded p2p port", zap.Int("port", port)) } } if ip, err := d.ExternalIP(); err != nil { - log.Println("WARN: couldn't determine external IP:", err) + log.Debug("couldn't determine external IP", zap.Error(err)) } else { - log.Println("p2p: External IP is", ip) + log.Debug("external IP is", zap.String("ip", ip)) syncerAddr = net.JoinHostPort(ip, portStr) } } @@ -151,30 +161,31 @@ func newNode(addr, dir string, chainNetwork string, useUPNP bool) (*node, error) syncerAddr = net.JoinHostPort("127.0.0.1", port) } - ps, err := syncerutil.NewJSONPeerStore(filepath.Join(dir, "peers.json")) + store, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) if err != nil { - log.Fatal(err) + return nil, fmt.Errorf("failed to open wallet database: %w", err) } + for _, peer := range bootstrapPeers { - ps.AddPeer(peer) + store.AddPeer(peer) } header := gateway.Header{ GenesisID: genesisBlock.ID(), UniqueID: gateway.GenerateUniqueID(), NetAddress: syncerAddr, } - - s := syncer.New(l, cm, ps, header) - - wm, err := walletutil.NewJSONWalletManager(dir, cm) + s := syncer.New(l, cm, store, header, syncer.WithLogger(log.Named("syncer"))) + wm, err := wallet.NewManager(cm, store, log.Named("wallet")) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create wallet manager: %w", err) } return &node{ - cm: cm, - s: s, - wm: wm, + chainStore: bdb, + cm: cm, + store: store, + s: s, + wm: wm, Start: func() func() { ch := make(chan struct{}) go func() { diff --git a/cmd/walletd/testnet.go b/cmd/walletd/testnet.go index 254d7a0..b87ad21 100644 --- a/cmd/walletd/testnet.go +++ b/cmd/walletd/testnet.go @@ -115,12 +115,13 @@ func initTestnetClient(addr string, network string, seed wallet.Seed) *api.Clien } else if err := wc.AddAddress(ourAddr, nil); err != nil { fmt.Println() log.Fatal(err) - } else if err := wc.Subscribe(0); err != nil { + } else if err := c.Resubscribe(0); err != nil { fmt.Println() log.Fatal(err) } fmt.Println("done.") } + return c } diff --git a/go.mod b/go.mod index e6df28b..084ce0b 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,12 @@ module go.sia.tech/walletd go 1.21 require ( - go.sia.tech/core v0.2.1-0.20240130145801-8067f34b2ecc + github.com/mattn/go-sqlite3 v1.14.21 + go.sia.tech/core v0.2.1 go.sia.tech/coreutils v0.0.0-20240130201319-8303550528d7 go.sia.tech/jape v0.9.0 go.sia.tech/web/walletd v0.16.0 + go.uber.org/zap v1.26.0 golang.org/x/term v0.6.0 lukechampine.com/flagg v1.1.1 lukechampine.com/frand v1.4.2 @@ -20,7 +22,6 @@ require ( go.sia.tech/mux v1.2.0 // indirect go.sia.tech/web v0.0.0-20230628194305-c6e1696bad89 // indirect go.uber.org/multierr v1.10.0 // indirect - go.uber.org/zap v1.26.0 // indirect golang.org/x/crypto v0.0.0-20220507011949-2cf3adece122 // indirect golang.org/x/sys v0.6.0 // indirect golang.org/x/tools v0.7.0 // indirect diff --git a/go.sum b/go.sum index f3a3ac8..f29db2e 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/mattn/go-sqlite3 v1.14.21 h1:IXocQLOykluc3xPE0Lvy8FtggMz1G+U3mEjg+0zGizc= +github.com/mattn/go-sqlite3 v1.14.21/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= @@ -12,6 +14,8 @@ go.etcd.io/bbolt v1.3.8 h1:xs88BrvEv273UsB79e0hcVrlUWmS0a8upikMFhSyAtA= go.etcd.io/bbolt v1.3.8/go.mod h1:N9Mkw9X8x5fupy0IKsmuqVtoGDyxsaDlbk4Rd05IAQw= go.sia.tech/core v0.2.1-0.20240130145801-8067f34b2ecc h1:oUCCTOatQIwYkJ2FUWRvJtgU+i/BwlzmzCxoSvmmJVQ= go.sia.tech/core v0.2.1-0.20240130145801-8067f34b2ecc/go.mod h1:3EoY+rR78w1/uGoXXVqcYdwSjSJKuEMI5bL7WROA27Q= +go.sia.tech/core v0.2.1 h1:CqmMd+T5rAhC+Py3NxfvGtvsj/GgwIqQHHVrdts/LqY= +go.sia.tech/core v0.2.1/go.mod h1:3EoY+rR78w1/uGoXXVqcYdwSjSJKuEMI5bL7WROA27Q= go.sia.tech/coreutils v0.0.0-20240130201319-8303550528d7 h1:G2l6fRzAdNZy2z7+FhoG2y8ARtFpR6PkXXTB5tkdfZ8= go.sia.tech/coreutils v0.0.0-20240130201319-8303550528d7/go.mod h1:3Mb206QDd3NtRiaHZ2kN87/HKXhcBF6lHVatS7PkViY= go.sia.tech/jape v0.9.0 h1:kWgMFqALYhLMJYOwWBgJda5ko/fi4iZzRxHRP7pp8NY= diff --git a/internal/syncerutil/store.go b/internal/syncerutil/store.go deleted file mode 100644 index 6456c6b..0000000 --- a/internal/syncerutil/store.go +++ /dev/null @@ -1,208 +0,0 @@ -package syncerutil - -import ( - "encoding/json" - "net" - "os" - "sync" - "time" - - "go.sia.tech/coreutils/syncer" -) - -type peerBan struct { - Expiry time.Time `json:"expiry"` - Reason string `json:"reason"` -} - -// EphemeralPeerStore implements PeerStore with an in-memory map. -type EphemeralPeerStore struct { - peers map[string]syncer.PeerInfo - bans map[string]peerBan - mu sync.Mutex -} - -func (eps *EphemeralPeerStore) banned(peer string) bool { - host, _, err := net.SplitHostPort(peer) - if err != nil { - return false // shouldn't happen - } - for _, s := range []string{ - peer, // 1.2.3.4:5678 - syncer.Subnet(host, "/32"), // 1.2.3.4:* - syncer.Subnet(host, "/24"), // 1.2.3.* - syncer.Subnet(host, "/16"), // 1.2.* - syncer.Subnet(host, "/8"), // 1.* - } { - if b, ok := eps.bans[s]; ok { - if time.Until(b.Expiry) <= 0 { - delete(eps.bans, s) - } else { - return true - } - } - } - return false -} - -// AddPeer implements PeerStore. -func (eps *EphemeralPeerStore) AddPeer(peer string) { - eps.mu.Lock() - defer eps.mu.Unlock() - if _, ok := eps.peers[peer]; !ok { - eps.peers[peer] = syncer.PeerInfo{FirstSeen: time.Now()} - } -} - -// Peers implements PeerStore. -func (eps *EphemeralPeerStore) Peers() []string { - eps.mu.Lock() - defer eps.mu.Unlock() - var peers []string - for p := range eps.peers { - if !eps.banned(p) { - peers = append(peers, p) - } - } - return peers -} - -// UpdatePeerInfo implements PeerStore. -func (eps *EphemeralPeerStore) UpdatePeerInfo(peer string, fn func(*syncer.PeerInfo)) { - eps.mu.Lock() - defer eps.mu.Unlock() - info, ok := eps.peers[peer] - if !ok { - return - } - fn(&info) - eps.peers[peer] = info -} - -// PeerInfo implements PeerStore. -func (eps *EphemeralPeerStore) PeerInfo(peer string) (syncer.PeerInfo, bool) { - eps.mu.Lock() - defer eps.mu.Unlock() - info, ok := eps.peers[peer] - return info, ok -} - -// Ban implements PeerStore. -func (eps *EphemeralPeerStore) Ban(peer string, duration time.Duration, reason string) { - eps.mu.Lock() - defer eps.mu.Unlock() - // canonicalize - if _, ipnet, err := net.ParseCIDR(peer); err == nil { - peer = ipnet.String() - } - eps.bans[peer] = peerBan{Expiry: time.Now().Add(duration), Reason: reason} -} - -// Banned implements PeerStore. -func (eps *EphemeralPeerStore) Banned(peer string) bool { - eps.mu.Lock() - defer eps.mu.Unlock() - return eps.banned(peer) -} - -// NewEphemeralPeerStore initializes an EphemeralPeerStore. -func NewEphemeralPeerStore() *EphemeralPeerStore { - return &EphemeralPeerStore{ - peers: make(map[string]syncer.PeerInfo), - bans: make(map[string]peerBan), - } -} - -type jsonPersist struct { - Peers map[string]syncer.PeerInfo `json:"peers"` - Bans map[string]peerBan `json:"bans"` -} - -// JSONPeerStore implements PeerStore with a JSON file on disk. -type JSONPeerStore struct { - *EphemeralPeerStore - path string - lastSave time.Time -} - -func (jps *JSONPeerStore) load() error { - f, err := os.Open(jps.path) - if os.IsNotExist(err) { - return nil - } else if err != nil { - return err - } - defer f.Close() - var p jsonPersist - if err := json.NewDecoder(f).Decode(&p); err != nil { - return err - } - jps.EphemeralPeerStore.peers = p.Peers - jps.EphemeralPeerStore.bans = p.Bans - return nil -} - -func (jps *JSONPeerStore) save() error { - jps.EphemeralPeerStore.mu.Lock() - defer jps.EphemeralPeerStore.mu.Unlock() - if time.Since(jps.lastSave) < 5*time.Second { - return nil - } - defer func() { jps.lastSave = time.Now() }() - // clear out expired bans - for peer, b := range jps.EphemeralPeerStore.bans { - if time.Until(b.Expiry) <= 0 { - delete(jps.EphemeralPeerStore.bans, peer) - } - } - p := jsonPersist{ - Peers: jps.EphemeralPeerStore.peers, - Bans: jps.EphemeralPeerStore.bans, - } - js, err := json.MarshalIndent(p, "", " ") - if err != nil { - return err - } - f, err := os.OpenFile(jps.path+"_tmp", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0660) - if err != nil { - return err - } - defer f.Close() - if _, err = f.Write(js); err != nil { - return err - } else if f.Sync(); err != nil { - return err - } else if f.Close(); err != nil { - return err - } else if err := os.Rename(jps.path+"_tmp", jps.path); err != nil { - return err - } - return nil -} - -// AddPeer implements PeerStore. -func (jps *JSONPeerStore) AddPeer(peer string) { - jps.EphemeralPeerStore.AddPeer(peer) - jps.save() -} - -// UpdatePeerInfo implements PeerStore. -func (jps *JSONPeerStore) UpdatePeerInfo(peer string, fn func(*syncer.PeerInfo)) { - jps.EphemeralPeerStore.UpdatePeerInfo(peer, fn) - jps.save() -} - -// Ban implements PeerStore. -func (jps *JSONPeerStore) Ban(peer string, duration time.Duration, reason string) { - jps.EphemeralPeerStore.Ban(peer, duration, reason) - jps.save() -} - -// NewJSONPeerStore returns a JSONPeerStore backed by the specified file. -func NewJSONPeerStore(path string) (*JSONPeerStore, error) { - jps := &JSONPeerStore{ - EphemeralPeerStore: NewEphemeralPeerStore(), - path: path, - } - return jps, jps.load() -} diff --git a/internal/walletutil/manager.go b/internal/walletutil/manager.go deleted file mode 100644 index be5e962..0000000 --- a/internal/walletutil/manager.go +++ /dev/null @@ -1,403 +0,0 @@ -package walletutil - -import ( - "encoding/json" - "errors" - "os" - "path/filepath" - "sync" - - "go.sia.tech/coreutils/chain" - "go.sia.tech/core/types" - "go.sia.tech/walletd/wallet" -) - -var errNoWallet = errors.New("wallet does not exist") - -type ChainManager interface { - AddSubscriber(s chain.Subscriber, tip types.ChainIndex) error - RemoveSubscriber(s chain.Subscriber) - BestIndex(height uint64) (types.ChainIndex, bool) -} - -type managedEphemeralWallet struct { - w *EphemeralStore - info json.RawMessage - subscribed bool -} - -// An EphemeralWalletManager manages multiple ephemeral wallet stores. -type EphemeralWalletManager struct { - cm ChainManager - mu sync.Mutex - wallets map[string]*managedEphemeralWallet -} - -// AddWallet implements api.WalletManager. -func (wm *EphemeralWalletManager) AddWallet(name string, info json.RawMessage) error { - wm.mu.Lock() - defer wm.mu.Unlock() - if _, ok := wm.wallets[name]; ok { - return errors.New("wallet already exists") - } - store := NewEphemeralStore() - wm.wallets[name] = &managedEphemeralWallet{store, info, false} - return nil -} - -// DeleteWallet implements api.WalletManager. -func (wm *EphemeralWalletManager) DeleteWallet(name string) error { - wm.mu.Lock() - defer wm.mu.Unlock() - delete(wm.wallets, name) - return nil -} - -// Wallets implements api.WalletManager. -func (wm *EphemeralWalletManager) Wallets() map[string]json.RawMessage { - wm.mu.Lock() - defer wm.mu.Unlock() - ws := make(map[string]json.RawMessage, len(wm.wallets)) - for name, w := range wm.wallets { - ws[name] = w.info - } - return ws -} - -// AddAddress implements api.WalletManager. -func (wm *EphemeralWalletManager) AddAddress(name string, addr types.Address, info json.RawMessage) error { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return errNoWallet - } - return mw.w.AddAddress(addr, info) -} - -// RemoveAddress implements api.WalletManager. -func (wm *EphemeralWalletManager) RemoveAddress(name string, addr types.Address) error { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return errNoWallet - } - return mw.w.RemoveAddress(addr) -} - -// Addresses implements api.WalletManager. -func (wm *EphemeralWalletManager) Addresses(name string) (map[types.Address]json.RawMessage, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, errNoWallet - } - return mw.w.Addresses() -} - -// Events implements api.WalletManager. -func (wm *EphemeralWalletManager) Events(name string, offset, limit int) ([]wallet.Event, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, errNoWallet - } - return mw.w.Events(offset, limit) -} - -// Annotate implements api.WalletManager. -func (wm *EphemeralWalletManager) Annotate(name string, txns []types.Transaction) ([]wallet.PoolTransaction, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, errNoWallet - } - return mw.w.Annotate(txns), nil -} - -// UnspentOutputs implements api.WalletManager. -func (wm *EphemeralWalletManager) UnspentOutputs(name string) ([]types.SiacoinElement, []types.SiafundElement, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, nil, errNoWallet - } - return mw.w.UnspentOutputs() -} - -// SubscribeWallet implements api.WalletManager. -func (wm *EphemeralWalletManager) SubscribeWallet(name string, startHeight uint64) error { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return errNoWallet - } else if mw.subscribed { - return errors.New("already subscribed") - } - // AddSubscriber applies each block *after* index, but we want to *include* - // the block at startHeight, so subtract one. - // - // NOTE: if subscribing from height 0, we must pass an empty index in order - // to receive the genesis block. - var index types.ChainIndex - if startHeight > 0 { - if index, ok = wm.cm.BestIndex(startHeight - 1); !ok { - return errors.New("invalid height") - } - } - if err := wm.cm.AddSubscriber(mw.w, index); err != nil { - return err - } - mw.subscribed = true - return nil -} - -// NewEphemeralWalletManager returns a new EphemeralWalletManager. -func NewEphemeralWalletManager(cm ChainManager) *EphemeralWalletManager { - return &EphemeralWalletManager{ - cm: cm, - wallets: make(map[string]*managedEphemeralWallet), - } -} - -type managedJSONWallet struct { - w *JSONStore - info json.RawMessage - subscribed bool -} - -type managerPersistData struct { - Wallets []managerPersistWallet `json:"wallets"` -} - -type managerPersistWallet struct { - Name string `json:"name"` - Info json.RawMessage `json:"info"` - Subscribed bool `json:"subscribed"` -} - -// A JSONWalletManager manages multiple JSON wallet stores. -type JSONWalletManager struct { - dir string - cm ChainManager - mu sync.Mutex - wallets map[string]*managedJSONWallet -} - -func (wm *JSONWalletManager) save() error { - var p managerPersistData - for name, mw := range wm.wallets { - p.Wallets = append(p.Wallets, managerPersistWallet{name, mw.info, mw.subscribed}) - } - js, err := json.MarshalIndent(p, "", " ") - if err != nil { - return err - } - dst := filepath.Join(wm.dir, "wallets.json") - f, err := os.OpenFile(dst+"_tmp", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0660) - if err != nil { - return err - } - defer f.Close() - if _, err = f.Write(js); err != nil { - return err - } else if f.Sync(); err != nil { - return err - } else if f.Close(); err != nil { - return err - } else if err := os.Rename(dst+"_tmp", dst); err != nil { - return err - } - return nil -} - -func (wm *JSONWalletManager) load() error { - dst := filepath.Join(wm.dir, "wallets.json") - f, err := os.Open(dst) - if os.IsNotExist(err) { - return nil - } else if err != nil { - return err - } - defer f.Close() - var p managerPersistData - if err := json.NewDecoder(f).Decode(&p); err != nil { - return err - } - for _, pw := range p.Wallets { - wm.wallets[pw.Name] = &managedJSONWallet{nil, pw.Info, pw.Subscribed} - } - return nil -} - -// AddWallet implements api.WalletManager. -func (wm *JSONWalletManager) AddWallet(name string, info json.RawMessage) error { - wm.mu.Lock() - defer wm.mu.Unlock() - if mw, ok := wm.wallets[name]; ok { - // update existing wallet - mw.info = info - return wm.save() - } else if _, err := os.Stat(filepath.Join(wm.dir, "wallets", name+".json")); err == nil { - // shouldn't happen in normal conditions - return errors.New("a wallet with that name already exists, but is absent from wallets.json") - } - store, _, err := NewJSONStore(filepath.Join(wm.dir, "wallets", name+".json")) - if err != nil { - return err - } - wm.wallets[name] = &managedJSONWallet{store, info, false} - return wm.save() -} - -// DeleteWallet implements api.WalletManager. -func (wm *JSONWalletManager) DeleteWallet(name string) error { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil - } - wm.cm.RemoveSubscriber(mw.w) - delete(wm.wallets, name) - return os.RemoveAll(filepath.Join(wm.dir, "wallets", name+".json")) -} - -// Wallets implements api.WalletManager. -func (wm *JSONWalletManager) Wallets() map[string]json.RawMessage { - wm.mu.Lock() - defer wm.mu.Unlock() - ws := make(map[string]json.RawMessage, len(wm.wallets)) - for name, w := range wm.wallets { - ws[name] = w.info - } - return ws -} - -// AddAddress implements api.WalletManager. -func (wm *JSONWalletManager) AddAddress(name string, addr types.Address, info json.RawMessage) error { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return errNoWallet - } - return mw.w.AddAddress(addr, info) -} - -// RemoveAddress implements api.WalletManager. -func (wm *JSONWalletManager) RemoveAddress(name string, addr types.Address) error { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return errNoWallet - } - return mw.w.RemoveAddress(addr) -} - -// Addresses implements api.WalletManager. -func (wm *JSONWalletManager) Addresses(name string) (map[types.Address]json.RawMessage, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, errNoWallet - } - return mw.w.Addresses() -} - -// Events implements api.WalletManager. -func (wm *JSONWalletManager) Events(name string, offset, limit int) ([]wallet.Event, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, errNoWallet - } - return mw.w.Events(offset, limit) -} - -// Annotate implements api.WalletManager. -func (wm *JSONWalletManager) Annotate(name string, txns []types.Transaction) ([]wallet.PoolTransaction, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, errNoWallet - } - return mw.w.Annotate(txns), nil -} - -// UnspentOutputs implements api.WalletManager. -func (wm *JSONWalletManager) UnspentOutputs(name string) ([]types.SiacoinElement, []types.SiafundElement, error) { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return nil, nil, errNoWallet - } - return mw.w.UnspentOutputs() -} - -// SubscribeWallet implements api.WalletManager. -func (wm *JSONWalletManager) SubscribeWallet(name string, startHeight uint64) error { - wm.mu.Lock() - defer wm.mu.Unlock() - mw, ok := wm.wallets[name] - if !ok { - return errNoWallet - } else if mw.subscribed { - return errors.New("already subscribed") - } - // AddSubscriber applies each block *after* index, but we want to *include* - // the block at startHeight, so subtract one. - // - // NOTE: if subscribing from height 0, we must pass an empty index in order - // to receive the genesis block. - var index types.ChainIndex - if startHeight > 0 { - if index, ok = wm.cm.BestIndex(startHeight - 1); !ok { - return errors.New("invalid height") - } - } - if err := wm.cm.AddSubscriber(mw.w, index); err != nil { - return err - } - mw.subscribed = true - return wm.save() -} - -// NewJSONWalletManager returns a wallet manager that stores wallets in the -// specified directory. -func NewJSONWalletManager(dir string, cm ChainManager) (*JSONWalletManager, error) { - wm := &JSONWalletManager{ - dir: dir, - cm: cm, - wallets: make(map[string]*managedJSONWallet), - } - if err := os.MkdirAll(filepath.Join(dir, "wallets"), 0700); err != nil { - return nil, err - } else if err := wm.load(); err != nil { - return nil, err - } - for name, mw := range wm.wallets { - store, tip, err := NewJSONStore(filepath.Join(dir, "wallets", name+".json")) - if err != nil { - return nil, err - } - if mw.subscribed { - if err := cm.AddSubscriber(store, tip); err != nil { - return nil, err - } - } - mw.w = store - } - return wm, nil -} diff --git a/internal/walletutil/store.go b/internal/walletutil/store.go deleted file mode 100644 index 8a2ffef..0000000 --- a/internal/walletutil/store.go +++ /dev/null @@ -1,402 +0,0 @@ -package walletutil - -import ( - "encoding/json" - "fmt" - "os" - "sync" - - "go.sia.tech/coreutils/chain" - "go.sia.tech/core/types" - "go.sia.tech/walletd/wallet" -) - -// An EphemeralStore stores wallet state in memory. -type EphemeralStore struct { - tip types.ChainIndex - addrs map[types.Address]json.RawMessage - sces map[types.SiacoinOutputID]types.SiacoinElement - sfes map[types.SiafundOutputID]types.SiafundElement - events []wallet.Event - mu sync.Mutex -} - -func (s *EphemeralStore) ownsAddress(addr types.Address) bool { - _, ok := s.addrs[addr] - return ok -} - -// Events implements api.Wallet. -func (s *EphemeralStore) Events(offset, limit int) (events []wallet.Event, err error) { - s.mu.Lock() - defer s.mu.Unlock() - if limit == -1 { - limit = len(s.events) - } - if offset > len(s.events) { - offset = len(s.events) - } - if offset+limit > len(s.events) { - limit = len(s.events) - offset - } - // reverse - es := make([]wallet.Event, limit) - for i := range es { - es[i] = s.events[len(s.events)-offset-i-1] - } - return es, nil -} - -// Annotate implements api.Wallet. -func (s *EphemeralStore) Annotate(txns []types.Transaction) (ptxns []wallet.PoolTransaction) { - s.mu.Lock() - defer s.mu.Unlock() - for _, txn := range txns { - ptxn := wallet.Annotate(txn, s.ownsAddress) - if ptxn.Type != "unrelated" { - ptxns = append(ptxns, ptxn) - } - } - return -} - -// UnspentOutputs implements api.Wallet. -func (s *EphemeralStore) UnspentOutputs() (sces []types.SiacoinElement, sfes []types.SiafundElement, err error) { - s.mu.Lock() - defer s.mu.Unlock() - for _, sco := range s.sces { - sces = append(sces, sco) - } - for _, sfo := range s.sfes { - sfes = append(sfes, sfo) - } - return -} - -// Addresses implements api.Wallet. -func (s *EphemeralStore) Addresses() (map[types.Address]json.RawMessage, error) { - s.mu.Lock() - defer s.mu.Unlock() - addrs := make(map[types.Address]json.RawMessage, len(s.addrs)) - for addr, info := range s.addrs { - addrs[addr] = info - } - return addrs, nil -} - -// AddAddress implements api.Wallet. -func (s *EphemeralStore) AddAddress(addr types.Address, info json.RawMessage) error { - s.mu.Lock() - defer s.mu.Unlock() - s.addrs[addr] = info - return nil -} - -// RemoveAddress implements api.Wallet. -func (s *EphemeralStore) RemoveAddress(addr types.Address) error { - s.mu.Lock() - defer s.mu.Unlock() - if _, ok := s.addrs[addr]; !ok { - return nil - } - delete(s.addrs, addr) - - // filter outputs - for scoid, sce := range s.sces { - if sce.SiacoinOutput.Address == addr { - delete(s.sces, scoid) - } - } - for sfoid, sfe := range s.sfes { - if sfe.SiafundOutput.Address == addr { - delete(s.sfes, sfoid) - } - } - - // filter events - relevantContract := func(fc types.FileContract) bool { - for _, sco := range fc.ValidProofOutputs { - if s.ownsAddress(sco.Address) { - return true - } - } - for _, sco := range fc.MissedProofOutputs { - if s.ownsAddress(sco.Address) { - return true - } - } - return false - } - relevantV2Contract := func(fc types.V2FileContract) bool { - return s.ownsAddress(fc.RenterOutput.Address) || s.ownsAddress(fc.HostOutput.Address) - } - relevantEvent := func(e wallet.Event) bool { - switch e := e.Val.(type) { - case *wallet.EventTransaction: - for _, sce := range e.SiacoinInputs { - if s.ownsAddress(sce.SiacoinOutput.Address) { - return true - } - } - for _, sce := range e.SiacoinOutputs { - if s.ownsAddress(sce.SiacoinOutput.Address) { - return true - } - } - for _, sfe := range e.SiafundInputs { - if s.ownsAddress(sfe.SiafundElement.SiafundOutput.Address) || - s.ownsAddress(sfe.ClaimElement.SiacoinOutput.Address) { - return true - } - } - for _, sfe := range e.SiafundOutputs { - if s.ownsAddress(sfe.SiafundOutput.Address) { - return true - } - } - for _, fc := range e.FileContracts { - if relevantContract(fc.FileContract.FileContract) || (fc.Revision != nil && relevantContract(*fc.Revision)) { - return true - } - } - for _, fc := range e.V2FileContracts { - if relevantV2Contract(fc.FileContract.V2FileContract) || (fc.Revision != nil && relevantV2Contract(*fc.Revision)) { - return true - } - if fc.Resolution != nil { - switch r := fc.Resolution.(type) { - case *types.V2FileContractFinalization: - if relevantV2Contract(types.V2FileContract(*r)) { - return true - } - case *types.V2FileContractRenewal: - if relevantV2Contract(r.FinalRevision) || relevantV2Contract(r.InitialRevision) { - return true - } - } - } - } - return false - case *wallet.EventMinerPayout: - return s.ownsAddress(e.SiacoinOutput.SiacoinOutput.Address) - case *wallet.EventMissedFileContract: - for _, sce := range e.MissedOutputs { - if s.ownsAddress(sce.SiacoinOutput.Address) { - return true - } - } - return false - default: - panic(fmt.Sprintf("unhandled event type %T", e)) - } - } - - rem := s.events[:0] - for _, e := range s.events { - if relevantEvent(e) { - rem = append(rem, e) - } - } - s.events = rem - return nil -} - -// ProcessChainApplyUpdate implements chain.Subscriber. -func (s *EphemeralStore) ProcessChainApplyUpdate(cau *chain.ApplyUpdate, _ bool) error { - s.mu.Lock() - defer s.mu.Unlock() - - events := wallet.AppliedEvents(cau.State, cau.Block, cau, s.ownsAddress) - s.events = append(s.events, events...) - - // add/remove outputs - cau.ForEachSiacoinElement(func(sce types.SiacoinElement, spent bool) { - if s.ownsAddress(sce.SiacoinOutput.Address) { - if spent { - delete(s.sces, types.SiacoinOutputID(sce.ID)) - } else { - sce.MerkleProof = append([]types.Hash256(nil), sce.MerkleProof...) - s.sces[types.SiacoinOutputID(sce.ID)] = sce - } - } - }) - cau.ForEachSiafundElement(func(sfe types.SiafundElement, spent bool) { - if s.ownsAddress(sfe.SiafundOutput.Address) { - if spent { - delete(s.sfes, types.SiafundOutputID(sfe.ID)) - } else { - sfe.MerkleProof = append([]types.Hash256(nil), sfe.MerkleProof...) - s.sfes[types.SiafundOutputID(sfe.ID)] = sfe - } - } - }) - - // update proofs - for id, sce := range s.sces { - cau.UpdateElementProof(&sce.StateElement) - s.sces[id] = sce - } - for id, sfe := range s.sfes { - cau.UpdateElementProof(&sfe.StateElement) - s.sfes[id] = sfe - } - - s.tip = cau.State.Index - return nil -} - -// ProcessChainRevertUpdate implements chain.Subscriber. -func (s *EphemeralStore) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { - s.mu.Lock() - defer s.mu.Unlock() - - // terribly inefficient, but not a big deal because reverts are infrequent - numEvents := len(wallet.AppliedEvents(cru.State, cru.Block, cru, s.ownsAddress)) - s.events = s.events[:len(s.events)-numEvents] - - cru.ForEachSiacoinElement(func(sce types.SiacoinElement, spent bool) { - if s.ownsAddress(sce.SiacoinOutput.Address) { - if !spent { - delete(s.sces, types.SiacoinOutputID(sce.ID)) - } else { - sce.MerkleProof = append([]types.Hash256(nil), sce.MerkleProof...) - s.sces[types.SiacoinOutputID(sce.ID)] = sce - } - } - }) - cru.ForEachSiafundElement(func(sfe types.SiafundElement, spent bool) { - if s.ownsAddress(sfe.SiafundOutput.Address) { - if !spent { - delete(s.sfes, types.SiafundOutputID(sfe.ID)) - } else { - sfe.MerkleProof = append([]types.Hash256(nil), sfe.MerkleProof...) - s.sfes[types.SiafundOutputID(sfe.ID)] = sfe - } - } - }) - - // update proofs - for id, sce := range s.sces { - cru.UpdateElementProof(&sce.StateElement) - s.sces[id] = sce - } - for id, sfe := range s.sfes { - cru.UpdateElementProof(&sfe.StateElement) - s.sfes[id] = sfe - } - - s.tip = cru.State.Index - return nil -} - -// NewEphemeralStore returns a new EphemeralStore. -func NewEphemeralStore() *EphemeralStore { - return &EphemeralStore{ - addrs: make(map[types.Address]json.RawMessage), - sces: make(map[types.SiacoinOutputID]types.SiacoinElement), - sfes: make(map[types.SiafundOutputID]types.SiafundElement), - } -} - -// A JSONStore stores wallet state in memory, backed by a JSON file. -type JSONStore struct { - *EphemeralStore - path string -} - -type persistData struct { - Tip types.ChainIndex - Addresses map[types.Address]json.RawMessage - SiacoinElements map[types.SiacoinOutputID]types.SiacoinElement - SiafundElements map[types.SiafundOutputID]types.SiafundElement - Events []wallet.Event -} - -func (s *JSONStore) save() error { - js, err := json.MarshalIndent(persistData{ - Tip: s.tip, - Addresses: s.addrs, - SiacoinElements: s.sces, - SiafundElements: s.sfes, - Events: s.events, - }, "", " ") - if err != nil { - return err - } - - f, err := os.OpenFile(s.path+"_tmp", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0660) - if err != nil { - return err - } - defer f.Close() - if _, err = f.Write(js); err != nil { - return err - } else if f.Sync(); err != nil { - return err - } else if f.Close(); err != nil { - return err - } else if err := os.Rename(s.path+"_tmp", s.path); err != nil { - return err - } - return nil -} - -func (s *JSONStore) load() error { - f, err := os.Open(s.path) - if os.IsNotExist(err) { - return nil - } else if err != nil { - return err - } - defer f.Close() - var p persistData - if err := json.NewDecoder(f).Decode(&p); err != nil { - return err - } - s.tip = p.Tip - s.addrs = p.Addresses - s.sces = p.SiacoinElements - s.sfes = p.SiafundElements - s.events = p.Events - return nil -} - -// ProcessChainApplyUpdate implements chain.Subscriber. -func (s *JSONStore) ProcessChainApplyUpdate(cau *chain.ApplyUpdate, mayCommit bool) error { - err := s.EphemeralStore.ProcessChainApplyUpdate(cau, mayCommit) - if err == nil && mayCommit { - err = s.save() - } - return err -} - -// ProcessChainRevertUpdate implements chain.Subscriber. -func (s *JSONStore) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { - return s.EphemeralStore.ProcessChainRevertUpdate(cru) -} - -// AddAddress implements api.Wallet. -func (s *JSONStore) AddAddress(addr types.Address, info json.RawMessage) error { - if err := s.EphemeralStore.AddAddress(addr, info); err != nil { - return err - } - return s.save() -} - -// RemoveAddress implements api.Wallet. -func (s *JSONStore) RemoveAddress(addr types.Address) error { - if err := s.EphemeralStore.RemoveAddress(addr); err != nil { - return err - } - return s.save() -} - -// NewJSONStore returns a new JSONStore. -func NewJSONStore(path string) (*JSONStore, types.ChainIndex, error) { - s := &JSONStore{ - EphemeralStore: NewEphemeralStore(), - path: path, - } - err := s.load() - return s, s.tip, err -} diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go new file mode 100644 index 0000000..33fa61d --- /dev/null +++ b/persist/sqlite/consensus.go @@ -0,0 +1,538 @@ +package sqlite + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" + "go.sia.tech/walletd/wallet" + "go.uber.org/zap" +) + +const updateProofBatchSize = 1000 + +type chainUpdate interface { + UpdateElementProof(*types.StateElement) + ForEachTreeNode(func(row, col uint64, h types.Hash256)) + ForEachSiacoinElement(func(types.SiacoinElement, bool)) + ForEachSiafundElement(func(types.SiafundElement, bool)) +} + +func insertChainIndex(tx *txn, index types.ChainIndex) (id int64, err error) { + err = tx.QueryRow(`INSERT INTO chain_indices (height, block_id) VALUES ($1, $2) ON CONFLICT (block_id) DO UPDATE SET height=EXCLUDED.height RETURNING id`, index.Height, encode(index.ID)).Scan(&id) + return +} + +func applyEvents(tx *txn, events []wallet.Event) error { + stmt, err := tx.Prepare(`INSERT INTO events (date_created, index_id, event_type, event_data) VALUES ($1, $2, $3, $4) RETURNING id`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer stmt.Close() + + addRelevantAddrStmt, err := tx.Prepare(`INSERT INTO event_addresses (event_id, address_id, block_height) VALUES ($1, $2, $3)`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer addRelevantAddrStmt.Close() + + for _, event := range events { + id, err := insertChainIndex(tx, event.Index) + if err != nil { + return fmt.Errorf("failed to create chain index: %w", err) + } + + buf, err := json.Marshal(event.Val) + if err != nil { + return fmt.Errorf("failed to marshal event: %w", err) + } + + var eventID int64 + err = stmt.QueryRow(encode(event.Timestamp), id, event.Val.EventType(), buf).Scan(&eventID) + if err != nil { + return fmt.Errorf("failed to execute statement: %w", err) + } + + for _, addr := range event.Relevant { + addressID, err := insertAddress(tx, addr) + if err != nil { + return fmt.Errorf("failed to insert address: %w", err) + } else if _, err := addRelevantAddrStmt.Exec(eventID, addressID, event.Index.Height); err != nil { + return fmt.Errorf("failed to add relevant address: %w", err) + } + } + } + return nil +} + +func applySiacoinElements(tx *txn, index types.ChainIndex, cu chainUpdate, relevantAddress func(types.Address) bool, log *zap.Logger) error { + addrStatement, err := tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $2, 0) +ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address +RETURNING id, siacoin_balance, immature_siacoin_balance`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer addrStatement.Close() + + updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1, immature_siacoin_balance=$2 WHERE id=$3`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateBalanceStmt.Close() + + addStmt, err := tx.Prepare(`INSERT INTO siacoin_elements (id, address_id, siacoin_value, merkle_proof, leaf_index, maturity_height) VALUES ($1, $2, $3, $4, $5, $6)`) + if err != nil { + return fmt.Errorf("failed to prepare insert statement: %w", err) + } + defer addStmt.Close() + + spendStmt, err := tx.Prepare(`DELETE FROM siacoin_elements WHERE id=$1 RETURNING id`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer spendStmt.Close() + + // using ForEachSiacoinElement creates an interesting problem. The + // ForEachSiacoinElement function is only called once for each element. So + // if a siacoin element is spent and created in the same block, the element + // will not exist in the database. + // + // This creates a problem with balance tracking since it subtracts the + // element value from the balance. However, since the element value was + // never added to the balance in the first place, the balance will be + // incorrect. The solution is to check if the UTXO is in the database before + // decrementing the balance. + // + // This is an important implementation detail since the store must assume + // the chain manager is correct and can't check the integrity of the database + // without reimplementing some of the consensus logic. + cu.ForEachSiacoinElement(func(se types.SiacoinElement, spent bool) { + // sticky error + if err != nil { + return + } else if !relevantAddress(se.SiacoinOutput.Address) { + return + } + + // query the address database ID and balance + var addressID int64 + var balance, immatureBalance types.Currency + err = addrStatement.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency)).Scan(&addressID, decode(&balance), decode(&immatureBalance)) + if err != nil { + err = fmt.Errorf("failed to query address %q: %w", se.SiacoinOutput.Address, err) + return + } + + if spent { + var dummy types.Hash256 + err = spendStmt.QueryRow(encode(se.ID)).Scan(decode(&dummy)) + if errors.Is(err, sql.ErrNoRows) { + // spent output not found, most likely an ephemeral output. ignore + err = nil + return + } else if err != nil { + err = fmt.Errorf("failed to delete output %q: %w", se.ID, err) + return + } + + if se.MaturityHeight > index.Height { + immatureBalance = immatureBalance.Sub(se.SiacoinOutput.Value) + } else { + balance = balance.Sub(se.SiacoinOutput.Value) + } + + _, err = updateBalanceStmt.Exec(encode(balance), encode(immatureBalance), addressID) + if err != nil { + err = fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) + return + } + + log.Debug("removed utxo", zap.Stringer("address", se.SiacoinOutput.Address), zap.Stringer("outputID", se.ID), zap.String("value", se.SiacoinOutput.Value.ExactString()), zap.Int64("addressID", addressID)) + } else { + // insert the created utxo + _, err = addStmt.Exec(encode(se.ID), addressID, encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight) + if err != nil { + err = fmt.Errorf("failed to insert output %q: %w", se.ID, err) + return + } + + if se.MaturityHeight > index.Height { + immatureBalance = immatureBalance.Add(se.SiacoinOutput.Value) + log.Debug("adding immature balance") + } else { + balance = balance.Add(se.SiacoinOutput.Value) + log.Debug("adding balance") + } + + // update the balance + _, err = updateBalanceStmt.Exec(encode(balance), encode(immatureBalance), addressID) + if err != nil { + err = fmt.Errorf("failed to update address %q balance: %w", se.SiacoinOutput.Address, err) + return + } + log.Debug("added utxo", zap.Uint64("maturityHeight", se.MaturityHeight), zap.Stringer("address", se.SiacoinOutput.Address), zap.Stringer("outputID", se.ID), zap.String("value", se.SiacoinOutput.Value.ExactString()), zap.Int64("addressID", addressID)) + } + }) + return err +} + +func applySiafundElements(tx *txn, cu chainUpdate, relevantAddress func(types.Address) bool, log *zap.Logger) error { + // create the address if it doesn't exist + addrStatement, err := tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $2, 0) +ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address +RETURNING id, siafund_balance`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer addrStatement.Close() + + updateBalanceStmt, err := tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE id=$2`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateBalanceStmt.Close() + + addStmt, err := tx.Prepare(`INSERT INTO siafund_elements (id, address_id, claim_start, merkle_proof, leaf_index, siafund_value) VALUES ($1, $2, $3, $4, $5, $6)`) + if err != nil { + return fmt.Errorf("failed to prepare insert statement: %w", err) + } + defer addStmt.Close() + + spendStmt, err := tx.Prepare(`DELETE FROM siacoin_elements WHERE id=$1 RETURNING id`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer spendStmt.Close() + + cu.ForEachSiafundElement(func(se types.SiafundElement, spent bool) { + // sticky error + if err != nil { + return + } else if !relevantAddress(se.SiafundOutput.Address) { + return + } + + // query the address database ID and balance + var addressID int64 + var balance uint64 + // get the address ID + err = addrStatement.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency)).Scan(&addressID, &balance) + if err != nil { + err = fmt.Errorf("failed to query address %q: %w", se.SiafundOutput.Address, err) + return + } + + // update the balance + if spent { + var dummy types.Hash256 + err = spendStmt.QueryRow(encode(se.ID)).Scan(decode(&dummy)) + if errors.Is(err, sql.ErrNoRows) { + // spent output not found, most likely an ephemeral output. + // ignore + err = nil + return + } else if err != nil { + err = fmt.Errorf("failed to delete output %q: %w", se.ID, err) + return + } + + // update the balance only if the utxo was successfully deleted + if se.SiafundOutput.Value > balance { + log.Panic("balance is negative", zap.Stringer("address", se.SiafundOutput.Address), zap.Uint64("balance", se.SiafundOutput.Value), zap.Stringer("outputID", se.ID), zap.Uint64("value", se.SiafundOutput.Value)) + } + + balance -= se.SiafundOutput.Value + _, err = updateBalanceStmt.Exec(encode(balance), addressID) + if err != nil { + err = fmt.Errorf("failed to update address %q balance: %w", se.SiafundOutput.Address, err) + return + } + } else { + balance += se.SiafundOutput.Value + // update the balance + _, err = updateBalanceStmt.Exec(balance, addressID) + if err != nil { + err = fmt.Errorf("failed to update address %q balance: %w", se.SiafundOutput.Address, err) + return + } + + // insert the created utxo + _, err = addStmt.Exec(encode(se.ID), addressID, encode(se.ClaimStart), encodeSlice(se.MerkleProof), se.LeafIndex, se.SiafundOutput.Value) + if err != nil { + err = fmt.Errorf("failed to insert output %q: %w", se.ID, err) + return + } + } + }) + return err +} + +func updateLastIndexedTip(tx *txn, tip types.ChainIndex) error { + _, err := tx.Exec(`UPDATE global_settings SET last_indexed_tip=$1`, encode(tip)) + return err +} + +func getStateElementBatch(s *stmt, offset, limit int) ([]types.StateElement, error) { + rows, err := s.Query(limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to query siacoin elements: %w", err) + } + defer rows.Close() + + var updated []types.StateElement + for rows.Next() { + var se types.StateElement + err := rows.Scan(decode(&se.ID), decodeSlice(&se.MerkleProof), &se.LeafIndex) + if err != nil { + return nil, fmt.Errorf("failed to scan state element: %w", err) + } + updated = append(updated, se) + } + return updated, nil +} + +func updateStateElement(s *stmt, se types.StateElement) error { + res, err := s.Exec(encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ID)) + if err != nil { + return fmt.Errorf("failed to update siacoin element %q: %w", se.ID, err) + } else if n, err := res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } else if n != 1 { + return fmt.Errorf("expected 1 row to be affected, got %d", n) + } + return nil +} + +// how slow is this going to be 😬? +func updateElementProofs(tx *txn, table string, cu chainUpdate) error { + stmt, err := tx.Prepare(`SELECT id, merkle_proof, leaf_index FROM ` + table + ` LIMIT $1 OFFSET $2`) + if err != nil { + return fmt.Errorf("failed to prepare batch statement: %w", err) + } + defer stmt.Close() + + updateStmt, err := tx.Prepare(`UPDATE ` + table + ` SET merkle_proof=$1, leaf_index=$2 WHERE id=$3 RETURNING id`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %w", err) + } + defer updateStmt.Close() + + for offset := 0; ; offset += updateProofBatchSize { + elements, err := getStateElementBatch(stmt, offset, updateProofBatchSize) + if err != nil { + return fmt.Errorf("failed to get state element batch: %w", err) + } else if len(elements) == 0 { + break + } + + for _, se := range elements { + cu.UpdateElementProof(&se) + if err := updateStateElement(updateStmt, se); err != nil { + return fmt.Errorf("failed to update state element: %w", err) + } + } + } + return nil +} + +func getMaturedValue(tx *txn, index types.ChainIndex) (matured map[int64]types.Currency, err error) { + rows, err := tx.Query(`SELECT address_id, siacoin_value FROM siacoin_elements WHERE maturity_height=$1`, index.Height) + if err != nil { + return nil, fmt.Errorf("failed to query siacoin elements: %w", err) + } + defer rows.Close() + + matured = make(map[int64]types.Currency) + for rows.Next() { + var addressID int64 + var value types.Currency + err := rows.Scan(&addressID, decode(&value)) + if err != nil { + return nil, fmt.Errorf("failed to scan matured balance: %w", err) + } + matured[addressID] = matured[addressID].Add(value) + } + return +} + +func updateImmatureBalance(tx *txn, index types.ChainIndex, revert bool) error { + balanceStmt, err := tx.Prepare(`SELECT siacoin_balance, immature_siacoin_balance FROM sia_addresses WHERE id=$1`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer balanceStmt.Close() + + updateStmt, err := tx.Prepare(`UPDATE sia_addresses SET siacoin_balance=$1, immature_siacoin_balance=$2 WHERE id=$3`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer updateStmt.Close() + + delta, err := getMaturedValue(tx, index) + if err != nil { + return fmt.Errorf("failed to get matured utxos: %w", err) + } + + for addressID, value := range delta { + var balance, immatureBalance types.Currency + err := balanceStmt.QueryRow(addressID).Scan(decode(&balance), decode(&immatureBalance)) + if err != nil { + return fmt.Errorf("failed to query address %d: %w", addressID, err) + } + + if revert { + balance = balance.Sub(value) + immatureBalance = immatureBalance.Add(value) + } else { + balance = balance.Add(value) + immatureBalance = immatureBalance.Sub(value) + } + + _, err = updateStmt.Exec(encode(balance), encode(immatureBalance), addressID) + if err != nil { + return fmt.Errorf("failed to update address %d: %w", addressID, err) + } + } + return nil +} + +// applyChainUpdates applies the given chain updates to the database. +func applyChainUpdates(tx *txn, updates []*chain.ApplyUpdate, log *zap.Logger) error { + stmt, err := tx.Prepare(`SELECT id FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer stmt.Close() + + // note: this would be more performant for small wallets to load all + // addresses into memory. However, for larger wallets (> 10K addresses), + // this is time consuming. Instead, the database is queried for each + // address. Monitor performance and consider changing this in the + // future. From a memory perspective, it would be fine to lazy load all + // addresses into memory. + relevantAddress := func(address types.Address) bool { + var dbID int64 + err := stmt.QueryRow(encode(address)).Scan(&dbID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + panic(err) // database error + } + return err == nil + } + + for _, update := range updates { + // mature the immature balance first + if err := updateImmatureBalance(tx, update.State.Index, false); err != nil { + return fmt.Errorf("failed to update immature balance: %w", err) + } + // apply new events + events := wallet.AppliedEvents(update.State, update.Block, update, relevantAddress) + if err := applyEvents(tx, events); err != nil { + return fmt.Errorf("failed to apply events: %w", err) + } + + // apply new elements + if err := applySiacoinElements(tx, update.State.Index, update, relevantAddress, log.Named("siacoins")); err != nil { + return fmt.Errorf("failed to apply siacoin elements: %w", err) + } else if err := applySiafundElements(tx, update, relevantAddress, log.Named("siafunds")); err != nil { + return fmt.Errorf("failed to apply siafund elements: %w", err) + } + + // update proofs + if err := updateElementProofs(tx, "siacoin_elements", update); err != nil { + return fmt.Errorf("failed to update siacoin element proofs: %w", err) + } else if err := updateElementProofs(tx, "siafund_elements", update); err != nil { + return fmt.Errorf("failed to update siafund element proofs: %w", err) + } + } + + lastTip := updates[len(updates)-1].State.Index + if err := updateLastIndexedTip(tx, lastTip); err != nil { + return fmt.Errorf("failed to update last indexed tip: %w", err) + } + return nil +} + +// ProcessChainApplyUpdate implements chain.Subscriber +func (s *Store) ProcessChainApplyUpdate(cau *chain.ApplyUpdate, mayCommit bool) error { + s.updates = append(s.updates, cau) + + if mayCommit { + return s.transaction(func(tx *txn) error { + if err := applyChainUpdates(tx, s.updates, s.log.Named("apply")); err != nil { + return err + } + s.updates = nil + return nil + }) + } + return nil +} + +// ProcessChainRevertUpdate implements chain.Subscriber +func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { + log := s.log.Named("revert") + + // update hasn't been committed yet + if len(s.updates) > 0 && s.updates[len(s.updates)-1].Block.ID() == cru.Block.ID() { + s.updates = s.updates[:len(s.updates)-1] + return nil + } + + // update has been committed, revert it + return s.transaction(func(tx *txn) error { + stmt, err := tx.Prepare(`SELECT id FROM sia_addresses WHERE sia_address=$1 LIMIT 1`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer stmt.Close() + + // note: this would be more performant for small wallets to load all + // addresses into memory. However, for larger wallets (> 10K addresses), + // this is time consuming. Instead, the database is queried for each + // address. Monitor performance and consider changing this in the + // future. From a memory perspective, it would be fine to lazy load all + // addresses into memory. + relevantAddress := func(address types.Address) bool { + var dbID int64 + err := stmt.QueryRow(encode(address)).Scan(&dbID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + panic(err) // database error + } + return err == nil + } + + if err := applySiacoinElements(tx, cru.State.Index, cru, relevantAddress, log.Named("siacoins")); err != nil { + return fmt.Errorf("failed to apply siacoin elements: %w", err) + } else if err := applySiafundElements(tx, cru, relevantAddress, log.Named("siafunds")); err != nil { + return fmt.Errorf("failed to apply siafund elements: %w", err) + } + + // revert events + _, err = tx.Exec(`DELETE FROM chain_indices WHERE block_id=$1`, cru.Block.ID()) + if err != nil { + return fmt.Errorf("failed to delete chain index: %w", err) + } + + // revert immature balance + if err := updateImmatureBalance(tx, cru.State.Index, true); err != nil { + return fmt.Errorf("failed to update immature balance: %w", err) + } + + // update proofs + if err := updateElementProofs(tx, "siacoin_elements", cru); err != nil { + return fmt.Errorf("failed to update siacoin element proofs: %w", err) + } else if err := updateElementProofs(tx, "siafund_elements", cru); err != nil { + return fmt.Errorf("failed to update siafund element proofs: %w", err) + } + return nil + }) +} + +// LastCommittedIndex returns the last chain index that was committed. +func (s *Store) LastCommittedIndex() (index types.ChainIndex, err error) { + err = s.db.QueryRow(`SELECT last_indexed_tip FROM global_settings`).Scan(decode(&index)) + return +} diff --git a/persist/sqlite/consts_default.go b/persist/sqlite/consts_default.go new file mode 100644 index 0000000..50b7330 --- /dev/null +++ b/persist/sqlite/consts_default.go @@ -0,0 +1,12 @@ +//go:build !testing + +package sqlite + +import "time" + +const ( + busyTimeout = 10000 // 10 seconds + maxRetryAttempts = 30 // 30 attempts + factor = 1.8 // factor ^ retryAttempts = backoff time in milliseconds + maxBackoff = 15 * time.Second +) diff --git a/persist/sqlite/consts_testing.go b/persist/sqlite/consts_testing.go new file mode 100644 index 0000000..f4911e3 --- /dev/null +++ b/persist/sqlite/consts_testing.go @@ -0,0 +1,12 @@ +//go:build testing + +package sqlite + +import "time" + +const ( + busyTimeout = 100 // 100ms + maxRetryAttempts = 10 // 10 attempts + factor = 2.0 // factor ^ retryAttempts = backoff time in milliseconds + maxBackoff = 15 * time.Second +) diff --git a/persist/sqlite/encoding.go b/persist/sqlite/encoding.go new file mode 100644 index 0000000..8f98230 --- /dev/null +++ b/persist/sqlite/encoding.go @@ -0,0 +1,127 @@ +package sqlite + +import ( + "bytes" + "database/sql" + "encoding/binary" + "errors" + "fmt" + "time" + + "go.sia.tech/core/types" +) + +func encode(obj any) any { + switch obj := obj.(type) { + case types.Currency: + buf := make([]byte, 16) + binary.LittleEndian.PutUint64(buf, obj.Lo) + binary.LittleEndian.PutUint64(buf[8:], obj.Hi) + return buf + case types.EncoderTo: + var buf bytes.Buffer + e := types.NewEncoder(&buf) + obj.EncodeTo(e) + e.Flush() + return buf.Bytes() + case uint64: + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, obj) + return b + case time.Time: + return obj.Unix() + default: + panic(fmt.Sprintf("dbEncode: unsupported type %T", obj)) + } +} + +type decodable struct { + v any +} + +// Scan implements the sql.Scanner interface. +func (d *decodable) Scan(src any) error { + if src == nil { + return errors.New("cannot scan nil into decodable") + } + + switch src := src.(type) { + case []byte: + switch v := d.v.(type) { + case *types.Currency: + if len(src) != 16 { + return fmt.Errorf("cannot scan %d bytes into Currency", len(src)) + } + v.Lo = binary.LittleEndian.Uint64(src) + v.Hi = binary.LittleEndian.Uint64(src[8:]) + case types.DecoderFrom: + dec := types.NewBufDecoder(src) + v.DecodeFrom(dec) + return dec.Err() + case *uint64: + *v = binary.LittleEndian.Uint64(src) + default: + return fmt.Errorf("cannot scan %T to %T", src, d.v) + } + return nil + case int64: + switch v := d.v.(type) { + case *uint64: + *v = uint64(src) + case *time.Time: + *v = time.Unix(src, 0).UTC() + default: + return fmt.Errorf("cannot scan %T to %T", src, d.v) + } + return nil + default: + return fmt.Errorf("cannot scan %T to %T", src, d.v) + } +} + +func decode(obj any) sql.Scanner { + return &decodable{obj} +} + +type decodableSlice[T any] struct { + v *[]T +} + +func (d *decodableSlice[T]) Scan(src any) error { + switch src := src.(type) { + case []byte: + dec := types.NewBufDecoder(src) + s := make([]T, dec.ReadPrefix()) + for i := range s { + dv, ok := any(&s[i]).(types.DecoderFrom) + if !ok { + panic(fmt.Errorf("cannot decode %T", s[i])) + } + dv.DecodeFrom(dec) + } + if err := dec.Err(); err != nil { + return err + } + *d.v = s + return nil + default: + return fmt.Errorf("cannot scan %T to []byte", src) + } +} + +func decodeSlice[T any](v *[]T) sql.Scanner { + return &decodableSlice[T]{v: v} +} + +func encodeSlice[T types.EncoderTo](v []T) []byte { + var buf bytes.Buffer + enc := types.NewEncoder(&buf) + enc.WritePrefix(len(v)) + for _, e := range v { + e.EncodeTo(enc) + } + if err := enc.Flush(); err != nil { + panic(err) + } + return buf.Bytes() +} diff --git a/persist/sqlite/init.go b/persist/sqlite/init.go new file mode 100644 index 0000000..24ae378 --- /dev/null +++ b/persist/sqlite/init.go @@ -0,0 +1,89 @@ +package sqlite + +import ( + "database/sql" + _ "embed" // for init.sql + "errors" + "time" + + "fmt" + + "go.sia.tech/core/types" + "go.uber.org/zap" +) + +// init queries are run when the database is first created. +// +//go:embed init.sql +var initDatabase string + +func initializeSettings(tx *txn, target int64) error { + _, err := tx.Exec(`INSERT INTO global_settings (id, db_version, last_indexed_tip) VALUES (0, ?, ?)`, target, encode(types.ChainIndex{})) + return err +} + +func (s *Store) initNewDatabase(target int64) error { + return s.transaction(func(tx *txn) error { + if _, err := tx.Exec(initDatabase); err != nil { + return fmt.Errorf("failed to initialize database: %w", err) + } else if err := initializeSettings(tx, target); err != nil { + return fmt.Errorf("failed to initialize settings: %w", err) + } + return nil + }) +} + +func (s *Store) upgradeDatabase(current, target int64) error { + log := s.log.Named("migrations") + log.Info("migrating database", zap.Int64("current", current), zap.Int64("target", target)) + + // disable foreign key constraints during migration + if _, err := s.db.Exec("PRAGMA foreign_keys = OFF"); err != nil { + return fmt.Errorf("failed to disable foreign key constraints: %w", err) + } + defer func() { + // re-enable foreign key constraints + if _, err := s.db.Exec("PRAGMA foreign_keys = ON"); err != nil { + log.Panic("failed to enable foreign key constraints", zap.Error(err)) + } + }() + + return s.transaction(func(tx *txn) error { + for _, fn := range migrations[current-1:] { + current++ + start := time.Now() + if err := fn(tx, log.With(zap.Int64("version", current))); err != nil { + return fmt.Errorf("failed to migrate database to version %v: %w", current, err) + } + // check that no foreign key constraints were violated + if err := tx.QueryRow("PRAGMA foreign_key_check").Scan(); !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("foreign key constraints are not satisfied") + } + log.Debug("migration complete", zap.Int64("current", current), zap.Int64("target", target), zap.Duration("elapsed", time.Since(start))) + } + + // set the final database version + return setDBVersion(tx, target) + }) +} + +func (s *Store) init() error { + // calculate the expected final database version + target := int64(len(migrations) + 1) + // disable foreign key constraints during migration + if _, err := s.db.Exec("PRAGMA foreign_keys = OFF"); err != nil { + return fmt.Errorf("failed to disable foreign key constraints: %w", err) + } + + version := getDBVersion(s.db) + switch { + case version == 0: + return s.initNewDatabase(target) + case version < target: + return s.upgradeDatabase(version, target) + case version > target: + return fmt.Errorf("database version %v is newer than expected %v. database downgrades are not supported", version, target) + } + // nothing to do + return nil +} diff --git a/persist/sqlite/init.sql b/persist/sqlite/init.sql new file mode 100644 index 0000000..d9d4cff --- /dev/null +++ b/persist/sqlite/init.sql @@ -0,0 +1,86 @@ +CREATE TABLE chain_indices ( + id INTEGER PRIMARY KEY, + block_id BLOB UNIQUE NOT NULL, + height INTEGER UNIQUE NOT NULL +); + +CREATE TABLE sia_addresses ( + id INTEGER PRIMARY KEY, + sia_address BLOB UNIQUE NOT NULL, + siacoin_balance BLOB NOT NULL, + immature_siacoin_balance BLOB NOT NULL, + siafund_balance INTEGER NOT NULL +); + +CREATE TABLE siacoin_elements ( + id BLOB PRIMARY KEY, + siacoin_value BLOB NOT NULL, + merkle_proof BLOB NOT NULL, + leaf_index INTEGER NOT NULL, + maturity_height INTEGER NOT NULL, /* stored as int64 for easier querying */ + address_id INTEGER NOT NULL REFERENCES sia_addresses (id) +); +CREATE INDEX siacoin_elements_address_id ON siacoin_elements (address_id); + +CREATE TABLE siafund_elements ( + id BLOB PRIMARY KEY, + claim_start BLOB NOT NULL, + merkle_proof BLOB NOT NULL, + leaf_index INTEGER NOT NULL, + siafund_value INTEGER NOT NULL, + address_id INTEGER NOT NULL REFERENCES sia_addresses (id) +); +CREATE INDEX siafund_elements_address_id ON siafund_elements (address_id); + +CREATE TABLE wallets ( + id TEXT PRIMARY KEY NOT NULL, + extra_data BLOB NOT NULL +); + +CREATE TABLE wallet_addresses ( + wallet_id TEXT NOT NULL REFERENCES wallets (id), + address_id INTEGER NOT NULL REFERENCES sia_addresses (id), + extra_data BLOB NOT NULL, + UNIQUE (wallet_id, address_id) +); +CREATE INDEX wallet_addresses_address_id ON wallet_addresses (address_id); + +CREATE TABLE events ( + id INTEGER PRIMARY KEY, + date_created INTEGER NOT NULL, + index_id BLOB NOT NULL REFERENCES chain_indices (id) ON DELETE CASCADE, + event_type TEXT NOT NULL, + event_data TEXT NOT NULL +); + +CREATE TABLE event_addresses ( + id INTEGER PRIMARY KEY, + event_id INTEGER NOT NULL REFERENCES events (id) ON DELETE CASCADE, + address_id INTEGER NOT NULL REFERENCES sia_addresses (id), + block_height INTEGER NOT NULL, /* prevents extra join when querying for events */ + UNIQUE (event_id, address_id) +); +CREATE INDEX event_addresses_event_id_idx ON event_addresses (event_id); +CREATE INDEX event_addresses_address_id_idx ON event_addresses (address_id); +CREATE INDEX event_addresses_event_id_address_id_block_height ON event_addresses(event_id, address_id, block_height DESC); + +CREATE TABLE syncer_peers ( + peer_address TEXT PRIMARY KEY NOT NULL, + first_seen INTEGER NOT NULL, + last_connect INTEGER NOT NULL, + synced_blocks INTEGER NOT NULL, + sync_duration INTEGER NOT NULL +); + +CREATE TABLE syncer_bans ( + net_cidr TEXT PRIMARY KEY NOT NULL, + expiration INTEGER NOT NULL, + reason TEXT NOT NULL +); +CREATE INDEX syncer_bans_expiration_index ON syncer_bans (expiration); + +CREATE TABLE global_settings ( + id INTEGER PRIMARY KEY NOT NULL DEFAULT 0 CHECK (id = 0), -- enforce a single row + db_version INTEGER NOT NULL, -- used for migrations + last_indexed_tip BLOB -- the last chain index that was processed +); diff --git a/persist/sqlite/migrations.go b/persist/sqlite/migrations.go new file mode 100644 index 0000000..99d01e1 --- /dev/null +++ b/persist/sqlite/migrations.go @@ -0,0 +1,10 @@ +package sqlite + +import ( + "go.uber.org/zap" +) + +// migrations is a list of functions that are run to migrate the database from +// one version to the next. Migrations are used to update existing databases to +// match the schema in init.sql. +var migrations = []func(tx *txn, log *zap.Logger) error{} diff --git a/persist/sqlite/peers.go b/persist/sqlite/peers.go new file mode 100644 index 0000000..4d8de8d --- /dev/null +++ b/persist/sqlite/peers.go @@ -0,0 +1,195 @@ +package sqlite + +import ( + "database/sql" + "errors" + "fmt" + "net" + "strconv" + "strings" + "time" + + "go.sia.tech/coreutils/syncer" + "go.uber.org/zap" +) + +func getPeerInfo(tx *txn, peer string) (syncer.PeerInfo, error) { + const query = `SELECT first_seen, last_connect, synced_blocks, sync_duration FROM syncer_peers WHERE peer_address=$1` + var info syncer.PeerInfo + err := tx.QueryRow(query, peer).Scan(decode(&info.FirstSeen), decode(&info.LastConnect), &info.SyncedBlocks, &info.SyncDuration) + return info, err +} + +func (s *Store) updatePeerInfo(tx *txn, peer string, info syncer.PeerInfo) error { + const query = `UPDATE syncer_peers SET first_seen=$1, last_connect=$2, synced_blocks=$3, sync_duration=$4 WHERE peer_address=$5 RETURNING peer_address` + err := tx.QueryRow(query, encode(info.FirstSeen), encode(info.LastConnect), info.SyncedBlocks, info.SyncDuration, peer).Scan(&peer) + return err +} + +// AddPeer adds the given peer to the store. +func (s *Store) AddPeer(peer string) { + err := s.transaction(func(tx *txn) error { + const query = `INSERT INTO syncer_peers (peer_address, first_seen, last_connect, synced_blocks, sync_duration) VALUES ($1, $2, 0, 0, 0) ON CONFLICT (peer_address) DO NOTHING` + _, err := tx.Exec(query, peer, encode(time.Now())) + return err + }) + if err != nil { + s.log.Error("failed to add peer", zap.Error(err)) + } +} + +// Peers returns the addresses of all known peers. +func (s *Store) Peers() (peers []string) { + err := s.transaction(func(tx *txn) error { + const query = `SELECT peer_address FROM syncer_peers` + rows, err := tx.Query(query) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var peer string + if err := rows.Scan(&peer); err != nil { + return err + } + peers = append(peers, peer) + } + return nil + }) + if err != nil { + panic(err) // 😔 + } + return +} + +// UpdatePeerInfo updates the info for the given peer. +func (s *Store) UpdatePeerInfo(peer string, fn func(*syncer.PeerInfo)) { + err := s.transaction(func(tx *txn) error { + info, err := getPeerInfo(tx, peer) + if err != nil { + return fmt.Errorf("failed to get peer info: %w", err) + } + fn(&info) + return s.updatePeerInfo(tx, peer, info) + }) + if err != nil { + panic(err) // 😔 + } +} + +// PeerInfo returns the info for the given peer. +func (s *Store) PeerInfo(peer string) (syncer.PeerInfo, bool) { + var info syncer.PeerInfo + var err error + err = s.transaction(func(tx *txn) error { + info, err = getPeerInfo(tx, peer) + return err + }) + if errors.Is(err, sql.ErrNoRows) { + return info, false + } else if err != nil { + panic(err) // 😔 + } + return info, true +} + +// normalizePeer normalizes a peer address to a CIDR subnet. +func normalizePeer(peer string) (string, error) { + host, _, err := net.SplitHostPort(peer) + if err != nil { + host = peer + } + if strings.IndexByte(host, '/') != -1 { + _, subnet, err := net.ParseCIDR(host) + if err != nil { + return "", fmt.Errorf("failed to parse CIDR: %w", err) + } + return subnet.String(), nil + } + + ip := net.ParseIP(host) + if ip == nil { + return "", errors.New("invalid IP address") + } + + var maskLen int + if ip.To4() != nil { + maskLen = 32 + } else { + maskLen = 128 + } + + _, normalized, err := net.ParseCIDR(fmt.Sprintf("%s/%d", ip.String(), maskLen)) + if err != nil { + panic("failed to parse CIDR") + } + return normalized.String(), nil +} + +// Ban temporarily bans one or more IPs. The addr should either be a single +// IP with port (e.g. 1.2.3.4:5678) or a CIDR subnet (e.g. 1.2.3.4/16). +func (s *Store) Ban(peer string, duration time.Duration, reason string) { + address, err := normalizePeer(peer) + if err != nil { + s.log.Error("failed to normalize peer", zap.Error(err)) + return + } + err = s.transaction(func(tx *txn) error { + const query = `INSERT INTO syncer_bans (net_cidr, expiration, reason) VALUES ($1, $2, $3) ON CONFLICT (net_cidr) DO UPDATE SET expiration=EXCLUDED.expiration, reason=EXCLUDED.reason` + _, err := tx.Exec(query, address, encode(time.Now().Add(duration)), reason) + return err + }) + if err != nil { + s.log.Error("failed to ban peer", zap.Error(err)) + } +} + +// Banned returns true if the peer is banned. +func (s *Store) Banned(peer string) (banned bool) { + // normalize the peer into a CIDR subnet + peer, err := normalizePeer(peer) + if err != nil { + s.log.Error("failed to normalize peer", zap.Error(err)) + return false + } + + _, subnet, err := net.ParseCIDR(peer) + if err != nil { + s.log.Error("failed to parse CIDR", zap.Error(err)) + return false + } + + // check all subnets from the given subnet to the max subnet length + var maxMaskLen int + if subnet.IP.To4() != nil { + maxMaskLen = 32 + } else { + maxMaskLen = 128 + } + + checkSubnets := make([]string, 0, maxMaskLen) + for i := maxMaskLen; i > 0; i-- { + _, subnet, err := net.ParseCIDR(subnet.IP.String() + "/" + strconv.Itoa(i)) + if err != nil { + panic("failed to parse CIDR") + } + checkSubnets = append(checkSubnets, subnet.String()) + } + + err = s.transaction(func(tx *txn) error { + query := `SELECT net_cidr, expiration FROM syncer_bans WHERE net_cidr IN (` + queryPlaceHolders(len(checkSubnets)) + `) ORDER BY expiration DESC LIMIT 1` + + var subnet string + var expiration time.Time + err := tx.QueryRow(query, queryArgs(checkSubnets)...).Scan(&subnet, decode(&expiration)) + banned = time.Now().Before(expiration) // will return false for any sql errors, including ErrNoRows + if err == nil && banned { + s.log.Debug("found ban", zap.String("subnet", subnet), zap.Time("expiration", expiration)) + } + return err + }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + s.log.Error("failed to check ban status", zap.Error(err)) + } + return +} diff --git a/persist/sqlite/peers_test.go b/persist/sqlite/peers_test.go new file mode 100644 index 0000000..4f3e26e --- /dev/null +++ b/persist/sqlite/peers_test.go @@ -0,0 +1,93 @@ +package sqlite + +import ( + "net" + "path/filepath" + "testing" + "time" + + "go.sia.tech/coreutils/syncer" + "go.uber.org/zap/zaptest" +) + +func TestAddPeer(t *testing.T) { + log := zaptest.NewLogger(t) + db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + const peer = "1.2.3.4:9981" + + db.AddPeer(peer) + + lastConnect := time.Now().Truncate(time.Second) // stored as unix milliseconds + syncedBlocks := uint64(15) + syncDuration := 5 * time.Second + + db.UpdatePeerInfo(peer, func(info *syncer.PeerInfo) { + info.LastConnect = lastConnect + info.SyncedBlocks = syncedBlocks + info.SyncDuration = syncDuration + }) + if err != nil { + t.Fatal(err) + } + + info, ok := db.PeerInfo(peer) + if !ok { + t.Fatal("expected peer to be in database") + } + + if !info.LastConnect.Equal(lastConnect) { + t.Errorf("expected LastConnect = %v; got %v", lastConnect, info.LastConnect) + } + if info.SyncedBlocks != syncedBlocks { + t.Errorf("expected SyncedBlocks = %d; got %d", syncedBlocks, info.SyncedBlocks) + } + if info.SyncDuration != 5*time.Second { + t.Errorf("expected SyncDuration = %s; got %s", syncDuration, info.SyncDuration) + } +} + +func TestBanPeer(t *testing.T) { + log := zaptest.NewLogger(t) + db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + const peer = "1.2.3.4" + + if db.Banned(peer) { + t.Fatal("expected peer to not be banned") + } + + // ban the peer + db.Ban(peer, time.Second, "test") + + if !db.Banned(peer) { + t.Fatal("expected peer to be banned") + } + + // wait for the ban to expire + time.Sleep(time.Second) + + if db.Banned(peer) { + t.Fatal("expected peer to not be banned") + } + + // ban a subnet + _, subnet, err := net.ParseCIDR(peer + "/24") + if err != nil { + t.Fatal(err) + } + + t.Log("banning", subnet) + db.Ban(subnet.String(), time.Second, "test") + if !db.Banned(peer) { + t.Fatal("expected peer to be banned") + } +} diff --git a/persist/sqlite/sql.go b/persist/sqlite/sql.go new file mode 100644 index 0000000..fb253d1 --- /dev/null +++ b/persist/sqlite/sql.go @@ -0,0 +1,217 @@ +package sqlite + +import ( + "context" + "database/sql" + "math/rand" + "strings" + "time" + + _ "github.com/mattn/go-sqlite3" // import sqlite3 driver + "go.uber.org/zap" +) + +const ( + longQueryDuration = 10 * time.Millisecond + longTxnDuration = 10 * time.Millisecond +) + +type ( + // A scanner is an interface that wraps the Scan method of sql.Rows and sql.Row + scanner interface { + Scan(dest ...any) error + } + + // A stmt wraps a *sql.Stmt, logging slow queries. + stmt struct { + *sql.Stmt + query string + + log *zap.Logger + } + + // A txn wraps a *sql.Tx, logging slow queries. + txn struct { + *sql.Tx + log *zap.Logger + } + + // A row wraps a *sql.Row, logging slow queries. + row struct { + *sql.Row + log *zap.Logger + } + + // rows wraps a *sql.Rows, logging slow queries. + rows struct { + *sql.Rows + + log *zap.Logger + } +) + +func (r *rows) Next() bool { + start := time.Now() + next := r.Rows.Next() + if dur := time.Since(start); dur > longQueryDuration { + r.log.Debug("slow next", zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return next +} + +func (r *rows) Scan(dest ...any) error { + start := time.Now() + err := r.Rows.Scan(dest...) + if dur := time.Since(start); dur > longQueryDuration { + r.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return err +} + +func (r *row) Scan(dest ...any) error { + start := time.Now() + err := r.Row.Scan(dest...) + if dur := time.Since(start); dur > longQueryDuration { + r.log.Debug("slow scan", zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return err +} + +func (s *stmt) Exec(args ...any) (sql.Result, error) { + return s.ExecContext(context.Background(), args...) +} + +func (s *stmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) { + start := time.Now() + result, err := s.Stmt.ExecContext(ctx, args...) + if dur := time.Since(start); dur > longQueryDuration { + s.log.Debug("slow exec", zap.String("query", s.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return result, err +} + +func (s *stmt) Query(args ...any) (*sql.Rows, error) { + return s.QueryContext(context.Background(), args...) +} + +func (s *stmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) { + start := time.Now() + rows, err := s.Stmt.QueryContext(ctx, args...) + if dur := time.Since(start); dur > longQueryDuration { + s.log.Debug("slow query", zap.String("query", s.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return rows, err +} + +func (s *stmt) QueryRow(args ...any) *row { + return s.QueryRowContext(context.Background(), args...) +} + +func (s *stmt) QueryRowContext(ctx context.Context, args ...any) *row { + start := time.Now() + r := s.Stmt.QueryRowContext(ctx, args...) + if dur := time.Since(start); dur > longQueryDuration { + s.log.Debug("slow query row", zap.String("query", s.query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return &row{r, s.log.Named("row")} +} + +// Exec executes a query without returning any rows. The args are for +// any placeholder parameters in the query. +func (tx *txn) Exec(query string, args ...any) (sql.Result, error) { + start := time.Now() + result, err := tx.Tx.Exec(query, args...) + if dur := time.Since(start); dur > longQueryDuration { + tx.log.Debug("slow exec", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return result, err +} + +// Prepare creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. The caller must call the statement's Close method +// when the statement is no longer needed. +func (tx *txn) Prepare(query string) (*stmt, error) { + start := time.Now() + s, err := tx.Tx.Prepare(query) + if dur := time.Since(start); dur > longQueryDuration { + tx.log.Debug("slow prepare", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } else if err != nil { + return nil, err + } + return &stmt{ + Stmt: s, + query: query, + log: tx.log.Named("statement"), + }, nil +} + +// Query executes a query that returns rows, typically a SELECT. The +// args are for any placeholder parameters in the query. +func (tx *txn) Query(query string, args ...any) (*rows, error) { + start := time.Now() + r, err := tx.Tx.Query(query, args...) + if dur := time.Since(start); dur > longQueryDuration { + tx.log.Debug("slow query", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return &rows{r, tx.log.Named("rows")}, err +} + +// QueryRow executes a query that is expected to return at most one row. +// QueryRow always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. If the query selects no rows, the *Row's +// Scan will return ErrNoRows. Otherwise, the *Row's Scan scans the +// first selected row and discards the rest. +func (tx *txn) QueryRow(query string, args ...any) *row { + start := time.Now() + r := tx.Tx.QueryRow(query, args...) + if dur := time.Since(start); dur > longQueryDuration { + tx.log.Debug("slow query row", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack")) + } + return &row{r, tx.log.Named("row")} +} + +func queryPlaceHolders(n int) string { + if n == 0 { + return "" + } else if n == 1 { + return "?" + } + var b strings.Builder + b.Grow(((n - 1) * 2) + 1) // ?,? + for i := 0; i < n-1; i++ { + b.WriteString("?,") + } + b.WriteString("?") + return b.String() +} + +func queryArgs[T any](args []T) []any { + if len(args) == 0 { + return nil + } + out := make([]any, len(args)) + for i, arg := range args { + out[i] = arg + } + return out +} + +// getDBVersion returns the current version of the database. +func getDBVersion(db *sql.DB) (version int64) { + // error is ignored -- the database may not have been initialized yet. + db.QueryRow(`SELECT db_version FROM global_settings;`).Scan(&version) + return +} + +// setDBVersion sets the current version of the database. +func setDBVersion(tx *txn, version int64) error { + const query = `UPDATE global_settings SET db_version=$1 RETURNING id;` + var dbID int64 + return tx.QueryRow(query, version).Scan(&dbID) +} + +// jitterSleep sleeps for a random duration between t and t*1.5. +func jitterSleep(t time.Duration) { + time.Sleep(t + time.Duration(rand.Int63n(int64(t/2)))) +} diff --git a/persist/sqlite/store.go b/persist/sqlite/store.go new file mode 100644 index 0000000..e50fda5 --- /dev/null +++ b/persist/sqlite/store.go @@ -0,0 +1,123 @@ +package sqlite + +import ( + "database/sql" + "encoding/hex" + "errors" + "fmt" + "math" + "strings" + "time" + + "go.sia.tech/coreutils/chain" + "go.uber.org/zap" + "lukechampine.com/frand" +) + +type ( + // A Store is a persistent store that uses a SQL database as its backend. + Store struct { + db *sql.DB + log *zap.Logger + + updates []*chain.ApplyUpdate + } +) + +// transaction executes a function within a database transaction. If the +// function returns an error, the transaction is rolled back. Otherwise, the +// transaction is committed. If the transaction fails due to a busy error, it is +// retried up to 10 times before returning. +func (s *Store) transaction(fn func(*txn) error) error { + var err error + txnID := hex.EncodeToString(frand.Bytes(4)) + log := s.log.Named("transaction").With(zap.String("id", txnID)) + start := time.Now() + attempt := 1 + for ; attempt < maxRetryAttempts; attempt++ { + attemptStart := time.Now() + log := log.With(zap.Int("attempt", attempt)) + err = doTransaction(s.db, log, fn) + if err == nil { + // no error, break out of the loop + return nil + } + + // return immediately if the error is not a busy error + if !strings.Contains(err.Error(), "database is locked") { + break + } + // exponential backoff + sleep := time.Duration(math.Pow(factor, float64(attempt))) * time.Millisecond + if sleep > maxBackoff { + sleep = maxBackoff + } + log.Debug("database locked", zap.Duration("elapsed", time.Since(attemptStart)), zap.Duration("totalElapsed", time.Since(start)), zap.Stack("stack"), zap.Duration("retry", sleep)) + jitterSleep(sleep) + } + return fmt.Errorf("transaction failed (attempt %d): %w", attempt, err) +} + +// Close closes the underlying database. +func (s *Store) Close() error { + return s.db.Close() +} + +func sqliteFilepath(fp string) string { + params := []string{ + fmt.Sprintf("_busy_timeout=%d", busyTimeout), + "_foreign_keys=true", + "_journal_mode=WAL", + "_secure_delete=false", + "_cache_size=-65536", // 64MiB + } + return "file:" + fp + "?" + strings.Join(params, "&") +} + +// doTransaction is a helper function to execute a function within a transaction. If fn returns +// an error, the transaction is rolled back. Otherwise, the transaction is +// committed. +func doTransaction(db *sql.DB, log *zap.Logger, fn func(tx *txn) error) error { + start := time.Now() + dbtx, err := db.Begin() + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer func() { + if err := dbtx.Rollback(); err != nil && !errors.Is(err, sql.ErrTxDone) { + log.Error("failed to rollback transaction", zap.Error(err)) + } + // log the transaction if it took longer than txn duration + if time.Since(start) > longTxnDuration { + log.Debug("long transaction", zap.Duration("elapsed", time.Since(start)), zap.Stack("stack"), zap.Bool("failed", err != nil)) + } + }() + + tx := &txn{ + Tx: dbtx, + log: log, + } + if err = fn(tx); err != nil { + return err + } else if err = tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + return nil +} + +// OpenDatabase creates a new SQLite store and initializes the database. If the +// database does not exist, it is created. +func OpenDatabase(fp string, log *zap.Logger) (*Store, error) { + db, err := sql.Open("sqlite3", sqliteFilepath(fp)) + if err != nil { + return nil, err + } + store := &Store{ + db: db, + log: log, + } + if err := store.init(); err != nil { + return nil, fmt.Errorf("failed to initialize database: %w", err) + } + return store, nil +} diff --git a/persist/sqlite/wallet.go b/persist/sqlite/wallet.go new file mode 100644 index 0000000..88f1248 --- /dev/null +++ b/persist/sqlite/wallet.go @@ -0,0 +1,306 @@ +package sqlite + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + + "go.sia.tech/core/types" + "go.sia.tech/walletd/wallet" +) + +func insertAddress(tx *txn, addr types.Address) (id int64, err error) { + const query = `INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) +VALUES ($1, $2, $2, 0) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address +RETURNING id` + + err = tx.QueryRow(query, encode(addr), encode(types.ZeroCurrency)).Scan(&id) + return +} + +// WalletEvents returns the events relevant to a wallet, sorted by height descending. +func (s *Store) WalletEvents(walletID string, offset, limit int) (events []wallet.Event, err error) { + err = s.transaction(func(tx *txn) error { + const query = `SELECT ev.id, ev.date_created, ci.height, ci.block_id, ev.event_type, ev.event_data +FROM events ev +INNER JOIN chain_indices ci ON (ev.index_id = ci.id) +WHERE ev.id IN (SELECT event_id FROM event_addresses WHERE address_id IN (SELECT address_id FROM wallet_addresses WHERE wallet_id=$1)) +ORDER BY ci.height DESC, ev.id ASC +LIMIT $2 OFFSET $3` + + rows, err := tx.Query(query, walletID, limit, offset) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var eventID int64 + var event wallet.Event + var eventType string + var eventBuf []byte + + err := rows.Scan(&eventID, decode(&event.Timestamp), &event.Index.Height, decode(&event.Index.ID), &eventType, &eventBuf) + if err != nil { + return fmt.Errorf("failed to scan event: %w", err) + } + + switch eventType { + case wallet.EventTypeTransaction: + var tx wallet.EventTransaction + if err = json.Unmarshal(eventBuf, &tx); err != nil { + return fmt.Errorf("failed to unmarshal transaction event: %w", err) + } + event.Val = &tx + case wallet.EventTypeMissedFileContract: + var m wallet.EventMissedFileContract + if err = json.Unmarshal(eventBuf, &m); err != nil { + return fmt.Errorf("failed to unmarshal missed file contract event: %w", err) + } + event.Val = &m + case wallet.EventTypeMinerPayout: + var m wallet.EventMinerPayout + if err = json.Unmarshal(eventBuf, &m); err != nil { + return fmt.Errorf("failed to unmarshal payout event: %w", err) + } + event.Val = &m + default: + return fmt.Errorf("unknown event type: %s", eventType) + } + + // event.Relevant = relevantAddresses[eventID] + events = append(events, event) + } + return nil + }) + return +} + +// AddWallet adds a wallet to the database. +func (s *Store) AddWallet(name string, info json.RawMessage) error { + return s.transaction(func(tx *txn) error { + const query = `INSERT INTO wallets (id, extra_data) VALUES ($1, $2)` + + _, err := tx.Exec(query, name, info) + if err != nil { + return fmt.Errorf("failed to insert wallet: %w", err) + } + return nil + }) +} + +// DeleteWallet deletes a wallet from the database. This does not stop tracking +// addresses that were previously associated with the wallet. +func (s *Store) DeleteWallet(name string) error { + return s.transaction(func(tx *txn) error { + _, err := tx.Exec(`DELETE FROM wallets WHERE id=$1`, name) + return err + }) +} + +// Wallets returns a map of wallet names to wallet extra data. +func (s *Store) Wallets() (map[string]json.RawMessage, error) { + wallets := make(map[string]json.RawMessage) + err := s.transaction(func(tx *txn) error { + const query = `SELECT id, extra_data FROM wallets` + + rows, err := tx.Query(query) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var friendlyName string + var extraData json.RawMessage + if err := rows.Scan(&friendlyName, &extraData); err != nil { + return fmt.Errorf("failed to scan wallet: %w", err) + } + wallets[friendlyName] = extraData + } + return nil + }) + return wallets, err +} + +// AddAddress adds an address to a wallet. +func (s *Store) AddAddress(walletID string, address types.Address, info json.RawMessage) error { + return s.transaction(func(tx *txn) error { + addressID, err := insertAddress(tx, address) + if err != nil { + return fmt.Errorf("failed to insert address: %w", err) + } + _, err = tx.Exec(`INSERT INTO wallet_addresses (wallet_id, extra_data, address_id) VALUES ($1, $2, $3)`, walletID, info, addressID) + return err + }) +} + +// RemoveAddress removes an address from a wallet. This does not stop tracking +// the address. +func (s *Store) RemoveAddress(walletID string, address types.Address) error { + return s.transaction(func(tx *txn) error { + const query = `DELETE FROM wallet_addresses WHERE wallet_id=$1 AND address_id=(SELECT id FROM sia_addresses WHERE sia_address=$2)` + _, err := tx.Exec(query, walletID, encode(address)) + return err + }) +} + +// Addresses returns a map of addresses to their extra data for a wallet. +func (s *Store) Addresses(walletID string) (map[types.Address]json.RawMessage, error) { + addresses := make(map[types.Address]json.RawMessage) + err := s.transaction(func(tx *txn) error { + const query = `SELECT sa.sia_address, wa.extra_data +FROM wallet_addresses wa +INNER JOIN sia_addresses sa ON (sa.id = wa.address_id) +WHERE wa.wallet_id=$1` + + rows, err := tx.Query(query, walletID) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var address types.Address + var extraData json.RawMessage + if err := rows.Scan(decode(&address), &extraData); err != nil { + return fmt.Errorf("failed to scan address: %w", err) + } + addresses[address] = extraData + } + return nil + }) + return addresses, err +} + +// UnspentSiacoinOutputs returns the unspent siacoin outputs for a wallet. +func (s *Store) UnspentSiacoinOutputs(walletID string) (siacoins []types.SiacoinElement, err error) { + err = s.transaction(func(tx *txn) error { + const query = `SELECT se.id, se.leaf_index, se.merkle_proof, se.siacoin_value, sa.sia_address, se.maturity_height + FROM siacoin_elements se + INNER JOIN sia_addresses sa ON (se.address_id = sa.id) + WHERE se.address_id IN (SELECT address_id FROM wallet_addresses WHERE wallet_id=$1)` + + rows, err := tx.Query(query, walletID) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var siacoin types.SiacoinElement + err := rows.Scan(decode(&siacoin.ID), &siacoin.LeafIndex, decodeSlice[types.Hash256](&siacoin.MerkleProof), decode(&siacoin.SiacoinOutput.Value), decode(&siacoin.SiacoinOutput.Address), &siacoin.MaturityHeight) + if err != nil { + return fmt.Errorf("failed to scan siacoin element: %w", err) + } + + siacoins = append(siacoins, siacoin) + } + return nil + }) + return +} + +// UnspentSiafundOutputs returns the unspent siafund outputs for a wallet. +func (s *Store) UnspentSiafundOutputs(walletID string) (siafunds []types.SiafundElement, err error) { + err = s.transaction(func(tx *txn) error { + const query = `SELECT se.id, se.leaf_index, se.merkle_proof, se.siafund_value, se.claim_start, sa.sia_address + FROM siafund_elements se + INNER JOIN sia_addresses sa ON (se.address_id = sa.id) + WHERE se.address_id IN (SELECT address_id FROM wallet_addresses WHERE wallet_id=$1)` + + rows, err := tx.Query(query, walletID) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var siafund types.SiafundElement + err := rows.Scan(decode(&siafund.ID), &siafund.LeafIndex, decodeSlice(&siafund.MerkleProof), &siafund.SiafundOutput.Value, decode(&siafund.ClaimStart), decode(&siafund.SiafundOutput.Address)) + if err != nil { + return fmt.Errorf("failed to scan siacoin element: %w", err) + } + siafunds = append(siafunds, siafund) + } + return nil + }) + return +} + +// WalletBalance returns the total balance of a wallet. +func (s *Store) WalletBalance(walletID string) (sc, immatureSC types.Currency, sf uint64, err error) { + err = s.transaction(func(tx *txn) error { + const query = `SELECT siacoin_balance, immature_siacoin_balance, siafund_balance FROM sia_addresses sa + INNER JOIN wallet_addresses wa ON (sa.id = wa.address_id) + WHERE wa.wallet_id=$1` + + rows, err := tx.Query(query, walletID) + if err != nil { + return err + } + + for rows.Next() { + var addressSC types.Currency + var addressISC types.Currency + var addressSF uint64 + + if err := rows.Scan(decode(&addressSC), decode(&addressISC), decode(&addressSF)); err != nil { + return fmt.Errorf("failed to scan address balance: %w", err) + } + sc = sc.Add(addressSC) + immatureSC = immatureSC.Add(addressISC) + sf += addressSF + } + return nil + }) + return +} + +// AddressBalance returns the balance of a single address. +func (s *Store) AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) { + err = s.transaction(func(tx *txn) error { + const query = `SELECT siacoin_balance, siafund_balance FROM address_balance WHERE sia_address=$1` + return tx.QueryRow(query, encode(address)).Scan(decode(&sc), &sf) + }) + return +} + +// Annotate annotates a list of transactions using the wallet's addresses. +func (s *Store) Annotate(walletID string, txns []types.Transaction) (annotated []wallet.PoolTransaction, err error) { + err = s.transaction(func(tx *txn) error { + const query = `SELECT sa.id FROM sia_addresses sa +INNER JOIN wallet_addresses wa ON (sa.id = wa.address_id) +WHERE wa.wallet_id=$1 AND sa.sia_address=$2 LIMIT 1` + stmt, err := tx.Prepare(query) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer stmt.Close() + + // note: this would be more performant for small wallets to load all + // addresses into memory. However, for larger wallets (> 10K addresses), + // this is time consuming. Instead, the database is queried for each + // address. Monitor performance and consider changing this in the + // future. From a memory perspective, it would be fine to lazy load all + // addresses into memory. + ownsAddress := func(address types.Address) bool { + var dbID int64 + err := stmt.QueryRow(walletID, encode(address)).Scan(dbID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + panic(err) // database error + } + return err == nil + } + + for _, txn := range txns { + ptxn := wallet.Annotate(txn, ownsAddress) + if ptxn.Type != "unrelated" { + annotated = append(annotated, ptxn) + } + } + return nil + }) + return +} diff --git a/wallet/manager.go b/wallet/manager.go new file mode 100644 index 0000000..74039ed --- /dev/null +++ b/wallet/manager.go @@ -0,0 +1,175 @@ +package wallet + +import ( + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" + "go.uber.org/zap" +) + +type ( + // A ChainManager manages the consensus state + ChainManager interface { + AddSubscriber(chain.Subscriber, types.ChainIndex) error + RemoveSubscriber(chain.Subscriber) + + BestIndex(height uint64) (types.ChainIndex, bool) + } + + // A Store is a persistent store of wallet data. + Store interface { + chain.Subscriber + + WalletEvents(name string, offset, limit int) ([]Event, error) + AddWallet(name string, info json.RawMessage) error + DeleteWallet(name string) error + Wallets() (map[string]json.RawMessage, error) + + AddAddress(walletID string, address types.Address, info json.RawMessage) error + RemoveAddress(walletID string, address types.Address) error + Addresses(walletID string) (map[types.Address]json.RawMessage, error) + UnspentSiacoinOutputs(walletID string) ([]types.SiacoinElement, error) + UnspentSiafundOutputs(walletID string) ([]types.SiafundElement, error) + Annotate(walletID string, txns []types.Transaction) ([]PoolTransaction, error) + WalletBalance(walletID string) (sc, immature types.Currency, sf uint64, err error) + + AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) + + LastCommittedIndex() (types.ChainIndex, error) + } + + // A Manager manages wallets. + Manager struct { + chain ChainManager + store Store + log *zap.Logger + + mu sync.Mutex + used map[types.Hash256]bool + } +) + +// AddWallet adds the given wallet. +func (m *Manager) AddWallet(name string, info json.RawMessage) error { + return m.store.AddWallet(name, info) +} + +// DeleteWallet deletes the given wallet. +func (m *Manager) DeleteWallet(name string) error { + return m.store.DeleteWallet(name) +} + +// Wallets returns the wallets of the wallet manager. +func (m *Manager) Wallets() (map[string]json.RawMessage, error) { + return m.store.Wallets() +} + +// AddAddress adds the given address to the given wallet. +func (m *Manager) AddAddress(name string, addr types.Address, info json.RawMessage) error { + return m.store.AddAddress(name, addr, info) +} + +// RemoveAddress removes the given address from the given wallet. +func (m *Manager) RemoveAddress(name string, addr types.Address) error { + return m.store.RemoveAddress(name, addr) +} + +// Addresses returns the addresses of the given wallet. +func (m *Manager) Addresses(name string) (map[types.Address]json.RawMessage, error) { + return m.store.Addresses(name) +} + +// Events returns the events of the given wallet. +func (m *Manager) Events(name string, offset, limit int) ([]Event, error) { + return m.store.WalletEvents(name, offset, limit) +} + +// UnspentSiacoinOutputs returns the unspent siacoin outputs of the given wallet +func (m *Manager) UnspentSiacoinOutputs(name string) ([]types.SiacoinElement, error) { + return m.store.UnspentSiacoinOutputs(name) +} + +// UnspentSiafundOutputs returns the unspent siafund outputs of the given wallet +func (m *Manager) UnspentSiafundOutputs(name string) ([]types.SiafundElement, error) { + return m.store.UnspentSiafundOutputs(name) +} + +// Annotate annotates the given transactions with the wallet they belong to. +func (m *Manager) Annotate(name string, pool []types.Transaction) ([]PoolTransaction, error) { + return m.store.Annotate(name, pool) +} + +// WalletBalance returns the balance of the given wallet. +func (m *Manager) WalletBalance(walletID string) (sc, immature types.Currency, sf uint64, err error) { + return m.store.WalletBalance(walletID) +} + +// AddressBalance returns the balance of the given address. +func (m *Manager) AddressBalance(address types.Address) (sc types.Currency, sf uint64, err error) { + return m.store.AddressBalance(address) +} + +// Reserve reserves the given ids for the given duration. +func (m *Manager) Reserve(ids []types.Hash256, duration time.Duration) error { + m.mu.Lock() + defer m.mu.Unlock() + + // check if any of the ids are already reserved + for _, id := range ids { + if m.used[id] { + return fmt.Errorf("output %q already reserved", id) + } + } + + // reserve the ids + for _, id := range ids { + m.used[id] = true + } + + // sleep for the duration and then unreserve the ids + time.AfterFunc(duration, func() { + m.mu.Lock() + defer m.mu.Unlock() + + for _, id := range ids { + delete(m.used, id) + } + }) + return nil +} + +// Subscribe resubscribes the indexer starting at the given height. +func (m *Manager) Subscribe(startHeight uint64) error { + var index types.ChainIndex + if startHeight > 0 { + var ok bool + index, ok = m.chain.BestIndex(startHeight - 1) + if !ok { + return errors.New("invalid height") + } + } + m.chain.RemoveSubscriber(m.store) + return m.chain.AddSubscriber(m.store, index) +} + +// NewManager creates a new wallet manager. +func NewManager(cm ChainManager, store Store, log *zap.Logger) (*Manager, error) { + m := &Manager{ + chain: cm, + store: store, + log: log, + } + + lastTip, err := store.LastCommittedIndex() + if err != nil { + return nil, fmt.Errorf("failed to get last committed index: %w", err) + } else if err := cm.AddSubscriber(store, lastTip); err != nil { + return nil, fmt.Errorf("failed to subscribe to chain manager: %w", err) + } + return m, nil +} diff --git a/wallet/wallet.go b/wallet/wallet.go index ecc9754..7a806b1 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -9,6 +9,13 @@ import ( "go.sia.tech/core/types" ) +// event type constants +const ( + EventTypeTransaction = "transaction" + EventTypeMinerPayout = "miner payout" + EventTypeMissedFileContract = "missed file contract" +) + // StandardTransactionSignature is the most common form of TransactionSignature. // It covers the entire transaction, references a sole public key, and has no // timelock. @@ -143,12 +150,17 @@ type Event struct { Index types.ChainIndex Timestamp time.Time Relevant []types.Address - Val interface{ eventType() string } + Val interface{ EventType() string } } -func (*EventTransaction) eventType() string { return "transaction" } -func (*EventMinerPayout) eventType() string { return "miner payout" } -func (*EventMissedFileContract) eventType() string { return "missed file contract" } +// EventType implements Event. +func (*EventTransaction) EventType() string { return EventTypeTransaction } + +// EventType implements Event. +func (*EventMinerPayout) EventType() string { return EventTypeMinerPayout } + +// EventType implements Event. +func (*EventMissedFileContract) EventType() string { return EventTypeMissedFileContract } // MarshalJSON implements json.Marshaler. func (e Event) MarshalJSON() ([]byte, error) { @@ -163,7 +175,7 @@ func (e Event) MarshalJSON() ([]byte, error) { Timestamp: e.Timestamp, Index: e.Index, Relevant: e.Relevant, - Type: e.Val.eventType(), + Type: e.Val.EventType(), Val: val, }) } @@ -184,11 +196,11 @@ func (e *Event) UnmarshalJSON(data []byte) error { e.Index = s.Index e.Relevant = s.Relevant switch s.Type { - case (*EventTransaction)(nil).eventType(): + case (*EventTransaction)(nil).EventType(): e.Val = new(EventTransaction) - case (*EventMinerPayout)(nil).eventType(): + case (*EventMinerPayout)(nil).EventType(): e.Val = new(EventMinerPayout) - case (*EventMissedFileContract)(nil).eventType(): + case (*EventMissedFileContract)(nil).EventType(): e.Val = new(EventMissedFileContract) } if e.Val == nil { @@ -228,6 +240,7 @@ type V2FileContract struct { Outputs []types.SiacoinElement `json:"outputs,omitempty"` } +// An EventTransaction represents a transaction that affects the wallet. type EventTransaction struct { ID types.TransactionID `json:"id"` SiacoinInputs []types.SiacoinElement `json:"siacoinInputs"` @@ -240,15 +253,19 @@ type EventTransaction struct { Fee types.Currency `json:"fee"` } +// An EventMinerPayout represents a miner payout from a block. type EventMinerPayout struct { SiacoinOutput types.SiacoinElement `json:"siacoinOutput"` } +// An EventMissedFileContract represents a file contract that has expired +// without a storage proof type EventMissedFileContract struct { FileContract types.FileContractElement `json:"fileContract"` MissedOutputs []types.SiacoinElement `json:"missedOutputs"` } +// A ChainUpdate is a set of changes to the consensus state. type ChainUpdate interface { ForEachSiacoinElement(func(sce types.SiacoinElement, spent bool)) ForEachSiafundElement(func(sfe types.SiafundElement, spent bool)) @@ -259,7 +276,7 @@ type ChainUpdate interface { // AppliedEvents extracts a list of relevant events from a chain update. func AppliedEvents(cs consensus.State, b types.Block, cu ChainUpdate, relevant func(types.Address) bool) []Event { var events []Event - addEvent := func(v interface{ eventType() string }, relevant []types.Address) { + addEvent := func(v interface{ EventType() string }, relevant []types.Address) { // dedup relevant addresses seen := make(map[types.Address]bool) unique := relevant[:0]