diff --git a/example/complex-http-crawler/main.go b/example/complex-http-crawler/main.go index fd6b01b..3d169a9 100644 --- a/example/complex-http-crawler/main.go +++ b/example/complex-http-crawler/main.go @@ -15,6 +15,8 @@ type Options struct { MaxRetries int `short:"r" long:"max-retries" description:"max retries" default:"3"` MaxRuntimePerTaskSeconds int `short:"t" long:"max-runtime-per-task-seconds" description:"max runtime per task seconds" default:"60"` NumWorkers int `short:"n" long:"num-workers" description:"number of workers" default:"32"` + NumShards int `short:"s" long:"num-shards" description:"number of shards" default:"1"` + Shard int `short:"d" long:"shard" description:"shard" default:"0"` } var opts Options @@ -27,7 +29,14 @@ func init() { } func main() { - scheduler := gojob.NewScheduler(opts.NumWorkers, opts.MaxRuntimePerTaskSeconds, opts.MaxRetries, opts.OutputFilePath) + scheduler := gojob.NewScheduler(). + SetNumWorkers(opts.NumWorkers). + SetMaxRetries(opts.MaxRetries). + SetMaxRuntimePerTaskSeconds(opts.MaxRuntimePerTaskSeconds). + SetNumShards(int64(opts.NumShards)). + SetShard(int64(opts.Shard)). + SetOutputFilePath(opts.OutputFilePath) + for line := range util.Cat(opts.InputFilePath) { scheduler.Submit(model.New(string(line))) } diff --git a/example/simple-http-crawler/main.go b/example/simple-http-crawler/main.go index 4a01ace..a12287e 100644 --- a/example/simple-http-crawler/main.go +++ b/example/simple-http-crawler/main.go @@ -29,7 +29,12 @@ func (t *MyTask) Do() error { } func main() { - scheduler := gojob.NewScheduler(1, 4, 8, "output.txt") + scheduler := gojob.NewScheduler(). + SetNumWorkers(8). + SetMaxRetries(4). + SetMaxRuntimePerTaskSeconds(16). + SetNumShards(4). + SetShard(0) for line := range util.Cat("input.txt") { scheduler.Submit(New(line)) } diff --git a/example/sleeper/main.go b/example/sleeper/main.go index 9df10d5..9cbb8d5 100644 --- a/example/sleeper/main.go +++ b/example/sleeper/main.go @@ -25,7 +25,12 @@ func (t *MyTask) Do() error { } func main() { - scheduler := gojob.NewScheduler(8, 4, 16, "output.txt") + scheduler := gojob.NewScheduler(). + SetNumWorkers(8). + SetMaxRetries(4). + SetMaxRuntimePerTaskSeconds(16). + SetNumShards(4). + SetShard(0) scheduler.Start() for i := 0; i < 256; i++ { scheduler.Submit(New(i, rand.Intn(10))) diff --git a/gojob.go b/gojob.go index 35b934e..3f07585 100644 --- a/gojob.go +++ b/gojob.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "sync" + "sync/atomic" "time" ) @@ -17,6 +18,7 @@ type Task interface { } type BasicTask struct { + Index int64 `json:"index"` StartedAt int64 `json:"started_at"` FinishedAt int64 `json:"finished_at"` NumTries int `json:"num_tries"` @@ -24,9 +26,14 @@ type BasicTask struct { Error string `json:"error"` } -func NewBasicTask(task Task) *BasicTask { +func NewBasicTask(index int64, task Task) *BasicTask { return &BasicTask{ - Task: task, + Index: index, + StartedAt: 0, + FinishedAt: 0, + NumTries: 0, + Task: task, + Error: "", } } @@ -36,32 +43,92 @@ type Scheduler struct { OutputFilePath string MaxRetries int MaxRuntimePerTaskSeconds int - TaskChan chan Task + NumShards int64 + Shard int64 + NumTasks atomic.Int64 + TaskChan chan *BasicTask LogChan chan string taskWg *sync.WaitGroup logWg *sync.WaitGroup } // NewScheduler creates a new scheduler -func NewScheduler(numWorkers int, maxRetries int, maxRuntimePerTaskSeconds int, outputFilePath string) *Scheduler { +func NewScheduler() *Scheduler { scheduler := &Scheduler{ - NumWorkers: numWorkers, - OutputFilePath: outputFilePath, - MaxRetries: maxRetries, - MaxRuntimePerTaskSeconds: maxRuntimePerTaskSeconds, - TaskChan: make(chan Task), + NumWorkers: 1, + OutputFilePath: "-", + MaxRetries: 4, + MaxRuntimePerTaskSeconds: 16, + NumShards: 3, + Shard: 1, + NumTasks: atomic.Int64{}, + TaskChan: make(chan *BasicTask), LogChan: make(chan string), taskWg: &sync.WaitGroup{}, logWg: &sync.WaitGroup{}, } - scheduler.Start() return scheduler } +// SetNumShards sets the number of shards, default is 1 which means no sharding +func (s *Scheduler) SetNumShards(numShards int64) *Scheduler { + if numShards <= 0 { + panic("numShards must be greater than 0") + } + s.NumShards = numShards + return s +} + +// SetShard sets the shard (from 0 to NumShards-1) +func (s *Scheduler) SetShard(shard int64) *Scheduler { + if shard < 0 || shard >= s.NumShards { + panic("shard must be in [0, NumShards)") + } + s.Shard = shard + return s +} + +// SetNumWorkers sets the number of workers +func (s *Scheduler) SetNumWorkers(numWorkers int) *Scheduler { + if numWorkers <= 0 { + panic("numWorkers must be greater than 0") + } + s.NumWorkers = numWorkers + return s +} + +// SetOutputFilePath sets the output file path +func (s *Scheduler) SetOutputFilePath(outputFilePath string) *Scheduler { + s.OutputFilePath = outputFilePath + return s +} + +// SetMaxRetries sets the max retries +func (s *Scheduler) SetMaxRetries(maxRetries int) *Scheduler { + if maxRetries <= 0 { + panic("maxRetries must be greater than 0") + } + s.MaxRetries = maxRetries + return s +} + +// SetMaxRuntimePerTaskSeconds sets the max runtime per task seconds +func (s *Scheduler) SetMaxRuntimePerTaskSeconds(maxRuntimePerTaskSeconds int) *Scheduler { + if maxRuntimePerTaskSeconds <= 0 { + panic("maxRuntimePerTaskSeconds must be greater than 0") + } + s.MaxRuntimePerTaskSeconds = maxRuntimePerTaskSeconds + return s +} + // Submit submits a task to the scheduler func (s *Scheduler) Submit(task Task) { - s.taskWg.Add(1) - s.TaskChan <- task + index := s.NumTasks.Load() + if (index % s.NumShards) == s.Shard { + s.taskWg.Add(1) + s.TaskChan <- NewBasicTask(index, task) + } + s.NumTasks.Add(1) } // Start starts the scheduler @@ -84,25 +151,24 @@ func (s *Scheduler) Wait() { func (s *Scheduler) Worker() { for task := range s.TaskChan { // Start task - bt := NewBasicTask(task) for i := 0; i < s.MaxRetries; i++ { err := func() error { - bt.StartedAt = time.Now().UnixMicro() + task.StartedAt = time.Now().UnixMicro() defer func() { - bt.NumTries++ - bt.FinishedAt = time.Now().UnixMicro() + task.NumTries++ + task.FinishedAt = time.Now().UnixMicro() }() - return RunWithTimeout(task.Do, time.Duration(s.MaxRuntimePerTaskSeconds)*time.Second) + return RunWithTimeout(task.Task.Do, time.Duration(s.MaxRuntimePerTaskSeconds)*time.Second) }() if err != nil { - bt.Error = err.Error() + task.Error = err.Error() } else { - bt.Error = "" + task.Error = "" break } } // Serialize task - data, err := json.Marshal(bt) + data, err := json.Marshal(task) if err != nil { slog.Error("error occured while serializing task", slog.String("error", err.Error())) } else {