From f253cc8ef8d312ce433c5f1dc77adacd25ca3305 Mon Sep 17 00:00:00 2001 From: Natarajan Krishnaswami Date: Tue, 12 Nov 2024 13:03:38 -0500 Subject: [PATCH] Add a policy-based model checkpointing callback and add a best-metric checkpoint policy --- pfl/callback.py | 190 +++++++++++++++++++++++++++++++++++------ tests/test_callback.py | 140 ++++++++++++++++++++++++++++-- 2 files changed, 295 insertions(+), 35 deletions(-) diff --git a/pfl/callback.py b/pfl/callback.py index e5809be..17bc188 100644 --- a/pfl/callback.py +++ b/pfl/callback.py @@ -11,6 +11,7 @@ import subprocess import time import typing +from abc import ABC, abstractmethod from collections import OrderedDict from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Union @@ -684,33 +685,41 @@ def on_train_end(self, *, model: ModelType) -> None: writer.close() -class ModelCheckpointingCallback(TrainingProcessCallback): +class CheckpointPolicy(ABC): + """ + Controls when `PolicyBasedModelCheckpointingCallback` should checkpoint. """ - Callback to save model checkpoints. Note that the model checkpoints - can also be saved as part of ``RestoreTrainingCallback`` as long as - the model is ``Saveable`` and provided in the list of saveeables in - the initialization of the callback. - :param model_checkpoint_dir: - A path to disk for saving the trained model. Location - will be relative to root dir on current platform. - :param checkpoint_frequency: - The number of central iterations after which to save a model. - When zero (the default), the model is saved once after - training is complete. + @abstractmethod + def should_checkpoint_now(self, aggregate_metrics: Metrics, + central_iteration: int) -> bool: + """ + Invoked at the end of each central iteration to decide whether + a checkpoint should be made. + """ + raise NotImplementedError + + @abstractmethod + def should_checkpoint_at_end(self) -> bool: + """ + Invoked at the end of training to decide whether a checkpoint should + be made. + """ + raise NotImplementedError + + +class IterationFrequencyCheckpointPolicy: + """ + Checkpoint policy for `PolicyBasedModelCheckpointingCallback` that + saves a checkpoint after every `checkpoint_frequency` iterations if the + value is positive or at the end of training if it is zero. """ - def __init__(self, - model_checkpoint_dir: str, - *, - checkpoint_frequency: int = 0): - if get_ops().distributed.local_rank == 0: - self.checkpoint_frequency = checkpoint_frequency - from pfl.internal.platform.selector import get_platform - self.model_checkpoint_dir = get_platform( - ).create_checkpoint_directories([model_checkpoint_dir])[0] + def __init__(self, checkpoint_frequency: int): + self.checkpoint_frequency = checkpoint_frequency - def _should_checkpoint_now(self, central_iteration: int) -> bool: + def should_checkpoint_now(self, aggregate_metrics: Metrics, + central_iteration: int) -> bool: """ Return true when the number of `central_iteration`s that have completed is a non-zero multiple of `self.checkpoint_frequency`. @@ -719,21 +728,146 @@ def _should_checkpoint_now(self, central_iteration: int) -> bool: and central_iteration % self.checkpoint_frequency == (self.checkpoint_frequency - 1)) + def should_checkpoint_at_end(self) -> bool: + return self.checkpoint_frequency == 0 + + +class MetricImprovementCheckpointPolicy(CheckpointPolicy): + """ + Stateful checkpoint policy for `PolicyBasedModelCheckpointingCallback` + to save a checkpoint after any iteration where the value of `metric_name` + has improved versus the prior best value. + + :param metric_name: + The metrics whose value to track. + + :param threshold_value: + If present, only save a checkpoint if the metric value is better than + this value. + + :param performance_is_better: + A binary predicate indicating that `lhs` is better `rhs`. + + For metrics where higher values are better, like precision, + you would want to use `operator.gt`, and for metrics like + loss, you would want to use `operator.lt` (the default). + """ + + metric_name: MetricName + best_value: float | None + performance_is_better: Callable[[Any, Any], bool] + + def __init__(self, + metric_name: MetricName, + *, + threshold_value: float | None = None, + performance_is_better: Callable[[Any, Any], + bool] = operator.lt): + self.metric_name = metric_name + self.best_value = threshold_value + self.performance_is_better = performance_is_better + + def should_checkpoint_now(self, aggregate_metrics: Metrics, + central_iteration: int): + cur_value = get_overall_value(aggregate_metrics[self.metric_name]) + if (self.best_value is None + or self.performance_is_better(cur_value, self.best_value)): + self.best_value = cur_value + return True + return False + + def should_checkpoint_at_end(self): + return False + + +class PolicyBasedModelCheckpointingCallback(TrainingProcessCallback): + """ + Callback to save model checkpoints after iterations and after + training, when indicated by `policy`. + + :param model_checkpoint_dir: + A path to disk for saving the trained model. + If running on Bolt, this will be a path relative to + ``ARTIFACT_DIR``. + :param policy: + An instance of a `CheckpointPolicy` subclass. + + :param numbered: If true, include the iteration number in each + checkpoint's path to save all the checkpoints without + overwriting. + """ + + def __init__(self, + model_checkpoint_dir: str, + *, + checkpoint_policy: CheckpointPolicy, + numbered: bool = False): + if get_ops().distributed.local_rank == 0: + self.numbered = numbered + self.checkpoint_policy = checkpoint_policy + from pfl.internal.platform.selector import get_platform + self.model_checkpoint_dir_name = model_checkpoint_dir + if not numbered: + self.model_checkpoint_dir = get_platform( + ).create_checkpoint_directories([model_checkpoint_dir])[0] + def after_central_iteration( self, aggregate_metrics: Metrics, model: StatefulModel, *, central_iteration: int) -> Tuple[bool, Metrics]: - if get_ops( - ).distributed.local_rank == 0 and self._should_checkpoint_now( - central_iteration): - model.save(self.model_checkpoint_dir) + if get_ops().distributed.local_rank == 0: + if self.checkpoint_policy.should_checkpoint_now( + aggregate_metrics, central_iteration): + if self.numbered: + from pfl.internal.platform.selector import get_platform + self.model_checkpoint_dir = get_platform( + ).create_checkpoint_directories([ + f'{self.model_checkpoint_dir_name}/' + f'{central_iteration:05}' + ])[0] + model.save(self.model_checkpoint_dir) return False, Metrics() def on_train_end(self, *, model: StatefulModel) -> None: - if get_ops( - ).distributed.local_rank == 0 and self.checkpoint_frequency == 0: + if get_ops().distributed.local_rank == 0 and ( + self.checkpoint_policy.should_checkpoint_at_end()): + if self.numbered: + from pfl.internal.platform.selector import get_platform + self.model_checkpoint_dir = get_platform( + ).create_checkpoint_directories( + [f'{self.model_checkpoint_dir_name}/final'])[0] model.save(self.model_checkpoint_dir) +class ModelCheckpointingCallback(PolicyBasedModelCheckpointingCallback): + """ + Callback to save model checkpoints. Note that the model checkpoints + can also be saved as part of ``RestoreTrainingCallback`` as long as + the model is ``Saveable`` and provided in the list of saveeables in + the initialization of the callback. + + :param model_checkpoint_dir: + A path to disk for saving the trained model. Location + will be relative to root dir on current platform. + :param checkpoint_frequency: + The number of central iterations after which to save a model. + When zero (the default), the model is saved once after + training is complete. + :param numbered: If true, append the iteration number to each + checkpoint path to save all the checkpoints without + overwriting. + """ + + def __init__(self, + model_checkpoint_dir: str, + *, + checkpoint_frequency: int = 0, + numbered: bool = False): + super().__init__(model_checkpoint_dir, + checkpoint_policy=IterationFrequencyCheckpointPolicy( + checkpoint_frequency), + numbered=numbered) + + class ProfilerCallback(TrainingProcessCallback): """ Profiles the code using Python's profiler, cProfile. diff --git a/tests/test_callback.py b/tests/test_callback.py index 2f1f973..7311bbf 100644 --- a/tests/test_callback.py +++ b/tests/test_callback.py @@ -14,9 +14,12 @@ AggregateMetricsToDisk, CentralEvaluationCallback, CentralEvaluationWithEMACallback, + CheckpointPolicy, ConvergenceCallback, EarlyStoppingCallback, + MetricImprovementCheckpointPolicy, ModelCheckpointingCallback, + PolicyBasedModelCheckpointingCallback, ProfilerCallback, RestoreTrainingCallback, StopwatchCallback, @@ -33,6 +36,8 @@ from pfl.model.base import StatefulModel from pfl.model.ema import CentralExponentialMovingAverage +# pylint: disable=too-many-lines + @pytest.fixture(scope='module') def central_dataset(): @@ -530,25 +535,146 @@ def test_on_train_end(self, tmp_path, model, mock_model_params): assert mock_writer.close.call_count == 3 -@pytest.mark.parametrize('checkpoint_frequency,expected_call_count', [ - (0, 1), - (1, 2), - (2, 1), +@pytest.mark.parametrize('checkpoint_frequency,expected_call_count,numbered', [ + (0, 1, True), + (1, 2, True), + (2, 1, True), + (0, 1, False), + (1, 2, False), + (2, 1, False), ]) def test_model_checkpointing_callback(checkpoint_frequency, - expected_call_count, tmp_path): + expected_call_count, numbered, tmp_path): platform = MagicMock(spec=GenericPlatform) - platform.create_checkpoint_directories.return_value = [str(tmp_path)] + platform.create_checkpoint_directories.side_effect = lambda dirs: dirs model = MagicMock(spec=StatefulModel) with patch('pfl.internal.platform.selector.get_platform', return_value=platform): callback = ModelCheckpointingCallback( - str(tmp_path), checkpoint_frequency=checkpoint_frequency) + str(tmp_path), + checkpoint_frequency=checkpoint_frequency, + numbered=numbered) callback.after_central_iteration(Metrics(), model, central_iteration=0) callback.after_central_iteration(Metrics(), model, central_iteration=1) callback.on_train_end(model=model) assert model.save.call_count == expected_call_count + call_args_list = model.save.call_args_list + if numbered: + if checkpoint_frequency != 0: + for idx in range(model.save.call_count): + central_iteration = (idx + 1) * checkpoint_frequency - 1 + assert call_args_list[idx].args == ( + f'{tmp_path}/{central_iteration:05}', ) + else: + assert call_args_list[0].args == (f'{tmp_path}/final', ) + else: + for call_args in model.save.call_args_list: + assert call_args.args == (f'{tmp_path}', ) + + +@pytest.mark.parametrize('policy_results,should_checkpoint_at_end,numbered', [ + ([False, False], False, False), + ([True, False], False, False), + ([False, True], False, False), + ([True, True], False, False), + ([False, False], False, True), + ([True, False], False, True), + ([False, True], False, True), + ([True, True], False, True), + ([False, False], True, False), + ([True, False], True, False), + ([False, True], True, False), + ([True, True], True, False), + ([False, False], True, True), + ([True, False], True, True), + ([False, True], True, True), + ([True, True], True, True), +]) +def test_policy_based_model_checkpointing_callback(policy_results, + should_checkpoint_at_end, + numbered, tmp_path): + platform = MagicMock(spec=GenericPlatform) + platform.create_checkpoint_directories.side_effect = lambda dirs: dirs + model = MagicMock(spec=StatefulModel) + policy = MagicMock(spec=CheckpointPolicy) + policy.should_checkpoint_now.side_effect = policy_results + policy.should_checkpoint_at_end.return_value = should_checkpoint_at_end + with patch('pfl.internal.platform.selector.get_platform', + return_value=platform): + callback = PolicyBasedModelCheckpointingCallback( + str(tmp_path), checkpoint_policy=policy, numbered=numbered) + callback.after_central_iteration(Metrics(), model, central_iteration=0) + callback.after_central_iteration(Metrics(), model, central_iteration=1) + callback.on_train_end(model=model) + + call_args_list = model.save.call_args_list + + expected_call_count = should_checkpoint_at_end + sum(policy_results) + assert model.save.call_count == expected_call_count + + if numbered: + call_args_iter = iter(call_args_list) + for central_iteration, checkpointed in enumerate(policy_results): + if checkpointed: + assert next(call_args_iter).args == ( + f'{tmp_path}/{central_iteration:05}', ) + if should_checkpoint_at_end: + assert next(call_args_iter).args == (f'{tmp_path}/final', ) + else: + for call_args in model.save.call_args_list: + assert call_args.args == (f'{tmp_path}', ) + + +@pytest.mark.parametrize( + 'metric_values,threshold,expected_call_count,numbered', [ + ([3, 2, 1], 2, 1, False), + ([1, 0, 0], 2, 2, False), + ([0, 0, 0], 2, 1, False), + ([3, 2, 1], 2, 1, True), + ([1, 0, 0], 2, 2, True), + ([0, 0, 0], 1, 1, True), + ([3, 2, 1], None, 3, False), + ([1, 0, 0], None, 2, False), + ([0, 0, 0], None, 1, False), + ([3, 2, 1], None, 3, True), + ([1, 0, 0], None, 2, True), + ([0, 0, 0], None, 1, True), + ]) +def test_metric_improvement_model_checkpointing_callback( + metric_values, threshold, expected_call_count, numbered, tmp_path): + platform = MagicMock(spec=GenericPlatform) + platform.create_checkpoint_directories.side_effect = lambda dirs: dirs + model = MagicMock(spec=StatefulModel) + policy = MetricImprovementCheckpointPolicy(metric_name='metric_name', + threshold_value=threshold) + with patch('pfl.internal.platform.selector.get_platform', + return_value=platform): + callback = PolicyBasedModelCheckpointingCallback( + str(tmp_path), checkpoint_policy=policy, numbered=False) + for central_iteration in range(3): + callback.after_central_iteration( + Metrics({'metric_name': + metric_values[central_iteration]}.items()), + model, + central_iteration=central_iteration) + callback.on_train_end(model=model) + + call_args_list = model.save.call_args_list + + assert model.save.call_count == expected_call_count + + if numbered: + call_args_iter = iter(call_args_list) + for central_iteration, (lhs, rhs) in enumerate( + zip(metric_values[:-1], metric_values[1:])): + if policy.performance_is_better(lhs, rhs): + assert next(call_args_iter).args == ( + f'{tmp_path}/{central_iteration:05}', ) + else: + for call_args in model.save.call_args_list: + assert call_args.args == (f'{tmp_path}', ) + class TestProfilerCallback: