diff --git a/error/error.go b/error/error.go index adc9776c4..2cf04a65d 100644 --- a/error/error.go +++ b/error/error.go @@ -212,6 +212,14 @@ func (e *ErrTxnTooLarge) Error() string { return fmt.Sprintf("txn too large, size: %v.", e.Size) } +type ErrKeyTooLarge struct { + KeySize int +} + +func (e *ErrKeyTooLarge) Error() string { + return fmt.Sprintf("key size too large, size: %v.", e.KeySize) +} + // ErrEntryTooLarge is the error when a key value entry is too large. type ErrEntryTooLarge struct { Limit uint64 diff --git a/go.mod b/go.mod index 6c40958f3..d74ccfc22 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/prometheus/client_model v0.5.0 github.com/stretchr/testify v1.8.2 github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a - github.com/tikv/pd/client v0.0.0-20240620115049-049de1761e56 + github.com/tikv/pd/client v0.0.0-20241111073742-238d4d79ea31 github.com/twmb/murmur3 v1.1.3 go.etcd.io/etcd/api/v3 v3.5.10 go.etcd.io/etcd/client/v3 v3.5.10 diff --git a/go.sum b/go.sum index 88bbc2899..7dcfef12b 100644 --- a/go.sum +++ b/go.sum @@ -100,6 +100,7 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -111,8 +112,8 @@ github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a h1:J/YdBZ46WKpXsxsW93SG+q0F8KI+yFrcIDT4c/RNoc4= github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a/go.mod h1:h4xBhSNtOeEosLJ4P7JyKXX7Cabg7AVkWCK5gV2vOrM= -github.com/tikv/pd/client v0.0.0-20240620115049-049de1761e56 h1:7TLLfwrKoty9UeJMsSopRTXYw8ooxcF0Z1fegXhIgks= -github.com/tikv/pd/client v0.0.0-20240620115049-049de1761e56/go.mod h1:EHHidLItrJGh0jqfdfFhIHG5vwkR8+43tFnp7v7iv1Q= +github.com/tikv/pd/client v0.0.0-20241111073742-238d4d79ea31 h1:oAYc4m5Eu1OY9ogJ103VO47AYPHvhtzbUPD8L8B67Qk= +github.com/tikv/pd/client v0.0.0-20241111073742-238d4d79ea31/go.mod h1:W5a0sDadwUpI9k8p7M77d3jo253ZHdmua+u4Ho4Xw8U= github.com/twmb/murmur3 v1.1.3 h1:D83U0XYKcHRYwYIpBKf3Pks91Z0Byda/9SJ8B6EMRcA= github.com/twmb/murmur3 v1.1.3/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/integration_tests/go.mod b/integration_tests/go.mod index df4630f3e..2cd7662be 100644 --- a/integration_tests/go.mod +++ b/integration_tests/go.mod @@ -12,7 +12,7 @@ require ( github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.14.1 github.com/tikv/client-go/v2 v2.0.8-0.20240626064248-4a72526f6c30 - github.com/tikv/pd/client v0.0.0-20240620115049-049de1761e56 + github.com/tikv/pd/client v0.0.0-20241111073742-238d4d79ea31 go.uber.org/goleak v1.3.0 ) diff --git a/integration_tests/go.sum b/integration_tests/go.sum index 63ae1f91b..df723a21d 100644 --- a/integration_tests/go.sum +++ b/integration_tests/go.sum @@ -474,8 +474,8 @@ github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/tikv/pd/client v0.0.0-20240620115049-049de1761e56 h1:7TLLfwrKoty9UeJMsSopRTXYw8ooxcF0Z1fegXhIgks= -github.com/tikv/pd/client v0.0.0-20240620115049-049de1761e56/go.mod h1:EHHidLItrJGh0jqfdfFhIHG5vwkR8+43tFnp7v7iv1Q= +github.com/tikv/pd/client v0.0.0-20241111073742-238d4d79ea31 h1:oAYc4m5Eu1OY9ogJ103VO47AYPHvhtzbUPD8L8B67Qk= +github.com/tikv/pd/client v0.0.0-20241111073742-238d4d79ea31/go.mod h1:W5a0sDadwUpI9k8p7M77d3jo253ZHdmua+u4Ho4Xw8U= github.com/tklauser/go-sysconf v0.3.9/go.mod h1:11DU/5sG7UexIrp/O6g35hrWzu0JxlwQ3LSFUzyeuhs= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= diff --git a/internal/client/client_interceptor.go b/internal/client/client_interceptor.go index 46dbf4bd2..b82ae7e01 100644 --- a/internal/client/client_interceptor.go +++ b/internal/client/client_interceptor.go @@ -124,13 +124,13 @@ func buildResourceControlInterceptor( resp, err := next(target, req) if resp != nil { respInfo := resourcecontrol.MakeResponseInfo(resp) - consumption, err = resourceControlInterceptor.OnResponse(resourceGroupName, reqInfo, respInfo) + consumption, waitDuration, err = resourceControlInterceptor.OnResponseWait(ctx, resourceGroupName, reqInfo, respInfo) if err != nil { return nil, err } if ruDetails != nil { detail := ruDetails.(*util.RUDetails) - detail.Update(consumption, time.Duration(0)) + detail.Update(consumption, waitDuration) } } return resp, err diff --git a/internal/unionstore/art/art.go b/internal/unionstore/art/art.go index 4fa76b401..743374510 100644 --- a/internal/unionstore/art/art.go +++ b/internal/unionstore/art/art.go @@ -47,7 +47,7 @@ type ART struct { len int size int - // The lastTraversedNode stores addr in uint64 of the last traversed node, include search and recursiveInsert. + // The lastTraversedNode stores addr in uint64 of the last traversed node, includes search and recursiveInsert. // Compare to atomic.Pointer, atomic.Uint64 can avoid heap allocation, so it's more efficient. lastTraversedNode atomic.Uint64 hitCount atomic.Uint64 @@ -68,8 +68,13 @@ func New() *ART { } func (t *ART) Get(key []byte) ([]byte, error) { + if t.vlogInvalid { + // panic for easier debugging. + panic("vlog is reset") + } + // 1. search the leaf node. - _, leaf := t.search(key) + _, leaf := t.traverse(key, false) if leaf == nil || leaf.vAddr.IsNull() { return nil, tikverr.ErrNotExist } @@ -79,7 +84,7 @@ func (t *ART) Get(key []byte) ([]byte, error) { // GetFlags returns the latest flags associated with key. func (t *ART) GetFlags(key []byte) (kv.KeyFlags, error) { - _, leaf := t.search(key) + _, leaf := t.traverse(key, false) if leaf == nil { return 0, tikverr.ErrNotExist } @@ -90,6 +95,17 @@ func (t *ART) GetFlags(key []byte) (kv.KeyFlags, error) { } func (t *ART) Set(key artKey, value []byte, ops ...kv.FlagsOp) error { + if t.vlogInvalid { + // panic for easier debugging. + panic("vlog is reset") + } + + if len(key) > MaxKeyLen { + return &tikverr.ErrKeyTooLarge{ + KeySize: len(key), + } + } + if value != nil { if size := uint64(len(key) + len(value)); size > t.entrySizeLimit { return &tikverr.ErrEntryTooLarge{ @@ -98,11 +114,12 @@ func (t *ART) Set(key artKey, value []byte, ops ...kv.FlagsOp) error { } } } + if len(t.stages) == 0 { t.dirty = true } // 1. create or search the existing leaf in the tree. - addr, leaf := t.recursiveInsert(key) + addr, leaf := t.traverse(key, true) // 2. set the value and flags. t.setValue(addr, leaf, value, ops) if uint64(t.Size()) > t.bufferSizeLimit { @@ -111,8 +128,8 @@ func (t *ART) Set(key artKey, value []byte, ops ...kv.FlagsOp) error { return nil } -// search wraps searchImpl with cache. -func (t *ART) search(key artKey) (arena.MemdbArenaAddr, *artLeaf) { +// traverse wraps search and recursiveInsert with cache. +func (t *ART) traverse(key artKey, insert bool) (arena.MemdbArenaAddr, *artLeaf) { // check cache addr, leaf, found := t.checkKeyInCache(key) if found { @@ -120,17 +137,21 @@ func (t *ART) search(key artKey) (arena.MemdbArenaAddr, *artLeaf) { return addr, leaf } t.missCount.Add(1) - addr, leaf = t.searchImpl(key) + if insert { + addr, leaf = t.recursiveInsert(key) + } else { + addr, leaf = t.search(key) + } if !addr.IsNull() { t.updateLastTraversed(addr) } return addr, leaf } -// searchImpl looks up the leaf with the given key. +// search looks up the leaf with the given key. // It returns the memory arena address and leaf itself it there is a match leaf, // returns arena.NullAddr and nil if the key is not found. -func (t *ART) searchImpl(key artKey) (arena.MemdbArenaAddr, *artLeaf) { +func (t *ART) search(key artKey) (arena.MemdbArenaAddr, *artLeaf) { current := t.root if current == nullArtNode { return arena.NullAddr, nil @@ -179,25 +200,9 @@ func (t *ART) searchImpl(key artKey) (arena.MemdbArenaAddr, *artLeaf) { } } -// recursiveInsert wraps recursiveInsertImpl with cache. -func (t *ART) recursiveInsert(key artKey) (arena.MemdbArenaAddr, *artLeaf) { - addr, leaf, found := t.checkKeyInCache(key) - if found { - t.hitCount.Add(1) - return addr, leaf - } - t.missCount.Add(1) - addr, leaf = t.recursiveInsertImpl(key) - if !addr.IsNull() { - t.updateLastTraversed(addr) - } - return addr, leaf -} - -// recursiveInsertImpl returns the node address of the key. +// recursiveInsert returns the node address of the key. // It will insert the key if not exists, returns the newly inserted or existing leaf. -func (t *ART) recursiveInsertImpl(key artKey) (arena.MemdbArenaAddr, *artLeaf) { - +func (t *ART) recursiveInsert(key artKey) (arena.MemdbArenaAddr, *artLeaf) { // lazy init root node and allocator. // this saves memory for read only txns. if t.root.addr.IsNull() { @@ -300,11 +305,7 @@ func (t *ART) expandLeaf(key artKey, depth uint32, prev, current artNode) (arena newAn.addChild(&t.allocator, l2Key.charAt(int(depth)), !l2Key.valid(int(depth)), leaf2Addr) // swap the old leaf with the new node4. - if prev == nullArtNode { - t.root = newAn - } else { - prev.replaceChild(&t.allocator, key.charAt(prevDepth), newAn) - } + prev.replaceChild(&t.allocator, key.charAt(prevDepth), newAn) return leaf2Addr.addr, leaf2 } @@ -328,30 +329,25 @@ func (t *ART) expandNode(key artKey, depth, mismatchIdx uint32, prev, current ar newN4.setPrefix(key[depth:], mismatchIdx) // update prefix for old node and move it as a child of the new node. + var prefix artKey if currNode.prefixLen <= maxPrefixLen { - nodeKey := currNode.prefix[mismatchIdx] - currNode.prefixLen -= mismatchIdx + 1 - copy(currNode.prefix[:], currNode.prefix[mismatchIdx+1:]) - newAn.addChild(&t.allocator, nodeKey, false, current) + // node.prefixLen <= maxPrefixLen means all the prefix is in the prefix array. + // The char at mismatchIdx will be stored in the index of new node. + prefix = currNode.prefix[:] } else { - currNode.prefixLen -= mismatchIdx + 1 + // Unless, we need to find the prefix in the leaf. + // Any leaves in the node should have the same prefix, we use minimum node here. leafArtNode := minimum(&t.allocator, current) - leaf := leafArtNode.asLeaf(&t.allocator) - leafKey := artKey(leaf.GetKey()) - kMin := depth + mismatchIdx + 1 - kMax := depth + mismatchIdx + 1 + min(currNode.prefixLen, maxPrefixLen) - copy(currNode.prefix[:], leafKey[kMin:kMax]) - newAn.addChild(&t.allocator, leafKey.charAt(int(depth+mismatchIdx)), !leafKey.valid(int(depth)), current) + prefix = leafArtNode.asLeaf(&t.allocator).GetKey()[depth : depth+currNode.prefixLen] } + nodeChar := prefix[mismatchIdx] + currNode.setPrefix(prefix[mismatchIdx+1:], currNode.prefixLen-mismatchIdx-1) + newAn.addChild(&t.allocator, nodeChar, false, current) // insert the artLeaf into new node newLeafAddr, newLeaf := t.newLeaf(key) newAn.addChild(&t.allocator, key.charAt(int(depth+mismatchIdx)), !key.valid(int(depth+mismatchIdx)), newLeafAddr) - if prev == nullArtNode { - t.root = newAn - } else { - prev.replaceChild(&t.allocator, key.charAt(prevDepth), newAn) - } + prev.replaceChild(&t.allocator, key.charAt(prevDepth), newAn) return newLeafAddr.addr, newLeaf } @@ -577,7 +573,6 @@ func (t *ART) SelectValueHistory(key []byte, predicate func(value []byte) bool) return nil, nil } return t.allocator.vlogAllocator.GetValue(result), nil - } func (t *ART) SetMemoryFootprintChangeHook(hook func(uint64)) { diff --git a/internal/unionstore/art/art_arena.go b/internal/unionstore/art/art_arena.go index 84ca21127..f196bf0cb 100644 --- a/internal/unionstore/art/art_arena.go +++ b/internal/unionstore/art/art_arena.go @@ -25,6 +25,8 @@ import ( // reusing blocks reduces the memory pieces. type nodeArena struct { arena.MemdbArena + // The ART node will expand to a higher capacity, and the address of the freed node will be stored in the free list for reuse. + // By reusing the freed node, memory usage and fragmentation can be reduced. freeNode4 []arena.MemdbArenaAddr freeNode16 []arena.MemdbArenaAddr freeNode48 []arena.MemdbArenaAddr diff --git a/internal/unionstore/art/art_iterator.go b/internal/unionstore/art/art_iterator.go index ad0c8521c..3db268612 100644 --- a/internal/unionstore/art/art_iterator.go +++ b/internal/unionstore/art/art_iterator.go @@ -16,6 +16,7 @@ package art import ( "bytes" + "fmt" "sort" "github.com/pkg/errors" @@ -156,140 +157,161 @@ func (it *Iterator) init(lowerBound, upperBound []byte) { } } + startKey, endKey := lowerBound, upperBound if it.reverse { - it.inner.idxes, it.inner.nodes = it.seek(upperBound) - if len(lowerBound) == 0 { - it.endAddr = arena.NullAddr - } else { - helper := new(baseIter) - helper.allocator = &it.tree.allocator - helper.idxes, helper.nodes = it.seek(lowerBound) - if it.inner.compare(helper) > 0 { - // lowerBound is inclusive, call next to find the smallest leaf node that >= lowerBound. - it.endAddr = helper.next().addr - if it.inner.compare(helper) < 0 || len(helper.idxes) == 0 { - it.valid = false - } - return - } - it.valid = false - } - return + startKey, endKey = upperBound, lowerBound } - it.inner.idxes, it.inner.nodes = it.seek(lowerBound) - if len(upperBound) == 0 { - it.endAddr = arena.NullAddr + if len(startKey) == 0 { + it.inner.seekToFirst(it.tree.root, it.reverse) } else { - helper := new(baseIter) - helper.allocator = &it.tree.allocator - helper.idxes, helper.nodes = it.seek(upperBound) - if it.inner.compare(helper) < 0 { - // upperBound is exclusive, so we move the helper cursor to the previous node, which is the true endAddr. - it.endAddr = helper.prev().addr - if it.inner.compare(helper) > 0 || len(helper.idxes) == 0 { - it.valid = false - } - return - } + it.inner.seek(it.tree.root, startKey) + } + if len(endKey) == 0 { + it.endAddr = arena.NullAddr + return + } + + helper := new(baseIter) + helper.allocator = &it.tree.allocator + helper.seek(it.tree.root, endKey) + cmp := it.inner.compare(helper) + if cmp == 0 { + // no keys exist between start key and end key, set the iterator to invalid. it.valid = false return } + + if it.reverse { + it.endAddr = helper.next().addr + } else { + it.endAddr = helper.prev().addr + } + cmp = it.inner.compare(helper) + if cmp == 0 { + // the current key is valid. + return + } + // in asc scan, if cmp > 0, it means current key is larger than end key, set the iterator to invalid. + // in desc scan, if cmp < 0, it means current key is less than end key, set the iterator to invalid. + if cmp < 0 == it.reverse || len(helper.idxes) == 0 { + it.valid = false + } } -// seek the first node and index that >= key, return the indexes and nodes of the lookup path -// nodes[0] is the root node -func (it *Iterator) seek(key artKey) ([]int, []artNode) { - curr := it.tree.root - depth := uint32(0) - idxes := make([]int, 0, 8) - nodes := make([]artNode, 0, 8) +// baseIter is the inner iterator for ART tree. +// You need to call seek or seekToFirst to initialize the iterator. +// after initialization, you can call next or prev to iterate the ART. +// next or prev returns nullArtNode if there is no more leaf node. +type baseIter struct { + allocator *artAllocator + // the baseIter iterate the ART tree in DFS order, the idxes and nodes are the current visiting stack. + idxes []int + nodes []artNode +} + +// seekToFirst seeks the boundary of the tree, sequential or reverse. +func (it *baseIter) seekToFirst(root artNode, reverse bool) { + // if the seek key is empty, it means -inf or +inf, return root node directly. + it.nodes = []artNode{root} + if reverse { + it.idxes = []int{node256cap} + } else { + it.idxes = []int{inplaceIndex} + } +} + +// seek the first node and index that >= key, nodes[0] is the root node +func (it *baseIter) seek(root artNode, key artKey) { if len(key) == 0 { - // if the seek key is empty, it means -inf or +inf, return root node directly. - nodes = append(nodes, curr) - if it.reverse { - idxes = append(idxes, node256cap) - } else { - idxes = append(idxes, inplaceIndex) - } - return idxes, nodes + panic("seek with empty key is not allowed") } + curr := root + depth := uint32(0) var node *nodeBase for { if curr.isLeaf() { if key.valid(int(depth)) { - lf := curr.asLeaf(&it.tree.allocator) + lf := curr.asLeaf(it.allocator) if bytes.Compare(key, lf.GetKey()) > 0 { // the seek key is not exist, and it's longer and larger than the current leaf's key. // e.g. key: [1, 1, 1], leaf: [1, 1]. - idxes[len(idxes)-1]++ + it.idxes[len(it.idxes)-1]++ } } - break + return } - node = curr.asNode(&it.tree.allocator) + it.nodes = append(it.nodes, curr) + node = curr.asNode(it.allocator) if node.prefixLen > 0 { - mismatchIdx := node.matchDeep(&it.tree.allocator, &curr, key, depth) + mismatchIdx := node.matchDeep(it.allocator, &curr, key, depth) if mismatchIdx < node.prefixLen { - // no leaf node is match with the seek key - leafNode := minimum(&it.tree.allocator, curr) - leafKey := leafNode.asLeaf(&it.tree.allocator).GetKey() - if mismatchIdx+depth == uint32(len(key)) || key[depth+mismatchIdx] < leafKey[depth+mismatchIdx] { - // key < leafKey, set index to -1 means all the children are larger than the seek key - idxes = append(idxes, -1) + // Check whether the seek key is smaller than the prefix, as no leaf node matches the seek key. + // If the seek key is smaller than the prefix, all the children are located on the right side of the seek key. + // Otherwise, the children are located on the left side of the seek key. + var prefix []byte + if mismatchIdx < maxPrefixLen { + prefix = node.prefix[:] } else { - // key > leafKey, set index to 256 means all the children are less than the seek key - idxes = append(idxes, node256cap) + leafNode := minimum(it.allocator, curr) + prefix = leafNode.asLeaf(it.allocator).getKeyDepth(depth) } - nodes = append(nodes, curr) - return idxes, nodes + if mismatchIdx+depth == uint32(len(key)) || key[depth+mismatchIdx] < prefix[mismatchIdx] { + // mismatchIdx + depth == len(key) indicates that the seek key is a prefix of any leaf in the current node, implying that key < leafKey. + // If key < leafKey, set index to -1 means all the children are larger than the seek key. + it.idxes = append(it.idxes, -1) + } else { + // If key > prefix, set index to 256 means all the children are smaller than the seek key. + it.idxes = append(it.idxes, node256cap) + } + return } depth += min(mismatchIdx, node.prefixLen) } - nodes = append(nodes, curr) char := key.charAt(int(depth)) - idx, next := curr.findChild(&it.tree.allocator, char, !key.valid(int(depth))) + idx, next := curr.findChild(it.allocator, char, !key.valid(int(depth))) if next.addr.IsNull() { - nextIdx := 0 - switch curr.kind { - case typeNode4: - n4 := curr.asNode4(&it.tree.allocator) - for ; nextIdx < int(n4.nodeNum); nextIdx++ { - if n4.keys[nextIdx] >= char { - break - } - } - case typeNode16: - n16 := curr.asNode16(&it.tree.allocator) - nextIdx, _ = sort.Find(int(n16.nodeNum), func(i int) int { - if n16.keys[i] < char { - return 1 - } - return -1 - }) - case typeNode48: - n48 := curr.asNode48(&it.tree.allocator) - nextIdx = n48.nextPresentIdx(int(char)) - case typeNode256: - n256 := curr.asNode256(&it.tree.allocator) - nextIdx = n256.nextPresentIdx(int(char)) - } - idxes = append(idxes, nextIdx) - return idxes, nodes + nextIdx := seekToIdx(it.allocator, curr, char) + it.idxes = append(it.idxes, nextIdx) + return } - idxes = append(idxes, idx) + it.idxes = append(it.idxes, idx) curr = next depth++ } - return idxes, nodes } -type baseIter struct { - allocator *artAllocator - idxes []int - nodes []artNode +// seekToIdx finds the index where all nodes before it are less than the given character. +func seekToIdx(a *artAllocator, curr artNode, char byte) int { + var nextIdx int + switch curr.kind { + case typeNode4: + n4 := curr.asNode4(a) + for ; nextIdx < int(n4.nodeNum); nextIdx++ { + if n4.keys[nextIdx] >= char { + break + } + } + case typeNode16: + n16 := curr.asNode16(a) + nextIdx, _ = sort.Find(int(n16.nodeNum), func(i int) int { + if n16.keys[i] < char { + return 1 + } + return -1 + }) + case typeNode48: + n48 := curr.asNode48(a) + nextIdx = n48.nextPresentIdx(int(char)) + case typeNode256: + n256 := curr.asNode256(a) + nextIdx = n256.nextPresentIdx(int(char)) + default: + panic("invalid node type") + } + return nextIdx } // compare compares the path of nodes, return 1 if self > other, -1 if self < other, 0 if self == other @@ -318,197 +340,209 @@ func (it *baseIter) compare(other *baseIter) int { // next returns the next leaf node // it returns nullArtNode if there is no more leaf node func (it *baseIter) next() artNode { - depth := len(it.nodes) - 1 - curr := it.nodes[depth] - idx := it.idxes[depth] - switch curr.kind { - case typeNode4: - n4 := it.allocator.getNode4(curr.addr) - if idx == inplaceIndex { - idx = 0 // mark in-place leaf is visited - it.idxes[depth] = idx - if !n4.inplaceLeaf.addr.IsNull() { - return n4.inplaceLeaf + for { + depth := len(it.nodes) - 1 + curr := it.nodes[depth] + idx := it.idxes[depth] + var child *artNode + switch curr.kind { + case typeNode4: + n4 := it.allocator.getNode4(curr.addr) + if idx == inplaceIndex { + idx = 0 // mark in-place leaf is visited + it.idxes[depth] = idx + if !n4.inplaceLeaf.addr.IsNull() { + return n4.inplaceLeaf + } + } else if idx == node4cap { + break } - } else if idx == node4cap { - break - } - if idx < int(n4.nodeNum) { - it.idxes[depth] = idx - child := n4.children[idx] - if child.kind == typeLeaf { - it.idxes[depth]++ - return child + if idx >= 0 && idx < int(n4.nodeNum) { + it.idxes[depth] = idx + child = &n4.children[idx] + } else if idx >= int(n4.nodeNum) { + // idx >= n4.nodeNum means this node is drain, break to pop stack. + break + } else { + panicForInvalidIndex(idx) } - it.nodes = append(it.nodes, child) - it.idxes = append(it.idxes, inplaceIndex) - return it.next() - } - case typeNode16: - n16 := it.allocator.getNode16(curr.addr) - if idx == inplaceIndex { - idx = 0 // mark in-place leaf is visited - it.idxes[depth] = idx - if !n16.inplaceLeaf.addr.IsNull() { - return n16.inplaceLeaf + case typeNode16: + n16 := it.allocator.getNode16(curr.addr) + if idx == inplaceIndex { + idx = 0 // mark in-place leaf is visited + it.idxes[depth] = idx + if !n16.inplaceLeaf.addr.IsNull() { + return n16.inplaceLeaf + } + } else if idx == node16cap { + break } - } else if idx == node16cap { - break - } - if idx < int(n16.nodeNum) { - it.idxes[depth] = idx - child := n16.children[idx] - if child.kind == typeLeaf { - it.idxes[depth]++ - return child + if idx >= 0 && idx < int(n16.nodeNum) { + it.idxes[depth] = idx + child = &n16.children[idx] + } else if idx >= int(n16.nodeNum) { + // idx >= n16.nodeNum means this node is drain, break to pop stack. + break + } else { + panicForInvalidIndex(idx) } - it.nodes = append(it.nodes, child) - it.idxes = append(it.idxes, inplaceIndex) - return it.next() - } - case typeNode48: - n48 := it.allocator.getNode48(curr.addr) - if idx == inplaceIndex { - idx = 0 // mark in-place leaf is visited - it.idxes[depth] = idx - if !n48.inplaceLeaf.addr.IsNull() { - return n48.inplaceLeaf + case typeNode48: + n48 := it.allocator.getNode48(curr.addr) + if idx == inplaceIndex { + idx = 0 // mark in-place leaf is visited + it.idxes[depth] = idx + if !n48.inplaceLeaf.addr.IsNull() { + return n48.inplaceLeaf + } + } else if idx == node256cap { + break } - } else if idx == node256cap { - break - } - idx = n48.nextPresentIdx(idx) - if idx < node256cap { - it.idxes[depth] = idx - child := n48.children[n48.keys[idx]] - if child.kind == typeLeaf { - it.idxes[depth]++ - return child + idx = n48.nextPresentIdx(idx) + if idx >= 0 && idx < node256cap { + it.idxes[depth] = idx + child = &n48.children[n48.keys[idx]] + } else if idx == node256cap { + // idx == node256cap means this node is drain, break to pop stack. + break + } else { + panicForInvalidIndex(idx) } - it.nodes = append(it.nodes, child) - it.idxes = append(it.idxes, inplaceIndex) - return it.next() - } - case typeNode256: - n256 := it.allocator.getNode256(curr.addr) - if idx == inplaceIndex { - idx = 0 // mark in-place leaf is visited - it.idxes[depth] = idx - if !n256.inplaceLeaf.addr.IsNull() { - return n256.inplaceLeaf + case typeNode256: + n256 := it.allocator.getNode256(curr.addr) + if idx == inplaceIndex { + idx = 0 // mark in-place leaf is visited + it.idxes[depth] = idx + if !n256.inplaceLeaf.addr.IsNull() { + return n256.inplaceLeaf + } + } else if idx == node256cap { + break + } + idx = n256.nextPresentIdx(idx) + if idx >= 0 && idx < 256 { + it.idxes[depth] = idx + child = &n256.children[idx] + } else if idx == node256cap { + // idx == node256cap means this node is drain, break to pop stack. + break + } else { + panicForInvalidIndex(idx) } - } else if idx == 256 { - break + default: + panic("invalid node type") } - idx = n256.nextPresentIdx(idx) - if idx < 256 { - it.idxes[depth] = idx - child := n256.children[idx] + if child != nil { if child.kind == typeLeaf { it.idxes[depth]++ - return child + return *child } - it.nodes = append(it.nodes, child) + it.nodes = append(it.nodes, *child) it.idxes = append(it.idxes, inplaceIndex) - return it.next() + continue } + it.nodes = it.nodes[:depth] + it.idxes = it.idxes[:depth] + if depth == 0 { + return nullArtNode + } + it.idxes[depth-1]++ } - it.nodes = it.nodes[:depth] - it.idxes = it.idxes[:depth] - if depth == 0 { - return nullArtNode - } - it.idxes[depth-1]++ - return it.next() } func (it *baseIter) prev() artNode { - depth := len(it.nodes) - 1 - curr := it.nodes[depth] - idx := it.idxes[depth] - idx-- - switch curr.kind { - case typeNode4: - n4 := it.allocator.getNode4(curr.addr) - idx = min(idx, int(n4.nodeNum)-1) - if idx >= 0 { - it.idxes[depth] = idx - child := n4.children[idx] - if child.kind == typeLeaf { - return child - } - it.nodes = append(it.nodes, child) - it.idxes = append(it.idxes, node256cap) - return it.prev() - } else if idx == inplaceIndex { - it.idxes[depth] = idx - if !n4.inplaceLeaf.addr.IsNull() { - return n4.inplaceLeaf - } - } - case typeNode16: - n16 := it.allocator.getNode16(curr.addr) - idx = min(idx, int(n16.nodeNum)-1) - if idx >= 0 { - it.idxes[depth] = idx - child := n16.children[idx] - if child.kind == typeLeaf { - return child - } - it.nodes = append(it.nodes, child) - it.idxes = append(it.idxes, node256cap) - return it.prev() - } else if idx == inplaceIndex { - it.idxes[depth] = idx - if !n16.inplaceLeaf.addr.IsNull() { - return n16.inplaceLeaf - } - } - case typeNode48: - n48 := it.allocator.getNode48(curr.addr) - if idx >= 0 { - idx = n48.prevPresentIdx(idx) - } - if idx >= 0 { - it.idxes[depth] = idx - child := n48.children[n48.keys[idx]] - if child.kind == typeLeaf { - return child + for { + depth := len(it.nodes) - 1 + curr := it.nodes[depth] + idx := it.idxes[depth] + idx-- + if idx != notExistIndex { + var child *artNode + switch curr.kind { + case typeNode4: + n4 := it.allocator.getNode4(curr.addr) + idx = min(idx, int(n4.nodeNum)-1) + if idx >= 0 { + it.idxes[depth] = idx + child = &n4.children[idx] + } else if idx == inplaceIndex { + it.idxes[depth] = idx + if !n4.inplaceLeaf.addr.IsNull() { + return n4.inplaceLeaf + } + } else { + panicForInvalidIndex(idx) + } + case typeNode16: + n16 := it.allocator.getNode16(curr.addr) + idx = min(idx, int(n16.nodeNum)-1) + if idx >= 0 { + it.idxes[depth] = idx + child = &n16.children[idx] + } else if idx == inplaceIndex { + it.idxes[depth] = idx + if !n16.inplaceLeaf.addr.IsNull() { + return n16.inplaceLeaf + } + } else { + panicForInvalidIndex(idx) + } + case typeNode48: + n48 := it.allocator.getNode48(curr.addr) + if idx >= 0 && n48.present[idx>>n48s]&(1<<(idx%n48m)) == 0 { + // if idx >= 0 and n48.keys[idx] is not present, goto the previous present key. + // for idx < 0, we check inplaceLeaf later. + idx = n48.prevPresentIdx(idx) + } + if idx >= 0 { + it.idxes[depth] = idx + child = &n48.children[n48.keys[idx]] + } else if idx == inplaceIndex { + it.idxes[depth] = idx + if !n48.inplaceLeaf.addr.IsNull() { + return n48.inplaceLeaf + } + } else { + panicForInvalidIndex(idx) + } + case typeNode256: + n256 := it.allocator.getNode256(curr.addr) + if idx >= 0 && n256.present[idx>>n48s]&(1<<(idx%n48m)) == 0 { + // if idx >= 0 and n256.keys[idx] is not present, goto the previous present key. + // for idx < 0, we check inplaceLeaf later. + idx = n256.prevPresentIdx(idx) + } + if idx >= 0 { + it.idxes[depth] = idx + child = &n256.children[idx] + } else if idx == inplaceIndex { + it.idxes[depth] = idx + if !n256.inplaceLeaf.addr.IsNull() { + return n256.inplaceLeaf + } + } else { + panicForInvalidIndex(idx) + } + default: + panic("invalid node type") } - it.nodes = append(it.nodes, child) - it.idxes = append(it.idxes, node256cap) - return it.prev() - } else if idx == inplaceIndex { - it.idxes[depth] = idx - if !n48.inplaceLeaf.addr.IsNull() { - return n48.inplaceLeaf + if child != nil { + if child.kind == typeLeaf { + return *child + } + it.nodes = append(it.nodes, *child) + it.idxes = append(it.idxes, node256cap) + continue } } - case typeNode256: - n256 := it.allocator.getNode256(curr.addr) - if idx >= 0 { - idx = n256.prevPresentIdx(idx) - } - if idx >= 0 { - it.idxes[depth] = idx - child := n256.children[idx] - if child.kind == typeLeaf { - return child - } - it.nodes = append(it.nodes, child) - it.idxes = append(it.idxes, node256cap) - return it.prev() - } else if idx == -1 { - it.idxes[depth] = idx - if !n256.inplaceLeaf.addr.IsNull() { - return n256.inplaceLeaf - } + it.nodes = it.nodes[:depth] + it.idxes = it.idxes[:depth] + if depth == 0 { + return nullArtNode } } - it.nodes = it.nodes[:depth] - it.idxes = it.idxes[:depth] - if depth == 0 { - return nullArtNode - } - return it.prev() +} + +func panicForInvalidIndex(idx int) { + msg := fmt.Sprintf("ART iterator meets an invalid index %d", idx) + panic(msg) } diff --git a/internal/unionstore/art/art_iterator_test.go b/internal/unionstore/art/art_iterator_test.go index 7dcd5046e..b21bcce4a 100644 --- a/internal/unionstore/art/art_iterator_test.go +++ b/internal/unionstore/art/art_iterator_test.go @@ -86,7 +86,8 @@ func TestIterSeekLeaf(t *testing.T) { key := []byte{byte(i)} it, err := tree.Iter(key, nil) require.Nil(t, err) - idxes, nodes := it.seek(key) + it.inner.seek(tree.root, key) + idxes, nodes := it.inner.idxes, it.inner.nodes require.Greater(t, len(idxes), 0) require.Equal(t, len(idxes), len(nodes)) leafNode := nodes[len(nodes)-1].at(&tree.allocator, idxes[len(idxes)-1]) @@ -198,10 +199,10 @@ func TestSeekInExistNode(t *testing.T) { } require.Equal(t, tree.root.kind, kind) for i := 0; i < cnt-1; i++ { - it := &Iterator{ - tree: tree, - } - idxes, _ := it.seek([]byte{byte(2*i + 1)}) + helper := new(baseIter) + helper.allocator = &tree.allocator + helper.seek(tree.root, []byte{byte(2*i + 1)}) + idxes := helper.idxes expect := 0 switch kind { case typeNode4, typeNode16: @@ -222,3 +223,147 @@ func TestSeekInExistNode(t *testing.T) { check(New(), typeNode48) check(New(), typeNode256) } + +func TestSeekToIdx(t *testing.T) { + tree := New() + check := func(kind nodeKind) { + var addr arena.MemdbArenaAddr + switch kind { + case typeNode4: + addr, _ = tree.allocator.allocNode4() + case typeNode16: + addr, _ = tree.allocator.allocNode16() + case typeNode48: + addr, _ = tree.allocator.allocNode48() + case typeNode256: + addr, _ = tree.allocator.allocNode256() + } + node := artNode{kind: kind, addr: addr} + lfAddr, _ := tree.allocator.allocLeaf([]byte{10}) + lfNode := artNode{kind: typeLeaf, addr: lfAddr} + node.addChild(&tree.allocator, 10, false, lfNode) + + var ( + existIdx int + maxIdx int + ) + switch kind { + case typeNode4, typeNode16: + existIdx = 0 + maxIdx = 1 + case typeNode48, typeNode256: + existIdx = 10 + maxIdx = node256cap + } + + nextIdx := seekToIdx(&tree.allocator, node, 1) + require.Equal(t, existIdx, nextIdx) + nextIdx = seekToIdx(&tree.allocator, node, 11) + require.Equal(t, maxIdx, nextIdx) + } + + check(typeNode4) + check(typeNode16) + check(typeNode48) + check(typeNode256) +} + +func TestIterateHandle(t *testing.T) { + tree := New() + h := tree.Staging() + require.Nil(t, tree.Set([]byte{1}, []byte{2})) + it := tree.IterWithFlags(nil, nil) + handle := it.Handle() + + require.Equal(t, tree.GetKeyByHandle(handle), []byte{1}) + val, valid := tree.GetValueByHandle(handle) + require.True(t, valid) + require.Equal(t, val, []byte{2}) + + tree.Cleanup(h) + require.Equal(t, tree.GetKeyByHandle(handle), []byte{1}) + _, valid = tree.GetValueByHandle(handle) + require.False(t, valid) +} + +func TestSeekPrefixMismatch(t *testing.T) { + tree := New() + + shortPrefix := make([]byte, 10) + longPrefix := make([]byte, 30) + for i := 0; i < len(shortPrefix); i++ { + shortPrefix[i] = 1 + } + for i := 0; i < len(longPrefix); i++ { + longPrefix[i] = 2 + } + + keys := [][]byte{ + append(shortPrefix, 1), + append(shortPrefix, 2), + append(longPrefix, 3), + append(longPrefix, 4), + } + for _, key := range keys { + require.Nil(t, tree.Set(key, key)) + } + + it, err := tree.Iter(append(shortPrefix[:len(shortPrefix)-1], 0), append(longPrefix[:len(longPrefix)-1], 3)) + require.Nil(t, err) + for _, key := range keys { + require.True(t, it.Valid()) + require.Equal(t, it.Key(), key) + require.Equal(t, it.Value(), key) + require.Nil(t, it.Next()) + } + require.False(t, it.Valid()) +} + +func TestIterPositionCompare(t *testing.T) { + compare := func(idx1, idx2 []int) int { + helper1, helper2 := new(baseIter), new(baseIter) + helper1.idxes, helper2.idxes = idx1, idx2 + return helper1.compare(helper2) + } + + require.Equal(t, compare([]int{1, 2, 3}, []int{1, 2, 3}), 0) + require.Equal(t, compare([]int{1, 2, 2}, []int{1, 2, 3}), -1) + require.Equal(t, compare([]int{1, 2, 4}, []int{1, 2, 3}), 1) + require.Equal(t, compare([]int{1, 2, 3}, []int{1, 2}), 1) + require.Equal(t, compare([]int{1, 2}, []int{1, 2, 3}), -1) +} + +func TestIterSeekNoResult(t *testing.T) { + check := func(kind nodeKind) { + var child int + switch kind { + case typeNode4: + child = 0 + case typeNode16: + child = node4cap + 1 + case typeNode48: + child = node16cap + 1 + case typeNode256: + child = node48cap + 1 + } + // let the node expand to target kind + tree := New() + for i := 0; i < child; i++ { + require.Nil(t, tree.Set([]byte{1, byte(i)}, []byte{1, byte(i)})) + } + + require.Nil(t, tree.Set([]byte{1, 100}, []byte{1, 100})) + require.Nil(t, tree.Set([]byte{1, 200}, []byte{1, 200})) + it, err := tree.Iter([]byte{1, 100, 1}, []byte{1, 200}) + require.Nil(t, err) + require.False(t, it.Valid()) + it, err = tree.IterReverse([]byte{1, 200}, []byte{1, 100, 1}) + require.Nil(t, err) + require.False(t, it.Valid()) + } + + check(typeNode4) + check(typeNode16) + check(typeNode48) + check(typeNode256) +} diff --git a/internal/unionstore/art/art_node.go b/internal/unionstore/art/art_node.go index 75c8d1110..c3e533ef4 100644 --- a/internal/unionstore/art/art_node.go +++ b/internal/unionstore/art/art_node.go @@ -18,6 +18,7 @@ import ( "bytes" "math" "math/bits" + "runtime" "sort" "testing" "unsafe" @@ -38,11 +39,14 @@ const ( ) const ( - maxPrefixLen = 20 - node4cap = 4 - node16cap = 16 - node48cap = 48 - node256cap = 256 + maxPrefixLen = 20 + node4cap = 4 + node16cap = 16 + node48cap = 48 + node256cap = 256 + // inplaceIndex is a special index to indicate the index of an in-place leaf in a node, + // the in-place leaf has the same key with its parent node and doesn't occupy the quota of the node. + // the other valid index of a node is [0, nodeNum), all the other leaves in the node have larger key than the in-place leaf. inplaceIndex = -1 notExistIndex = -2 ) @@ -117,6 +121,8 @@ type node256 struct { children [node256cap]artNode } +const MaxKeyLen = math.MaxUint16 + type artLeaf struct { vAddr arena.MemdbArenaAddr klen uint16 @@ -305,14 +311,7 @@ func (n *nodeBase) setPrefix(key artKey, prefixLen uint32) { // Node if the nodeBase.prefixLen > maxPrefixLen and the returned mismatch index equals to maxPrefixLen, // key[maxPrefixLen:] will not be checked by this function. func (n *nodeBase) match(key artKey, depth uint32) uint32 /* mismatch index */ { - idx := uint32(0) - limit := min(min(n.prefixLen, maxPrefixLen), uint32(len(key))-depth) - for ; idx < limit; idx++ { - if n.prefix[idx] != key[idx+depth] { - return idx - } - } - return idx + return longestCommonPrefix(key[depth:], n.prefix[:min(n.prefixLen, maxPrefixLen)], 0) } // matchDeep returns the mismatch index of the key and the node's prefix. @@ -347,13 +346,53 @@ func (an *artNode) asNode256(a *artAllocator) *node256 { return a.getNode256(an.addr) } +// for amd64 and arm64 architectures, we use the chunk comparison to speed up finding the longest common prefix. +const enableChunkComparison = runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" + // longestCommonPrefix returns the length of the longest common prefix of two keys. // the LCP is calculated from the given depth, you need to guarantee l1Key[:depth] equals to l2Key[:depth] before calling this function. func longestCommonPrefix(l1Key, l2Key artKey, depth uint32) uint32 { - idx, limit := depth, min(uint32(len(l1Key)), uint32(len(l2Key))) - // TODO: possible optimization - // Compare the key by loop can be very slow if the final LCP is large. - // Maybe optimize it by comparing the key in chunks if the limit exceeds certain threshold. + if enableChunkComparison { + return longestCommonPrefixByChunk(l1Key, l2Key, depth) + } + // For other architectures, we use the byte-by-byte comparison. + idx, limit := depth, uint32(min(len(l1Key), len(l2Key))) + for ; idx < limit; idx++ { + if l1Key[idx] != l2Key[idx] { + break + } + } + return idx - depth +} + +// longestCommonPrefixByChunk compares two keys by 8 bytes at a time, which is significantly faster when the keys are long. +// Note this function only support architecture which is under little-endian and can read memory across unaligned address. +func longestCommonPrefixByChunk(l1Key, l2Key artKey, depth uint32) uint32 { + idx, limit := depth, uint32(min(len(l1Key), len(l2Key))) + + if idx == limit { + return 0 + } + + p1 := unsafe.Pointer(&l1Key[depth]) + p2 := unsafe.Pointer(&l2Key[depth]) + + // Compare 8 bytes at a time + remaining := limit - depth + for remaining >= 8 { + if *(*uint64)(p1) != *(*uint64)(p2) { + // Find first different byte using trailing zeros + xor := *(*uint64)(p1) ^ *(*uint64)(p2) + return limit - remaining + uint32(bits.TrailingZeros64(xor)>>3) - depth + } + + p1 = unsafe.Add(p1, 8) + p2 = unsafe.Add(p2, 8) + remaining -= 8 + } + + // Compare rest bytes + idx = limit - remaining for ; idx < limit; idx++ { if l1Key[idx] != l2Key[idx] { break @@ -394,8 +433,8 @@ func minimum(a *artAllocator, an artNode) artNode { } idx := n256.nextPresentIdx(0) an = n256.children[idx] - case typeInvalid: - return nullArtNode + default: + panic("invalid node kind") } } } @@ -526,52 +565,52 @@ func (an *artNode) addChild(a *artAllocator, c byte, inplace bool, child artNode case typeNode256: return an.addChild256(a, c, child) } - return false + panic("add child failed") } func (an *artNode) addChild4(a *artAllocator, c byte, child artNode) bool { - node := an.asNode4(a) + n4 := an.asNode4(a) - if node.nodeNum >= node4cap { - an.grow(a) - an.addChild(a, c, false, child) + if n4.nodeNum >= node4cap { + an.growNode4(n4, a) + an.addChild16(a, c, child) return true } i := uint8(0) - for ; i < node.nodeNum; i++ { - if c <= node.keys[i] { - if testing.Testing() && c == node.keys[i] { + for ; i < n4.nodeNum; i++ { + if c <= n4.keys[i] { + if testing.Testing() && c == n4.keys[i] { panic("key already exists") } break } } - if i < node.nodeNum { - copy(node.keys[i+1:node.nodeNum+1], node.keys[i:node.nodeNum]) - copy(node.children[i+1:node.nodeNum+1], node.children[i:node.nodeNum]) + if i < n4.nodeNum { + copy(n4.keys[i+1:n4.nodeNum+1], n4.keys[i:n4.nodeNum]) + copy(n4.children[i+1:n4.nodeNum+1], n4.children[i:n4.nodeNum]) } - node.keys[i] = c - node.children[i] = child - node.nodeNum++ + n4.keys[i] = c + n4.children[i] = child + n4.nodeNum++ return false } func (an *artNode) addChild16(a *artAllocator, c byte, child artNode) bool { - node := an.asNode16(a) + n16 := an.asNode16(a) - if node.nodeNum >= node16cap { - an.grow(a) - an.addChild(a, c, false, child) + if n16.nodeNum >= node16cap { + an.growNode16(n16, a) + an.addChild48(a, c, child) return true } - i, found := sort.Find(int(node.nodeNum), func(i int) int { - if node.keys[i] < c { + i, found := sort.Find(int(n16.nodeNum), func(i int) int { + if n16.keys[i] < c { return 1 } - if node.keys[i] == c { + if n16.keys[i] == c { return 0 } return -1 @@ -581,47 +620,47 @@ func (an *artNode) addChild16(a *artAllocator, c byte, child artNode) bool { panic("key already exists") } - if i < int(node.nodeNum) { - copy(node.keys[i+1:node.nodeNum+1], node.keys[i:node.nodeNum]) - copy(node.children[i+1:node.nodeNum+1], node.children[i:node.nodeNum]) + if i < int(n16.nodeNum) { + copy(n16.keys[i+1:n16.nodeNum+1], n16.keys[i:n16.nodeNum]) + copy(n16.children[i+1:n16.nodeNum+1], n16.children[i:n16.nodeNum]) } - node.keys[i] = c - node.children[i] = child - node.nodeNum++ + n16.keys[i] = c + n16.children[i] = child + n16.nodeNum++ return false } func (an *artNode) addChild48(a *artAllocator, c byte, child artNode) bool { - node := an.asNode48(a) + n48 := an.asNode48(a) - if node.nodeNum >= node48cap { - an.grow(a) - an.addChild(a, c, false, child) + if n48.nodeNum >= node48cap { + an.growNode48(n48, a) + an.addChild256(a, c, child) return true } - if testing.Testing() && node.present[c>>n48s]&(1<<(c%n48m)) != 0 { + if testing.Testing() && n48.present[c>>n48s]&(1<<(c%n48m)) != 0 { panic("key already exists") } - node.keys[c] = node.nodeNum - node.present[c>>n48s] |= 1 << (c % n48m) - node.children[node.nodeNum] = child - node.nodeNum++ + n48.keys[c] = n48.nodeNum + n48.present[c>>n48s] |= 1 << (c % n48m) + n48.children[n48.nodeNum] = child + n48.nodeNum++ return false } func (an *artNode) addChild256(a *artAllocator, c byte, child artNode) bool { - node := an.asNode256(a) + n256 := an.asNode256(a) - if testing.Testing() && node.present[c>>n48s]&(1<<(c%n48m)) != 0 { + if testing.Testing() && n256.present[c>>n48s]&(1<<(c%n48m)) != 0 { panic("key already exists") } - node.present[c>>n48s] |= 1 << (c % n48m) - node.children[c] = child - node.nodeNum++ + n256.present[c>>n48s] |= 1 << (c % n48m) + n256.children[c] = child + n256.nodeNum++ return false } @@ -632,49 +671,47 @@ func (n *nodeBase) copyMeta(src *nodeBase) { copy(n.prefix[:], src.prefix[:]) } -func (an *artNode) grow(a *artAllocator) { - switch an.kind { - case typeNode4: - n4 := an.asNode4(a) - newAddr, n16 := a.allocNode16() - n16.copyMeta(&n4.nodeBase) +func (an *artNode) growNode4(n4 *node4, a *artAllocator) { + newAddr, n16 := a.allocNode16() + n16.copyMeta(&n4.nodeBase) - copy(n16.keys[:], n4.keys[:]) - copy(n16.children[:], n4.children[:]) + copy(n16.keys[:], n4.keys[:]) + copy(n16.children[:], n4.children[:]) - // replace addr and free node4 - a.freeNode4(an.addr) - an.kind = typeNode16 - an.addr = newAddr - case typeNode16: - n16 := an.asNode16(a) - newAddr, n48 := a.allocNode48() - n48.copyMeta(&n16.nodeBase) - - for i := uint8(0); i < n16.nodeBase.nodeNum; i++ { - ch := n16.keys[i] - n48.keys[ch] = i - n48.present[ch>>n48s] |= 1 << (ch % n48m) - n48.children[i] = n16.children[i] - } + // replace addr and free node4 + a.freeNode4(an.addr) + an.kind = typeNode16 + an.addr = newAddr +} - // replace addr and free node16 - a.freeNode16(an.addr) - an.kind = typeNode48 - an.addr = newAddr - case typeNode48: - n48 := an.asNode48(a) - newAddr, n256 := a.allocNode256() - n256.copyMeta(&n48.nodeBase) +func (an *artNode) growNode16(n16 *node16, a *artAllocator) { + newAddr, n48 := a.allocNode48() + n48.copyMeta(&n16.nodeBase) - for i := n48.nextPresentIdx(0); i < node256cap; i = n48.nextPresentIdx(i + 1) { - n256.children[i] = n48.children[n48.keys[i]] - } - copy(n256.present[:], n48.present[:]) + for i := uint8(0); i < n16.nodeBase.nodeNum; i++ { + ch := n16.keys[i] + n48.keys[ch] = i + n48.present[ch>>n48s] |= 1 << (ch % n48m) + n48.children[i] = n16.children[i] + } + + // replace addr and free node16 + a.freeNode16(an.addr) + an.kind = typeNode48 + an.addr = newAddr +} - // replace addr and free node48 - a.freeNode48(an.addr) - an.kind = typeNode256 - an.addr = newAddr +func (an *artNode) growNode48(n48 *node48, a *artAllocator) { + newAddr, n256 := a.allocNode256() + n256.copyMeta(&n48.nodeBase) + + for i := n48.nextPresentIdx(0); i < node256cap; i = n48.nextPresentIdx(i + 1) { + n256.children[i] = n48.children[n48.keys[i]] } + copy(n256.present[:], n48.present[:]) + + // replace addr and free node48 + a.freeNode48(an.addr) + an.kind = typeNode256 + an.addr = newAddr } diff --git a/internal/unionstore/art/art_node_test.go b/internal/unionstore/art/art_node_test.go index da4ad5da0..641689f25 100644 --- a/internal/unionstore/art/art_node_test.go +++ b/internal/unionstore/art/art_node_test.go @@ -265,3 +265,164 @@ func TestLCP(t *testing.T) { require.Equal(t, uint32(5-i), longestCommonPrefix(k1, k2, uint32(i))) } } + +func TestNodeAddChild(t *testing.T) { + var allocator artAllocator + allocator.init() + + check := func(n artNode) { + require.Equal(t, uint8(0), n.asNode(&allocator).nodeNum) + lfAddr, _ := allocator.allocLeaf([]byte{1, 2, 3, 4, 5}) + lfNode := artNode{kind: typeLeaf, addr: lfAddr} + n.addChild(&allocator, 1, false, lfNode) + require.Equal(t, uint8(1), n.asNode(&allocator).nodeNum) + require.Panics(t, func() { + // addChild should panic if the key is existed. + n.addChild(&allocator, 1, false, lfNode) + }) + n.addChild(&allocator, 2, false, lfNode) + require.Equal(t, uint8(2), n.asNode(&allocator).nodeNum) + // inplace leaf won't be counted in nodeNum + n.addChild(&allocator, 0, true, lfNode) + require.Equal(t, uint8(2), n.asNode(&allocator).nodeNum) + } + + addr, _ := allocator.allocNode4() + check(artNode{kind: typeNode4, addr: addr}) + addr, _ = allocator.allocNode16() + check(artNode{kind: typeNode16, addr: addr}) + addr, _ = allocator.allocNode48() + check(artNode{kind: typeNode48, addr: addr}) + addr, _ = allocator.allocNode256() + check(artNode{kind: typeNode256, addr: addr}) +} + +func TestNodeGrow(t *testing.T) { + var allocator artAllocator + allocator.init() + + check := func(n artNode) { + capacities := map[nodeKind]int{ + typeNode4: node4cap, + typeNode16: node16cap, + typeNode48: node48cap, + } + growTypes := map[nodeKind]nodeKind{ + typeNode4: typeNode16, + typeNode16: typeNode48, + typeNode48: typeNode256, + } + + capacity, ok := capacities[n.kind] + require.True(t, ok) + beforeKind := n.kind + afterKind, ok := growTypes[n.kind] + require.True(t, ok) + + for i := 0; i < capacity; i++ { + lfAddr, _ := allocator.allocLeaf([]byte{byte(i)}) + lfNode := artNode{kind: typeLeaf, addr: lfAddr} + n.addChild(&allocator, byte(i), false, lfNode) + require.Equal(t, beforeKind, n.kind) + } + lfAddr, _ := allocator.allocLeaf([]byte{byte(capacity)}) + lfNode := artNode{kind: typeLeaf, addr: lfAddr} + n.addChild(&allocator, byte(capacity), false, lfNode) + require.Equal(t, afterKind, n.kind) + } + + addr, _ := allocator.allocNode4() + check(artNode{kind: typeNode4, addr: addr}) + addr, _ = allocator.allocNode16() + check(artNode{kind: typeNode16, addr: addr}) + addr, _ = allocator.allocNode48() + check(artNode{kind: typeNode48, addr: addr}) +} + +func TestReplaceChild(t *testing.T) { + var allocator artAllocator + allocator.init() + + check := func(n artNode) { + require.Equal(t, uint8(0), n.asNode(&allocator).nodeNum) + lfAddr, _ := allocator.allocLeaf([]byte{1, 2, 3, 4, 5}) + lfNode := artNode{kind: typeLeaf, addr: lfAddr} + n.addChild(&allocator, 1, false, lfNode) + require.Equal(t, uint8(1), n.asNode(&allocator).nodeNum) + newLfAddr, _ := allocator.allocLeaf([]byte{1, 2, 3, 4, 4}) + newLfNode := artNode{kind: typeLeaf, addr: newLfAddr} + require.Panics(t, func() { + // replaceChild should panic if the key is not existed. + n.replaceChild(&allocator, 2, newLfNode) + }) + n.replaceChild(&allocator, 1, newLfNode) + require.Equal(t, uint8(1), n.asNode(&allocator).nodeNum) + _, childLf := n.findChild(&allocator, 1, false) + require.NotEqual(t, childLf.addr, lfAddr) + require.Equal(t, childLf.addr, newLfAddr) + } + + addr, _ := allocator.allocNode4() + check(artNode{kind: typeNode4, addr: addr}) + addr, _ = allocator.allocNode16() + check(artNode{kind: typeNode16, addr: addr}) + addr, _ = allocator.allocNode48() + check(artNode{kind: typeNode48, addr: addr}) + addr, _ = allocator.allocNode256() + check(artNode{kind: typeNode256, addr: addr}) + +} + +func TestMinimumNode(t *testing.T) { + var allocator artAllocator + allocator.init() + + check := func(kind nodeKind) { + var addr arena.MemdbArenaAddr + + switch kind { + case typeNode4: + addr, _ = allocator.allocNode4() + case typeNode16: + addr, _ = allocator.allocNode16() + case typeNode48: + addr, _ = allocator.allocNode48() + case typeNode256: + addr, _ = allocator.allocNode256() + } + + node := artNode{kind: kind, addr: addr} + + for _, char := range []byte{255, 127, 63, 0} { + lfAddr, _ := allocator.allocLeaf([]byte{char}) + lfNode := artNode{kind: typeLeaf, addr: lfAddr} + node.addChild(&allocator, char, false, lfNode) + minNode := minimum(&allocator, node) + require.Equal(t, typeLeaf, minNode.kind) + require.Equal(t, lfAddr, minNode.addr) + } + + lfAddr, _ := allocator.allocLeaf([]byte{0}) + lfNode := artNode{kind: typeLeaf, addr: lfAddr} + node.addChild(&allocator, 0, true, lfNode) + minNode := minimum(&allocator, node) + require.Equal(t, typeLeaf, minNode.kind) + require.Equal(t, lfAddr, minNode.addr) + } + + check(typeNode4) + check(typeNode16) + check(typeNode48) + check(typeNode256) +} + +func TestKey2Chunk(t *testing.T) { + key := artKey([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) + + for i := 0; i < len(key); i++ { + diffKey := make(artKey, len(key)) + copy(diffKey, key) + diffKey[i] = 255 + require.Equal(t, uint32(i), longestCommonPrefix(key, diffKey, 0)) + } +} diff --git a/internal/unionstore/art/art_snapshot.go b/internal/unionstore/art/art_snapshot.go index 1aab70a4d..6f470dda4 100644 --- a/internal/unionstore/art/art_snapshot.go +++ b/internal/unionstore/art/art_snapshot.go @@ -74,7 +74,7 @@ type SnapGetter struct { } func (snap *SnapGetter) Get(ctx context.Context, key []byte) ([]byte, error) { - addr, lf := snap.tree.search(key) + addr, lf := snap.tree.traverse(key, false) if addr.IsNull() { return nil, tikverr.ErrNotExist } diff --git a/internal/unionstore/art/art_test.go b/internal/unionstore/art/art_test.go index f5dc0b7d6..b5f4148c7 100644 --- a/internal/unionstore/art/art_test.go +++ b/internal/unionstore/art/art_test.go @@ -109,21 +109,21 @@ func TestFlag(t *testing.T) { require.Nil(t, err) require.True(t, flags.HasLocked()) // iterate can also see the flags - //it, err := tree.Iter(nil, nil) - //require.Nil(t, err) - //require.True(t, it.Valid()) - //require.Equal(t, it.Key(), []byte{0}) - //require.Equal(t, it.Value(), []byte{0}) - //require.True(t, it.Flags().HasPresumeKeyNotExists()) - //require.False(t, it.Flags().HasLocked()) - //require.Nil(t, it.Next()) - //require.True(t, it.Valid()) - //require.Equal(t, it.Key(), []byte{1}) - //require.Equal(t, it.Value(), []byte{1}) - //require.True(t, it.Flags().HasLocked()) - //require.False(t, it.Flags().HasPresumeKeyNotExists()) - //require.Nil(t, it.Next()) - //require.False(t, it.Valid()) + it, err := tree.Iter(nil, nil) + require.Nil(t, err) + require.True(t, it.Valid()) + require.Equal(t, it.Key(), []byte{0}) + require.Equal(t, it.Value(), []byte{0}) + require.True(t, it.Flags().HasPresumeKeyNotExists()) + require.False(t, it.Flags().HasLocked()) + require.Nil(t, it.Next()) + require.True(t, it.Valid()) + require.Equal(t, it.Key(), []byte{1}) + require.Equal(t, it.Value(), []byte{1}) + require.True(t, it.Flags().HasLocked()) + require.False(t, it.Flags().HasPresumeKeyNotExists()) + require.Nil(t, it.Next()) + require.False(t, it.Valid()) } func TestLongPrefix1(t *testing.T) { @@ -184,6 +184,14 @@ func TestFlagOnlyKey(t *testing.T) { require.Error(t, err) } +func TestSearchPrefixMisatch(t *testing.T) { + tree := New() + tree.Set([]byte{1, 1, 1, 1, 1, 1}, []byte{1, 1, 1, 1, 1, 1}) + tree.Set([]byte{1, 1, 1, 1, 1, 2}, []byte{1, 1, 1, 1, 1, 2}) + _, err := tree.Get([]byte{1, 1, 1, 3, 1, 1}) + require.NotNil(t, err) +} + func TestSearchOptimisticMismatch(t *testing.T) { tree := New() prefix := make([]byte, 22) @@ -230,3 +238,27 @@ func TestExpansion(t *testing.T) { require.Equal(t, n4.keys[:2], []byte{1, 255}) require.Equal(t, n4.children[1].asLeaf(&tree.allocator).GetKey(), append(prefix, []byte{1, 255, 2}...)) } + +func TestDiscardValues(t *testing.T) { + tree := New() + tree.Set([]byte{1}, []byte{2}) + it := tree.IterWithFlags(nil, nil) + handle := it.Handle() + val, exist := tree.GetValueByHandle(handle) + require.Equal(t, val, []byte{2}) + require.True(t, exist) + key := tree.GetKeyByHandle(handle) + require.Equal(t, key, []byte{1}) + + tree.DiscardValues() + _, exist = tree.GetValueByHandle(handle) + require.False(t, exist) + key = tree.GetKeyByHandle(handle) + require.Equal(t, key, []byte{1}) + require.Panics(t, func() { + tree.Get([]byte{3}) + }) + require.Panics(t, func() { + tree.Set([]byte{3}, []byte{4}) + }) +} diff --git a/internal/unionstore/memdb_bench_test.go b/internal/unionstore/memdb_bench_test.go index e6477d2d0..8a2c3e5d4 100644 --- a/internal/unionstore/memdb_bench_test.go +++ b/internal/unionstore/memdb_bench_test.go @@ -38,6 +38,7 @@ import ( "context" "encoding/binary" "math/rand" + "slices" "testing" ) @@ -250,3 +251,23 @@ func BenchmarkMemBufferCache(b *testing.B) { b.Run("RBT", func(b *testing.B) { fn(b, newRbtDBWithContext()) }) b.Run("ART", func(b *testing.B) { fn(b, newArtDBWithContext()) }) } + +func BenchmarkMemBufferSetGetLongKey(b *testing.B) { + fn := func(b *testing.B, buffer MemBuffer) { + keys := make([][]byte, b.N) + for i := 0; i < b.N; i++ { + keys[i] = make([]byte, 1024) + binary.BigEndian.PutUint64(keys[i], uint64(i)) + slices.Reverse(keys[i]) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + buffer.Set(keys[i], keys[i]) + } + for i := 0; i < b.N; i++ { + buffer.Get(context.Background(), keys[i]) + } + } + b.Run("RBT", func(b *testing.B) { fn(b, newRbtDBWithContext()) }) + b.Run("ART", func(b *testing.B) { fn(b, newArtDBWithContext()) }) +} diff --git a/internal/unionstore/memdb_test.go b/internal/unionstore/memdb_test.go index 2786de13c..6003059ae 100644 --- a/internal/unionstore/memdb_test.go +++ b/internal/unionstore/memdb_test.go @@ -37,14 +37,19 @@ package unionstore import ( + "bytes" "context" "encoding/binary" "fmt" + "math" + "strconv" + "strings" "testing" leveldb "github.com/pingcap/goleveldb/leveldb/memdb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + tikverr "github.com/tikv/client-go/v2/error" "github.com/tikv/client-go/v2/kv" ) @@ -392,6 +397,8 @@ func testReset(t *testing.T, db interface { db.Reset() _, err := db.Get(context.Background(), []byte{0, 0, 0, 0}) assert.NotNil(err) + _, err = db.GetFlags([]byte{0, 0, 0, 0}) + assert.NotNil(err) it, _ := db.Iter(nil, nil) assert.False(it.Valid()) } @@ -495,6 +502,9 @@ func testDirty(t *testing.T, createDb func() MemBuffer) { func TestFlags(t *testing.T) { testFlags(t, newRbtDBWithContext(), func(db MemBuffer) Iterator { return db.(*rbtDBWithContext).IterWithFlags(nil, nil) }) testFlags(t, newArtDBWithContext(), func(db MemBuffer) Iterator { return db.(*artDBWithContext).IterWithFlags(nil, nil) }) + + testFlags(t, newRbtDBWithContext(), func(db MemBuffer) Iterator { return db.(*rbtDBWithContext).IterReverseWithFlags(nil) }) + testFlags(t, newArtDBWithContext(), func(db MemBuffer) Iterator { return db.(*artDBWithContext).IterReverseWithFlags(nil) }) } func testFlags(t *testing.T, db MemBuffer, iterWithFlags func(db MemBuffer) Iterator) { @@ -538,6 +548,10 @@ func testFlags(t *testing.T, db MemBuffer, iterWithFlags func(db MemBuffer) Iter for ; it.Valid(); it.Next() { k := binary.BigEndian.Uint32(it.Key()) assert.True(k%2 == 0) + hasValue := it.(interface { + HasValue() bool + }).HasValue() + assert.False(hasValue) } for i := uint32(0); i < cnt; i++ { @@ -1046,6 +1060,49 @@ func testSnapshotGetIter(t *testing.T, db MemBuffer) { assert.Equal(reverseIter.Key(), []byte{byte(1)}) assert.Equal(reverseIter.Value(), []byte{byte(50)}) } + + db.(interface { + Reset() + }).Reset() + db.UpdateFlags([]byte{255}, kv.SetPresumeKeyNotExists) + // set (2, 2) ... (100, 100) in snapshot + for i := 1; i < 50; i++ { + db.Set([]byte{byte(2 * i)}, []byte{byte(2 * i)}) + } + h := db.Staging() + // set (0, 0) (1, 2) (2, 4) ... (100, 200) in staging + for i := 0; i < 100; i++ { + db.Set([]byte{byte(i)}, []byte{byte(2 * i)}) + } + + snapGetter := db.SnapshotGetter() + v, err := snapGetter.Get(context.Background(), []byte{byte(2)}) + assert.Nil(err) + assert.Equal(v, []byte{byte(2)}) + _, err = snapGetter.Get(context.Background(), []byte{byte(1)}) + assert.NotNil(err) + _, err = snapGetter.Get(context.Background(), []byte{byte(254)}) + assert.NotNil(err) + _, err = snapGetter.Get(context.Background(), []byte{byte(255)}) + assert.NotNil(err) + + it := db.SnapshotIter(nil, nil) + // snapshot iter only see the snapshot data + for i := 1; i < 50; i++ { + assert.Equal(it.Key(), []byte{byte(2 * i)}) + assert.Equal(it.Value(), []byte{byte(2 * i)}) + assert.True(it.Valid()) + it.Next() + } + it = db.SnapshotIterReverse(nil, nil) + for i := 49; i >= 1; i-- { + assert.Equal(it.Key(), []byte{byte(2 * i)}) + assert.Equal(it.Value(), []byte{byte(2 * i)}) + assert.True(it.Valid()) + it.Next() + } + assert.False(it.Valid()) + db.Release(h) } func TestCleanupKeepPersistentFlag(t *testing.T) { @@ -1172,3 +1229,101 @@ func testMemBufferCache(t *testing.T, buffer MemBuffer) { assert.Equal(v, []byte{2, 2}) }) } + +func TestMemDBLeafFragmentation(t *testing.T) { + // RBT cannot pass the leaf fragmentation test. + testMemDBLeafFragmentation(t, newArtDBWithContext()) +} + +func testMemDBLeafFragmentation(t *testing.T, buffer MemBuffer) { + assert := assert.New(t) + h := buffer.Staging() + mem := buffer.Mem() + for i := 0; i < 10; i++ { + for k := 0; k < 100; k++ { + buffer.Set([]byte(strings.Repeat(strconv.Itoa(k), 256)), []byte("value")) + } + cur := buffer.Mem() + if mem == 0 { + mem = cur + } else { + assert.LessOrEqual(cur, mem) + } + buffer.Cleanup(h) + h = buffer.Staging() + } +} + +func TestReadOnlyZeroMem(t *testing.T) { + // read only MemBuffer should not allocate heap memory. + assert.Zero(t, newRbtDBWithContext().Mem()) + assert.Zero(t, newArtDBWithContext().Mem()) +} + +func TestKeyValueOversize(t *testing.T) { + check := func(t *testing.T, db MemBuffer) { + key := make([]byte, math.MaxUint16) + overSizeKey := make([]byte, math.MaxUint16+1) + + assert.Nil(t, db.Set(key, overSizeKey)) + err := db.Set(overSizeKey, key) + assert.NotNil(t, err) + assert.Equal(t, err.(*tikverr.ErrKeyTooLarge).KeySize, math.MaxUint16+1) + } + + check(t, newRbtDBWithContext()) + check(t, newArtDBWithContext()) +} + +func TestSetMemoryFootprintChangeHook(t *testing.T) { + check := func(t *testing.T, db MemBuffer) { + memoryConsumed := uint64(0) + assert.False(t, db.MemHookSet()) + db.SetMemoryFootprintChangeHook(func(mem uint64) { + memoryConsumed = mem + }) + assert.True(t, db.MemHookSet()) + + assert.Zero(t, memoryConsumed) + db.Set([]byte{1}, []byte{1}) + assert.NotZero(t, memoryConsumed) + } + + check(t, newRbtDBWithContext()) + check(t, newArtDBWithContext()) +} + +func TestSelectValueHistory(t *testing.T) { + check := func(t *testing.T, db interface { + MemBuffer + SelectValueHistory(key []byte, predicate func(value []byte) bool) ([]byte, error) + }) { + db.Set([]byte{1}, []byte{1}) + h := db.Staging() + db.Set([]byte{1}, []byte{1, 1}) + + val, err := db.SelectValueHistory([]byte{1}, func(value []byte) bool { return bytes.Equal(value, []byte{1}) }) + assert.Nil(t, err) + assert.Equal(t, val, []byte{1}) + val, err = db.SelectValueHistory([]byte{1}, func(value []byte) bool { return bytes.Equal(value, []byte{1, 1}) }) + assert.Nil(t, err) + assert.Equal(t, val, []byte{1, 1}) + val, err = db.SelectValueHistory([]byte{1}, func(value []byte) bool { return bytes.Equal(value, []byte{1, 1, 1}) }) + assert.Nil(t, err) + assert.Nil(t, val) + _, err = db.SelectValueHistory([]byte{2}, func([]byte) bool { return false }) + assert.NotNil(t, err) + + db.Cleanup(h) + + val, err = db.SelectValueHistory([]byte{1}, func(value []byte) bool { return bytes.Equal(value, []byte{1}) }) + assert.Nil(t, err) + assert.Equal(t, val, []byte{1}) + val, err = db.SelectValueHistory([]byte{1}, func(value []byte) bool { return bytes.Equal(value, []byte{1, 1}) }) + assert.Nil(t, err) + assert.Nil(t, val) + } + + check(t, newRbtDBWithContext()) + check(t, newArtDBWithContext()) +} diff --git a/internal/unionstore/rbt/rbt.go b/internal/unionstore/rbt/rbt.go index 40b3234c3..06884fb42 100644 --- a/internal/unionstore/rbt/rbt.go +++ b/internal/unionstore/rbt/rbt.go @@ -324,6 +324,12 @@ func (db *RBT) Set(key []byte, value []byte, ops ...kv.FlagsOp) error { panic("vlog is reset") } + if len(key) > MaxKeyLen { + return &tikverr.ErrKeyTooLarge{ + KeySize: len(key), + } + } + if value != nil { if size := uint64(len(key) + len(value)); size > db.entrySizeLimit { return &tikverr.ErrEntryTooLarge{ @@ -837,6 +843,8 @@ func (a MemdbNodeAddr) getRight(db *RBT) MemdbNodeAddr { return db.getNode(a.right) } +const MaxKeyLen = math.MaxUint16 + type memdbNode struct { up arena.MemdbArenaAddr left arena.MemdbArenaAddr diff --git a/metrics/metrics.go b/metrics/metrics.go index e608d800e..ce247600c 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -41,80 +41,82 @@ import ( // Client metrics. var ( - TiKVTxnCmdHistogram *prometheus.HistogramVec - TiKVBackoffHistogram *prometheus.HistogramVec - TiKVSendReqHistogram *prometheus.HistogramVec - TiKVSendReqCounter *prometheus.CounterVec - TiKVSendReqTimeCounter *prometheus.CounterVec - TiKVRPCNetLatencyHistogram *prometheus.HistogramVec - TiKVCoprocessorHistogram *prometheus.HistogramVec - TiKVLockResolverCounter *prometheus.CounterVec - TiKVRegionErrorCounter *prometheus.CounterVec - TiKVRPCErrorCounter *prometheus.CounterVec - TiKVTxnWriteKVCountHistogram *prometheus.HistogramVec - TiKVTxnWriteSizeHistogram *prometheus.HistogramVec - TiKVRawkvCmdHistogram *prometheus.HistogramVec - TiKVRawkvSizeHistogram *prometheus.HistogramVec - TiKVTxnRegionsNumHistogram *prometheus.HistogramVec - TiKVLoadSafepointCounter *prometheus.CounterVec - TiKVSecondaryLockCleanupFailureCounter *prometheus.CounterVec - TiKVRegionCacheCounter *prometheus.CounterVec - TiKVLoadRegionCounter *prometheus.CounterVec - TiKVLoadRegionCacheHistogram *prometheus.HistogramVec - TiKVLocalLatchWaitTimeHistogram prometheus.Histogram - TiKVStatusDuration *prometheus.HistogramVec - TiKVStatusCounter *prometheus.CounterVec - TiKVBatchSendTailLatency prometheus.Histogram - TiKVBatchSendLoopDuration *prometheus.SummaryVec - TiKVBatchRecvLoopDuration *prometheus.SummaryVec - TiKVBatchHeadArrivalInterval *prometheus.SummaryVec - TiKVBatchBestSize *prometheus.SummaryVec - TiKVBatchMoreRequests *prometheus.SummaryVec - TiKVBatchWaitOverLoad prometheus.Counter - TiKVBatchPendingRequests *prometheus.HistogramVec - TiKVBatchRequests *prometheus.HistogramVec - TiKVBatchRequestDuration *prometheus.SummaryVec - TiKVBatchClientUnavailable prometheus.Histogram - TiKVBatchClientWaitEstablish prometheus.Histogram - TiKVBatchClientRecycle prometheus.Histogram - TiKVRangeTaskStats *prometheus.GaugeVec - TiKVRangeTaskPushDuration *prometheus.HistogramVec - TiKVTokenWaitDuration prometheus.Histogram - TiKVTxnHeartBeatHistogram *prometheus.HistogramVec - TiKVTTLManagerHistogram prometheus.Histogram - TiKVPessimisticLockKeysDuration prometheus.Histogram - TiKVTTLLifeTimeReachCounter prometheus.Counter - TiKVNoAvailableConnectionCounter prometheus.Counter - TiKVTwoPCTxnCounter *prometheus.CounterVec - TiKVAsyncCommitTxnCounter *prometheus.CounterVec - TiKVOnePCTxnCounter *prometheus.CounterVec - TiKVStoreLimitErrorCounter *prometheus.CounterVec - TiKVGRPCConnTransientFailureCounter *prometheus.CounterVec - TiKVPanicCounter *prometheus.CounterVec - TiKVForwardRequestCounter *prometheus.CounterVec - TiKVTSFutureWaitDuration prometheus.Histogram - TiKVSafeTSUpdateCounter *prometheus.CounterVec - TiKVMinSafeTSGapSeconds *prometheus.GaugeVec - TiKVReplicaSelectorFailureCounter *prometheus.CounterVec - TiKVRequestRetryTimesHistogram prometheus.Histogram - TiKVTxnCommitBackoffSeconds prometheus.Histogram - TiKVTxnCommitBackoffCount prometheus.Histogram - TiKVSmallReadDuration prometheus.Histogram - TiKVReadThroughput prometheus.Histogram - TiKVUnsafeDestroyRangeFailuresCounterVec *prometheus.CounterVec - TiKVPrewriteAssertionUsageCounter *prometheus.CounterVec - TiKVGrpcConnectionState *prometheus.GaugeVec - TiKVAggressiveLockedKeysCounter *prometheus.CounterVec - TiKVStoreSlowScoreGauge *prometheus.GaugeVec - TiKVFeedbackSlowScoreGauge *prometheus.GaugeVec - TiKVHealthFeedbackOpsCounter *prometheus.CounterVec - TiKVPreferLeaderFlowsGauge *prometheus.GaugeVec - TiKVStaleReadCounter *prometheus.CounterVec - TiKVStaleReadReqCounter *prometheus.CounterVec - TiKVStaleReadBytes *prometheus.CounterVec - TiKVPipelinedFlushLenHistogram prometheus.Histogram - TiKVPipelinedFlushSizeHistogram prometheus.Histogram - TiKVPipelinedFlushDuration prometheus.Histogram + TiKVTxnCmdHistogram *prometheus.HistogramVec + TiKVBackoffHistogram *prometheus.HistogramVec + TiKVSendReqHistogram *prometheus.HistogramVec + TiKVSendReqCounter *prometheus.CounterVec + TiKVSendReqTimeCounter *prometheus.CounterVec + TiKVRPCNetLatencyHistogram *prometheus.HistogramVec + TiKVCoprocessorHistogram *prometheus.HistogramVec + TiKVLockResolverCounter *prometheus.CounterVec + TiKVRegionErrorCounter *prometheus.CounterVec + TiKVRPCErrorCounter *prometheus.CounterVec + TiKVTxnWriteKVCountHistogram *prometheus.HistogramVec + TiKVTxnWriteSizeHistogram *prometheus.HistogramVec + TiKVRawkvCmdHistogram *prometheus.HistogramVec + TiKVRawkvSizeHistogram *prometheus.HistogramVec + TiKVTxnRegionsNumHistogram *prometheus.HistogramVec + TiKVLoadSafepointCounter *prometheus.CounterVec + TiKVSecondaryLockCleanupFailureCounter *prometheus.CounterVec + TiKVRegionCacheCounter *prometheus.CounterVec + TiKVLoadRegionCounter *prometheus.CounterVec + TiKVLoadRegionCacheHistogram *prometheus.HistogramVec + TiKVLocalLatchWaitTimeHistogram prometheus.Histogram + TiKVStatusDuration *prometheus.HistogramVec + TiKVStatusCounter *prometheus.CounterVec + TiKVBatchSendTailLatency prometheus.Histogram + TiKVBatchSendLoopDuration *prometheus.SummaryVec + TiKVBatchRecvLoopDuration *prometheus.SummaryVec + TiKVBatchHeadArrivalInterval *prometheus.SummaryVec + TiKVBatchBestSize *prometheus.SummaryVec + TiKVBatchMoreRequests *prometheus.SummaryVec + TiKVBatchWaitOverLoad prometheus.Counter + TiKVBatchPendingRequests *prometheus.HistogramVec + TiKVBatchRequests *prometheus.HistogramVec + TiKVBatchRequestDuration *prometheus.SummaryVec + TiKVBatchClientUnavailable prometheus.Histogram + TiKVBatchClientWaitEstablish prometheus.Histogram + TiKVBatchClientRecycle prometheus.Histogram + TiKVRangeTaskStats *prometheus.GaugeVec + TiKVRangeTaskPushDuration *prometheus.HistogramVec + TiKVTokenWaitDuration prometheus.Histogram + TiKVTxnHeartBeatHistogram *prometheus.HistogramVec + TiKVTTLManagerHistogram prometheus.Histogram + TiKVPessimisticLockKeysDuration prometheus.Histogram + TiKVTTLLifeTimeReachCounter prometheus.Counter + TiKVNoAvailableConnectionCounter prometheus.Counter + TiKVTwoPCTxnCounter *prometheus.CounterVec + TiKVAsyncCommitTxnCounter *prometheus.CounterVec + TiKVOnePCTxnCounter *prometheus.CounterVec + TiKVStoreLimitErrorCounter *prometheus.CounterVec + TiKVGRPCConnTransientFailureCounter *prometheus.CounterVec + TiKVPanicCounter *prometheus.CounterVec + TiKVForwardRequestCounter *prometheus.CounterVec + TiKVTSFutureWaitDuration prometheus.Histogram + TiKVSafeTSUpdateCounter *prometheus.CounterVec + TiKVMinSafeTSGapSeconds *prometheus.GaugeVec + TiKVReplicaSelectorFailureCounter *prometheus.CounterVec + TiKVRequestRetryTimesHistogram prometheus.Histogram + TiKVTxnCommitBackoffSeconds prometheus.Histogram + TiKVTxnCommitBackoffCount prometheus.Histogram + TiKVSmallReadDuration prometheus.Histogram + TiKVReadThroughput prometheus.Histogram + TiKVUnsafeDestroyRangeFailuresCounterVec *prometheus.CounterVec + TiKVPrewriteAssertionUsageCounter *prometheus.CounterVec + TiKVGrpcConnectionState *prometheus.GaugeVec + TiKVAggressiveLockedKeysCounter *prometheus.CounterVec + TiKVStoreSlowScoreGauge *prometheus.GaugeVec + TiKVFeedbackSlowScoreGauge *prometheus.GaugeVec + TiKVHealthFeedbackOpsCounter *prometheus.CounterVec + TiKVPreferLeaderFlowsGauge *prometheus.GaugeVec + TiKVStaleReadCounter *prometheus.CounterVec + TiKVStaleReadReqCounter *prometheus.CounterVec + TiKVStaleReadBytes *prometheus.CounterVec + TiKVPipelinedFlushLenHistogram prometheus.Histogram + TiKVPipelinedFlushSizeHistogram prometheus.Histogram + TiKVPipelinedFlushDuration prometheus.Histogram + TiKVValidateReadTSFromPDCount prometheus.Counter + TiKVLowResolutionTSOUpdateIntervalSecondsGauge prometheus.Gauge ) // Label constants. @@ -834,6 +836,22 @@ func initMetrics(namespace, subsystem string, constLabels prometheus.Labels) { Buckets: prometheus.ExponentialBuckets(0.0005, 2, 28), // 0.5ms ~ 18h }) + TiKVValidateReadTSFromPDCount = prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "validate_read_ts_from_pd_count", + Help: "Counter of validating read ts by getting a timestamp from PD", + }) + + TiKVLowResolutionTSOUpdateIntervalSecondsGauge = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "low_resolution_tso_update_interval_seconds", + Help: "The actual working update interval for the low resolution TSO. As there are adaptive mechanism internally, this value may differ from the config.", + }) + initShortcuts() } @@ -928,6 +946,8 @@ func RegisterMetrics() { prometheus.MustRegister(TiKVPipelinedFlushLenHistogram) prometheus.MustRegister(TiKVPipelinedFlushSizeHistogram) prometheus.MustRegister(TiKVPipelinedFlushDuration) + prometheus.MustRegister(TiKVValidateReadTSFromPDCount) + prometheus.MustRegister(TiKVLowResolutionTSOUpdateIntervalSecondsGauge) } // readCounter reads the value of a prometheus.Counter. diff --git a/oracle/oracle.go b/oracle/oracle.go index 5579d3568..a4ffa1b93 100644 --- a/oracle/oracle.go +++ b/oracle/oracle.go @@ -51,6 +51,12 @@ type Oracle interface { GetLowResolutionTimestamp(ctx context.Context, opt *Option) (uint64, error) GetLowResolutionTimestampAsync(ctx context.Context, opt *Option) Future SetLowResolutionTimestampUpdateInterval(time.Duration) error + // GetStaleTimestamp generates a timestamp based on the recently fetched timestamp and the elapsed time since + // when that timestamp was fetched. The result is expected to be about `prevSecond` seconds before the current + // time. + // WARNING: This method does not guarantee whether the generated timestamp is legal for accessing the data. + // Neither is it safe to use it for verifying the legality of another calculated timestamp. + // Be sure to validate the timestamp before using it to access the data. GetStaleTimestamp(ctx context.Context, txnScope string, prevSecond uint64) (uint64, error) IsExpired(lockTimestamp, TTL uint64, opt *Option) bool UntilExpired(lockTimeStamp, TTL uint64, opt *Option) int64 @@ -61,6 +67,13 @@ type Oracle interface { // GetAllTSOKeyspaceGroupMinTS gets a minimum timestamp from all TSO keyspace groups. GetAllTSOKeyspaceGroupMinTS(ctx context.Context) (uint64, error) + + // ValidateSnapshotReadTS verifies whether it can be guaranteed that the given readTS doesn't exceed the maximum ts + // that has been allocated by the oracle, so that it's safe to use this ts to perform snapshot read, stale read, + // etc. + // Note that this method only checks the ts from the oracle's perspective. It doesn't check whether the snapshot + // has been GCed. + ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *Option) error } // Future is a future which promises to return a timestamp. diff --git a/oracle/oracles/local.go b/oracle/oracles/local.go index f8016f468..bf9b43c89 100644 --- a/oracle/oracles/local.go +++ b/oracle/oracles/local.go @@ -39,6 +39,7 @@ import ( "sync" "time" + "github.com/pingcap/errors" "github.com/tikv/client-go/v2/oracle" ) @@ -148,3 +149,14 @@ func (l *localOracle) SetExternalTimestamp(ctx context.Context, newTimestamp uin func (l *localOracle) GetExternalTimestamp(ctx context.Context) (uint64, error) { return l.getExternalTimestamp(ctx) } + +func (l *localOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { + currentTS, err := l.GetTimestamp(ctx, opt) + if err != nil { + return errors.Errorf("fail to validate read timestamp: %v", err) + } + if currentTS < readTS { + return errors.Errorf("cannot set read timestamp to a future time") + } + return nil +} diff --git a/oracle/oracles/mock.go b/oracle/oracles/mock.go index 633d97537..cab3335ab 100644 --- a/oracle/oracles/mock.go +++ b/oracle/oracles/mock.go @@ -137,6 +137,17 @@ func (o *MockOracle) SetLowResolutionTimestampUpdateInterval(time.Duration) erro return nil } +func (o *MockOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { + currentTS, err := o.GetTimestamp(ctx, opt) + if err != nil { + return errors.Errorf("fail to validate read timestamp: %v", err) + } + if currentTS < readTS { + return errors.Errorf("cannot set read timestamp to a future time") + } + return nil +} + // IsExpired implements oracle.Oracle interface. func (o *MockOracle) IsExpired(lockTimestamp, TTL uint64, _ *oracle.Option) bool { o.RLock() diff --git a/oracle/oracles/pd.go b/oracle/oracles/pd.go index 83dd41f3c..3c4cb02fa 100644 --- a/oracle/oracles/pd.go +++ b/oracle/oracles/pd.go @@ -48,19 +48,109 @@ import ( "github.com/tikv/client-go/v2/oracle" pd "github.com/tikv/pd/client" "go.uber.org/zap" + "golang.org/x/sync/singleflight" ) var _ oracle.Oracle = &pdOracle{} const slowDist = 30 * time.Millisecond +type adaptiveUpdateTSIntervalState int + +const ( + adaptiveUpdateTSIntervalStateNone adaptiveUpdateTSIntervalState = iota + // adaptiveUpdateTSIntervalStateNormal represents the state that the adaptive update ts interval is synced with the + // configuration without performing any automatic adjustment. + adaptiveUpdateTSIntervalStateNormal + // adaptiveUpdateTSIntervalStateAdapting represents the state that as there are recently some stale read / snapshot + // read operations requesting a short staleness (now - readTS is nearly or exceeds the current update interval), + // so that we automatically shrink the update interval. Otherwise, read operations may don't have low resolution ts + // that is new enough for checking the legality of the read ts, causing them have to fetch the latest ts from PD, + // which is time-consuming. + adaptiveUpdateTSIntervalStateAdapting + // adaptiveUpdateTSIntervalStateRecovering represents the state that the update ts interval have once been shrunk, + // to adapt to reads with short staleness, but there isn't any such read operations for a while, so that we + // gradually recover the update interval to the configured value. + adaptiveUpdateTSIntervalStateRecovering + // adaptiveUpdateTSIntervalStateUnadjustable represents the state that the user has configured a very short update + // interval, so that we don't have any space to automatically adjust it. + adaptiveUpdateTSIntervalStateUnadjustable +) + +func (s adaptiveUpdateTSIntervalState) String() string { + switch s { + case adaptiveUpdateTSIntervalStateNormal: + return "normal" + case adaptiveUpdateTSIntervalStateAdapting: + return "adapting" + case adaptiveUpdateTSIntervalStateRecovering: + return "recovering" + case adaptiveUpdateTSIntervalStateUnadjustable: + return "unadjustable" + default: + return fmt.Sprintf("unknown(%v)", int(s)) + } +} + +const ( + // minAllowedAdaptiveUpdateTSInterval is the lower bound of the adaptive update ts interval for avoiding an abnormal + // read operation causing the update interval to be too short. + minAllowedAdaptiveUpdateTSInterval = 500 * time.Millisecond + // adaptiveUpdateTSIntervalShrinkingPreserve is the duration that we additionally shrinks when adapting to a read + // operation that requires a short staleness. + adaptiveUpdateTSIntervalShrinkingPreserve = 100 * time.Millisecond + // adaptiveUpdateTSIntervalBlockRecoverThreshold is the threshold of the difference between the current update + // interval and the staleness the read operation request to prevent the update interval from recovering back to + // normal. + adaptiveUpdateTSIntervalBlockRecoverThreshold = 200 * time.Millisecond + // adaptiveUpdateTSIntervalRecoverPerSecond is the duration that the update interval should grow per second when + // recovering to normal state from adapting state. + adaptiveUpdateTSIntervalRecoverPerSecond = 20 * time.Millisecond + // adaptiveUpdateTSIntervalDelayBeforeRecovering is the duration that we should hold the current adaptive update + // interval before turning back to normal state. + adaptiveUpdateTSIntervalDelayBeforeRecovering = 5 * time.Minute +) + // pdOracle is an Oracle that uses a placement driver client as source. type pdOracle struct { c pd.Client // txn_scope (string) -> lastTSPointer (*atomic.Pointer[lastTSO]) - lastTSMap sync.Map - quit chan struct{} + lastTSMap sync.Map + quit chan struct{} + // The configured interval to update the low resolution ts. Set by SetLowResolutionTimestampUpdateInterval. + // For TiDB, this is directly controlled by the system variable `tidb_low_resolution_tso_update_interval`. lastTSUpdateInterval atomic.Int64 + // The actual interval to update the low resolution ts. If the configured one is too large to satisfy the + // requirement of the stale read or snapshot read, the actual interval can be automatically set to a shorter + // value than lastTSUpdateInterval. + // This value is also possible to be updated by SetLowResolutionTimestampUpdateInterval, which may happen when + // user adjusting the update interval manually. + adaptiveLastTSUpdateInterval atomic.Int64 + + adaptiveUpdateIntervalState struct { + // The mutex to avoid racing between updateTS goroutine and SetLowResolutionTimestampUpdateInterval. + mu sync.Mutex + // The most recent time that a stale read / snapshot read requests a timestamp that is close enough to + // the current adaptive update interval. If there is such a request recently, the adaptive interval + // should avoid falling back to the original (configured) value. + // Stored in unix microseconds to make it able to be accessed atomically. + lastShortStalenessReadTime atomic.Int64 + // When someone requests need shrinking the update interval immediately, it sends the duration it expects to + // this channel. + shrinkIntervalCh chan time.Duration + + // Only accessed in updateTS goroutine. No need to use atomic value. + lastTick time.Time + // Represents a description about the current state. + state adaptiveUpdateTSIntervalState + } + + // When the low resolution ts is not new enough and there are many concurrent stane read / snapshot read + // operations that needs to validate the read ts, we can use this to avoid too many concurrent GetTS calls by + // reusing a result for different `ValidateSnapshotReadTS` calls. This can be done because that + // we don't require the ts for validation to be strictly the latest one. + // Note that the result can't be reused for different txnScopes. The txnScope is used as the key. + tsForValidation singleflight.Group } // lastTSO stores the last timestamp oracle gets from PD server and the local time when the TSO is fetched. @@ -69,25 +159,39 @@ type lastTSO struct { arrival uint64 } +type PDOracleOptions struct { + // The duration to update the last ts, i.e., the low resolution ts. + UpdateInterval time.Duration + // Disable the background periodic update of the last ts. This is for test purposes only. + NoUpdateTS bool +} + // NewPdOracle create an Oracle that uses a pd client source. // Refer https://github.com/tikv/pd/blob/master/client/client.go for more details. // PdOracle maintains `lastTS` to store the last timestamp got from PD server. If // `GetTimestamp()` is not called after `lastTSUpdateInterval`, it will be called by // itself to keep up with the timestamp on PD server. -func NewPdOracle(pdClient pd.Client, updateInterval time.Duration) (oracle.Oracle, error) { +func NewPdOracle(pdClient pd.Client, options *PDOracleOptions) (oracle.Oracle, error) { + if options.UpdateInterval <= 0 { + return nil, fmt.Errorf("updateInterval must be > 0") + } + o := &pdOracle{ c: pdClient, quit: make(chan struct{}), lastTSUpdateInterval: atomic.Int64{}, } - err := o.SetLowResolutionTimestampUpdateInterval(updateInterval) - if err != nil { - return nil, err - } + o.adaptiveUpdateIntervalState.shrinkIntervalCh = make(chan time.Duration, 1) + o.lastTSUpdateInterval.Store(int64(options.UpdateInterval)) + o.adaptiveLastTSUpdateInterval.Store(int64(options.UpdateInterval)) + o.adaptiveUpdateIntervalState.lastTick = time.Now() + ctx := context.TODO() - go o.updateTS(ctx) + if !options.NoUpdateTS { + go o.updateTS(ctx) + } // Initialize the timestamp of the global txnScope by Get. - _, err = o.GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + _, err := o.GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) if err != nil { o.Close() return nil, err @@ -241,28 +345,172 @@ func (o *pdOracle) getLastTSWithArrivalTS(txnScope string) (*lastTSO, bool) { return last, true } +func (o *pdOracle) nextUpdateInterval(now time.Time, requiredStaleness time.Duration) time.Duration { + o.adaptiveUpdateIntervalState.mu.Lock() + defer o.adaptiveUpdateIntervalState.mu.Unlock() + + configuredInterval := time.Duration(o.lastTSUpdateInterval.Load()) + prevAdaptiveUpdateInterval := time.Duration(o.adaptiveLastTSUpdateInterval.Load()) + lastReachDropThresholdTime := time.UnixMilli(o.adaptiveUpdateIntervalState.lastShortStalenessReadTime.Load()) + + currentAdaptiveUpdateInterval := prevAdaptiveUpdateInterval + + // Shortcut + const none = adaptiveUpdateTSIntervalStateNone + + // The following `checkX` functions checks whether it should transit to the X state. Returns + // a tuple representing (state, newInterval). + // When `checkX` returns a valid state, it means that the current situation matches the state. In this case, it + // also returns the new interval that should be used next. + // When it returns `none`, we need to check if it should transit to other states. For each call to + // nextUpdateInterval, if all attempts to `checkX` function returns false, it keeps the previous state unchanged. + + checkUnadjustable := func() (adaptiveUpdateTSIntervalState, time.Duration) { + // If the user has configured a very short interval, we don't have any space to adjust it. Just use + // the user's configured value directly. + if configuredInterval <= minAllowedAdaptiveUpdateTSInterval { + return adaptiveUpdateTSIntervalStateUnadjustable, configuredInterval + } + return none, 0 + } + + checkNormal := func() (adaptiveUpdateTSIntervalState, time.Duration) { + // If the current actual update interval is synced with the configured value, and it's not unadjustable state, + // then it's the normal state. + if configuredInterval > minAllowedAdaptiveUpdateTSInterval && currentAdaptiveUpdateInterval == configuredInterval { + return adaptiveUpdateTSIntervalStateNormal, currentAdaptiveUpdateInterval + } + return none, 0 + } + + checkAdapting := func() (adaptiveUpdateTSIntervalState, time.Duration) { + if requiredStaleness != 0 && requiredStaleness < currentAdaptiveUpdateInterval && currentAdaptiveUpdateInterval > minAllowedAdaptiveUpdateTSInterval { + // If we are calculating the interval because of a request that requires a shorter staleness, we shrink the + // update interval immediately to adapt to it. + // We shrink the update interval to a value slightly lower than the requested staleness to avoid potential + // frequent shrinking operations. But there's a lower bound to prevent loading ts too frequently. + newInterval := max(requiredStaleness-adaptiveUpdateTSIntervalShrinkingPreserve, minAllowedAdaptiveUpdateTSInterval) + return adaptiveUpdateTSIntervalStateAdapting, newInterval + } + + if currentAdaptiveUpdateInterval != configuredInterval && now.Sub(lastReachDropThresholdTime) < adaptiveUpdateTSIntervalDelayBeforeRecovering { + // There is a recent request that requires a short staleness. Keep the current adaptive interval. + // If it's not adapting state, it's possible that it's previously in recovering state, and it stops recovering + // as there is a new read operation requesting a short staleness. + return adaptiveUpdateTSIntervalStateAdapting, currentAdaptiveUpdateInterval + } + + return none, 0 + } + + checkRecovering := func() (adaptiveUpdateTSIntervalState, time.Duration) { + if currentAdaptiveUpdateInterval == configuredInterval || now.Sub(lastReachDropThresholdTime) < adaptiveUpdateTSIntervalDelayBeforeRecovering { + return none, 0 + } + + timeSinceLastTick := now.Sub(o.adaptiveUpdateIntervalState.lastTick) + newInterval := currentAdaptiveUpdateInterval + time.Duration(timeSinceLastTick.Seconds()*float64(adaptiveUpdateTSIntervalRecoverPerSecond)) + if newInterval > configuredInterval { + newInterval = configuredInterval + } + + return adaptiveUpdateTSIntervalStateRecovering, newInterval + } + + // Check the specified states in order, until the state becomes determined. + // If it's still undetermined after all checks, keep the previous state. + nextState := func(checkFuncs ...func() (adaptiveUpdateTSIntervalState, time.Duration)) time.Duration { + for _, f := range checkFuncs { + state, newInterval := f() + if state == none { + continue + } + + currentAdaptiveUpdateInterval = newInterval + + // If the final state is the recovering state, do an additional step to check whether it can go back to + // normal state immediately. + if state == adaptiveUpdateTSIntervalStateRecovering { + var nextState adaptiveUpdateTSIntervalState + nextState, newInterval = checkNormal() + if nextState != none { + state = nextState + currentAdaptiveUpdateInterval = newInterval + } + } + + o.adaptiveLastTSUpdateInterval.Store(int64(currentAdaptiveUpdateInterval)) + if o.adaptiveUpdateIntervalState.state != state { + logutil.BgLogger().Info("adaptive update ts interval state transition", + zap.Duration("configuredInterval", configuredInterval), + zap.Duration("prevAdaptiveUpdateInterval", prevAdaptiveUpdateInterval), + zap.Duration("newAdaptiveUpdateInterval", currentAdaptiveUpdateInterval), + zap.Duration("requiredStaleness", requiredStaleness), + zap.Stringer("prevState", o.adaptiveUpdateIntervalState.state), + zap.Stringer("newState", state)) + o.adaptiveUpdateIntervalState.state = state + } + + return currentAdaptiveUpdateInterval + } + return currentAdaptiveUpdateInterval + } + + var newInterval time.Duration + if requiredStaleness != 0 { + newInterval = nextState(checkUnadjustable, checkAdapting) + } else { + newInterval = nextState(checkUnadjustable, checkAdapting, checkNormal, checkRecovering) + } + + metrics.TiKVLowResolutionTSOUpdateIntervalSecondsGauge.Set(newInterval.Seconds()) + + return newInterval +} + func (o *pdOracle) updateTS(ctx context.Context) { - currentInterval := o.lastTSUpdateInterval.Load() - ticker := time.NewTicker(time.Duration(currentInterval)) + currentInterval := time.Duration(o.lastTSUpdateInterval.Load()) + ticker := time.NewTicker(currentInterval) defer ticker.Stop() + + doUpdate := func(now time.Time) { + // Update the timestamp for each txnScope + o.lastTSMap.Range(func(key, _ interface{}) bool { + txnScope := key.(string) + ts, err := o.getTimestamp(ctx, txnScope) + if err != nil { + logutil.Logger(ctx).Error("updateTS error", zap.String("txnScope", txnScope), zap.Error(err)) + return true + } + o.setLastTS(ts, txnScope) + return true + }) + + o.adaptiveUpdateIntervalState.lastTick = now + } + for { select { - case <-ticker.C: - // Update the timestamp for each txnScope - o.lastTSMap.Range(func(key, _ interface{}) bool { - txnScope := key.(string) - ts, err := o.getTimestamp(ctx, txnScope) - if err != nil { - logutil.Logger(ctx).Error("updateTS error", zap.String("txnScope", txnScope), zap.Error(err)) - return true - } - o.setLastTS(ts, txnScope) - return true - }) - newInterval := o.lastTSUpdateInterval.Load() + case now := <-ticker.C: + doUpdate(now) + + newInterval := o.nextUpdateInterval(now, 0) + if newInterval != currentInterval { + currentInterval = newInterval + ticker.Reset(currentInterval) + } + + case requiredStaleness := <-o.adaptiveUpdateIntervalState.shrinkIntervalCh: + now := time.Now() + newInterval := o.nextUpdateInterval(now, requiredStaleness) if newInterval != currentInterval { currentInterval = newInterval - ticker.Reset(time.Duration(currentInterval)) + + if time.Since(o.adaptiveUpdateIntervalState.lastTick) >= currentInterval { + doUpdate(time.Now()) + } + + ticker.Reset(currentInterval) } case <-o.quit: return @@ -296,11 +544,35 @@ func (f lowResolutionTsFuture) Wait() (uint64, error) { // SetLowResolutionTimestampUpdateInterval sets the refresh interval for low resolution timestamps. Note this will take // effect up to the previous update interval amount of time after being called. -func (o *pdOracle) SetLowResolutionTimestampUpdateInterval(updateInterval time.Duration) error { - if updateInterval <= 0 { +// This setting may not be strictly followed. If Stale Read requests too new data to be available, the low resolution +// ts may be actually updated in a shorter interval than the configured one. +func (o *pdOracle) SetLowResolutionTimestampUpdateInterval(newUpdateInterval time.Duration) error { + if newUpdateInterval <= 0 { return fmt.Errorf("updateInterval must be > 0") } - o.lastTSUpdateInterval.Store(updateInterval.Nanoseconds()) + + o.adaptiveUpdateIntervalState.mu.Lock() + defer o.adaptiveUpdateIntervalState.mu.Unlock() + + prevConfigured := o.lastTSUpdateInterval.Swap(int64(newUpdateInterval)) + adaptiveUpdateInterval := o.adaptiveLastTSUpdateInterval.Load() + + var adaptiveUpdateIntervalUpdated bool + + if adaptiveUpdateInterval == prevConfigured || newUpdateInterval < time.Duration(adaptiveUpdateInterval) { + // If the adaptive update interval is the same as the configured one, treat it as the adaptive adjusting + // mechanism not taking effect. So update it immediately. + // If the new configured interval is short so that it's smaller than the current adaptive interval, also shrink + // the adaptive interval immediately. + o.adaptiveLastTSUpdateInterval.Store(int64(newUpdateInterval)) + adaptiveUpdateIntervalUpdated = true + } + logutil.Logger(context.Background()).Info("updated low resolution ts update interval", + zap.Duration("previous", time.Duration(prevConfigured)), + zap.Duration("new", newUpdateInterval), + zap.Duration("prevAdaptiveUpdateInterval", time.Duration(adaptiveUpdateInterval)), + zap.Bool("adaptiveUpdateIntervalUpdated", adaptiveUpdateIntervalUpdated)) + return nil } @@ -366,3 +638,84 @@ func (o *pdOracle) SetExternalTimestamp(ctx context.Context, ts uint64) error { func (o *pdOracle) GetExternalTimestamp(ctx context.Context) (uint64, error) { return o.c.GetExternalTimestamp(ctx) } + +func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Option) (uint64, error) { + ch := o.tsForValidation.DoChan(opt.TxnScope, func() (interface{}, error) { + metrics.TiKVValidateReadTSFromPDCount.Inc() + + // If the call that triggers the execution of this function is canceled by the context, other calls that are + // waiting for reusing the same result should not be canceled. So pass context.Background() instead of the + // current ctx. + res, err := o.GetTimestamp(context.Background(), opt) + return res, err + }) + select { + case <-ctx.Done(): + return 0, errors.WithStack(ctx.Err()) + case res := <-ch: + if res.Err != nil { + return 0, errors.WithStack(res.Err) + } + return res.Val.(uint64), nil + } +} + +func (o *pdOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { + latestTS, err := o.GetLowResolutionTimestamp(ctx, opt) + // If we fail to get latestTS or the readTS exceeds it, get a timestamp from PD to double-check. + // But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function + // loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls. + if err != nil || readTS > latestTS { + currentTS, err := o.getCurrentTSForValidation(ctx, opt) + if err != nil { + return errors.Errorf("fail to validate read timestamp: %v", err) + } + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now()) + if readTS > currentTS { + return errors.Errorf("cannot set read timestamp to a future time") + } + } else { + estimatedCurrentTS, err := o.getStaleTimestamp(opt.TxnScope, 0) + if err != nil { + logutil.Logger(ctx).Warn("failed to estimate current ts by getSlateTimestamp for auto-adjusting update low resolution ts interval", + zap.Error(err), zap.Uint64("readTS", readTS), zap.String("txnScope", opt.TxnScope)) + } else { + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, estimatedCurrentTS, time.Now()) + } + } + return nil +} + +func (o *pdOracle) adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS uint64, currentTS uint64, now time.Time) { + requiredStaleness := oracle.GetTimeFromTS(currentTS).Sub(oracle.GetTimeFromTS(readTS)) + + // Do not acquire the mutex, as here we only needs a rough check. + // So it's possible that we get inconsistent values from these two atomic fields, but it won't cause any problem. + currentUpdateInterval := time.Duration(o.adaptiveLastTSUpdateInterval.Load()) + + if requiredStaleness <= currentUpdateInterval+adaptiveUpdateTSIntervalBlockRecoverThreshold { + // Record the most recent time when there's a read operation requesting the staleness close enough to the + // current update interval. + nowMillis := now.UnixMilli() + last := o.adaptiveUpdateIntervalState.lastShortStalenessReadTime.Load() + if last < nowMillis { + // Do not retry if the CAS fails (which may happen when there are other goroutines updating it + // concurrently), as we don't actually need to set it strictly. + o.adaptiveUpdateIntervalState.lastShortStalenessReadTime.CompareAndSwap(last, nowMillis) + } + } + + if requiredStaleness <= currentUpdateInterval && currentUpdateInterval > minAllowedAdaptiveUpdateTSInterval { + // Considering system time / PD time drifts, it's possible that we get a non-positive value from the + // calculation. Make sure it's always positive before passing it to the updateTS goroutine. + // Note that `nextUpdateInterval` method expects the requiredStaleness is always non-zero when triggerred + // by this path. + requiredStaleness = max(requiredStaleness, time.Millisecond) + // Try to non-blocking send a signal to notify it to change the interval immediately. But if the channel is + // busy, it means that there's another concurrent call trying to update it. Just skip it in this case. + select { + case o.adaptiveUpdateIntervalState.shrinkIntervalCh <- requiredStaleness: + default: + } + } +} diff --git a/oracle/oracles/pd_test.go b/oracle/oracles/pd_test.go index c9fc24449..b1fa6d3e3 100644 --- a/oracle/oracles/pd_test.go +++ b/oracle/oracles/pd_test.go @@ -32,7 +32,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package oracles_test +package oracles import ( "context" @@ -44,25 +44,24 @@ import ( "github.com/stretchr/testify/assert" "github.com/tikv/client-go/v2/oracle" - "github.com/tikv/client-go/v2/oracle/oracles" pd "github.com/tikv/pd/client" ) func TestPDOracle_UntilExpired(t *testing.T) { lockAfter, lockExp := 10, 15 - o := oracles.NewEmptyPDOracle() + o := NewEmptyPDOracle() start := time.Now() - oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(start)) + SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(start)) lockTs := oracle.GoTimeToTS(start.Add(time.Duration(lockAfter)*time.Millisecond)) + 1 waitTs := o.UntilExpired(lockTs, uint64(lockExp), &oracle.Option{TxnScope: oracle.GlobalTxnScope}) assert.Equal(t, int64(lockAfter+lockExp), waitTs) } func TestPdOracle_GetStaleTimestamp(t *testing.T) { - o := oracles.NewEmptyPDOracle() + o := NewEmptyPDOracle() start := time.Now() - oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(start)) + SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(start)) ts, err := o.GetStaleTimestamp(context.Background(), oracle.GlobalTxnScope, 10) assert.Nil(t, err) assert.WithinDuration(t, start.Add(-10*time.Second), oracle.GetTimeFromTS(ts), 2*time.Second) @@ -90,7 +89,7 @@ func (c *MockPdClient) GetTS(ctx context.Context) (int64, int64, error) { func TestPdOracle_SetLowResolutionTimestampUpdateInterval(t *testing.T) { pdClient := MockPdClient{} - o := oracles.NewPdOracleWithClient(&pdClient) + o := NewPdOracleWithClient(&pdClient) ctx := context.TODO() wg := sync.WaitGroup{} @@ -131,7 +130,7 @@ func TestPdOracle_SetLowResolutionTimestampUpdateInterval(t *testing.T) { assert.LessOrEqual(t, elapsed, 3*updateInterval) } - oracles.StartTsUpdateLoop(o, ctx, &wg) + StartTsUpdateLoop(o, ctx, &wg) // Check each update interval. Note that since these are in increasing // order the time for the new interval to take effect is always less // than the new interval. If we iterated in opposite order, then we'd have @@ -150,8 +149,8 @@ func TestPdOracle_SetLowResolutionTimestampUpdateInterval(t *testing.T) { } func TestNonFutureStaleTSO(t *testing.T) { - o := oracles.NewEmptyPDOracle() - oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(time.Now())) + o := NewEmptyPDOracle() + SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(time.Now())) for i := 0; i < 100; i++ { time.Sleep(10 * time.Millisecond) now := time.Now() @@ -160,7 +159,7 @@ func TestNonFutureStaleTSO(t *testing.T) { closeCh := make(chan struct{}) go func() { time.Sleep(100 * time.Microsecond) - oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(now)) + SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(now)) close(closeCh) }() CHECK: @@ -180,3 +179,320 @@ func TestNonFutureStaleTSO(t *testing.T) { } } } + +func TestAdaptiveUpdateTSInterval(t *testing.T) { + oracleInterface, err := NewPdOracle(&MockPdClient{}, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + NoUpdateTS: true, + }) + assert.NoError(t, err) + o := oracleInterface.(*pdOracle) + defer o.Close() + + now := time.Now() + + mockTS := func(beforeNow time.Duration) uint64 { + return oracle.ComposeTS(oracle.GetPhysical(now.Add(-beforeNow)), 1) + } + mustNotifyShrinking := func(expectedRequiredStaleness time.Duration) { + // Normally this channel should be checked in pdOracle.updateTS method. Here we are testing the layer below the + // updateTS method, so we just do this assert to ensure the message is sent to this channel. + select { + case requiredStaleness := <-o.adaptiveUpdateIntervalState.shrinkIntervalCh: + assert.Equal(t, expectedRequiredStaleness, requiredStaleness) + default: + assert.Fail(t, "expects notifying shrinking update interval immediately, but no message received") + } + } + mustNoNotify := func() { + select { + case <-o.adaptiveUpdateIntervalState.shrinkIntervalCh: + assert.Fail(t, "expects not notifying shrinking update interval immediately, but message was received") + default: + } + } + + now = now.Add(time.Second * 2) + assert.Equal(t, time.Second*2, o.nextUpdateInterval(now, 0)) + now = now.Add(time.Second * 2) + assert.Equal(t, time.Second*2, o.nextUpdateInterval(now, 0)) + assert.Equal(t, adaptiveUpdateTSIntervalStateNormal, o.adaptiveUpdateIntervalState.state) + + now = now.Add(time.Second) + // Simulate a read requesting a staleness larger than 2s, in which case nothing special will happen. + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(mockTS(time.Second*3), mockTS(0), now) + mustNoNotify() + assert.Equal(t, time.Second*2, o.nextUpdateInterval(now, 0)) + + now = now.Add(time.Second) + // Simulate a read requesting a staleness less than 2s, in which case it should trigger immediate shrinking on the + // update interval. + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(mockTS(time.Second), mockTS(0), now) + mustNotifyShrinking(time.Second) + expectedInterval := time.Second - adaptiveUpdateTSIntervalShrinkingPreserve + assert.Equal(t, expectedInterval, o.nextUpdateInterval(now, time.Second)) + assert.Equal(t, adaptiveUpdateTSIntervalStateAdapting, o.adaptiveUpdateIntervalState.state) + assert.Equal(t, now.UnixMilli(), o.adaptiveUpdateIntervalState.lastShortStalenessReadTime.Load()) + + // Let read with short staleness continue happening. + now = now.Add(adaptiveUpdateTSIntervalDelayBeforeRecovering / 2) + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(mockTS(time.Second), mockTS(0), now) + mustNoNotify() + assert.Equal(t, now.UnixMilli(), o.adaptiveUpdateIntervalState.lastShortStalenessReadTime.Load()) + + // The adaptiveUpdateTSIntervalDelayBeforeRecovering has not been elapsed since the last time there is a read with short + // staleness. The update interval won't start being reset at this time. + now = now.Add(adaptiveUpdateTSIntervalDelayBeforeRecovering/2 + time.Second) + o.adaptiveUpdateIntervalState.lastTick = now.Add(-time.Second) + assert.Equal(t, expectedInterval, o.nextUpdateInterval(now, 0)) + assert.Equal(t, adaptiveUpdateTSIntervalStateAdapting, o.adaptiveUpdateIntervalState.state) + + // The adaptiveUpdateTSIntervalDelayBeforeRecovering has been elapsed. + now = now.Add(adaptiveUpdateTSIntervalDelayBeforeRecovering / 2) + o.adaptiveUpdateIntervalState.lastTick = now.Add(-time.Second) + expectedInterval += adaptiveUpdateTSIntervalRecoverPerSecond + assert.InEpsilon(t, expectedInterval.Seconds(), o.nextUpdateInterval(now, 0).Seconds(), 1e-3) + assert.Equal(t, adaptiveUpdateTSIntervalStateRecovering, o.adaptiveUpdateIntervalState.state) + o.adaptiveUpdateIntervalState.lastTick = now + now = now.Add(time.Second * 2) + // No effect if the required staleness didn't trigger the threshold. + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(mockTS(expectedInterval+adaptiveUpdateTSIntervalBlockRecoverThreshold*2), mockTS(0), now) + mustNoNotify() + expectedInterval += adaptiveUpdateTSIntervalRecoverPerSecond * 2 + assert.InEpsilon(t, expectedInterval.Seconds(), o.nextUpdateInterval(now, 0).Seconds(), 1e-3) + assert.Equal(t, adaptiveUpdateTSIntervalStateRecovering, o.adaptiveUpdateIntervalState.state) + + // If there's a read operation requires a staleness that is close enough to the current adaptive update interval, + // then block the update interval from recovering. + o.adaptiveUpdateIntervalState.lastTick = now + now = now.Add(time.Second) + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(mockTS(expectedInterval+adaptiveUpdateTSIntervalBlockRecoverThreshold/2), mockTS(0), now) + mustNoNotify() + assert.InEpsilon(t, expectedInterval.Seconds(), o.nextUpdateInterval(now, 0).Seconds(), 1e-3) + assert.Equal(t, adaptiveUpdateTSIntervalStateAdapting, o.adaptiveUpdateIntervalState.state) + o.adaptiveUpdateIntervalState.lastTick = now + now = now.Add(time.Second) + assert.InEpsilon(t, expectedInterval.Seconds(), o.nextUpdateInterval(now, 0).Seconds(), 1e-3) + assert.Equal(t, adaptiveUpdateTSIntervalStateAdapting, o.adaptiveUpdateIntervalState.state) + + // Now adaptiveUpdateTSIntervalDelayBeforeRecovering + 1s has been elapsed. Continue recovering. + now = now.Add(adaptiveUpdateTSIntervalDelayBeforeRecovering) + o.adaptiveUpdateIntervalState.lastTick = now.Add(-time.Second) + expectedInterval += adaptiveUpdateTSIntervalRecoverPerSecond + assert.InEpsilon(t, expectedInterval.Seconds(), o.nextUpdateInterval(now, 0).Seconds(), 1e-3) + assert.Equal(t, adaptiveUpdateTSIntervalStateRecovering, o.adaptiveUpdateIntervalState.state) + + // Without any other interruption, the update interval will gradually recover to the same value as configured. + for { + o.adaptiveUpdateIntervalState.lastTick = now + now = now.Add(time.Second) + expectedInterval += adaptiveUpdateTSIntervalRecoverPerSecond + if expectedInterval >= time.Second*2 { + break + } + assert.InEpsilon(t, expectedInterval.Seconds(), o.nextUpdateInterval(now, 0).Seconds(), 1e-3) + assert.Equal(t, adaptiveUpdateTSIntervalStateRecovering, o.adaptiveUpdateIntervalState.state) + } + expectedInterval = time.Second * 2 + assert.Equal(t, expectedInterval, o.nextUpdateInterval(now, 0)) + assert.Equal(t, adaptiveUpdateTSIntervalStateNormal, o.adaptiveUpdateIntervalState.state) + + // Test adjusting configurations manually. + // When the adaptive update interval is not taking effect, the actual used update interval follows the change of + // the configuration immediately. + err = o.SetLowResolutionTimestampUpdateInterval(time.Second * 1) + assert.NoError(t, err) + assert.Equal(t, time.Second, time.Duration(o.adaptiveLastTSUpdateInterval.Load())) + assert.Equal(t, time.Second, o.nextUpdateInterval(now, 0)) + + err = o.SetLowResolutionTimestampUpdateInterval(time.Second * 2) + assert.NoError(t, err) + assert.Equal(t, time.Second*2, time.Duration(o.adaptiveLastTSUpdateInterval.Load())) + assert.Equal(t, time.Second*2, o.nextUpdateInterval(now, 0)) + + // If the adaptive update interval is taking effect, the configuration change doesn't immediately affect the actual + // update interval. + now = now.Add(time.Second) + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(mockTS(time.Second), mockTS(0), now) + mustNotifyShrinking(time.Second) + expectedInterval = time.Second - adaptiveUpdateTSIntervalShrinkingPreserve + assert.Equal(t, expectedInterval, o.nextUpdateInterval(now, time.Second)) + assert.Equal(t, adaptiveUpdateTSIntervalStateAdapting, o.adaptiveUpdateIntervalState.state) + err = o.SetLowResolutionTimestampUpdateInterval(time.Second * 3) + assert.NoError(t, err) + assert.Equal(t, expectedInterval, time.Duration(o.adaptiveLastTSUpdateInterval.Load())) + assert.Equal(t, expectedInterval, o.nextUpdateInterval(now, 0)) + err = o.SetLowResolutionTimestampUpdateInterval(time.Second) + assert.NoError(t, err) + assert.Equal(t, expectedInterval, time.Duration(o.adaptiveLastTSUpdateInterval.Load())) + assert.Equal(t, expectedInterval, o.nextUpdateInterval(now, 0)) + + // ...unless it's set to a value shorter than the current actual update interval. + err = o.SetLowResolutionTimestampUpdateInterval(time.Millisecond * 800) + assert.NoError(t, err) + assert.Equal(t, time.Millisecond*800, time.Duration(o.adaptiveLastTSUpdateInterval.Load())) + assert.Equal(t, time.Millisecond*800, o.nextUpdateInterval(now, 0)) + assert.Equal(t, adaptiveUpdateTSIntervalStateNormal, o.adaptiveUpdateIntervalState.state) + + // If the configured value is too short, the actual update interval won't be adaptive + err = o.SetLowResolutionTimestampUpdateInterval(minAllowedAdaptiveUpdateTSInterval / 2) + assert.NoError(t, err) + assert.Equal(t, minAllowedAdaptiveUpdateTSInterval/2, time.Duration(o.adaptiveLastTSUpdateInterval.Load())) + assert.Equal(t, minAllowedAdaptiveUpdateTSInterval/2, o.nextUpdateInterval(now, 0)) + assert.Equal(t, adaptiveUpdateTSIntervalStateUnadjustable, o.adaptiveUpdateIntervalState.state) +} + +func TestValidateSnapshotReadTS(t *testing.T) { + pdClient := MockPdClient{} + o, err := NewPdOracle(&pdClient, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + }) + assert.NoError(t, err) + defer o.Close() + + ctx := context.Background() + opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope} + ts, err := o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + assert.GreaterOrEqual(t, ts, uint64(1)) + + err = o.ValidateSnapshotReadTS(ctx, 1, opt) + assert.NoError(t, err) + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + // The readTS exceeds the latest ts, so it first fails the check with the low resolution ts. Then it fallbacks to + // the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass. + err = o.ValidateSnapshotReadTS(ctx, ts+1, opt) + assert.NoError(t, err) + // It can't pass if the readTS is newer than previous ts + 2. + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + err = o.ValidateSnapshotReadTS(ctx, ts+2, opt) + assert.Error(t, err) + + // Simulate other PD clients requests a timestamp. + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + pdClient.logicalTimestamp.Add(2) + err = o.ValidateSnapshotReadTS(ctx, ts+3, opt) + assert.NoError(t, err) +} + +type MockPDClientWithPause struct { + MockPdClient + mu sync.Mutex +} + +func (c *MockPDClientWithPause) GetTS(ctx context.Context) (int64, int64, error) { + c.mu.Lock() + defer c.mu.Unlock() + return c.MockPdClient.GetTS(ctx) +} + +func (c *MockPDClientWithPause) Pause() { + c.mu.Lock() +} + +func (c *MockPDClientWithPause) Resume() { + c.mu.Unlock() +} + +func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { + pdClient := &MockPDClientWithPause{} + o, err := NewPdOracle(pdClient, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + NoUpdateTS: true, + }) + assert.NoError(t, err) + defer o.Close() + + asyncValidate := func(ctx context.Context, readTS uint64) chan error { + ch := make(chan error, 1) + go func() { + err := o.ValidateSnapshotReadTS(ctx, readTS, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + ch <- err + }() + return ch + } + + noResult := func(ch chan error) { + select { + case <-ch: + assert.FailNow(t, "a ValidateSnapshotReadTS operation is not blocked while it's expected to be blocked") + default: + } + } + + cancelIndices := []int{-1, -1, 0, 1} + for i, ts := range []uint64{100, 200, 300, 400} { + // Note: the ts is the result that the next GetTS will return. Any validation with readTS <= ts should pass, otherwise fail. + + // We will cancel the cancelIndex-th validation call. This is for testing that canceling some of the calls + // doesn't affect other calls that are waiting + cancelIndex := cancelIndices[i] + + pdClient.Pause() + + results := make([]chan error, 0, 5) + + ctx, cancel := context.WithCancel(context.Background()) + + getCtx := func(index int) context.Context { + if cancelIndex == index { + return ctx + } else { + return context.Background() + } + } + + results = append(results, asyncValidate(getCtx(0), ts-2)) + results = append(results, asyncValidate(getCtx(1), ts+2)) + results = append(results, asyncValidate(getCtx(2), ts-1)) + results = append(results, asyncValidate(getCtx(3), ts+1)) + results = append(results, asyncValidate(getCtx(4), ts)) + + expectedSucceeds := []bool{true, false, true, false, true} + + time.Sleep(time.Millisecond * 50) + for _, ch := range results { + noResult(ch) + } + + cancel() + + for i, ch := range results { + if i == cancelIndex { + select { + case err := <-ch: + assert.Errorf(t, err, "index: %v", i) + assert.Containsf(t, err.Error(), "context canceled", "index: %v", i) + case <-time.After(time.Second): + assert.FailNowf(t, "expected result to be ready but still blocked", "index: %v", i) + } + } else { + noResult(ch) + } + } + + // ts will be the next ts returned to these validation calls. + pdClient.logicalTimestamp.Store(int64(ts - 1)) + pdClient.Resume() + for i, ch := range results { + if i == cancelIndex { + continue + } + + select { + case err = <-ch: + case <-time.After(time.Second): + assert.FailNowf(t, "expected result to be ready but still blocked", "index: %v", i) + } + if expectedSucceeds[i] { + assert.NoErrorf(t, err, "index: %v", i) + } else { + assert.Errorf(t, err, "index: %v", i) + assert.NotContainsf(t, err.Error(), "context canceled", "index: %v", i) + } + } + } +} diff --git a/tikv/kv.go b/tikv/kv.go index 7c45137b2..db375ae57 100644 --- a/tikv/kv.go +++ b/tikv/kv.go @@ -259,7 +259,9 @@ func requestHealthFeedbackFromKVClient(ctx context.Context, addr string, tikvCli // NewKVStore creates a new TiKV store instance. func NewKVStore(uuid string, pdClient pd.Client, spkv SafePointKV, tikvclient Client, opt ...Option) (*KVStore, error) { - o, err := oracles.NewPdOracle(pdClient, defaultOracleUpdateInterval) + o, err := oracles.NewPdOracle(pdClient, &oracles.PDOracleOptions{ + UpdateInterval: defaultOracleUpdateInterval, + }) if err != nil { return nil, err }