From 5d37218ee149409d4f9147be06399249d05feb59 Mon Sep 17 00:00:00 2001 From: polebug Date: Thu, 16 Nov 2023 21:05:38 +0800 Subject: [PATCH] feat(common/sync): reimplement QuickGroup for IPFS client Co-authored-by: KallyDev --- Makefile | 2 +- common/sync/quickgroup.go | 80 ++++++++++++++++++++++++++++++++++ common/sync/quickgroup_test.go | 41 +++++++++++++++++ 3 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 common/sync/quickgroup.go create mode 100644 common/sync/quickgroup_test.go diff --git a/Makefile b/Makefile index 18e4f0f5d..502e2e3c3 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ lint: generate go run github.com/golangci/golangci-lint/cmd/golangci-lint@v1.55.2 run test: - go test -cover -v ./... + go test -cover -race -v ./... .PHONY: build build: generate diff --git a/common/sync/quickgroup.go b/common/sync/quickgroup.go new file mode 100644 index 000000000..e01dac8a8 --- /dev/null +++ b/common/sync/quickgroup.go @@ -0,0 +1,80 @@ +package sync + +import ( + "context" + "errors" + "sync" + "sync/atomic" +) + +var ErrorNoResult = errors.New("no result") + +type QuickGroup[R any] interface { + Go(f func(ctx context.Context) (R, error)) + Wait() (R, error) +} + +var _ QuickGroup[any] = (*quickGroup[any])(nil) + +type quickGroup[R any] struct { + waitGroup sync.WaitGroup + ctx context.Context + cancels []context.CancelFunc + result R + err error + done atomic.Bool + locker sync.Mutex +} + +func (q *quickGroup[R]) Go(f func(ctx context.Context) (R, error)) { + if q.done.Load() { + return + } + + // Ensure that the cancel slice is thread-safe. + q.locker.Lock() + + q.waitGroup.Add(1) + + // Create a context to later cancel slow tasks. + ctx, cancel := context.WithCancel(q.ctx) + q.cancels = append(q.cancels, cancel) + index := len(q.cancels) - 1 + + q.locker.Unlock() + + go func() { + defer q.waitGroup.Done() + + if result, err := f(ctx); err == nil { + if !q.done.Swap(true) { // Here is equivalent to sync.Once. + q.result = result + q.err = nil + + // Cancel all other pending tasks. + for i, cancel := range q.cancels { + if i == index { + continue + } + + cancel() + } + } + } + }() +} + +func (q *quickGroup[R]) Wait() (result R, err error) { + q.waitGroup.Wait() + + return q.result, q.err +} + +func NewQuickGroup[R any](ctx context.Context) QuickGroup[R] { + instance := &quickGroup[R]{ + ctx: ctx, + err: ErrorNoResult, // Default error. + } + + return instance +} diff --git a/common/sync/quickgroup_test.go b/common/sync/quickgroup_test.go new file mode 100644 index 000000000..08d241762 --- /dev/null +++ b/common/sync/quickgroup_test.go @@ -0,0 +1,41 @@ +package sync + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestQuickGroup(t *testing.T) { + t.Parallel() + + quickGroup := NewQuickGroup[time.Duration](context.Background()) + + for duration := time.Second; duration > 0; duration -= 100 * time.Millisecond { + duration := duration + + task := func(ctx context.Context) (time.Duration, error) { + timer := time.NewTimer(duration) + defer timer.Stop() + + select { + case <-ctx.Done(): + t.Logf("Task %s: %s", duration, context.Canceled) + + return duration, context.Canceled + case <-timer.C: + t.Logf("Task %s: %s", duration, "done") + + return duration, nil + } + } + + quickGroup.Go(task) + } + + result, err := quickGroup.Wait() + require.NoError(t, err) + require.Equal(t, result, 100*time.Millisecond) +}