diff --git a/configs/_base_/nas_backbones/nsga_mobilenetv3_supernet.py b/configs/_base_/nas_backbones/nsga_mobilenetv3_supernet.py new file mode 100644 index 000000000..4ecdae032 --- /dev/null +++ b/configs/_base_/nas_backbones/nsga_mobilenetv3_supernet.py @@ -0,0 +1,106 @@ +# search space +arch_setting = dict( + kernel_size=[ # [min_kernel_size, max_kernel_size, step] + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + ], + num_blocks=[ # [min_num_blocks, max_num_blocks, step] + [1, 1, 1], + [1, 1, 1], + [0, 1, 1], + [0, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 1, 1], + [0, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 1, 1], + [0, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 1, 1], + [0, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 1, 1], + [0, 1, 1], + ], + expand_ratio=[ # [min_expand_ratio, max_expand_ratio, step] + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [6, 6, 1], # last layer + ], + num_out_channels=[ # [min_channel, max_channel, step] + [16, 16, 1], # first layer + [24, 24, 1], + [24, 24, 1], + [24, 24, 1], + [24, 24, 1], + [40, 40, 1], + [40, 40, 1], + [40, 40, 1], + [40, 40, 1], + [80, 80, 1], + [80, 80, 1], + [80, 80, 1], + [80, 80, 1], + [112, 112, 1], + [112, 112, 1], + [112, 112, 1], + [112, 112, 1], + [160, 160, 1], + [160, 160, 1], + [160, 160, 1], + [160, 160, 1], + [1280, 1280, 1], # last layer + ]) + +input_resizer_cfg = dict( + input_sizes=[[192, 192], [208, 208], [224, 224], [256, 256]]) + +nas_backbone = dict( + type='AttentiveMobileNetV3', + arch_setting=arch_setting, + out_indices=(20, ), + stride_list=[2, 1, 1, 1] * 3 + [1] * 4 + [2] + [1] * 3, + with_se_list=[False] * 4 + [True] * 4 + [False] * 4 + [True] * 8, + act_cfg_list=['ReLU'] * 9 + ['HSwish'] * 13, + conv_cfg=dict(type='OFAConv2d'), + norm_cfg=dict(type='DynamicBatchNorm2d', momentum=0.1)) diff --git a/configs/_base_/settings/cifar10_bs96_nsga.py b/configs/_base_/settings/cifar10_bs96_nsga.py new file mode 100644 index 000000000..bfd69e45b --- /dev/null +++ b/configs/_base_/settings/cifar10_bs96_nsga.py @@ -0,0 +1,49 @@ +# dataset settings +dataset_type = 'mmcls.CIFAR10' +data_preprocessor = dict( + type='mmcls.ClsDataPreprocessor', + num_classes=10, + # RGB format normalization parameters + mean=[125.307, 122.961, 113.8575], + std=[51.5865, 50.847, 51.255], + # loaded images are already RGB format + to_rgb=False) + +train_pipeline = [ + dict(type='mmcls.RandomCrop', crop_size=32, padding=4), + dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'), + dict(type='mmcls.Cutout', shape=16, pad_val=0, prob=1.0), + dict(type='mmcls.PackClsInputs'), +] + +test_pipeline = [ + dict(type='mmcls.PackClsInputs'), +] + +train_dataloader = dict( + batch_size=96, + num_workers=5, + dataset=dict( + type=dataset_type, + data_prefix='data/cifar10', + test_mode=False, + pipeline=train_pipeline), + sampler=dict(type='mmcls.DefaultSampler', shuffle=True), + persistent_workers=True, +) + +val_dataloader = dict( + batch_size=96, + num_workers=5, + dataset=dict( + type=dataset_type, + data_prefix='data/cifar10/', + test_mode=True, + pipeline=test_pipeline), + sampler=dict(type='mmcls.DefaultSampler', shuffle=False), + persistent_workers=True, +) +val_evaluator = dict(type='mmcls.Accuracy', topk=(1, )) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/configs/nas/mmcls/bignas/attentive_mobilenet_supernet_32xb64_in1k.py b/configs/nas/mmcls/bignas/attentive_mobilenet_supernet_32xb64_in1k.py index 3b44dc36f..303fea924 100644 --- a/configs/nas/mmcls/bignas/attentive_mobilenet_supernet_32xb64_in1k.py +++ b/configs/nas/mmcls/bignas/attentive_mobilenet_supernet_32xb64_in1k.py @@ -44,7 +44,7 @@ loss_kl=dict( preds_S=dict(recorder='fc', from_student=True), preds_T=dict(recorder='fc', from_student=False)))), - mutators=dict(type='mmrazor.NasMutator')) + mutator=dict(type='mmrazor.NasMutator')) model_wrapper_cfg = dict( type='mmrazor.BigNASDDP', diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py new file mode 100644 index 000000000..9d0991195 --- /dev/null +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py @@ -0,0 +1,44 @@ +_base_ = ['./nsganetv2_mobilenet_supernet_1xb96_cifar10.py'] + +model = dict(norm_training=True) + +train_dataloader = dict(batch_size=256) +val_dataloader = dict(batch_size=256) +test_dataloader = val_dataloader + +param_scheduler = [dict(type='ConstantLR', factor=1.0)] + +train_cfg = dict( + _delete_=True, + type='mmrazor.NSGA2SearchLoop', + dataloader=_base_.val_dataloader, + evaluator=_base_.val_evaluator, + max_epochs=30, + num_candidates=50, + top_k=10, + num_mutation=25, + num_crossover=25, + mutate_prob=0.3, + constraints_range=dict(flops=(0., 330.)), + score_key='accuracy/top1', + predictor_cfg=dict( + type='mmrazor.MetricPredictor', + encoding_type='normal', + train_samples=2, + handler_cfg=dict(type='mmrazor.GaussProcessHandler')), + finetune_cfg=dict( + model=_base_.model, + train_dataloader=_base_.train_dataloader, + train_cfg=dict(by_epoch=True, max_epochs=2), + optim_wrapper=dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=0.1, + momentum=0.9, + weight_decay=1e-4, + nesterov=True)), + param_scheduler=param_scheduler, + default_hooks=_base_.default_hooks, + ), +) diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py new file mode 100644 index 000000000..969b9a663 --- /dev/null +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py @@ -0,0 +1,40 @@ +_base_ = ['./nsganetv2_mobilenet_supernet_8xb128_in1k.py'] + +model = dict(norm_training=True) + +param_scheduler = [dict(type='ConstantLR', factor=1.0)] + +train_cfg = dict( + _delete_=True, + type='mmrazor.NSGA2SearchLoop', + dataloader=_base_.val_dataloader, + evaluator=_base_.val_evaluator, + max_epochs=4, + num_candidates=4, + top_k=2, + num_mutation=2, + num_crossover=2, + mutate_prob=0.1, + constraints_range=dict(flops=(0., 360.)), + score_key='accuracy/top1', + predictor_cfg=dict( + type='mmrazor.MetricPredictor', + encoding_type='normal', + train_samples=2, + handler_cfg=dict(type='mmrazor.GaussProcessHandler')), + finetune_cfg=dict( + model=_base_.model, + train_dataloader=_base_.train_dataloader, + train_cfg=dict(by_epoch=True, max_epochs=1), + optim_wrapper=dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=0.025, + momentum=0.9, + weight_decay=3e-4, + nesterov=True)), + param_scheduler=param_scheduler, + default_hooks=_base_.default_hooks, + ), +) diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_supernet_1xb96_cifar10.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_supernet_1xb96_cifar10.py new file mode 100644 index 000000000..27fced8e7 --- /dev/null +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_supernet_1xb96_cifar10.py @@ -0,0 +1,39 @@ +_base_ = [ + 'mmcls::_base_/default_runtime.py', + 'mmcls::_base_/schedules/imagenet_bs2048.py', + 'mmrazor::_base_/settings/cifar10_bs96_nsga.py', + 'mmrazor::_base_/nas_backbones/nsga_mobilenetv3_supernet.py', +] + +supernet = dict( + _scope_='mmrazor', + type='SearchableImageClassifier', + backbone=_base_.nas_backbone, + neck=dict(type='SqueezeMeanPoolingWithDropout', drop_ratio=0.2), + head=dict( + type='DynamicLinearClsHead', + num_classes=1000, + in_channels=1280, + loss=dict( + type='mmcls.LabelSmoothLoss', + num_classes=1000, + label_smooth_val=0.1, + mode='original', + loss_weight=1.0), + topk=(1, 5)), + input_resizer_cfg=_base_.input_resizer_cfg, + connect_head=dict(connect_with_backbone='backbone.last_mutable_channels'), +) + +model = dict( + _scope_='mmrazor', + type='NSGANetV2', + architecture=supernet, + data_preprocessor=_base_.data_preprocessor, + mutator=dict(type='mmrazor.NasMutator')) + +find_unused_parameters = True + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', interval=1, max_keep_ckpts=1, save_best='auto')) diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_supernet_8xb128_in1k.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_supernet_8xb128_in1k.py new file mode 100644 index 000000000..6708a6037 --- /dev/null +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_supernet_8xb128_in1k.py @@ -0,0 +1,38 @@ +_base_ = [ + 'mmcls::_base_/default_runtime.py', + 'mmrazor::_base_/settings/imagenet_bs2048_bignas.py', + 'mmrazor::_base_/nas_backbones/nsga_mobilenetv3_supernet.py', +] + +supernet = dict( + _scope_='mmrazor', + type='SearchableImageClassifier', + backbone=_base_.nas_backbone, + neck=dict(type='SqueezeMeanPoolingWithDropout', drop_ratio=0.2), + head=dict( + type='DynamicLinearClsHead', + num_classes=1000, + in_channels=1280, + loss=dict( + type='mmcls.LabelSmoothLoss', + num_classes=1000, + label_smooth_val=0.1, + mode='original', + loss_weight=1.0), + topk=(1, 5)), + input_resizer_cfg=_base_.input_resizer_cfg, + connect_head=dict(connect_with_backbone='backbone.last_mutable_channels'), +) + +model = dict( + _scope_='mmrazor', + type='NSGANetV2', + architecture=supernet, + data_preprocessor=_base_.data_preprocessor, + mutator=dict(type='mmrazor.NasMutator')) + +find_unused_parameters = True + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', interval=1, max_keep_ckpts=1, save_best='auto')) diff --git a/configs/nas/mmcls/onceforall/ofa_mobilenet_supernet_32xb64_in1k.py b/configs/nas/mmcls/onceforall/ofa_mobilenet_supernet_32xb64_in1k.py index c2e0f05ab..53b9ec86f 100644 --- a/configs/nas/mmcls/onceforall/ofa_mobilenet_supernet_32xb64_in1k.py +++ b/configs/nas/mmcls/onceforall/ofa_mobilenet_supernet_32xb64_in1k.py @@ -43,7 +43,7 @@ loss_kl=dict( preds_S=dict(recorder='fc', from_student=True), preds_T=dict(recorder='fc', from_student=False)))), - mutators=dict(type='mmrazor.NasMutator')) + mutator=dict(type='mmrazor.NasMutator')) model_wrapper_cfg = dict( type='mmrazor.BigNASDDP', diff --git a/configs/nas/mmcls/spos/spos_shufflenet_search_nsga2_predictor_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_shufflenet_search_nsga2_predictor_8xb128_in1k.py new file mode 100644 index 000000000..a2ab4cf07 --- /dev/null +++ b/configs/nas/mmcls/spos/spos_shufflenet_search_nsga2_predictor_8xb128_in1k.py @@ -0,0 +1,22 @@ +_base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py'] + +model = dict(norm_training=True) + +train_cfg = dict( + _delete_=True, + type='mmrazor.NSGA2SearchLoop', + dataloader=_base_.val_dataloader, + evaluator=_base_.val_evaluator, + max_epochs=20, + num_candidates=50, + top_k=10, + num_mutation=25, + num_crossover=25, + mutate_prob=0.1, + constraints_range=dict(flops=(0., 360.)), + predictor_cfg=dict( + type='mmrazor.MetricPredictor', + encoding_type='normal', + train_samples=2, + handler_cfg=dict(type='mmrazor.GaussProcessHandler')), +) diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index 7435fa822..1504ae2c0 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -3,14 +3,14 @@ from .optimizers import SeparateOptimWrapperConstructor from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop, EvolutionSearchLoop, - GreedySamplerTrainLoop, SelfDistillValLoop, - SingleTeacherDistillValLoop, SlimmableValLoop, - SubnetValLoop) + GreedySamplerTrainLoop, NSGA2SearchLoop, + SelfDistillValLoop, SingleTeacherDistillValLoop, + SlimmableValLoop, SubnetValLoop) __all__ = [ 'SeparateOptimWrapperConstructor', 'DumpSubnetHook', 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', - 'GreedySamplerTrainLoop', 'EstimateResourcesHook', 'SelfDistillValLoop', - 'AutoSlimGreedySearchLoop', 'SubnetValLoop' + 'GreedySamplerTrainLoop', 'SubnetValLoop', 'EstimateResourcesHook', + 'SelfDistillValLoop', 'NSGA2SearchLoop', 'AutoSlimGreedySearchLoop' ] diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 10eb2b598..377f55500 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -4,6 +4,7 @@ from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop from .evolution_search_loop import EvolutionSearchLoop from .iteprune_val_loop import ItePruneValLoop +from .nsganetv2_search_loop import NSGA2SearchLoop from .slimmable_val_loop import SlimmableValLoop from .subnet_sampler_loop import GreedySamplerTrainLoop from .subnet_val_loop import SubnetValLoop @@ -12,5 +13,5 @@ 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop', - 'ItePruneValLoop', 'AutoSlimGreedySearchLoop' + 'NSGA2SearchLoop', 'ItePruneValLoop', 'AutoSlimGreedySearchLoop' ] diff --git a/mmrazor/engine/runner/attentive_search_loop.py b/mmrazor/engine/runner/attentive_search_loop.py new file mode 100644 index 000000000..ceb830a60 --- /dev/null +++ b/mmrazor/engine/runner/attentive_search_loop.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.registry import LOOPS +from .evolution_search_loop import EvolutionSearchLoop + + +@LOOPS.register_module() +class AttentiveSearchLoop(EvolutionSearchLoop): + """Loop for evolution searching with attentive tricks from AttentiveNAS. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + max_epochs (int): Total searching epochs. Defaults to 20. + max_keep_ckpts (int): The maximum checkpoints of searcher to keep. + Defaults to 3. + resume_from (str, optional): Specify the path of saved .pkl file for + resuming searching. + num_candidates (int): The length of candidate pool. Defaults to 50. + top_k (int): Specify top k candidates based on scores. Defaults to 10. + num_mutation (int): The number of candidates got by mutation. + Defaults to 25. + num_crossover (int): The number of candidates got by crossover. + Defaults to 25. + mutate_prob (float): The probability of mutation. Defaults to 0.1. + flops_range (tuple, optional): It is used for screening candidates. + resource_estimator_cfg (dict): The config for building estimator, which + is be used to estimate the flops of sampled subnet. Defaults to + None, which means default config is used. + score_key (str): Specify one metric in evaluation results to score + candidates. Defaults to 'accuracy_top-1'. + init_candidates (str, optional): The candidates file path, which is + used to init `self.candidates`. Its format is usually in .yaml + format. Defaults to None. + """ + + def _init_pareto(self): + # TODO (gaoyang): Fix apis with mmrazor2.0 + for k, v in self.constraints.items(): + if not isinstance(v, (list, tuple)): + self.constraints[k] = (0, v) + + assert len(self.constraints) == 1, 'Only accept one kind constrain.' + self.pareto_candidates = dict() + constraints = list(self.constraints.items())[0] + discretize_step = self.pareto_mode['discretize_step'] + ds = discretize_step + # find the left bound + while ds + 0.5 * discretize_step < constraints[1][0]: + ds += discretize_step + self.pareto_candidates[ds] = [] + # find the right bound + while ds - 0.5 * discretize_step < constraints[1][1]: + self.pareto_candidates[ds] = [] + ds += discretize_step diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index c1a73d4c3..fbc9cfe31 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -9,9 +9,8 @@ import numpy as np import torch from mmengine import fileio -from mmengine.dist import broadcast_object_list from mmengine.evaluator import Evaluator -from mmengine.runner import EpochBasedTrainLoop +from mmengine.runner import EpochBasedTrainLoop, Runner from mmengine.utils import is_list_of from torch.utils.data import DataLoader @@ -53,6 +52,8 @@ class EvolutionSearchLoop(EpochBasedTrainLoop, CalibrateBNMixin): Defaults to None. predictor_cfg (dict, Optional): Used for building a metric predictor. Defaults to None. + finetune_cfg (dict, Optional): Used for building an extra runner to + finetune the searched model. Defaults to None. score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. init_candidates (str, optional): The candidates file path, which is @@ -62,8 +63,8 @@ class EvolutionSearchLoop(EpochBasedTrainLoop, CalibrateBNMixin): def __init__(self, runner, - dataloader: Union[DataLoader, Dict], evaluator: Union[Evaluator, Dict, List], + dataloader: Union[DataLoader, Dict], max_epochs: int = 20, max_keep_ckpts: int = 3, resume_from: Optional[str] = None, @@ -77,6 +78,7 @@ def __init__(self, constraints_range: Dict[str, Any] = dict(flops=(0., 330.)), estimator_cfg: Optional[Dict] = None, predictor_cfg: Optional[Dict] = None, + finetune_cfg: Optional[Dict] = None, score_key: str = 'accuracy/top1', init_candidates: Optional[str] = None) -> None: super().__init__(runner, dataloader, max_epochs) @@ -84,6 +86,7 @@ def __init__(self, self.evaluator = runner.build_evaluator(evaluator) # type: ignore else: self.evaluator = evaluator # type: ignore + if hasattr(self.dataloader.dataset, 'metainfo'): self.evaluator.dataset_meta = self.dataloader.dataset.metainfo else: @@ -103,6 +106,7 @@ def __init__(self, self.crossover_prob = crossover_prob self.max_keep_ckpts = max_keep_ckpts self.resume_from = resume_from + self.trade_off = dict(max_score_key=100) self.fp16 = False if init_candidates is None: @@ -134,6 +138,10 @@ def __init__(self, self.model.mutator.search_groups self.predictor = TASK_UTILS.build(self.predictor_cfg) + self.finetune_cfg = finetune_cfg + if finetune_cfg is not None: + self.finetune_runner = self.build_finetune_runner(finetune_cfg) + def run(self) -> None: """Launch searching.""" self.runner.call_hook('before_train') @@ -173,7 +181,8 @@ def run_epoch(self) -> None: self.candidates.extend(self.top_k_candidates) self.candidates.sort_by(key_indicator='score', reverse=True) - self.top_k_candidates = Candidates(self.candidates.data[:self.top_k]) + self.top_k_candidates = \ + Candidates(self.candidates[:self.top_k]) # type: ignore scores_after = self.top_k_candidates.scores self.runner.logger.info(f'top k scores after update: ' @@ -195,33 +204,28 @@ def sample_candidates(self) -> None: """Update candidate pool contains specified number of candicates.""" candidates_resources = [] init_candidates = len(self.candidates) - if self.runner.rank == 0: - while len(self.candidates) < self.num_candidates: - candidate = self.model.sample_subnet() - is_pass, result = self._check_constraints( - random_subnet=candidate) - if is_pass: - self.candidates.append(candidate) - candidates_resources.append(result) - self.candidates = Candidates(self.candidates.data) - else: - self.candidates = Candidates([dict(a=0)] * self.num_candidates) + while len(self.candidates) < self.num_candidates: + candidate = self.model.mutator.sample_choices() + is_pass, result = self._check_constraints(candidate) + if is_pass: + self.candidates.append(candidate) + candidates_resources.append(result) + self.candidates = Candidates(self.candidates.data) if len(candidates_resources) > 0: self.candidates.update_resources( candidates_resources, - start=len(self.candidates.data) - len(candidates_resources)) + start=len(self.candidates) - len(candidates_resources)) assert init_candidates + len( candidates_resources) == self.num_candidates - # broadcast candidates to val with multi-GPUs. - broadcast_object_list(self.candidates.data) - def update_candidates_scores(self) -> None: """Validate candicate one by one from the candicate pool, and update top-k candicates.""" for i, candidate in enumerate(self.candidates.subnets): self.model.mutator.set_choices(candidate) + if self.finetune_cfg: + self.finetune_step(self.model) metrics = self._val_candidate(use_predictor=self.use_predictor) score = round(metrics[self.score_key], 2) \ if len(metrics) != 0 else 0. @@ -246,9 +250,7 @@ def gen_mutation_candidates(self): break mutation_candidate = self._mutation() - - is_pass, result = self._check_constraints( - random_subnet=mutation_candidate) + is_pass, result = self._check_constraints(mutation_candidate) if is_pass: mutation_candidates.append(mutation_candidate) mutation_resources.append(result) @@ -259,7 +261,7 @@ def gen_mutation_candidates(self): return mutation_candidates def gen_crossover_candidates(self): - """Generate specofied number of crossover candicates.""" + """Generate specified number of crossover candicates.""" crossover_resources = [] crossover_candidates: List = [] crossover_iter = 0 @@ -270,9 +272,7 @@ def gen_crossover_candidates(self): break crossover_candidate = self._crossover() - - is_pass, result = self._check_constraints( - random_subnet=crossover_candidate) + is_pass, result = self._check_constraints(crossover_candidate) if is_pass: crossover_candidates.append(crossover_candidate) crossover_resources.append(result) @@ -285,7 +285,7 @@ def gen_crossover_candidates(self): def _mutation(self) -> SupportRandomSubnet: """Mutate with the specified mutate_prob.""" candidate1 = random.choice(self.top_k_candidates.subnets) - candidate2 = self.model.sample_subnet() + candidate2 = self.model.mutator.sample_choices() candidate = crossover(candidate1, candidate2, prob=self.mutate_prob) return candidate @@ -441,3 +441,55 @@ def _init_predictor(self): f'Predictor pre-trained, saved in {predictor_dir}.') self.use_predictor = True self.candidates = Candidates() + + def build_finetune_runner(self, finetune_cfg: Dict) -> Runner: + """Build a runner for finetuning the sliced_model.""" + finetune_cfg.update(work_dir=self.runner.work_dir) + finetune_cfg.update(launcher=self.runner.launcher) + finetune_cfg.update(env_cfg=dict(dist_cfg=dict(backend='nccl'))) + + runner = Runner.from_cfg(finetune_cfg) + + from mmengine.hooks import CheckpointHook + + # remove CheckpointHook to avoid extra problems. + for hook in runner._hooks: + if isinstance(hook, CheckpointHook): + runner._hooks.remove(hook) + break + + return runner + + def finetune_step(self, model): + """Fintune before candidates evaluation.""" + self.finetune_runner._train_loop = \ + self.finetune_runner.build_train_loop( + self.finetune_cfg.train_cfg) # type: ignore + + self.finetune_runner.optim_wrapper = \ + self.finetune_runner.build_optim_wrapper( + self.finetune_cfg.optim_wrapper) + + self.finetune_runner.scale_lr(self.finetune_runner.optim_wrapper, + self.finetune_runner.auto_scale_lr) + + self.finetune_runner.param_schedulers = \ + self.finetune_runner.build_param_scheduler( + self.finetune_cfg.param_scheduler) # type: ignore + + model.train() + self.runner.logger.info('Start finetuning...') + self.finetune_runner.model = model + self.finetune_runner.call_hook('before_run') + + self.finetune_runner.optim_wrapper.initialize_count_status( + self.finetune_runner.model, self.finetune_runner._train_loop.iter, + self.finetune_runner._train_loop.max_iters) + + self.model = self.finetune_runner.train_loop.run() + self.finetune_runner.train_loop._iter = 0 + self.finetune_runner.train_loop._epoch = 0 + + self.finetune_runner.call_hook('after_run') + self.runner.logger.info('End finetuning...') + model.eval() diff --git a/mmrazor/engine/runner/nsganetv2_search_loop.py b/mmrazor/engine/runner/nsganetv2_search_loop.py new file mode 100644 index 000000000..62f418abd --- /dev/null +++ b/mmrazor/engine/runner/nsganetv2_search_loop.py @@ -0,0 +1,264 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp + +import numpy as np +from mmengine import fileio + +try: + from pymoo.algorithms.so_genetic_algorithm import GA + from pymoo.factory import get_algorithm, get_crossover, get_mutation + from pymoo.optimize import minimize + from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting +except ImportError: + from mmrazor.utils import get_placeholder + GA = get_placeholder('pymoo') + get_algorithm = get_placeholder('pymoo') + get_crossover = get_placeholder('pymoo') + get_mutation = get_placeholder('pymoo') + minimize = get_placeholder('pymoo') + NonDominatedSorting = get_placeholder('pymoo') + +from mmrazor.registry import LOOPS +from mmrazor.structures import Candidates +from .attentive_search_loop import AttentiveSearchLoop +from .utils.pymoo_utils import (AuxiliarySingleLevelProblem, + HighTradeoffPoints, SubsetProblem) +from .utils.pymoo_utils.problems import (BinaryCrossover, RandomMutation, + RandomSampling) + + +@LOOPS.register_module() +class NSGA2SearchLoop(AttentiveSearchLoop): + """Evolution search loop with NSGA2 optimizer.""" + + def run_epoch(self) -> None: + """Iterate one epoch. + + Steps: + 0. Collect archives and predictor. + 1. Sample some new candidates from the supernet.Then Append them + to the candidates, Thus make its number equal to the specified + number. + 2. Validate these candidates(step 1) and update their scores. + 3. Pick the top k candidates based on the scores(step 2), which + will be used in mutation and crossover. + 4. Implement Mutation and crossover, generate better candidates. + """ + self.archive = Candidates() + if len(self.candidates) > 0: + for subnet, score, flops in zip( + self.candidates.subnets, self.candidates.scores, + self.candidates.resources('flops')): + self.update_candidates_scores() + if self.trade_off['max_score_key'] != 0: + score = self.trade_off['max_score_key'] - score + self.archive.append(subnet) + self.archive.set_score(-1, score) + self.archive.set_resource(-1, flops, 'flops') + + self.sample_candidates(random=(self._epoch == 0), archive=self.archive) + self.update_candidates_scores() + + scores_before = self.top_k_candidates.scores + self.runner.logger.info(f'top k scores before update: ' + f'{scores_before}') + + self.candidates.extend(self.top_k_candidates) + self.sort_candidates() + self.top_k_candidates = \ + Candidates(self.candidates[:self.top_k]) # type: ignore + + scores_after = self.top_k_candidates.scores + self.runner.logger.info(f'top k scores after update: ' + f'{scores_after}') + + mutation_candidates = self.gen_mutation_candidates() + self.candidates_mutator_crossover = Candidates(mutation_candidates) + crossover_candidates = self.gen_crossover_candidates() + self.candidates_mutator_crossover.extend(crossover_candidates) + + assert len(self.candidates_mutator_crossover + ) <= self.num_candidates, 'Total of mutation and \ + crossover should be less than the number of candidates.' + + self.candidates = self.candidates_mutator_crossover + self._epoch += 1 + + def sample_candidates(self, random: bool = True, archive=None) -> None: + if random: + super().sample_candidates() + else: + candidates = self.sample_candidates_with_nsga2( + archive, self.num_candidates) + new_candidates = [] + candidates_resources = [] + for candidate in candidates: + is_pass, result = self._check_constraints(candidate) + if is_pass: + new_candidates.append(candidate) + candidates_resources.append(result) + self.candidates = Candidates(new_candidates) + + if len(candidates_resources) > 0: + self.candidates.update_resources( + candidates_resources, + start=len(self.candidates) - len(candidates_resources)) + + def sample_candidates_with_nsga2(self, archive: Candidates, + num_candidates): + """Searching for candidates with high-fidelity evaluation.""" + F = np.column_stack((archive.scores, archive.resources('flops'))) + front_index = NonDominatedSorting().do( + F, only_non_dominated_front=True) + + fronts = np.array(archive.subnets)[front_index] + fronts = np.array( + [self.predictor.model2vector(cand) for cand in fronts]) + fronts = self.predictor.preprocess(fronts) + + # initialize the candidate finding optimization problem + problem = AuxiliarySingleLevelProblem(self, len(fronts[0])) + + # initiate a multi-objective solver to optimize the problem + method = get_algorithm( + 'nsga2', + pop_size=40, + sampling=fronts, + crossover=get_crossover('int_two_point', prob=0.9), + mutation=get_mutation('int_pm', eta=1.0), + eliminate_duplicates=True) + + result = minimize(problem, method, ('n_gen', 4), seed=1, verbose=True) + + check_list = [] + for x in result.pop.get('X'): + check_list.append(self.predictor.vector2model(x)) + + not_duplicate = np.logical_not( + [x in archive.subnets for x in check_list]) + + sub_problem = SubsetProblem(result.pop[not_duplicate].get('F')[:, 1], + F[front_index, 1], num_candidates) + + sub_method = GA( + pop_size=100, + sampling=RandomSampling(), + crossover=BinaryCrossover(), + mutation=RandomMutation(), + eliminate_duplicates=True) + + sub_result = minimize( + sub_problem, sub_method, ('n_gen', 2), seed=1, verbose=False) + indices = sub_result.X + + _X = result.pop.get('X')[not_duplicate][indices] + + candidates = [] + for x in _X: + x = x[0] if isinstance(x[0], list) else x + candidates.append(self.predictor.vector2model(x)) + + return candidates + + def sort_candidates(self) -> None: + """Support sort candidates in single and multiple-obj optimization.""" + assert self.trade_off is not None, ( + '`self.trade_off` is required when sorting candidates in ' + 'NSGA2SearchLoop. Got `self.trade_off` is None.') + + max_score_key = self.trade_off.get('max_score_key', 100) + max_score_key = np.array(max_score_key) + + patches = [] + for score, flops in zip(self.candidates.scores, + self.candidates.resources('flops')): + patches.append((score, flops)) + patches = np.array(patches) + + if max_score_key != 0: + patches[:, 0] = max_score_key - patches[:, 0] # type: ignore + + sort_idx = np.argsort(patches[:, 0]) # type: ignore + F = patches[sort_idx] + + dm = HighTradeoffPoints(n_survive=len(patches)) + candidate_index = dm.do(F) + candidate_index = sort_idx[candidate_index] + + self.candidates = \ + [self.candidates[idx] for idx in candidate_index] # type: ignore + + def _save_searcher_ckpt(self): + """Save searcher ckpt, which is different from common ckpt. + + It mainly contains the candicate pool, the top-k candicates with scores + and the current epoch. + """ + if self.runner.rank == 0: + rmse, rho, tau = 0, 0, 0 + if len(self.archive) > 0: + top1_err_pred = self.fit_predictor(self.archive) + + self.candidates = self.archive + self.use_predictor = False + self.update_candidates_scores() + + rmse, rho, tau = self.predictor.get_correlation( + top1_err_pred, np.array(self.candidates.scores)) + + save_for_resume = dict() + save_for_resume['_epoch'] = self._epoch + for k in ['candidates', 'top_k_candidates']: + save_for_resume[k] = getattr(self, k) + fileio.dump( + save_for_resume, + osp.join(self.runner.work_dir, + f'search_epoch_{self._epoch}.pkl')) + + correlation_str = 'fitting ' + correlation_str += f'RMSE = {rmse:.4f}, ' + correlation_str += f'Spearmans Rho = {rho:.4f}, ' + correlation_str += f'Kendalls Tau = {tau:.4f}' + + self.pareto_mode = False + if self.pareto_mode: + step_str = '\n' + for step, candidates in self.pareto_candidates.items(): + if len(candidates) > 0: + step_str += f'step: {step}: ' + step_str += f'{candidates[0][self.score_key]}\n' + self.runner.logger.info( + f'Epoch:[{self._epoch}/{self._max_epochs}] ' + f'Top1_score: {step_str} ' + f'{correlation_str}') + else: + self.runner.logger.info( + f'Epoch:[{self._epoch}/{self._max_epochs}] ' + f'Top1_score: {self.top_k_candidates.scores[0]} ' + f'{correlation_str}') + + if self.max_keep_ckpts > 0: + cur_ckpt = self._epoch + 1 + redundant_ckpts = range(1, cur_ckpt - self.max_keep_ckpts) + for _step in redundant_ckpts: + ckpt_path = osp.join(self.runner.work_dir, + f'search_epoch_{_step}.pkl') + if osp.isfile(ckpt_path): + os.remove(ckpt_path) + + def fit_predictor(self, candidates): + """Predict performance using predictor.""" + assert self.predictor.initialize is True + + metrics = [] + for i, candidate in enumerate(candidates.subnets): + self.model.mutator.set_choices(candidate) + self.finetune_step(self.model) + metric = self._val_candidate(use_predictor=True) + metrics.append(metric[self.score_key]) + + max_score_key = self.trade_off.get('max_score_key', 0.) + assert max_score_key > 0. + metrics = max_score_key - np.array(metrics) + return metrics diff --git a/mmrazor/engine/runner/utils/pymoo_utils/__init__.py b/mmrazor/engine/runner/utils/pymoo_utils/__init__.py new file mode 100644 index 000000000..9fa1ac34b --- /dev/null +++ b/mmrazor/engine/runner/utils/pymoo_utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .high_tradeoff_points import HighTradeoffPoints +from .problems import AuxiliarySingleLevelProblem, SubsetProblem + +__all__ = [ + 'AuxiliarySingleLevelProblem', 'SubsetProblem', 'HighTradeoffPoints' +] diff --git a/mmrazor/engine/runner/utils/pymoo_utils/high_tradeoff_points.py b/mmrazor/engine/runner/utils/pymoo_utils/high_tradeoff_points.py new file mode 100644 index 000000000..021fa3e74 --- /dev/null +++ b/mmrazor/engine/runner/utils/pymoo_utils/high_tradeoff_points.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +try: + from pymoo.configuration import Configuration + from pymoo.model.decision_making import (DecisionMaking, NeighborFinder, + find_outliers_upper_tail, + normalize) + Configuration.show_compile_hint = False +except ImportError: + from mmrazor.utils import get_placeholder + DecisionMaking = get_placeholder('pymoo') + NeighborFinder = get_placeholder('pymoo') + normalize = get_placeholder('pymoo') + find_outliers_upper_tail = get_placeholder('pymoo') + Configuration = get_placeholder('pymoo') + + +class HighTradeoffPoints(DecisionMaking): + """Method for multi-object optimization. + + Args: + ratio(float): weight between score_key and sec_obj, details in + demo/nas/demo.ipynb. + epsilon(float): specific a radius for each neighbour. + n_survive(int): how many high-tradeoff points will return finally. + """ + + def __init__(self, epsilon=0.125, n_survive=None, **kwargs) -> None: + super().__init__(**kwargs) + self.epsilon = epsilon + self.n_survive = n_survive # number of points to be selected + + def _do(self, F, **kwargs): + n, m = F.shape + + if self.normalize: + F = normalize( + F, + self.ideal_point, + self.nadir_point, + estimate_bounds_if_none=True) + + neighbors_finder = NeighborFinder( + F, epsilon=0.125, n_min_neigbors='auto', consider_2d=False) + + mu = np.full(n, -np.inf) + + for i in range(n): + + # for each neighbour in a specific radius of that solution + neighbors = neighbors_finder.find(i) + + # calculate the trade-off to all neighbours + diff = F[neighbors] - F[i] + + # calculate sacrifice and gain + sacrifice = np.maximum(0, diff).sum(axis=1) + gain = np.maximum(0, -diff).sum(axis=1) + + np.warnings.filterwarnings('ignore') + tradeoff = sacrifice / gain + + # otherwise find the one with the smalled one + mu[i] = np.nanmin(tradeoff) + + if self.n_survive is not None: + return np.argsort(mu)[-self.n_survive:] + else: + # return points with trade-off > 2*sigma + return find_outliers_upper_tail(mu) diff --git a/mmrazor/engine/runner/utils/pymoo_utils/problems.py b/mmrazor/engine/runner/utils/pymoo_utils/problems.py new file mode 100644 index 000000000..38fcf5fa1 --- /dev/null +++ b/mmrazor/engine/runner/utils/pymoo_utils/problems.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +try: + from pymoo.model.crossover import Crossover + from pymoo.model.mutation import Mutation + from pymoo.model.problem import Problem + from pymoo.model.sampling import Sampling +except ImportError: + from mmrazor.utils import get_placeholder + Crossover = get_placeholder('pymoo') + Mutation = get_placeholder('pymoo') + Problem = get_placeholder('pymoo') + Sampling = get_placeholder('pymoo') + + +class RandomSampling(Sampling): + + def _do(self, problem, n_samples, **kwargs): + X = np.full((n_samples, problem.n_var), False, dtype=np.bool) + + for k in range(n_samples): + index = np.random.permutation(problem.n_var)[:problem.n_max] + X[k, index] = True + + return X + + +class BinaryCrossover(Crossover): + + def __init__(self): + super().__init__(2, 1) + + def _do(self, problem, X, **kwargs): + n_parents, n_matings, n_var = X.shape + + _X = np.full((self.n_offsprings, n_matings, problem.n_var), False) + + for k in range(n_matings): + p1, p2 = X[0, k], X[1, k] + + both_are_true = np.logical_and(p1, p2) + _X[0, k, both_are_true] = True + + n_remaining = problem.n_max - np.sum(both_are_true) + + index = np.where(np.logical_xor(p1, p2))[0] + + S = index[np.random.permutation(len(index))][:n_remaining] + _X[0, k, S] = True + + return _X + + +class RandomMutation(Mutation): + + def _do(self, problem, X, **kwargs): + for i in range(X.shape[0]): + X[i, :] = X[i, :] + is_false = np.where(np.logical_not(X[i, :]))[0] + is_true = np.where(X[i, :])[0] + try: + X[i, np.random.choice(is_false)] = True + X[i, np.random.choice(is_true)] = False + except ValueError: + pass + + return X + + +class AuxiliarySingleLevelProblem(Problem): + """The optimization problem for finding the next N candidate + architectures.""" + + def __init__(self, search_loop, n_var=15, sec_obj='flops'): + super().__init__(n_var=n_var, n_obj=2, n_constr=0, type_var=np.int) + + self.search_loop = search_loop + self.predictor = self.search_loop.predictor + self.sec_obj = sec_obj + + self.xl = np.zeros(self.n_var) + # upper bound for variable, automatically calculate by search space + self.xu = [] + from mmrazor.models.mutables import OneShotMutableChannelUnit + for mutable in self.predictor.search_groups.values(): + if isinstance(mutable[0], OneShotMutableChannelUnit): + if mutable[0].num_channels > 0: + self.xu.append(mutable[0].num_channels - 1) + else: + if mutable[0].num_choices > 0: + self.xu.append(mutable[0].num_choices - 1) + self.xu = np.array(self.xu) + + def _evaluate(self, x, out, *args, **kwargs): + """Evaluate results.""" + f = np.full((x.shape[0], self.n_obj), np.nan) + # predicted top1 error + top1_err = self.predictor.handler.predict(x)[:, 0] + for i in range(len(x)): + candidate = self.predictor.vector2model(x[i]) + _, resource = self.search_loop._check_constraints(candidate) + f[i, 0] = top1_err[i] + f[i, 1] = resource[self.sec_obj] + + out['F'] = f + + +class SubsetProblem(Problem): + """select a subset to diversify the pareto front.""" + + def __init__(self, candidates, archive, K): + super().__init__( + n_var=len(candidates), + n_obj=1, + n_constr=1, + xl=0, + xu=1, + type_var=bool) + self.archive = archive + self.candidates = candidates + self.n_max = K + + def _evaluate(self, x, out, *args, **kwargs): + f = np.full((x.shape[0], 1), np.nan) + g = np.full((x.shape[0], 1), np.nan) + + for i, _x in enumerate(x): + # append selected candidates to archive then sort + tmp = np.sort(np.concatenate((self.archive, self.candidates[_x]))) + f[i, 0] = np.std(np.diff(tmp)) + # we penalize if the number of selected candidates is not exactly K + g[i, 0] = (self.n_max - np.sum(_x))**2 + + out['F'] = f + out['G'] = g diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index 3f649c426..c885e745b 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -4,7 +4,7 @@ FpnTeacherDistill, OverhaulFeatureDistillation, SelfDistill, SingleTeacherDistill) from .nas import (DSNAS, DSNASDDP, SPOS, Autoformer, AutoSlim, AutoSlimDDP, - BigNAS, BigNASDDP, Darts, DartsDDP) + BigNAS, BigNASDDP, Darts, DartsDDP, NSGANetV2) from .pruning import DCFF, SlimmableNetwork, SlimmableNetworkDDP from .pruning.ite_prune_algorithm import ItePruneAlgorithm @@ -14,5 +14,5 @@ 'Darts', 'DartsDDP', 'DCFF', 'SelfDistill', 'DataFreeDistillation', 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'ItePruneAlgorithm', 'DSNAS', 'DSNASDDP', 'Autoformer', 'BigNAS', - 'BigNASDDP' + 'BigNASDDP', 'NSGANetV2' ] diff --git a/mmrazor/models/algorithms/nas/__init__.py b/mmrazor/models/algorithms/nas/__init__.py index 6a9c29161..967156a12 100644 --- a/mmrazor/models/algorithms/nas/__init__.py +++ b/mmrazor/models/algorithms/nas/__init__.py @@ -4,9 +4,11 @@ from .bignas import BigNAS, BigNASDDP from .darts import Darts, DartsDDP from .dsnas import DSNAS, DSNASDDP +from .nsganetv2 import NSGANetV2 from .spos import SPOS __all__ = [ 'SPOS', 'AutoSlim', 'AutoSlimDDP', 'BigNAS', 'BigNASDDP', 'Darts', - 'DartsDDP', 'DSNAS', 'DSNASDDP', 'DSNAS', 'DSNASDDP', 'Autoformer' + 'DartsDDP', 'DSNAS', 'DSNASDDP', 'DSNAS', 'DSNASDDP', 'Autoformer', + 'NSGANetV2' ] diff --git a/mmrazor/models/algorithms/nas/nsganetv2.py b/mmrazor/models/algorithms/nas/nsganetv2.py new file mode 100644 index 000000000..fad3d99d0 --- /dev/null +++ b/mmrazor/models/algorithms/nas/nsganetv2.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement +from torch import nn +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.models.distillers import ConfigurableDistiller +from mmrazor.models.mutators import NasMutator +from mmrazor.registry import MODELS +from mmrazor.utils import ValidFixMutable +from ..base import BaseAlgorithm, LossResults + +VALID_MUTATOR_TYPE = Union[NasMutator, Dict] +VALID_DISTILLER_TYPE = Union[ConfigurableDistiller, Dict] + + +@MODELS.register_module() +class NSGANetV2(BaseAlgorithm): + """Implementation of `NSGANetV2 `_ + + NSGANetV2 generates task-specific models that are competitive under + multiple competing objectives. + + NSGANetV2 comprises of two surrogates, one at the architecture level to + improve sample efficiency and one at the weights level, through a supernet, + to improve gradient descent training efficiency. + + The logic of the search part is implemented in + :class:`mmrazor.engine.NSGA2SearchLoop` + + Args: + architecture (dict|:obj:`BaseModel`): The config of :class:`BaseModel` + or built model. Corresponding to supernet in NAS algorithm. + mutator (VALID_MUTATOR_TYPE): The config of :class:`NasMutator` or + built mutator. + fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or + loaded dict or built :obj:`FixSubnet`. Defaults to None. + data_preprocessor (Optional[Union[dict, nn.Module]]): The pre-process + config of :class:`BaseDataPreprocessor`. Defaults to None. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.2. + backbone_dropout_stages (List): Stages to be set dropout. Defaults to + [6, 7]. + norm_training (bool): Whether to set norm layers to training mode, + namely, not freeze running stats (mean and var). Note: Effect on + Batch Norm and its variants only. Defaults to False. + init_cfg (Optional[dict]): Init config for ``BaseModule``. + Defaults to None. + """ + + def __init__(self, + architecture: Union[BaseModel, Dict], + mutator: VALID_MUTATOR_TYPE, + fix_subnet: Optional[ValidFixMutable] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + drop_path_rate: float = 0.2, + backbone_dropout_stages: List = [6, 7], + norm_training: bool = False, + init_cfg: Optional[dict] = None): + super().__init__(architecture, data_preprocessor, init_cfg) + + if fix_subnet: + # Avoid circular import + from mmrazor.structures import load_fix_subnet + + # According to fix_subnet, delete the unchosen part of supernet + load_fix_subnet(self, fix_subnet) + self.is_supernet = False + else: + self.mutator = self._build_mutator(mutator) + self.mutator.prepare_from_supernet(self.architecture) + + self.sample_kinds = ['max', 'min'] + self.is_supernet = True + + self.drop_path_rate = drop_path_rate + self.backbone_dropout_stages = backbone_dropout_stages + self.norm_training = norm_training + + def _build_mutator(self, mutator: VALID_MUTATOR_TYPE = None) -> NasMutator: + """Build mutator.""" + if isinstance(mutator, dict): + mutator = MODELS.build(mutator) + if not isinstance(mutator, NasMutator): + raise TypeError('mutator should be a `dict` or `NasMutator` ' + f'instance, but got {type(mutator)}.') + return mutator + + def loss( + self, + batch_inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + ) -> LossResults: + """Calculate losses from a batch of inputs and data samples.""" + if self.is_supernet: + self.mutator.set_choices(self.mutator.sample_choices()) + return self.architecture(batch_inputs, data_samples, mode='loss') + else: + return self.architecture(batch_inputs, data_samples, mode='loss') + + def train(self, mode=True): + """Convert the model into eval mode while keep normalization layer + unfreezed.""" + + super().train(mode) + if self.norm_training and not mode: + for module in self.architecture.modules(): + if isinstance(module, _BatchNorm): + module.training = True diff --git a/mmrazor/models/architectures/backbones/searchable_mobilenet_v3.py b/mmrazor/models/architectures/backbones/searchable_mobilenet_v3.py index 5a6e15cbc..6e4b62805 100644 --- a/mmrazor/models/architectures/backbones/searchable_mobilenet_v3.py +++ b/mmrazor/models/architectures/backbones/searchable_mobilenet_v3.py @@ -92,14 +92,9 @@ def __init__(self, self.arch_setting = arch_setting self.widen_factor = widen_factor self.out_indices = out_indices - for index in out_indices: - if index not in range(0, 8): - raise ValueError('the item in out_indices must in ' - f'range(0, 8). But received {index}') if frozen_stages not in range(-1, 8): raise ValueError('frozen_stages must in range(-1, 8). ' f'But received {frozen_stages}') - self.out_indices = out_indices self.frozen_stages = frozen_stages self.conv_cfg = conv_cfg diff --git a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py index 65d5a44d6..ed3a9e617 100644 --- a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py @@ -62,7 +62,8 @@ def fix_chosen(self, chosen=None): def dump_chosen(self) -> DumpChosen: """Dump chosen.""" - meta = dict(max_channels=self.mask.size(0)) + meta = dict( + max_channels=self.mask.size(0), all_choices=self.candidate_choices) chosen = self.export_chosen() return DumpChosen(chosen=chosen, meta=meta) diff --git a/mmrazor/models/mutators/group_mixin.py b/mmrazor/models/mutators/group_mixin.py index 569f01ebc..4af878d1a 100644 --- a/mmrazor/models/mutators/group_mixin.py +++ b/mmrazor/models/mutators/group_mixin.py @@ -221,7 +221,15 @@ def build_search_groups( f'The duplicate keys are {duplicate_keys}. ' \ 'Please check if there are duplicate keys in the `custom_group`.' - return search_groups + # TODO: update search groups + new_search_groups = dict() + for group_id, module in search_groups.items(): + if module[0].alias: + new_search_groups[module[0].alias] = module + else: + new_search_groups[group_id] = module + + return new_search_groups def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]], name2mutable: Dict[str, MUTABLE_TYPE], diff --git a/mmrazor/models/task_modules/predictor/metric_predictor.py b/mmrazor/models/task_modules/predictor/metric_predictor.py index a05e2a1b4..796c3ac5b 100644 --- a/mmrazor/models/task_modules/predictor/metric_predictor.py +++ b/mmrazor/models/task_modules/predictor/metric_predictor.py @@ -96,6 +96,10 @@ def model2vector( vector_dict: Dict[str, list] = \ dict(normal_vector=[], onehot_vector=[]) + assert len(model.keys()) == len(self.search_groups.keys()), ( + f'Length mismatch for model({len(model.keys())}) and search_groups' + f'({len(self.search_groups.keys())}).') + for key, choice in model.items(): if isinstance(choice, DumpChosen): assert choice.meta is not None, ( @@ -105,8 +109,16 @@ def model2vector( len(choice.meta['all_choices']), dtype=np.int) _chosen_index = choice.meta['all_choices'].index(choice.chosen) else: - assert len(self.search_groups[index]) == 1 - choices = self.search_groups[index][0].choices + if key is not None: + from mmrazor.models.mutables import MutableChannelUnit + if isinstance(self.search_groups[key][0], + MutableChannelUnit): + choices = self.search_groups[key][0].candidate_choices + else: + choices = self.search_groups[key][0].choices + else: + assert len(self.search_groups[index]) == 1 + choices = self.search_groups[index][0].choices onehot = np.zeros(len(choices), dtype=np.int) _chosen_index = choices.index(choice) onehot[_chosen_index] = 1 @@ -126,18 +138,26 @@ def vector2model(self, vector: np.array) -> Dict[str, str]: Returns: Dict[str, str]: converted model. """ + from mmrazor.models.mutables import OneShotMutableChannelUnit + start = 0 model = {} - for key, value in self.search_groups.items(): + vector = np.squeeze(vector) + for name, mutables in self.search_groups.items(): + if isinstance(mutables[0], OneShotMutableChannelUnit): + choices = mutables[0].candidate_choices + else: + choices = mutables[0].choices + if self.encoding_type == 'onehot': - index = np.where(vector[start:start + - len(value[0].choices)] == 1)[0][0] - start += len(value) + index = np.where(vector[start:start + len(choices)] == 1)[0][0] + start += len(choices) else: index = vector[start] start += 1 - chosen = value[0].choices[int(index)] - model[key] = chosen + + chosen = choices[int(index)] if len(choices) > 1 else choices[0] + model[name] = chosen return model diff --git a/mmrazor/structures/subnet/candidate.py b/mmrazor/structures/subnet/candidate.py index 9f0ebc344..3d1d8deaa 100644 --- a/mmrazor/structures/subnet/candidate.py +++ b/mmrazor/structures/subnet/candidate.py @@ -6,6 +6,7 @@ class Candidates(UserList): """The data structure of sampled candidate. The format is Union[Dict[str, Dict], List[Dict[str, Dict]]]. + Examples: >>> candidates = Candidates() >>> subnet_1 = {'1': 'choice1', '2': 'choice2'} @@ -66,21 +67,8 @@ def resources(self, key_indicator: str = 'flops') -> List[float]: @property def subnets(self) -> List[Dict]: """The subnets of candidates.""" - import copy assert len(self.data) > 0, ('Got empty candidates.') - if 'value_subnet' in self.data[0]: - subnets = [] - for data in self.data: - subnet = dict() - _data = copy.deepcopy(data) - for k1 in ['value_subnet', 'channel_subnet']: - for k2 in self._indicators: - _data[k1].pop(k2) - subnet[k1] = _data[k1] - subnets.append(subnet) - return subnets - else: - return [eval(key) for item in self.data for key, _ in item.items()] + return [eval(key) for item in self.data for key, _ in item.items()] def _format(self, data: _format_input) -> _format_return: """Transform [Dict, ...] to Union[Dict[str, Dict], List[Dict[str,