From 19bcc3ec675f08acf09fc4dbe28fba9a164da30b Mon Sep 17 00:00:00 2001 From: Leonid Emar-Kar Date: Thu, 13 Jun 2024 12:52:37 +0100 Subject: [PATCH 1/2] primary update of the Task interface with context --- pkg/utils/timeout.go | 4 ++-- pkg/utils/timeout_test.go | 3 ++- scheduler_test.go | 3 ++- task.go | 4 +++- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pkg/utils/timeout.go b/pkg/utils/timeout.go index f1d8aa4..16e65d7 100644 --- a/pkg/utils/timeout.go +++ b/pkg/utils/timeout.go @@ -5,7 +5,7 @@ import ( "time" ) -func RunWithTimeout(f func() error, timeout time.Duration) error { +func RunWithTimeout(f func(ctx context.Context) error, timeout time.Duration) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() @@ -14,7 +14,7 @@ func RunWithTimeout(f func() error, timeout time.Duration) error { go func() { defer close(done) - done <- f() + done <- f(ctx) }() select { diff --git a/pkg/utils/timeout_test.go b/pkg/utils/timeout_test.go index d451354..2b130e4 100644 --- a/pkg/utils/timeout_test.go +++ b/pkg/utils/timeout_test.go @@ -1,6 +1,7 @@ package utils_test import ( + "context" "testing" "time" @@ -8,7 +9,7 @@ import ( ) func TestRunWithTimeout(t *testing.T) { - task := func() error { + task := func(_ context.Context) error { time.Sleep(2 * time.Second) return nil } diff --git a/scheduler_test.go b/scheduler_test.go index 3360344..e6bef34 100644 --- a/scheduler_test.go +++ b/scheduler_test.go @@ -1,6 +1,7 @@ package gojob_test import ( + "context" "fmt" "reflect" "sort" @@ -46,7 +47,7 @@ func newTask(i int, writer *safeWriter) *schedulerTestTask { } } -func (t *schedulerTestTask) Do() error { +func (t *schedulerTestTask) Do(_ context.Context) error { t.writer.WriteString(fmt.Sprintf("%d\n", t.I)) return nil } diff --git a/task.go b/task.go index 378d59a..499865b 100644 --- a/task.go +++ b/task.go @@ -1,6 +1,8 @@ package gojob import ( + "context" + "github.com/google/uuid" ) @@ -9,7 +11,7 @@ type Task interface { // Do starts the task, returns error if failed // If an error is returned, the task will be retried until MaxRetries // You can set MaxRetries by calling SetMaxRetries on the scheduler - Do() error + Do(context.Context) error } type basicTask struct { From 2e2540780fd0cc43224b13302a7ab2b077198479 Mon Sep 17 00:00:00 2001 From: Leonid Emar-Kar Date: Thu, 13 Jun 2024 13:02:49 +0100 Subject: [PATCH 2/2] update examples with context --- examples/complex-http-crawler/pkg/model/task.go | 3 ++- examples/metadata/main.go | 3 ++- examples/nopper/main.go | 3 ++- examples/prometheus/main.go | 3 ++- examples/random-error/main.go | 3 ++- examples/result-channel/main.go | 3 ++- examples/simple-http-crawler/main.go | 3 ++- examples/sleeper/main.go | 3 ++- examples/tcp-port-scanner/task.go | 7 +++++-- 9 files changed, 21 insertions(+), 10 deletions(-) diff --git a/examples/complex-http-crawler/pkg/model/task.go b/examples/complex-http-crawler/pkg/model/task.go index 4c3b000..753239b 100644 --- a/examples/complex-http-crawler/pkg/model/task.go +++ b/examples/complex-http-crawler/pkg/model/task.go @@ -1,6 +1,7 @@ package model import ( + "context" "net/http" "time" ) @@ -18,7 +19,7 @@ func New(url string) *MyTask { } } -func (t *MyTask) Do() error { +func (t *MyTask) Do(_ context.Context) error { transport := &http.Transport{ DisableCompression: true, } diff --git a/examples/metadata/main.go b/examples/metadata/main.go index 5af0517..a1bc24b 100644 --- a/examples/metadata/main.go +++ b/examples/metadata/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "math/rand" "time" @@ -15,7 +16,7 @@ func New() *MyTask { return &MyTask{} } -func (t *MyTask) Do() error { +func (t *MyTask) Do(_ context.Context) error { time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) return nil } diff --git a/examples/nopper/main.go b/examples/nopper/main.go index 22f658f..6940ba2 100644 --- a/examples/nopper/main.go +++ b/examples/nopper/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "math/rand" "time" @@ -14,7 +15,7 @@ func New() *MyTask { return &MyTask{} } -func (t *MyTask) Do() error { +func (t *MyTask) Do(_ context.Context) error { time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond) return nil } diff --git a/examples/prometheus/main.go b/examples/prometheus/main.go index 632ff89..0e140a0 100644 --- a/examples/prometheus/main.go +++ b/examples/prometheus/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "net/http" @@ -18,7 +19,7 @@ func New(url string) *MyTask { } } -func (t *MyTask) Do() error { +func (t *MyTask) Do(_ context.Context) error { response, err := http.Get(t.Url) if err != nil { return err diff --git a/examples/random-error/main.go b/examples/random-error/main.go index 9b2ff96..6e87cac 100644 --- a/examples/random-error/main.go +++ b/examples/random-error/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "math/rand" "time" @@ -22,7 +23,7 @@ func New(index int, sleepSeconds int) *MyTask { } } -func (t *MyTask) Do() error { +func (t *MyTask) Do(_ context.Context) error { time.Sleep(time.Duration(t.SleepSeconds) * time.Second) if rand.Float64() < t.ErrorProbability { return errors.New("an error occurred") diff --git a/examples/result-channel/main.go b/examples/result-channel/main.go index a806eac..2412c67 100644 --- a/examples/result-channel/main.go +++ b/examples/result-channel/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "log/slog" @@ -20,7 +21,7 @@ func New(url string) *MyTask { } } -func (t *MyTask) Do() error { +func (t *MyTask) Do(_ context.Context) error { response, err := http.Get(t.Url) if err != nil { return err diff --git a/examples/simple-http-crawler/main.go b/examples/simple-http-crawler/main.go index 097baef..5023c42 100644 --- a/examples/simple-http-crawler/main.go +++ b/examples/simple-http-crawler/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "net/http" @@ -18,7 +19,7 @@ func New(url string) *MyTask { } } -func (t *MyTask) Do() error { +func (t *MyTask) Do(_ context.Context) error { response, err := http.Get(t.Url) if err != nil { return err diff --git a/examples/sleeper/main.go b/examples/sleeper/main.go index 9803b65..925f543 100644 --- a/examples/sleeper/main.go +++ b/examples/sleeper/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "math/rand" "time" @@ -19,7 +20,7 @@ func New(index int, sleepSeconds int) *MyTask { } } -func (t *MyTask) Do() error { +func (t *MyTask) Do(_ context.Context) error { time.Sleep(time.Duration(t.SleepSeconds) * time.Second) return nil } diff --git a/examples/tcp-port-scanner/task.go b/examples/tcp-port-scanner/task.go index ff5d3ce..94d2a7d 100644 --- a/examples/tcp-port-scanner/task.go +++ b/examples/tcp-port-scanner/task.go @@ -1,6 +1,9 @@ package main -import "net" +import ( + "context" + "net" +) type MyTask struct { IP string `json:"ip"` @@ -16,7 +19,7 @@ func New(ip string, port uint16) *MyTask { } } -func (t *MyTask) Do() error { +func (t *MyTask) Do(_ context.Context) error { conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{ IP: net.ParseIP(t.IP), Port: int(t.Port),