Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 25, 2022
1 parent e12e5ed commit 4242ad4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 20 deletions.
12 changes: 7 additions & 5 deletions examples/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]:
def __len__(self) -> int:
return self.n_samples


if __name__ == "__main__":
ml.select_device([0])
ml.seed_everything(2, gpu_dtm=True)
Expand Down Expand Up @@ -103,7 +104,7 @@ def training_epoch_end(self) -> Dict[str, float]:
logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader))
logger.info(trainer.test(ldm.test_dataloader, True, True))

# train from ckpt (model, optimizer state dict, global epoch, global step)
# train from ckpt
time.sleep(1)
ckpt_path = trainer.last_ckpt_path
optimizer = optim.SGD(model.parameters(), 0.1, 0.9)
Expand All @@ -114,19 +115,20 @@ def training_epoch_end(self) -> Dict[str, float]:
logger.info(trainer.test(ldm.val_dataloader, True, True))
logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader))
logger.info(trainer.test(ldm.test_dataloader, True, True))
# train from ckpt different optimizer (only model)
# train from ckpt (only model)
time.sleep(1)
ckpt_path = trainer.last_ckpt_path
model, _, _ = ml.load_ckpt(ckpt_path, Device(0))
optimizer = optim.Adam(model.parameters(), 0.001)
lmodel = MyLModule(None, optimizer, loss_fn, metrics, "loss")
lmodel = MyLModule(model, optimizer, loss_fn, metrics, "loss")
ldm = ml.LDataModule(train_dataset, val_dataset, test_dataset, 64)
trainer = ml.Trainer(lmodel, [0], 20, RUNS_DIR, gradient_clip_norm=10,
val_every_n_epoch=10, verbose=True, resume_from_ckpt=ckpt_path)
val_every_n_epoch=10, verbose=True)
logger.info(trainer.test(ldm.val_dataloader, True, True))
logger.info(trainer.fit(ldm.train_dataloader, ldm.val_dataloader))
logger.info(trainer.test(ldm.test_dataloader, True, True))

# only test from ckpt (model, global epoch, global step)
# only test from ckpt
time.sleep(1)
ckpt_path = trainer.last_ckpt_path
lmodel = MyLModule(None, None, loss_fn, metrics, "loss")
Expand Down
21 changes: 6 additions & 15 deletions mini_lightning/_mini_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@ def __init__(
self.save_hparams(hparams)
#
if resume_from_ckpt is not None:
self._load_ckpt(resume_from_ckpt, self.device)
logger.info(f"Using ckpt: {resume_from_ckpt}")
self._load_ckpt(resume_from_ckpt, self.device, True)
#
self.lmodel.trainer_init(self)
print_model_info(lmodel.model, None)
Expand Down Expand Up @@ -524,24 +524,15 @@ def _save_ckpt(self, fpath: str) -> None:
}
save_ckpt(fpath, de_parallel(self.lmodel.model), self.lmodel.optimizer, self.global_epoch, **kwargs)

def _load_ckpt(self, fpath: str, map_location: Optional[Device] = None, verbose: bool = False) -> None:
def _load_ckpt(self, fpath: str, map_location: Optional[Device] = None) -> None:
new_model, optimizer_state_dict, mes = load_ckpt(fpath, map_location)
self.lmodel.model = new_model
#
optimizer_name = self.lmodel.optimizer.__class__.__name__
tag = ["Ignore", "Ignore"]
if mes["optimizer_name"] == optimizer_name:
self.lmodel.load_state_dict(None, optimizer_state_dict)
tag[0] = "Success"

if mes["optimizer_name"] == optimizer_name or self.lmodel.optimizer is None:
self.global_epoch = mes["last_epoch"]
self.global_step = mes["global_step"]
tag[1] = "Success"

if verbose:
logger.info(
f"Using ckpt model: Success. optimizer state dict: {tag[0]}. global_epoch, global_step: {tag[1]}")
assert self.lmodel.optimizer is None or optimizer_name == mes["optimizer_name"]
self.lmodel.load_state_dict(None, optimizer_state_dict)
self.global_epoch = mes["last_epoch"]
self.global_step = mes["global_step"]

def _model_saving(self, core_metric: Optional[float]) -> bool:
best_saving = False
Expand Down

0 comments on commit 4242ad4

Please sign in to comment.