diff --git a/internal/errgroup/errgroup.go b/internal/errgroup/errgroup.go index 984bdf0..078c970 100644 --- a/internal/errgroup/errgroup.go +++ b/internal/errgroup/errgroup.go @@ -16,7 +16,7 @@ import ( // The Group is superior errgroup.Group which aborts whole group // execution when parent context is cancelled type Group struct { - grp errgroup.Group + grp *errgroup.Group ctx context.Context } @@ -24,12 +24,15 @@ type Group struct { // so the Go method would respect parent context cancellation func WithContext(ctx context.Context) (*Group, context.Context) { grp, child_ctx := errgroup.WithContext(ctx) - return &Group{grp: *grp, ctx: ctx}, child_ctx + return &Group{grp: grp, ctx: ctx}, child_ctx } // Go runs the provided f function in a dedicated goroutine and waits for its // completion or for the parent context cancellation. func (g *Group) Go(f func() error) { + if g.grp == nil { + g.grp = &errgroup.Group{} + } g.grp.Go(g.wrap(f)) } @@ -38,6 +41,9 @@ func (g *Group) Go(f func() error) { // If the error group was created via WithContext then the Wait returns error // of cancelled parent context prior any functions calls complete. func (g *Group) Wait() error { + if g.grp == nil { + g.grp = &errgroup.Group{} + } return g.grp.Wait() } @@ -49,6 +55,9 @@ func (g *Group) Wait() error { // // The limit must not be modified while any goroutines in the group are active. func (g *Group) SetLimit(n int) { + if g.grp == nil { + g.grp = &errgroup.Group{} + } g.grp.SetLimit(n) } @@ -57,6 +66,9 @@ func (g *Group) SetLimit(n int) { // // The return value reports whether the goroutine was started. func (g *Group) TryGo(f func() error) bool { + if g.grp == nil { + g.grp = &errgroup.Group{} + } return g.grp.TryGo(g.wrap(f)) } diff --git a/rep_test.go b/rep_test.go index 9d5dba5..e449f68 100644 --- a/rep_test.go +++ b/rep_test.go @@ -134,6 +134,7 @@ func TestCancellation(t *testing.T) { defer wg.Done() repCtx, cancel := context.WithCancel(context.Background()) + defer cancel() rep := zmq4.NewRep(repCtx) defer rep.Close()