diff --git a/federatedscope/core/trainers/trainer.py b/federatedscope/core/trainers/trainer.py index 689f64abe..41be4ad71 100644 --- a/federatedscope/core/trainers/trainer.py +++ b/federatedscope/core/trainers/trainer.py @@ -279,9 +279,10 @@ def _run_routine(self, mode, hooks_set, dataset_name=None): return self.ctx.num_samples @lifecycle(LIFECYCLE.EPOCH) - def _run_epoch(self, hooks_set): - for epoch_i in range( - getattr(self.ctx, f"num_{self.ctx.cur_split}_epoch")): + def _run_epoch(self, hooks_set, run_step=-1): + if run_step == -1: + run_step = getattr(self.ctx, f"num_{self.ctx.cur_split}_epoch") + for epoch_i in range(run_step): self.ctx.cur_epoch_i = CtxVar(epoch_i, "epoch") for hook in hooks_set["on_epoch_start"]: @@ -293,9 +294,10 @@ def _run_epoch(self, hooks_set): hook(self.ctx) @lifecycle(LIFECYCLE.BATCH) - def _run_batch(self, hooks_set): - for batch_i in range( - getattr(self.ctx, f"num_{self.ctx.cur_split}_batch")): + def _run_batch(self, hooks_set, run_step=-1): + if run_step == -1: + run_step = getattr(self.ctx, f"num_{self.ctx.cur_split}_batch") + for batch_i in range(run_step): self.ctx.cur_batch_i = CtxVar(batch_i, LIFECYCLE.BATCH) for hook in hooks_set["on_batch_start"]: diff --git a/federatedscope/llm/trainer/trainer.py b/federatedscope/llm/trainer/trainer.py index 405cbf0c0..905c69e4d 100644 --- a/federatedscope/llm/trainer/trainer.py +++ b/federatedscope/llm/trainer/trainer.py @@ -3,7 +3,7 @@ from federatedscope.register import register_trainer from federatedscope.core.trainers import GeneralTorchTrainer from federatedscope.core.trainers.context import CtxVar -from federatedscope.core.trainers.enums import LIFECYCLE +from federatedscope.core.trainers.enums import MODE, LIFECYCLE from federatedscope.core.monitors.monitor import Monitor from federatedscope.llm.model.adapter_builder import AdapterModel @@ -54,6 +54,11 @@ def _hook_on_batch_backward(self, ctx): def _hook_on_batch_end(self, ctx): if ctx.skip_this_batch: + # Retry with new data in train and finetune + if ctx.cur_mode == MODE.TRAIN: + self._run_batch(self.hooks_in_train, run_step=1) + elif ctx.cur_mode == MODE.FINETUNE: + self._run_batch(self.hooks_in_ft, run_step=1) return ctx.num_samples += ctx.batch_size