-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Designate a filename for the "best" model #351
Comments
Unfortunately looking at the PyLightning documentation this isn't supported by the |
I like the name |
Simply saving the best model is already supported by setting |
Yes, but my point is that the user has no way in advance of knowing exactly what that model will be called. And if save_top_k is >1, then it's hard to tell which is the best. |
I see two low-code options to address that:
Well yes, but that's by design. That's exactly what saving the top 5 models means. If you don't want that, change the number of models to be saved. 🤷♂️ Likely you don't need any of those lower-performing models anyway. |
Addressing the second low-code approach, I'm not terribly familiar with the PyLightning API but wouldn't this create a copy of the best performing weights rather than a symlink? I'm not sure how concerned we are about disk usage. |
Yes, it would be a copy of the weights. I don't think that's a real concern though (weights will be smaller than training data or even a run for sequencing anyway). |
I tried introducing another model checkpoint that monitors the |
After investigating some potential workarounds I found a workable low(ish) code solution. By adding this to class SaveBestModelCheckpoint(ModelCheckpoint):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
target_path = Path(self.best_model_path)
simlink_path = Path(self.dirpath) / "best.ckpt"
simlink_path.unlink()
simlink_path.symlink_to(target_path) and replacing if config.save_top_k is not None:
self.callbacks.append(
ModelCheckpoint(
dirpath=config.model_save_folder_path,
monitor="valid_CELoss",
mode="min",
save_top_k=config.save_top_k,
)
) with if config.save_top_k is not None:
self.callbacks.append(
SaveBestModelCheckpoint(
dirpath=config.model_save_folder_path,
monitor="valid_CELoss",
mode="min",
save_top_k=config.save_top_k,
)
) we can essentially create a thin wrapper around |
So the conflict is between saving the top k models and only the best one? We've discussed this a bit before already, but I'm not sure what the benefit is of saving the best 5 models. Why would you want to use models ranked 2–5? 🤷♂️ Alternatively, we could change it to save all checkpoints, which is default Lightning behavior and doesn't require a |
I agree that the utility of saving the top 5 is not clear to me. I'm fine with saving all models, as long as there is a way to disable that behavior to save disk space. |
It would be nice if there were a programmatic way to identify the best-performing model, as measured by validation error. I am thinking that in addition to outputting checkpoints with names like epoch=4-step=450000.ckpt we could output a copy (or symlink) to the one with the best validation error using a name like "final.ckpt".
The text was updated successfully, but these errors were encountered: