Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
adapt pytorch lighting 2.0 AKA lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
qmpzzpmq committed Jun 13, 2023
1 parent 928575b commit 27c21e5
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 44 deletions.
68 changes: 45 additions & 23 deletions nni/compression/pytorch/utils/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
else:
LIGHTNING_INSTALLED = True

try:
import lightning as L
except ImportError:
LIGHTNING2_INSTALLED = False
else:
LIGHTNING2_INSTALLED = True

try:
from transformers.trainer import Trainer as HFTrainer
except ImportError:
Expand Down Expand Up @@ -161,7 +168,7 @@ def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule):
"""
raise NotImplementedError

def bind_model(self, model: Module | pl.LightningModule, param_names_map: Dict[str, str] | None = None):
def bind_model(self, model: Module | pl.LightningModule | L.LightningModule, param_names_map: Dict[str, str] | None = None):
"""
Bind the model suitable for this ``Evaluator`` to use the evaluator's abilities of model modification,
model training, and model evaluation.
Expand Down Expand Up @@ -312,25 +319,27 @@ class LightningEvaluator(Evaluator):
If the the test metric is needed by nni, please make sure log metric with key ``default`` in ``LightningModule.test_step()``.
"""

def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule,
def __init__(self, trainer: pl.Trainer | L.trainer, data_module: pl.LightningDataModule,
dummy_input: Any | None = None):
assert LIGHTNING_INSTALLED, 'pytorch_lightning is not installed.'
err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg
lighting2_check = not LIGHTNING2_INSTALLED and isinstance(trainer, L.Trainer)
assert (isinstance(trainer, pl.Trainer) or lighting2_check) and is_traceable(trainer), err_msg
err_msg = err_msg_p.format('pytorch_lightning.LightningDataModule', 'pytorch_lightning.LightningDataModule')
assert isinstance(data_module, pl.LightningDataModule) and is_traceable(data_module), err_msg
lighting2_check = not LIGHTNING2_INSTALLED and isinstance(data_module, L.LightningDataModule)
assert (isinstance(data_module, pl.LightningDataModule) or lighting2_check) and is_traceable(data_module), err_msg
self.trainer = trainer
self.data_module = data_module
self._dummy_input = dummy_input

self.model: pl.LightningModule | None = None
self.model: pl.LightningModule | L.LightningModule | None = None
self._ori_model_attr = {}
self._param_names_map: Dict[str, str] | None = None

self._initialization_complete = False

def _init_optimizer_helpers(self, pure_model: pl.LightningModule):
def _init_optimizer_helpers(self, pure_model: pl.LightningModule | L.LightningModule):
assert self._initialization_complete is False, 'Evaluator initialization is already complete.'

self._optimizer_helpers = []
Expand Down Expand Up @@ -395,10 +404,14 @@ def _init_optimizer_helpers(self, pure_model: pl.LightningModule):

self._initialization_complete = True

def bind_model(self, model: pl.LightningModule, param_names_map: Dict[str, str] | None = None):
def bind_model(
self,
model: pl.LightningModule | L.LightningModule,
param_names_map: Dict[str, str] | None = None
):
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
assert self._initialization_complete is True, err_msg
assert isinstance(model, pl.LightningModule)
assert isinstance(model, pl.LightningModule) or isinstance(model, L.LightningModule)
if self.model is not None:
_logger.warning('Already bound a model, will unbind it before bind a new model.')
self.unbind_model()
Expand All @@ -425,7 +438,7 @@ def unbind_model(self):
_logger.warning('Did not bind any model, no need to unbind model.')

def _patch_configure_optimizers(self):
assert isinstance(self.model, pl.LightningModule)
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)

if self._opt_returned_dicts:
def new_configure_optimizers(_): # type: ignore
Expand All @@ -452,11 +465,11 @@ def new_configure_optimizers(_):
self.model.configure_optimizers = types.MethodType(new_configure_optimizers, self.model)

def _revert_configure_optimizers(self):
assert isinstance(self.model, pl.LightningModule)
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
self.model.configure_optimizers = self._ori_model_attr['configure_optimizers']

def patch_loss(self, patch: Callable[[Tensor], Tensor]):
assert isinstance(self.model, pl.LightningModule)
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
old_training_step = self.model.training_step

def patched_training_step(_, *args, **kwargs):
Expand All @@ -470,19 +483,28 @@ def patched_training_step(_, *args, **kwargs):
self.model.training_step = types.MethodType(patched_training_step, self.model)

def revert_loss(self):
assert isinstance(self.model, pl.LightningModule)
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
self.model.training_step = self._ori_model_attr['training_step']

def patch_optimizer_step(self, before_step_tasks: List[Callable], after_step_tasks: List[Callable]):
assert isinstance(self.model, pl.LightningModule)
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)

class OptimizerCallback(Callback):
def on_before_optimizer_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule,
optimizer: Optimizer, opt_idx: int) -> None:
def on_before_optimizer_step(
self,
trainer: pl.Trainer | L.Trainer,
pl_module: pl.LightningModule | L.LightningModule,
optimizer: Optimizer, opt_idx: int
) -> None:
for task in before_step_tasks:
task()

def on_before_zero_grad(self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Optimizer) -> None:
def on_before_zero_grad(
self,
trainer: pl.Trainer | L.trainer,
pl_module: pl.LightningModule | L.LightningModule,
optimizer: Optimizer,
) -> None:
for task in after_step_tasks:
task()

Expand All @@ -496,13 +518,13 @@ def patched_configure_callbacks(_):
self.model.configure_callbacks = types.MethodType(patched_configure_callbacks, self.model)

def revert_optimizer_step(self):
assert isinstance(self.model, pl.LightningModule)
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
self.model.configure_callbacks = self._ori_model_attr['configure_callbacks']

def train(self, max_steps: int | None = None, max_epochs: int | None = None):
assert isinstance(self.model, pl.LightningModule)
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
# reset trainer
trainer: pl.Trainer = self.trainer.trace_copy().get() # type: ignore
trainer: pl.Trainer | L.Trainer = self.trainer.trace_copy().get() # type: ignore
# NOTE: lightning may dry run some steps at first for sanity check in Trainer.fit() by default,
# If we want to record some information in the forward hook, we may get some additional information,
# so using Trainer.num_sanity_val_steps = 0 disable sanity check.
Expand All @@ -529,9 +551,9 @@ def evaluate(self) -> Tuple[float | None, List[Dict[str, float]]]:
E.g., if ``Trainer.test()`` returned ``[{'default': 0.8, 'loss': 2.3}, {'default': 0.6, 'loss': 2.4}]``,
NNI will take the final metric ``(0.8 + 0.6) / 2 = 0.7``.
"""
assert isinstance(self.model, pl.LightningModule)
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
# reset trainer
trainer: pl.Trainer = self.trainer.trace_copy().get() # type: ignore
trainer: pl.Trainer | L.Trainer = self.trainer.trace_copy().get() # type: ignore
original_results = trainer.test(self.model, self.data_module)
# del trainer reference, we don't want to dump trainer when we dump the entire model.
self.model.trainer = None
Expand Down Expand Up @@ -831,7 +853,7 @@ def __init__(self, trainer: HFTrainer, dummy_input: Any | None = None) -> None:

