diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index 6a3549b0..ab027e1c 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -89,6 +89,7 @@ def check_model_weights(model, ckpt_path, total_equal=False): def main(args): + very_begining_time = time.time() # init setting skip_batches = gpc.config.data.skip_batches total_steps = gpc.config.data.total_steps @@ -305,6 +306,7 @@ def main(args): optimizer=optimizer, beta2_scheduler=beta2_scheduler, trainer=trainer, + very_begining_time=very_begining_time, start_time=start_time, loss=loss, moe_loss=moe_loss,