Skip to content

Commit

Permalink
light refactor and add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
buck54321 committed Aug 10, 2024
1 parent 1387d0b commit 32ac5d3
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 12 deletions.
10 changes: 8 additions & 2 deletions wallet/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import (
)

type mockChainClient struct {
getBestBlockHeight int32
getBlockHashFunc func() (*chainhash.Hash, error)
getBlockHeader *wire.BlockHeader
}

var _ chain.Interface = (*mockChainClient)(nil)
Expand All @@ -26,20 +29,23 @@ func (m *mockChainClient) Stop() {
func (m *mockChainClient) WaitForShutdown() {}

func (m *mockChainClient) GetBestBlock() (*chainhash.Hash, int32, error) {
return nil, 0, nil
return nil, m.getBestBlockHeight, nil
}

func (m *mockChainClient) GetBlock(*chainhash.Hash) (*wire.MsgBlock, error) {
return nil, nil
}

func (m *mockChainClient) GetBlockHash(int64) (*chainhash.Hash, error) {
if m.getBlockHashFunc != nil {
return m.getBlockHashFunc()
}
return nil, nil
}

func (m *mockChainClient) GetBlockHeader(*chainhash.Hash) (*wire.BlockHeader,
error) {
return nil, nil
return m.getBlockHeader, nil
}

func (m *mockChainClient) IsCurrent() bool {
Expand Down
20 changes: 10 additions & 10 deletions wallet/wallet.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ type Wallet struct {
lockedOutpoints map[wire.OutPoint]struct{}
lockedOutpointsMtx sync.Mutex

recovering atomic.Value
recovering atomic.Value // *recoverySyncer
recoveryWindow uint32

// Channels for rescan processing. Requests are added and merged with
Expand Down Expand Up @@ -280,7 +280,7 @@ func (w *Wallet) quitChan() <-chan struct{} {

// Stop signals all wallet goroutines to shutdown.
func (w *Wallet) Stop() {
w.endRecoveryAndWait()
<-w.endRecovery()

w.quitMu.Lock()
quit := w.quit
Expand Down Expand Up @@ -1382,21 +1382,21 @@ type (
heldUnlock chan struct{}
)

// endRecoveryAndWait tells (*Wallet).recovery to stop, if running, and waits
// for it to exit.
func (w *Wallet) endRecoveryAndWait() {
// endRecovery tells (*Wallet).recovery to stop, if running, and returns a
// channel that will be closed when the recovery routine exits.
func (w *Wallet) endRecovery() <-chan struct{} {
if recoverySyncI := w.recovering.Load(); recoverySyncI != nil {
recoverySync := recoverySyncI.(*recoverySyncer)

// If recovery is still running, it will end early with an error
// once we set the quit flag.
atomic.StoreUint32(&recoverySync.quit, 1)

select {
case <-recoverySync.done:
case <-w.quitChan():
}
return recoverySync.done
}
c := make(chan struct{})
close(c)
return c
}

// walletLocker manages the locked/unlocked state of a wallet.
Expand Down Expand Up @@ -1491,7 +1491,7 @@ out:

// We can't lock the manager if recovery is active because we use
// cryptoKeyPriv and cryptoKeyScript in recovery.
w.endRecoveryAndWait()
<-w.endRecovery()

timeout = nil
err := w.Manager.Lock()
Expand Down
120 changes: 120 additions & 0 deletions wallet/wallet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ package wallet
import (
"encoding/hex"
"fmt"
"math"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/btcsuite/btcwallet/wtxmgr"
Expand Down Expand Up @@ -359,3 +363,119 @@ func TestDuplicateAddressDerivation(t *testing.T) {
require.NoError(t, eg.Wait())
}
}

func TestEndRecovery(t *testing.T) {
// This is an unconventional unit test, but I'm trying to keep things as
// succint as possible so that this test is readable without having to mock
// up literally everything.
// The unmonitored goroutine we're looking at is pretty deep:
// SynchronizeRPC -> handleChainNotifications -> syncWithChain -> recovery
// The "deadlock" we're addressing isn't actually a deadlock, but the wallet
// will hang on Stop() -> WaitForShutdown() until (*Wallet).recovery gets
// every single block, which could be hours depending on hardware and
// network factors. The WaitGroup is incremented in SynchronizeRPC, and
// WaitForShutdown will not return until handleChainNotifications returns,
// which is blocked by a running (*Wallet).recovery loop.
// It is noted that the conditions for long recovery are difficult to hit
// when using btcwallet with a fresh seed, because it requires an early
// birthday to be set or established.

w, cleanup := testWallet(t)

blockHashCalled := make(chan struct{})
iterateLoop := func() {
select {
case <-blockHashCalled:
default:
}
}

chainClient := &mockChainClient{
// Force the loop to iterate about forever.
getBestBlockHeight: math.MaxInt32,
// Get control of when the loop iterates.
getBlockHashFunc: func() (*chainhash.Hash, error) {
blockHashCalled <- struct{}{}
return &chainhash.Hash{}, nil
},
// Avoid a panic.
getBlockHeader: &wire.BlockHeader{},
}

go w.recovery(chainClient, &waddrmgr.BlockStamp{})

getBlockHashCalls := func(expCalls int) {
var i int
for {
select {
case <-blockHashCalled:
i++
case <-time.After(time.Second):
t.Fatal("expected BlockHash to be called")
}
if i == expCalls {
break
}
}
}

// Recovery is running
getBlockHashCalls(3)

// Closing the quit channel, e.g. Stop() without endRecovery, alone will not
// end the recovery loop.
w.quitMu.Lock()
close(w.quit)
w.quitMu.Unlock()
// Continues scanning.
getBlockHashCalls(3)

// We're done with this one
atomic.StoreUint32(&w.recovering.Load().(*recoverySyncer).quit, 1)
iterateLoop()
cleanup()

// Try again.
w, cleanup = testWallet(t)
defer cleanup()

// We'll catch the error to make sure we're hitting our desired path. The
// WaitGroup isn't required for the test, but does show how it completes
// shutdown at a higher level.
var err error
w.wg.Add(1)
go func() {
defer w.wg.Done()
err = w.recovery(chainClient, &waddrmgr.BlockStamp{})
}()

done := make(chan struct{})
go func() {
w.WaitForShutdown()
close(done)
}()

// Recovery is running
getBlockHashCalls(3)

// endRecovery is required to exit the unmonitored goroutine.
end := w.endRecovery()
iterateLoop()
<-end

if !strings.EqualFold(err.Error(), "recovery: forced shutdown") {
t.Fatal("wrong error")
}

// testWallet starts a couple of other unrelated goroutines that need to be
// killed, so we still need to close the quit channel.
w.quitMu.Lock()
close(w.quit)
w.quitMu.Unlock()

select {
case <-done:
case <-time.After(time.Second):
t.Fatal("WaitForShutdown never returned")
}
}

0 comments on commit 32ac5d3

Please sign in to comment.