self._initialization_complete = False

def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule):
def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule | L.LightningModule):
assert self._initialization_complete is False, 'Evaluator initialization is already complete.'

if self.traced_trainer.optimizer is not None and is_traceable(self.traced_trainer.optimizer):
Expand Down Expand Up @@ -862,7 +884,7 @@ def patched_get_optimizer_cls_and_kwargs(args) -> Tuple[Any, Any]:

self._initialization_complete = True

def bind_model(self, model: Module | pl.LightningModule, param_names_map: Dict[str, str] | None = None):
def bind_model(self, model: Module | pl.LightningModule | L.LightningModule, param_names_map: Dict[str, str] | None = None):
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
assert self._initialization_complete is True, err_msg
assert isinstance(model, Module)
Expand Down
61 changes: 40 additions & 21 deletions nni/contrib/compression/utils/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
else:
LIGHTNING_INSTALLED = True

try:
import lightning as L
except ImportError:
LIGHTNING2_INSTALLED = False
else:
LIGHTNING2_INSTALLED = True

try:
from transformers.trainer import Trainer as HFTrainer
except ImportError:
Expand Down Expand Up @@ -149,7 +156,7 @@ class Evaluator:
_initialization_complete: bool
_hook: List[Hook]

def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule):
def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule | L.LightningModule):
"""
This is an internal API, ``pure_model`` means the model is the original model passed in by the user,
it should not be the modified model (wrapped, hooked, or patched by NNI).
Expand All @@ -164,7 +171,7 @@ def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule):
"""
raise NotImplementedError

def bind_model(self, model: Module | pl.LightningModule, param_names_map: Dict[str, str] | None = None):
def bind_model(self, model: Module | pl.LightningModule | L.LightningModule, param_names_map: Dict[str, str] | None = None):
"""
Bind the model suitable for this ``Evaluator`` to use the evaluator's abilities of model modification,
model training, and model evaluation.
Expand All @@ -186,8 +193,12 @@ def unbind_model(self):
"""
raise NotImplementedError

