Skip to content

Commit

Permalink
Genericize the batch controller
Browse files Browse the repository at this point in the history
Signed-off-by: JmPotato <[email protected]>
  • Loading branch information
JmPotato committed Nov 8, 2024
1 parent c9e532c commit ca09490
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 101 deletions.
139 changes: 70 additions & 69 deletions client/tso_batch_controller.go → client/batch_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,43 @@ package pd

import (
"context"
"runtime/trace"
"time"

"github.com/tikv/pd/client/tsoutil"
"github.com/prometheus/client_golang/prometheus"
)

type tsoBatchController struct {
type batchController[T any] struct {
maxBatchSize int
// bestBatchSize is a dynamic size that changed based on the current batch effect.
bestBatchSize int

collectedRequests []*tsoRequest
collectedRequests []T
collectedRequestCount int

// The finisher function to cancel collected requests when an internal error occurs.
cancelFinisher func(int, T)
// The observer to record the best batch size.
bestBatchObserver prometheus.Histogram
// The time after getting the first request and the token, and before performing extra batching.
extraBatchingStartTime time.Time
}

func newTSOBatchController(maxBatchSize int) *tsoBatchController {
return &tsoBatchController{
func newBatchController[T any](maxBatchSize int, cancelFinisher func(int, T), bestBatchObserver prometheus.Histogram) *batchController[T] {
return &batchController[T]{
maxBatchSize: maxBatchSize,
bestBatchSize: 8, /* Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4) */
collectedRequests: make([]*tsoRequest, maxBatchSize+1),
collectedRequests: make([]T, maxBatchSize+1),
collectedRequestCount: 0,
cancelFinisher: cancelFinisher,
bestBatchObserver: bestBatchObserver,
}
}

// fetchPendingRequests will start a new round of the batch collecting from the channel.
// It returns nil error if everything goes well, otherwise a non-nil error which means we should stop the service.
// It's guaranteed that if this function failed after collecting some requests, then these requests will be cancelled
// when the function returns, so the caller don't need to clear them manually.
func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequestCh <-chan *tsoRequest, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) {
func (bc *batchController[T]) fetchPendingRequests(ctx context.Context, requestCh <-chan T, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) {
var tokenAcquired bool
defer func() {
if errRet != nil {
Expand All @@ -56,17 +61,17 @@ func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequ
if tokenAcquired {
tokenCh <- struct{}{}
}
tbc.finishCollectedRequests(0, 0, 0, invalidStreamID, errRet)
bc.finishCollectedRequests(bc.cancelFinisher)
}
}()

// Wait until BOTH the first request and the token have arrived.
// TODO: `tbc.collectedRequestCount` should never be non-empty here. Consider do assertion here.
tbc.collectedRequestCount = 0
// TODO: `bc.collectedRequestCount` should never be non-empty here. Consider do assertion here.
bc.collectedRequestCount = 0
for {
// If the batch size reaches the maxBatchSize limit but the token haven't arrived yet, don't receive more
// requests, and return when token is ready.
if tbc.collectedRequestCount >= tbc.maxBatchSize && !tokenAcquired {
if bc.collectedRequestCount >= bc.maxBatchSize && !tokenAcquired {
select {
case <-ctx.Done():
return ctx.Err()
Expand All @@ -78,9 +83,9 @@ func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequ
select {
case <-ctx.Done():
return ctx.Err()
case req := <-tsoRequestCh:
// Start to batch when the first TSO request arrives.
tbc.pushRequest(req)
case req := <-requestCh:
// Start to batch when the first request arrives.
bc.pushRequest(req)
// A request arrives but the token is not ready yet. Continue waiting, and also allowing collecting the next
// request if it arrives.
continue
Expand All @@ -89,50 +94,49 @@ func (tbc *tsoBatchController) fetchPendingRequests(ctx context.Context, tsoRequ
}

// The token is ready. If the first request didn't arrive, wait for it.
if tbc.collectedRequestCount == 0 {
if bc.collectedRequestCount == 0 {
select {
case <-ctx.Done():
return ctx.Err()
case firstRequest := <-tsoRequestCh:
tbc.pushRequest(firstRequest)
case firstRequest := <-requestCh:
bc.pushRequest(firstRequest)
}
}

// Both token and the first request have arrived.
break
}

tbc.extraBatchingStartTime = time.Now()
bc.extraBatchingStartTime = time.Now()

// This loop is for trying best to collect more requests, so we use `tbc.maxBatchSize` here.
// This loop is for trying best to collect more requests, so we use `bc.maxBatchSize` here.
fetchPendingRequestsLoop:
for tbc.collectedRequestCount < tbc.maxBatchSize {
for bc.collectedRequestCount < bc.maxBatchSize {
select {
case tsoReq := <-tsoRequestCh:
tbc.pushRequest(tsoReq)
case req := <-requestCh:
bc.pushRequest(req)
case <-ctx.Done():
return ctx.Err()
default:
break fetchPendingRequestsLoop
}
}

// Check whether we should fetch more pending TSO requests from the channel.
// TODO: maybe consider the actual load that returns through a TSO response from PD server.
if tbc.collectedRequestCount >= tbc.maxBatchSize || maxBatchWaitInterval <= 0 {
// Check whether we should fetch more pending requests from the channel.
if bc.collectedRequestCount >= bc.maxBatchSize || maxBatchWaitInterval <= 0 {
return nil
}

// Fetches more pending TSO requests from the channel.
// Try to collect `tbc.bestBatchSize` requests, or wait `maxBatchWaitInterval`
// when `tbc.collectedRequestCount` is less than the `tbc.bestBatchSize`.
if tbc.collectedRequestCount < tbc.bestBatchSize {
// Fetches more pending requests from the channel.
// Try to collect `bc.bestBatchSize` requests, or wait `maxBatchWaitInterval`
// when `bc.collectedRequestCount` is less than the `bc.bestBatchSize`.
if bc.collectedRequestCount < bc.bestBatchSize {

Check warning on line 133 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L133

Added line #L133 was not covered by tests
after := time.NewTimer(maxBatchWaitInterval)
defer after.Stop()
for tbc.collectedRequestCount < tbc.bestBatchSize {
for bc.collectedRequestCount < bc.bestBatchSize {

Check warning on line 136 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L136

Added line #L136 was not covered by tests
select {
case tsoReq := <-tsoRequestCh:
tbc.pushRequest(tsoReq)
case req := <-requestCh:
bc.pushRequest(req)

Check warning on line 139 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L138-L139

Added lines #L138 - L139 were not covered by tests
case <-ctx.Done():
return ctx.Err()
case <-after.C:
Expand All @@ -141,13 +145,13 @@ fetchPendingRequestsLoop:
}
}

// Do an additional non-block try. Here we test the length with `tbc.maxBatchSize` instead
// of `tbc.bestBatchSize` because trying best to fetch more requests is necessary so that
// we can adjust the `tbc.bestBatchSize` dynamically later.
for tbc.collectedRequestCount < tbc.maxBatchSize {
// Do an additional non-block try. Here we test the length with `bc.maxBatchSize` instead
// of `bc.bestBatchSize` because trying best to fetch more requests is necessary so that
// we can adjust the `bc.bestBatchSize` dynamically later.
for bc.collectedRequestCount < bc.maxBatchSize {

Check warning on line 151 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L151

Added line #L151 was not covered by tests
select {
case tsoReq := <-tsoRequestCh:
tbc.pushRequest(tsoReq)
case req := <-requestCh:
bc.pushRequest(req)

Check warning on line 154 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L153-L154

Added lines #L153 - L154 were not covered by tests
case <-ctx.Done():
return ctx.Err()
default:
Expand All @@ -159,27 +163,27 @@ fetchPendingRequestsLoop:

// fetchRequestsWithTimer tries to fetch requests until the given timer ticks. The caller must set the timer properly
// before calling this function.
func (tbc *tsoBatchController) fetchRequestsWithTimer(ctx context.Context, tsoRequestCh <-chan *tsoRequest, timer *time.Timer) error {
func (bc *batchController[T]) fetchRequestsWithTimer(ctx context.Context, requestCh <-chan T, timer *time.Timer) error {

Check warning on line 166 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L166

Added line #L166 was not covered by tests
batchingLoop:
for tbc.collectedRequestCount < tbc.maxBatchSize {
for bc.collectedRequestCount < bc.maxBatchSize {

Check warning on line 168 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L168

Added line #L168 was not covered by tests
select {
case <-ctx.Done():
return ctx.Err()
case req := <-tsoRequestCh:
tbc.pushRequest(req)
case req := <-requestCh:
bc.pushRequest(req)

Check warning on line 173 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L172-L173

Added lines #L172 - L173 were not covered by tests
case <-timer.C:
break batchingLoop
}
}

// Try to collect more requests in non-blocking way.
nonWaitingBatchLoop:
for tbc.collectedRequestCount < tbc.maxBatchSize {
for bc.collectedRequestCount < bc.maxBatchSize {

Check warning on line 181 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L181

Added line #L181 was not covered by tests
select {
case <-ctx.Done():
return ctx.Err()
case req := <-tsoRequestCh:
tbc.pushRequest(req)
case req := <-requestCh:
bc.pushRequest(req)

Check warning on line 186 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L185-L186

Added lines #L185 - L186 were not covered by tests
default:
break nonWaitingBatchLoop
}
Expand All @@ -188,38 +192,35 @@ nonWaitingBatchLoop:
return nil
}

func (tbc *tsoBatchController) pushRequest(tsoReq *tsoRequest) {
tbc.collectedRequests[tbc.collectedRequestCount] = tsoReq
tbc.collectedRequestCount++
func (bc *batchController[T]) pushRequest(req T) {
bc.collectedRequests[bc.collectedRequestCount] = req
bc.collectedRequestCount++
}

func (tbc *tsoBatchController) getCollectedRequests() []*tsoRequest {
return tbc.collectedRequests[:tbc.collectedRequestCount]
func (bc *batchController[T]) getCollectedRequests() []T {
return bc.collectedRequests[:bc.collectedRequestCount]
}

// adjustBestBatchSize stabilizes the latency with the AIAD algorithm.
func (tbc *tsoBatchController) adjustBestBatchSize() {
tsoBestBatchSize.Observe(float64(tbc.bestBatchSize))
length := tbc.collectedRequestCount
if length < tbc.bestBatchSize && tbc.bestBatchSize > 1 {
func (bc *batchController[T]) adjustBestBatchSize() {
bc.bestBatchObserver.Observe(float64(bc.bestBatchSize))
length := bc.collectedRequestCount
if length < bc.bestBatchSize && bc.bestBatchSize > 1 {
// Waits too long to collect requests, reduce the target batch size.
tbc.bestBatchSize--
} else if length > tbc.bestBatchSize+4 /* Hard-coded number, in order to make `tbc.bestBatchSize` stable */ &&
tbc.bestBatchSize < tbc.maxBatchSize {
tbc.bestBatchSize++
bc.bestBatchSize--
} else if length > bc.bestBatchSize+4 /* Hard-coded number, in order to make `bc.bestBatchSize` stable */ &&
bc.bestBatchSize < bc.maxBatchSize {
bc.bestBatchSize++
}
}

func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical int64, suffixBits uint32, streamID string, err error) {
for i := range tbc.collectedRequestCount {
tsoReq := tbc.collectedRequests[i]
// Retrieve the request context before the request is done to trace without race.
requestCtx := tsoReq.requestCtx
tsoReq.physical, tsoReq.logical = physical, tsoutil.AddLogical(firstLogical, int64(i), suffixBits)
tsoReq.streamID = streamID
tsoReq.tryDone(err)
trace.StartRegion(requestCtx, "pdclient.tsoReqDequeue").End()
func (bc *batchController[T]) finishCollectedRequests(finisher func(int, T)) {
if finisher == nil {
finisher = bc.cancelFinisher
}

Check warning on line 220 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L219-L220

Added lines #L219 - L220 were not covered by tests
for i := range bc.collectedRequestCount {
finisher(i, bc.collectedRequests[i])
}
// Prevent the finished requests from being processed again.
tbc.collectedRequestCount = 0
bc.collectedRequestCount = 0
}
Loading

0 comments on commit ca09490

Please sign in to comment.