diff --git a/channeldb/kvdb/etcd/commit_queue.go b/channeldb/kvdb/etcd/commit_queue.go index f03845650c..8b9c372927 100644 --- a/channeldb/kvdb/etcd/commit_queue.go +++ b/channeldb/kvdb/etcd/commit_queue.go @@ -3,14 +3,12 @@ package etcd import ( + "container/list" "context" "sync" + "time" ) -// commitQueueSize is the maximum number of commits we let to queue up. All -// remaining commits will block on commitQueue.Add(). -const commitQueueSize = 100 - // commitQueue is a simple execution queue to manage conflicts for transactions // and thereby reduce the number of times conflicting transactions need to be // retried. When a new transaction is added to the queue, we first upgrade the @@ -25,9 +23,12 @@ type commitQueue struct { readerMap map[string]int writerMap map[string]int - commitMutex sync.RWMutex - queue chan (func()) - wg sync.WaitGroup + queueMu sync.RWMutex + queueCond *sync.Cond + queue *list.List + freeCount uint32 + + shutdown chan struct{} } // NewCommitQueue creates a new commit queue, with the passed abort context. @@ -36,11 +37,12 @@ func NewCommitQueue(ctx context.Context) *commitQueue { ctx: ctx, readerMap: make(map[string]int), writerMap: make(map[string]int), - queue: make(chan func(), commitQueueSize), + queue: list.New(), + shutdown: make(chan struct{}), } + q.queueCond = sync.NewCond(&q.queueMu) // Start the queue consumer loop. - q.wg.Add(1) go q.mainLoop() return q @@ -48,7 +50,7 @@ func NewCommitQueue(ctx context.Context) *commitQueue { // Wait waits for the queue to stop (after the queue context has been canceled). func (c *commitQueue) Wait() { - c.wg.Wait() + c.signalUntilShutdown() } // Add increases lock counts and queues up tx commit closure for execution. @@ -83,27 +85,38 @@ func (c *commitQueue) Add(commitLoop func(), rset readSet, wset writeSet) { } if blocked { - // Add the transaction to the queue if conflicts with an already - // queued one. + // Add the transaction to the queue if it conflicts with an + // already queued one. It is safe to do so outside the lock, + // since this we know it will be executed serially. c.mx.Unlock() - select { - case c.queue <- commitLoop: - case <-c.ctx.Done(): - } + c.queueCond.L.Lock() + c.queue.PushBack(commitLoop) + c.queueCond.L.Unlock() } else { // To make sure we don't add a new tx to the queue that depends - // on this "unblocked" tx, grab the commitMutex before lifting - // the mutex guarding the lock maps. - c.commitMutex.RLock() + // on this "unblocked" tx. Increment our free counter before + // unlocking so that the mainLoop stops pulling off blocked + // transactions from the queue. + + c.queueCond.L.Lock() + c.freeCount++ + c.queueCond.L.Unlock() + c.mx.Unlock() - // At this point we're safe to execute the "unblocked" tx, as - // we cannot execute blocked tx that may have been read from the - // queue until the commitMutex is held. - commitLoop() + // At this point it is safe to execute the "unblocked" tx, as no + // blocked tx will be read from the queue until the freeCount is + // decremented back to 0. + go func() { + commitLoop() + + c.queueCond.L.Lock() + c.freeCount-- + c.queueCond.L.Unlock() + c.queueCond.Signal() + }() - c.commitMutex.RUnlock() } } @@ -131,20 +144,59 @@ func (c *commitQueue) Done(rset readSet, wset writeSet) { // dependencies. The queue ensures that the top element doesn't conflict with // any other transactions and therefore can be executed freely. func (c *commitQueue) mainLoop() { - defer c.wg.Done() + defer close(c.shutdown) for { + // Wait until there are no unblocked transactions being + // executed, and for there to be at least one blocked + // transaction in our queue. + c.queueCond.L.Lock() + for c.freeCount > 0 || c.queue.Front() == nil { + c.queueCond.Wait() + + // Check the exit condition before looping again. + select { + case <-c.ctx.Done(): + c.queueCond.L.Unlock() + return + default: + } + } + + // Remove the top element from the queue, now that we know there + // are no possible conflicts. + e := c.queue.Front() + top := c.queue.Remove(e).(func()) + c.queueCond.L.Unlock() + + // Check if we need to exit before continuing. select { - case top := <-c.queue: - // Execute the next blocked transaction. As it is - // the top element in the queue it means that it doesn't - // depend on any other transactions anymore. - c.commitMutex.Lock() - top() - c.commitMutex.Unlock() + case <-c.ctx.Done(): + return + default: + } + + // Execute the next blocked transaction. + top() + // Check if we need to exit before continuing. + select { case <-c.ctx.Done(): return + default: + } + } +} + +// signalUntilShutdown strobes the queue's condition variable to ensure the +// mainLoop reliably unblocks to check for the exit condition. +func (c *commitQueue) signalUntilShutdown() { + for { + select { + case <-time.After(time.Millisecond): + c.queueCond.Signal() + case <-c.shutdown: + return } } } diff --git a/channeldb/kvdb/etcd/commit_queue_test.go b/channeldb/kvdb/etcd/commit_queue_test.go index 16ff71006d..0b0585a85f 100644 --- a/channeldb/kvdb/etcd/commit_queue_test.go +++ b/channeldb/kvdb/etcd/commit_queue_test.go @@ -7,7 +7,6 @@ import ( "sync" "sync/atomic" "testing" - "time" "github.com/stretchr/testify/require" ) @@ -15,16 +14,18 @@ import ( // TestCommitQueue tests that non-conflicting transactions commit concurrently, // while conflicting transactions are queued up. func TestCommitQueue(t *testing.T) { - // The duration of each commit. - const commitDuration = time.Millisecond * 500 const numCommits = 4 - var wg sync.WaitGroup commits := make([]string, numCommits) idx := int32(-1) - commit := func(tag string, sleep bool) func() { + commit := func(tag string, commit chan struct{}, + ready <-chan struct{}) func() { + return func() { + if commit != nil { + close(commit) + } defer wg.Done() // Update our log of commit order. Avoid blocking @@ -33,8 +34,8 @@ func TestCommitQueue(t *testing.T) { i := atomic.AddInt32(&idx, 1) commits[i] = tag - if sleep { - time.Sleep(commitDuration) + if ready != nil { + <-ready } } } @@ -68,45 +69,53 @@ func TestCommitQueue(t *testing.T) { defer cancel() wg.Add(numCommits) - t1 := time.Now() + + ready := make(chan struct{}) // Tx1: reads: key1, key2, writes: key3, conflict: none + // Since we simulate that the txn takes a long time, we'll add in a + // new goroutine and wait for the txn commit to start execution. q.Add( - commit("free", true), + commit("free", nil, ready), makeReadSet([]string{"key1", "key2"}), makeWriteSet([]string{"key3"}), ) - // Tx2: reads: key1, key2, writes: key3, conflict: Tx1 + + // Tx2: reads: key1, key5, writes: key3, conflict: Tx1 (key3) + // We don't expect queue add to block as this txn will queue up after + // tx1. q.Add( - commit("blocked1", false), + commit("blocked1", nil, nil), makeReadSet([]string{"key1", "key2"}), makeWriteSet([]string{"key3"}), ) - // Tx3: reads: key1, writes: key4, conflict: none + + // Tx3: reads: key1, key2, writes: key4, conflict: none + // We expect this transaction to be reordered before blocked1, even + // though it was added after since it it doesn't have any conflicts. q.Add( - commit("free", true), + commit("free", nil, ready), makeReadSet([]string{"key1", "key2"}), makeWriteSet([]string{"key4"}), ) - // Tx4: reads: key2, writes: key4 conflict: Tx3 + + // Tx4: reads: key2, writes: key4 conflicts: Tx3 (key4) + // We don't expect queue add to block as this txn will queue up after + // tx2. q.Add( - commit("blocked2", false), + commit("blocked2", nil, nil), makeReadSet([]string{"key2"}), makeWriteSet([]string{"key4"}), ) + // Allow Tx1 to continue with the commit. + close(ready) + // Wait for all commits. wg.Wait() - t2 := time.Now() - - // Expected total execution time: delta. - // 2 * commitDuration <= delta < 3 * commitDuration - delta := t2.Sub(t1) - require.LessOrEqual(t, int64(commitDuration*2), int64(delta)) - require.Greater(t, int64(commitDuration*3), int64(delta)) // Expect that the non-conflicting "free" transactions are executed - // before the blocking ones, and the blocking ones are executed in + // before the conflicting ones, and the conflicting ones are executed in // the order of addition. require.Equal(t, []string{"free", "free", "blocked1", "blocked2"}, diff --git a/channeldb/kvdb/etcd/db.go b/channeldb/kvdb/etcd/db.go index 576591aa60..764a9f1652 100644 --- a/channeldb/kvdb/etcd/db.go +++ b/channeldb/kvdb/etcd/db.go @@ -116,6 +116,8 @@ func (c *commitStatsCollector) callback(succ bool, stats CommitStats) { // db holds a reference to the etcd client connection. type db struct { + ctx context.Context + cancel func() config BackendConfig cli *clientv3.Client commitStatsCollector *commitStatsCollector @@ -168,6 +170,8 @@ func newEtcdBackend(config BackendConfig) (*db, error) { config.Ctx = context.Background() } + ctx, cancel := context.WithCancel(config.Ctx) + tlsInfo := transport.TLSInfo{ CertFile: config.CertFile, KeyFile: config.KeyFile, @@ -180,7 +184,7 @@ func newEtcdBackend(config BackendConfig) (*db, error) { } cli, err := clientv3.New(clientv3.Config{ - Context: config.Ctx, + Context: ctx, Endpoints: []string{config.Host}, DialTimeout: etcdConnectionTimeout, Username: config.User, @@ -198,9 +202,11 @@ func newEtcdBackend(config BackendConfig) (*db, error) { cli.Lease = namespace.NewLease(cli.Lease, config.Namespace) backend := &db{ + ctx: ctx, + cancel: cancel, cli: cli, config: config, - txQueue: NewCommitQueue(config.Ctx), + txQueue: NewCommitQueue(ctx), } if config.CollectCommitStats { @@ -213,7 +219,7 @@ func newEtcdBackend(config BackendConfig) (*db, error) { // getSTMOptions creats all STM options based on the backend config. func (db *db) getSTMOptions() []STMOptionFunc { opts := []STMOptionFunc{ - WithAbortContext(db.config.Ctx), + WithAbortContext(db.ctx), } if db.config.CollectCommitStats { @@ -286,7 +292,7 @@ func (db *db) BeginReadTx() (walletdb.ReadTx, error) { // start a read-only transaction to perform all operations. // This function is part of the walletdb.Db interface implementation. func (db *db) Copy(w io.Writer) error { - ctx, cancel := context.WithTimeout(db.config.Ctx, etcdLongTimeout) + ctx, cancel := context.WithTimeout(db.ctx, etcdLongTimeout) defer cancel() readCloser, err := db.cli.Snapshot(ctx) @@ -302,5 +308,8 @@ func (db *db) Copy(w io.Writer) error { // Close cleanly shuts down the database and syncs all data. // This function is part of the walletdb.Db interface implementation. func (db *db) Close() error { - return db.cli.Close() + err := db.cli.Close() + db.cancel() + db.txQueue.Wait() + return err } diff --git a/channeldb/kvdb/etcd/stm.go b/channeldb/kvdb/etcd/stm.go index 59ac1f457d..1a16bc5176 100644 --- a/channeldb/kvdb/etcd/stm.go +++ b/channeldb/kvdb/etcd/stm.go @@ -287,8 +287,17 @@ func runSTM(s *stm, apply func(STM) error) error { select { case <-done: case <-s.options.ctx.Done(): + return context.Canceled } + // If the transaction executed, we can decrement the read and write lock + // sets and apply an commit stat callbacks. + // + // NOTE: It is not safe to do this in the case where the context is + // canceled, as it might inadvertently unblock other transactions that + // _should_ depend on this one. Furthermore, the executeErr is mutable + // so long as the done channel hasn't returned, so we can't read or + // return it. s.txQueue.Done(s.rset, s.wset) if s.options.commitStatsCallback != nil {