def _optimizer_add_param_group(self, model: Union[torch.nn.Module, pl.LightningModule],
module_name_param_dict: Dict[str, List[Tensor]], optimizers: Optimizer | List[Optimizer]):
def _optimizer_add_param_group(
self,
model: Union[torch.nn.Module, pl.LightningModule, L.LightningModule],
module_name_param_dict: Dict[str, List[Tensor]],
optimizers: Optimizer | List[Optimizer]
):
# used in the bind_model process
def find_param_group(param_groups: List[Dict], module_name: str):
for i, param_group in enumerate(param_groups):
Expand Down Expand Up @@ -367,25 +378,33 @@ class LightningEvaluator(Evaluator):
If the the test metric is needed by nni, please make sure log metric with key ``default`` in ``LightningModule.test_step()``.
"""

def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule,
def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule | L.LightningDataModule,
dummy_input: Any | None = None):
assert LIGHTNING_INSTALLED, 'pytorch_lightning is not installed.'
err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg
err_msg = err_msg_p.format('pytorch_lightning.LightningDataModule', 'pytorch_lightning.LightningDataModule')
err_msg = err_msg_p.format(
'pytorch_lightning.Trainer or lightning.Trainer',
'pytorch_lightning.Trainer or lightning.Trainer',
)
lighting2_check = not LIGHTNING2_INSTALLED and isinstance(trainer, L.Trainer)
assert (isinstance(trainer, pl.Trainer) or lighting2_check)and is_traceable(trainer), err_msg
err_msg = err_msg_p.format(
'pytorch_lightning.LightningDataModule or lightning.LightningDataModule',
'pytorch_lightning.LightningDataModule or lightning.LightningDataModule',
)
lighting2_check = not LIGHTNING2_INSTALLED and isinstance(data_module, L.LightningDataModule)
assert isinstance(data_module, pl.LightningDataModule) and is_traceable(data_module), err_msg
self.trainer = trainer
self.data_module = data_module
self.trainer: pl.Trainer | L.Trainer = trainer
self.data_module: pl.LightningDataModule | L.LightningDataModule = data_module
self._dummy_input = dummy_input

self.model: pl.LightningModule | None = None
self.model: pl.LightningModule | L.LightningModule | None = None
self._ori_model_attr = {}
self._param_names_map: Dict[str, str] | None = None

self._initialization_complete = False

def _init_optimizer_helpers(self, pure_model: pl.LightningModule):
def _init_optimizer_helpers(self, pure_model: pl.LightningModule | L.LightningModule):
assert self._initialization_complete is False, 'Evaluator initialization is already complete.'

self._optimizer_helpers = []
Expand Down Expand Up @@ -450,7 +469,7 @@ def _init_optimizer_helpers(self, pure_model: pl.LightningModule):

self._initialization_complete = True

def bind_model(self, model: pl.LightningModule, param_names_map: Dict[str, str] | None = None):
def bind_model(self, model: pl.LightningModule | L.LightningModule, param_names_map: Dict[str, str] | None = None):
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
assert self._initialization_complete is True, err_msg
assert isinstance(model, pl.LightningModule)
Expand Down Expand Up @@ -514,7 +533,7 @@ def new_configure_optimizers(_):
self.model.configure_optimizers = types.MethodType(new_configure_optimizers, self.model)

def _patch_configure_optimizers(self):
assert isinstance(self.model, pl.LightningModule)
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
if self._opt_returned_dicts:
def new_configure_optimizers(_): # type: ignore
optimizers = [opt_helper.call(self.model, self._param_names_map) for opt_helper in self._optimizer_helpers] # type: ignore
Expand Down Expand Up @@ -559,11 +578,11 @@ def patched_training_step(_, *args, **kwargs):
self.model.training_step = types.MethodType(patched_training_step, self.model)

def revert_loss(self):
assert isinstance(self.model, pl.LightningModule)
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
self.model.training_step = self._ori_model_attr['training_step']

def patch_optimizer_step(self, before_step_tasks: List[Callable], after_step_tasks: List[Callable]):
assert isinstance(self.model, pl.LightningModule)
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
old_configure_optimizers = self.model.configure_optimizers

def patched_step_factory(old_step):
Expand Down Expand Up @@ -599,13 +618,13 @@ def new_configure_optimizers(_):
self.model.configure_optimizers = types.MethodType(new_configure_optimizers, self.model)

def revert_optimizer_step(self):
assert isinstance(self.model, pl.LightningModule)
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
self.model.configure_callbacks = self._ori_model_attr['configure_callbacks']

def train(self, max_steps: int | None = None, max_epochs: int | None = None):
assert isinstance(self.model, pl.LightningModule)
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
# reset trainer
trainer: pl.Trainer = self.trainer.trace_copy().get() # type: ignore
trainer: pl.Trainer | L.Trainer = self.trainer.trace_copy().get() # type: ignore
# NOTE: lightning may dry run some steps at first for sanity check in Trainer.fit() by default,
# If we want to record some information in the forward hook, we may get some additional information,
# so using Trainer.num_sanity_val_steps = 0 disable sanity check.
Expand All @@ -632,9 +651,9 @@ def evaluate(self) -> Tuple[float | None, List[Dict[str, float]]]:
E.g., if ``Trainer.test()`` returned ``[{'default': 0.8, 'loss': 2.3}, {'default': 0.6, 'loss': 2.4}]``,
NNI will take the final metric ``(0.8 + 0.6) / 2 = 0.7``.
"""
assert isinstance(self.model, pl.LightningModule)
assert isinstance(self.model, pl.LightningModule) or isinstance(self.model, L.LightningModule)
# reset trainer
trainer: pl.Trainer = self.trainer.trace_copy().get() # type: ignore
trainer: pl.Trainer | L.trainer = self.trainer.trace_copy().get() # type: ignore
original_results = trainer.test(self.model, self.data_module)
# del trainer reference, we don't want to dump trainer when we dump the entire model.
self.model.trainer = None
Expand Down

0 comments on commit 27c21e5

Please sign in to comment.