Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add policy-based model checkpointing callback #106

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 162 additions & 28 deletions pfl/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import subprocess
import time
import typing
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Union

Expand Down Expand Up @@ -684,33 +685,41 @@ def on_train_end(self, *, model: ModelType) -> None:
writer.close()


class ModelCheckpointingCallback(TrainingProcessCallback):
class CheckpointPolicy(ABC):
"""
Controls when `PolicyBasedModelCheckpointingCallback` should checkpoint.
"""
Callback to save model checkpoints. Note that the model checkpoints
can also be saved as part of ``RestoreTrainingCallback`` as long as
the model is ``Saveable`` and provided in the list of saveeables in
the initialization of the callback.

:param model_checkpoint_dir:
A path to disk for saving the trained model. Location
will be relative to root dir on current platform.
:param checkpoint_frequency:
The number of central iterations after which to save a model.
When zero (the default), the model is saved once after
training is complete.
@abstractmethod
def should_checkpoint_now(self, aggregate_metrics: Metrics,
central_iteration: int) -> bool:
"""
Invoked at the end of each central iteration to decide whether
a checkpoint should be made.
"""
raise NotImplementedError

@abstractmethod
def should_checkpoint_at_end(self) -> bool:
"""
Invoked at the end of training to decide whether a checkpoint should
be made.
"""
raise NotImplementedError


class IterationFrequencyCheckpointPolicy:
"""
Checkpoint policy for `PolicyBasedModelCheckpointingCallback` that
saves a checkpoint after every `checkpoint_frequency` iterations if the
value is positive or at the end of training if it is zero.
"""

def __init__(self,
model_checkpoint_dir: str,
*,
checkpoint_frequency: int = 0):
if get_ops().distributed.local_rank == 0:
self.checkpoint_frequency = checkpoint_frequency
from pfl.internal.platform.selector import get_platform
self.model_checkpoint_dir = get_platform(
).create_checkpoint_directories([model_checkpoint_dir])[0]
def __init__(self, checkpoint_frequency: int):
self.checkpoint_frequency = checkpoint_frequency

def _should_checkpoint_now(self, central_iteration: int) -> bool:
def should_checkpoint_now(self, aggregate_metrics: Metrics,
central_iteration: int) -> bool:
"""
Return true when the number of `central_iteration`s that have
completed is a non-zero multiple of `self.checkpoint_frequency`.
Expand All @@ -719,21 +728,146 @@ def _should_checkpoint_now(self, central_iteration: int) -> bool:
and central_iteration % self.checkpoint_frequency
== (self.checkpoint_frequency - 1))

def should_checkpoint_at_end(self) -> bool:
return self.checkpoint_frequency == 0


class MetricImprovementCheckpointPolicy(CheckpointPolicy):
"""
Stateful checkpoint policy for `PolicyBasedModelCheckpointingCallback`
to save a checkpoint after any iteration where the value of `metric_name`
has improved versus the prior best value.

:param metric_name:
The metrics whose value to track.

:param threshold_value:
If present, only save a checkpoint if the metric value is better than
this value.

:param performance_is_better:
A binary predicate indicating that `lhs` is better `rhs`.

For metrics where higher values are better, like precision,
you would want to use `operator.gt`, and for metrics like
loss, you would want to use `operator.lt` (the default).
"""

metric_name: MetricName
best_value: float | None
performance_is_better: Callable[[Any, Any], bool]

def __init__(self,
metric_name: MetricName,
*,
threshold_value: float | None = None,
performance_is_better: Callable[[Any, Any],
bool] = operator.lt):
self.metric_name = metric_name
self.best_value = threshold_value
self.performance_is_better = performance_is_better

def should_checkpoint_now(self, aggregate_metrics: Metrics,
central_iteration: int):
cur_value = get_overall_value(aggregate_metrics[self.metric_name])
if (self.best_value is None
or self.performance_is_better(cur_value, self.best_value)):
self.best_value = cur_value
return True
return False

def should_checkpoint_at_end(self):
return False


class PolicyBasedModelCheckpointingCallback(TrainingProcessCallback):
"""
Callback to save model checkpoints after iterations and after
training, when indicated by `policy`.

:param model_checkpoint_dir:
A path to disk for saving the trained model.
If running on Bolt, this will be a path relative to
``ARTIFACT_DIR``.
:param policy:
An instance of a `CheckpointPolicy` subclass.

:param numbered: If true, include the iteration number in each
checkpoint's path to save all the checkpoints without
overwriting.
"""

def __init__(self,
model_checkpoint_dir: str,
*,
checkpoint_policy: CheckpointPolicy,
numbered: bool = False):
if get_ops().distributed.local_rank == 0:
self.numbered = numbered
self.checkpoint_policy = checkpoint_policy
from pfl.internal.platform.selector import get_platform
self.model_checkpoint_dir_name = model_checkpoint_dir
if not numbered:
self.model_checkpoint_dir = get_platform(
).create_checkpoint_directories([model_checkpoint_dir])[0]

def after_central_iteration(
self, aggregate_metrics: Metrics, model: StatefulModel, *,
central_iteration: int) -> Tuple[bool, Metrics]:
if get_ops(
).distributed.local_rank == 0 and self._should_checkpoint_now(
central_iteration):
model.save(self.model_checkpoint_dir)
if get_ops().distributed.local_rank == 0:
if self.checkpoint_policy.should_checkpoint_now(
aggregate_metrics, central_iteration):
if self.numbered:
from pfl.internal.platform.selector import get_platform
self.model_checkpoint_dir = get_platform(
).create_checkpoint_directories([
f'{self.model_checkpoint_dir_name}/'
f'{central_iteration:05}'
])[0]
model.save(self.model_checkpoint_dir)
return False, Metrics()

def on_train_end(self, *, model: StatefulModel) -> None:
if get_ops(
).distributed.local_rank == 0 and self.checkpoint_frequency == 0:
if get_ops().distributed.local_rank == 0 and (
self.checkpoint_policy.should_checkpoint_at_end()):
if self.numbered:
from pfl.internal.platform.selector import get_platform
self.model_checkpoint_dir = get_platform(
).create_checkpoint_directories(
[f'{self.model_checkpoint_dir_name}/final'])[0]
model.save(self.model_checkpoint_dir)


class ModelCheckpointingCallback(PolicyBasedModelCheckpointingCallback):
"""
Callback to save model checkpoints. Note that the model checkpoints
can also be saved as part of ``RestoreTrainingCallback`` as long as
the model is ``Saveable`` and provided in the list of saveeables in
the initialization of the callback.

:param model_checkpoint_dir:
A path to disk for saving the trained model. Location
will be relative to root dir on current platform.
:param checkpoint_frequency:
The number of central iterations after which to save a model.
When zero (the default), the model is saved once after
training is complete.
:param numbered: If true, append the iteration number to each
checkpoint path to save all the checkpoints without
overwriting.
"""

def __init__(self,
model_checkpoint_dir: str,
*,
checkpoint_frequency: int = 0,
numbered: bool = False):
super().__init__(model_checkpoint_dir,
checkpoint_policy=IterationFrequencyCheckpointPolicy(
checkpoint_frequency),
numbered=numbered)


class ProfilerCallback(TrainingProcessCallback):
"""
Profiles the code using Python's profiler, cProfile.
Expand Down
Loading