From b3a0069ab81de3a68306747ae835ccc98066b34f Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Fri, 2 Dec 2022 15:59:15 +0800 Subject: [PATCH 01/12] update nsga2 search with pymoo_v0.50 --- ...enet_search_nsga2_predictor_8xb128_in1k.py | 22 + mmrazor/engine/__init__.py | 5 +- mmrazor/engine/runner/__init__.py | 4 +- .../engine/runner/attentive_search_loop.py | 56 ++ .../engine/runner/evolution_search_loop.py | 1 + .../engine/runner/nsganetv2_search_loop.py | 253 +++++++ .../runner/utils/high_tradeoff_points.py | 81 +++ mmrazor/models/algorithms/nas/nsganetv2.py | 102 +++ mmrazor/models/task_modules/__init__.py | 1 + .../multi_object_optimizer/__init__.py | 9 + .../multi_object_optimizer/base_optimizer.py | 210 ++++++ .../genetic_optimizer.py | 88 +++ .../multi_object_optimizer/nsga2_optimizer.py | 150 ++++ .../problem/__init__.py | 5 + .../problem/auxiliary_singlelevel_problem.py | 36 + .../problem/base_problem.py | 327 +++++++++ .../problem/subset_problem.py | 34 + .../multi_object_optimizer/utils/__init__.py | 1 + .../utils/domin_matrix.py | 134 ++++ .../multi_object_optimizer/utils/helper.py | 668 ++++++++++++++++++ .../multi_object_optimizer/utils/selection.py | 490 +++++++++++++ .../predictor/metric_predictor.py | 2 +- 22 files changed, 2675 insertions(+), 4 deletions(-) create mode 100644 configs/nas/mmcls/spos/spos_shufflenet_search_nsga2_predictor_8xb128_in1k.py create mode 100644 mmrazor/engine/runner/attentive_search_loop.py create mode 100644 mmrazor/engine/runner/nsganetv2_search_loop.py create mode 100644 mmrazor/engine/runner/utils/high_tradeoff_points.py create mode 100644 mmrazor/models/algorithms/nas/nsganetv2.py create mode 100644 mmrazor/models/task_modules/multi_object_optimizer/__init__.py create mode 100644 mmrazor/models/task_modules/multi_object_optimizer/base_optimizer.py create mode 100644 mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py create mode 100644 mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py create mode 100644 mmrazor/models/task_modules/multi_object_optimizer/problem/__init__.py create mode 100644 mmrazor/models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py create mode 100644 mmrazor/models/task_modules/multi_object_optimizer/problem/base_problem.py create mode 100644 mmrazor/models/task_modules/multi_object_optimizer/problem/subset_problem.py create mode 100644 mmrazor/models/task_modules/multi_object_optimizer/utils/__init__.py create mode 100644 mmrazor/models/task_modules/multi_object_optimizer/utils/domin_matrix.py create mode 100644 mmrazor/models/task_modules/multi_object_optimizer/utils/helper.py create mode 100644 mmrazor/models/task_modules/multi_object_optimizer/utils/selection.py diff --git a/configs/nas/mmcls/spos/spos_shufflenet_search_nsga2_predictor_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_shufflenet_search_nsga2_predictor_8xb128_in1k.py new file mode 100644 index 000000000..a2ab4cf07 --- /dev/null +++ b/configs/nas/mmcls/spos/spos_shufflenet_search_nsga2_predictor_8xb128_in1k.py @@ -0,0 +1,22 @@ +_base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py'] + +model = dict(norm_training=True) + +train_cfg = dict( + _delete_=True, + type='mmrazor.NSGA2SearchLoop', + dataloader=_base_.val_dataloader, + evaluator=_base_.val_evaluator, + max_epochs=20, + num_candidates=50, + top_k=10, + num_mutation=25, + num_crossover=25, + mutate_prob=0.1, + constraints_range=dict(flops=(0., 360.)), + predictor_cfg=dict( + type='mmrazor.MetricPredictor', + encoding_type='normal', + train_samples=2, + handler_cfg=dict(type='mmrazor.GaussProcessHandler')), +) diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index f2df86a83..973b9c761 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -4,12 +4,13 @@ from .runner import (AutoSlimValLoop, DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop, EvolutionSearchLoop, GreedySamplerTrainLoop, SelfDistillValLoop, - SingleTeacherDistillValLoop, SlimmableValLoop) + SingleTeacherDistillValLoop, SlimmableValLoop, + NSGA2SearchLoop) __all__ = [ 'SeparateOptimWrapperConstructor', 'DumpSubnetHook', 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'EstimateResourcesHook', - 'SelfDistillValLoop' + 'SelfDistillValLoop', 'NSGA2SearchLoop' ] diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 9715a4e6b..6f1befa3b 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -3,11 +3,13 @@ from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop from .evolution_search_loop import EvolutionSearchLoop +from .nsganetv2_search_loop import NSGA2SearchLoop from .slimmable_val_loop import SlimmableValLoop from .subnet_sampler_loop import GreedySamplerTrainLoop __all__ = [ 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', - 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop' + 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop', + 'NSGA2SearchLoop' ] diff --git a/mmrazor/engine/runner/attentive_search_loop.py b/mmrazor/engine/runner/attentive_search_loop.py new file mode 100644 index 000000000..ceb830a60 --- /dev/null +++ b/mmrazor/engine/runner/attentive_search_loop.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.registry import LOOPS +from .evolution_search_loop import EvolutionSearchLoop + + +@LOOPS.register_module() +class AttentiveSearchLoop(EvolutionSearchLoop): + """Loop for evolution searching with attentive tricks from AttentiveNAS. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + max_epochs (int): Total searching epochs. Defaults to 20. + max_keep_ckpts (int): The maximum checkpoints of searcher to keep. + Defaults to 3. + resume_from (str, optional): Specify the path of saved .pkl file for + resuming searching. + num_candidates (int): The length of candidate pool. Defaults to 50. + top_k (int): Specify top k candidates based on scores. Defaults to 10. + num_mutation (int): The number of candidates got by mutation. + Defaults to 25. + num_crossover (int): The number of candidates got by crossover. + Defaults to 25. + mutate_prob (float): The probability of mutation. Defaults to 0.1. + flops_range (tuple, optional): It is used for screening candidates. + resource_estimator_cfg (dict): The config for building estimator, which + is be used to estimate the flops of sampled subnet. Defaults to + None, which means default config is used. + score_key (str): Specify one metric in evaluation results to score + candidates. Defaults to 'accuracy_top-1'. + init_candidates (str, optional): The candidates file path, which is + used to init `self.candidates`. Its format is usually in .yaml + format. Defaults to None. + """ + + def _init_pareto(self): + # TODO (gaoyang): Fix apis with mmrazor2.0 + for k, v in self.constraints.items(): + if not isinstance(v, (list, tuple)): + self.constraints[k] = (0, v) + + assert len(self.constraints) == 1, 'Only accept one kind constrain.' + self.pareto_candidates = dict() + constraints = list(self.constraints.items())[0] + discretize_step = self.pareto_mode['discretize_step'] + ds = discretize_step + # find the left bound + while ds + 0.5 * discretize_step < constraints[1][0]: + ds += discretize_step + self.pareto_candidates[ds] = [] + # find the right bound + while ds - 0.5 * discretize_step < constraints[1][1]: + self.pareto_candidates[ds] = [] + ds += discretize_step diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index fc907f3aa..6d62d99fa 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -96,6 +96,7 @@ def __init__(self, self.crossover_prob = crossover_prob self.max_keep_ckpts = max_keep_ckpts self.resume_from = resume_from + self.trade_off = dict(max_score_key=40) if init_candidates is None: self.candidates = Candidates() diff --git a/mmrazor/engine/runner/nsganetv2_search_loop.py b/mmrazor/engine/runner/nsganetv2_search_loop.py new file mode 100644 index 000000000..73aadd4f6 --- /dev/null +++ b/mmrazor/engine/runner/nsganetv2_search_loop.py @@ -0,0 +1,253 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from copy import deepcopy + +import numpy as np +from mmengine import fileio +from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting + +from mmrazor.models.task_modules import (GeneticOptimizer, + NSGA2Optimizer, + AuxiliarySingleLevelProblem, + SubsetProblem) +from mmrazor.registry import LOOPS +from mmrazor.structures import Candidates, export_fix_subnet +from .attentive_search_loop import AttentiveSearchLoop +from .utils.high_tradeoff_points import HighTradeoffPoints + +# from pymoo.algorithms.moo.nsga2 import NSGA2 as NSGA2Optimizer +# from pymoo.algorithms.soo.nonconvex.ga import GA as GeneticOptimizer +# from pymoo.optimize import minimize + + +@LOOPS.register_module() +class NSGA2SearchLoop(AttentiveSearchLoop): + """Evolution search loop with NSGA2 optimizer.""" + + def run_epoch(self) -> None: + """Iterate one epoch. + + Steps: + 0. Collect archives and predictor. + 1. Sample some new candidates from the supernet.Then Append them + to the candidates, Thus make its number equal to the specified + number. + 2. Validate these candidates(step 1) and update their scores. + 3. Pick the top k candidates based on the scores(step 2), which + will be used in mutation and crossover. + 4. Implement Mutation and crossover, generate better candidates. + """ + archive = Candidates() + for subnet, score, flops in zip(self.candidates.subnets, + self.candidates.scores, + self.candidates.resources('flops')): + if self.trade_off['max_score_key'] != 0: + score = self.trade_off['max_score_key'] - score + archive.append(subnet) + archive.set_score(-1, score) + archive.set_resource(-1, flops, 'flops') + + self.sample_candidates(random=(self._epoch == 0), archive=archive) + self.update_candidates_scores() + + scores_before = self.top_k_candidates.scores + self.runner.logger.info(f'top k scores before update: ' + f'{scores_before}') + + self.candidates.extend(self.top_k_candidates) + self.sort_candidates() + self.top_k_candidates = Candidates(self.candidates[:self.top_k]) + + scores_after = self.top_k_candidates.scores + self.runner.logger.info(f'top k scores after update: ' + f'{scores_after}') + + mutation_candidates = self.gen_mutation_candidates() + self.candidates_mutator_crossover = Candidates(mutation_candidates) + crossover_candidates = self.gen_crossover_candidates() + self.candidates_mutator_crossover.extend(crossover_candidates) + + assert len(self.candidates_mutator_crossover + ) <= self.num_candidates, 'Total of mutation and \ + crossover should be less than the number of candidates.' + + self.candidates = self.candidates_mutator_crossover + self._epoch += 1 + + def sample_candidates(self, random: bool = True, archive=None) -> None: + if random: + super().sample_candidates() + else: + candidates = self.sample_candidates_with_nsga2( + archive, self.num_candidates) + new_candidates = [] + candidates_resources = [] + for candidate in candidates: + is_pass, result = self._check_constraints(candidate) + if is_pass: + new_candidates.append(candidate) + candidates_resources.append(result) + self.candidates = Candidates(new_candidates) + + if len(candidates_resources) > 0: + self.candidates.update_resources( + candidates_resources, + start=len(self.candidates.data)-len(candidates_resources)) + + def sample_candidates_with_nsga2(self, archive: Candidates, num_candidates): + """Searching for candidates with high-fidelity evaluation.""" + F = np.column_stack((archive.scores, archive.resources('flops'))) + front_index = NonDominatedSorting().do(F, only_non_dominated_front=True) + + fronts = np.array(archive.subnets)[front_index] + fronts = np.array([self.predictor.model2vector(cand) for cand in fronts]) + fronts = self.predictor.preprocess(fronts) + + # initialize the candidate finding optimization problem + problem = AuxiliarySingleLevelProblem(self, len(fronts[0])) + + # initiate a multi-objective solver to optimize the problem + method = NSGA2Optimizer( + pop_size=4, + sampling=fronts, # initialize with current nd archs + eliminate_duplicates=True, + logger=self.runner.logger) + + # # kick-off the search + method.initialize(problem, n_gen=2, verbose=True) + result = method.solve() + + # check for duplicates + check_list = [] + for x in result['pop'].get('X'): + assert x is not None + check_list.append(self.predictor.vector2model(x)) + + not_duplicate = np.logical_not( + [x in archive.subnets for x in check_list]) + + # extra process after nsga2 search + sub_problem = SubsetProblem(result['pop'][not_duplicate].get('F')[:, 1], + F[front_index, 1], + num_candidates) + sub_method = GeneticOptimizer(pop_size=num_candidates, + eliminate_duplicates=True) + sub_method.initialize(sub_problem, n_gen=4, verbose=False) + indices = sub_method.solve()['X'] + + candidates = Candidates() + pop = result['pop'][not_duplicate][indices] + for x in pop.get('X'): + candidates.append(self.predictor.vector2model(x)) + + return candidates + + def sort_candidates(self) -> None: + """Support sort candidates in single and multiple-obj optimization.""" + assert self.trade_off is not None, ( + '`self.trade_off` is required when sorting candidates in ' + 'NSGA2SearchLoop. Got self.trade_off is None.') + ratio = self.trade_off.get('ratio', 1) + multiple_obj_score = [] + for score, flops in zip(self.candidates.scores, + self.candidates.resources('flops')): + multiple_obj_score.append((score, flops)) + multiple_obj_score = np.array(multiple_obj_score) + max_score_key = self.trade_off.get('max_score_key', 100) + if max_score_key != 0: + multiple_obj_score[:, 0] = \ + max_score_key - multiple_obj_score[:, 0] + sort_idx = np.argsort(multiple_obj_score[:, 0]) + F = multiple_obj_score[sort_idx] + dm = HighTradeoffPoints(ratio, n_survive=len(multiple_obj_score)) + candidate_index = dm.do(F) + candidate_index = sort_idx[candidate_index] + self.candidates = [self.candidates[idx] for idx in candidate_index] + + def _save_searcher_ckpt(self, archive=[]): + """Save searcher ckpt, which is different from common ckpt. + + It mainly contains the candicate pool, the top-k candicates with scores + and the current epoch. + """ + if self.runner.rank == 0: + rmse, rho, tau = 0, 0, 0 + if len(archive) > 0: + top1_err_pred = self.fit_predictor(archive) + rmse, rho, tau = self.predictor.get_correlation( + top1_err_pred, np.array([x[1] for x in archive])) + + save_for_resume = dict() + save_for_resume['_epoch'] = self._epoch + for k in ['candidates', 'top_k_candidates']: + save_for_resume[k] = getattr(self, k) + fileio.dump( + save_for_resume, + osp.join(self.runner.work_dir, + f'search_epoch_{self._epoch}.pkl')) + + correlation_str = 'fitting ' + # correlation_str += f'{self.predictor.type}: ' + correlation_str += f'RMSE = {rmse:.4f}, ' + correlation_str += f'Spearmans Rho = {rho:.4f}, ' + correlation_str += f'num_candidatesendalls Tau = {tau:.4f}' + + self.pareto_mode = False + if self.pareto_mode: + step_str = '\n' + for step, candidates in self.pareto_candidates.items(): + if len(candidates) > 0: + step_str += f'step: {step}: ' + step_str += f'{candidates[0][self.score_key]}\n' + self.runner.logger.info( + f'Epoch:[{self._epoch + 1}/{self._max_epochs}], ' + f'top1_score: {step_str} ' + f'{correlation_str}') + else: + self.runner.logger.info( + f'Epoch:[{self._epoch + 1}/{self._max_epochs}], ' + f'top1_score: {self.top_k_candidates.scores[0]} ' + f'{correlation_str}') + + def fit_predictor(self, candidates): + """anticipate testfn training(err rate).""" + inputs = [export_fix_subnet(x) for x in candidates.subnets] + inputs = np.array([self.predictor.model2vector(x) for x in inputs]) + + targets = np.array([x[1] for x in candidates]) + + if not self.predictor.pretrained: + self.predictor.fit(inputs, targets) + + metrics = self.predictor.predict(inputs) + if self.max_score_key != 0: + for i in range(len(metrics)): + metrics[i] = self.max_score_key - metrics[i] + return metrics + + def finetune_step(self, model): + """fintune before candidates evaluation.""" + # TODO (gaoyang): update with 2.0 version. + self.runner.logger.info('start finetuning...') + model.train() + while self._fintune_epoch < self._max_finetune_epochs: + self.runner.call_hook('before_train_epoch') + for idx, data_batch in enumerate(self.dataloader): + self.runner.call_hook( + 'before_train_iter', + batch_idx=idx, + data_batch=data_batch) + + outputs = model.train_step( + data_batch, optim_wrapper=self.optim_wrapper) + + self.runner.call_hook( + 'after_train_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) + + self.runner.call_hook('after_train_epoch') + self._finetune_epoch += 1 + + model.eval() diff --git a/mmrazor/engine/runner/utils/high_tradeoff_points.py b/mmrazor/engine/runner/utils/high_tradeoff_points.py new file mode 100644 index 000000000..38a627c18 --- /dev/null +++ b/mmrazor/engine/runner/utils/high_tradeoff_points.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from pymoo.config import Config +from pymoo.core.decision_making import (DecisionMaking, NeighborFinder, + find_outliers_upper_tail) +from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting +from pymoo.util.normalization import normalize + +Config.warnings['not_compiled'] = False + + +class HighTradeoffPoints(DecisionMaking): + """Method for multi-object optimization. + + Args: + ratio(float): weight between score_key and sec_obj, details in + demo/nas/demo.ipynb. + epsilon(float): specific a radius for each neighbour. + n_survive(int): how many high-tradeoff points will return finally. + """ + + def __init__(self, + ratio=1, + epsilon=0.125, + n_survive=None, + **kwargs) -> None: + super().__init__(**kwargs) + self.epsilon = epsilon + self.n_survive = n_survive + self.ratio = ratio + + def _do(self, data, **kwargs): + front = NonDominatedSorting().do(data, only_non_dominated_front=True) + F = data[front, :] + + n, m = F.shape + F = normalize(F, self.ideal, self.nadir) + F[:, 1] = F[:, 1] * self.ratio + + neighbors_finder = NeighborFinder( + F, epsilon=0.125, n_min_neigbors='auto', consider_2d=False) + + mu = np.full(n, -np.inf) + + for i in range(n): + + # for each neighbour in a specific radius of that solution + neighbors = neighbors_finder.find(i) + + # calculate the trade-off to all neighbours + diff = F[neighbors] - F[i] + + # calculate sacrifice and gain + sacrifice = np.maximum(0, diff).sum(axis=1) + gain = np.maximum(0, -diff).sum(axis=1) + + np.warnings.filterwarnings('ignore') + tradeoff = sacrifice / gain + + # otherwise find the one with the smalled one + mu[i] = np.nanmin(tradeoff) + + # if given topk + if self.n_survive is not None: + n_survive = min(self.n_survive, len(mu)) + index = np.argsort(mu)[-n_survive:][::-1] + front_survive = front[index] + + self.n_survive -= n_survive + if self.n_survive == 0: + return front_survive + # in case the survived in front is not enough for topk + index = np.array(list(set(np.arange(len(data))) - set(front))) + unused_data = data[index] + no_front_survive = index[self._do(unused_data)] + + return np.concatenate([front_survive, no_front_survive]) + else: + # return points with trade-off > 2*sigma + mu = find_outliers_upper_tail(mu) + return mu if len(mu) else [] diff --git a/mmrazor/models/algorithms/nas/nsganetv2.py b/mmrazor/models/algorithms/nas/nsganetv2.py new file mode 100644 index 000000000..8eaad2f38 --- /dev/null +++ b/mmrazor/models/algorithms/nas/nsganetv2.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement +from torch import nn +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.models.distillers import ConfigurableDistiller +from mmrazor.models.mutators.base_mutator import BaseMutator +from mmrazor.models.mutators import OneShotModuleMutator +from mmrazor.registry import MODELS +from mmrazor.structures.subnet.fix_subnet import load_fix_subnet +from mmrazor.utils import SingleMutatorRandomSubnet, ValidFixMutable +from ..base import BaseAlgorithm, LossResults + +VALID_MUTATOR_TYPE = Union[BaseMutator, Dict] +VALID_MUTATORS_TYPE = Dict[str, Union[BaseMutator, Dict]] +VALID_DISTILLER_TYPE = Union[ConfigurableDistiller, Dict] + + +@MODELS.register_module() +class NSGANetV2(BaseAlgorithm): + """ + + """ + + # TODO fix ea's name in doc-string. + + def __init__(self, + architecture: Union[BaseModel, Dict], + mutator: VALID_MUTATORS_TYPE, + # distiller: VALID_DISTILLER_TYPE, + # norm_training: bool = False, + fix_subnet: Optional[ValidFixMutable] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None, + drop_prob: float = 0.2): + super().__init__(architecture, data_preprocessor, init_cfg) + + if fix_subnet: + # Avoid circular import + from mmrazor.structures import load_fix_subnet + + # According to fix_subnet, delete the unchosen part of supernet + load_fix_subnet(self.architecture, fix_subnet) + self.is_supernet = False + else: + # Mutator is an essential component of the NAS algorithm. It + # provides some APIs commonly used by NAS. + # Before using it, you must do some preparations according to + # the supernet. + self.mutator.prepare_from_supernet(self.architecture) + self.is_supernet = True + + self.drop_prob = drop_prob + + def _build_mutator(self, mutator: VALID_MUTATOR_TYPE) -> BaseMutator: + """build mutator.""" + assert mutator is not None, \ + 'mutator cannot be None when fix_subnet is None.' + if isinstance(mutator, OneShotModuleMutator): + self.mutator = mutator + elif isinstance(mutator, dict): + self.mutator = MODELS.build(mutator) + else: + raise TypeError('mutator should be a `dict` or ' + f'`OneShotModuleMutator` instance, but got ' + f'{type(mutator)}') + return mutator + + def sample_subnet(self) -> SingleMutatorRandomSubnet: + """Random sample subnet by mutator.""" + return self.mutator.sample_choices() + + def set_subnet(self, subnet: SingleMutatorRandomSubnet): + """Set the subnet sampled by :meth:sample_subnet.""" + self.mutator.set_choices(subnet) + + def loss( + self, + batch_inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + ) -> LossResults: + """Calculate losses from a batch of inputs and data samples.""" + if self.is_supernet: + random_subnet = self.sample_subnet() + self.set_subnet(random_subnet) + return self.architecture(batch_inputs, data_samples, mode='loss') + else: + return self.architecture(batch_inputs, data_samples, mode='loss') + + def train(self, mode=True): + """Convert the model into eval mode while keep normalization layer + unfreezed.""" + + super().train(mode) + if self.norm_training and not mode: + for module in self.architecture.modules(): + if isinstance(module, _BatchNorm): + module.training = True diff --git a/mmrazor/models/task_modules/__init__.py b/mmrazor/models/task_modules/__init__.py index b86bebbb9..4d152d383 100644 --- a/mmrazor/models/task_modules/__init__.py +++ b/mmrazor/models/task_modules/__init__.py @@ -4,5 +4,6 @@ from .predictor import * # noqa: F401,F403 from .recorder import * # noqa: F401,F403 from .tracer import * # noqa: F401,F403 +from .multi_object_optimizer import * # noqa: F401,F403 __all__ = ['ResourceEstimator'] diff --git a/mmrazor/models/task_modules/multi_object_optimizer/__init__.py b/mmrazor/models/task_modules/multi_object_optimizer/__init__.py new file mode 100644 index 000000000..fc985e92c --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .genetic_optimizer import GeneticOptimizer +from .nsga2_optimizer import NSGA2Optimizer +from .problem import AuxiliarySingleLevelProblem, SubsetProblem + +__all__ = [ + 'AuxiliarySingleLevelProblem', 'SubsetProblem', + 'GeneticOptimizer', 'NSGA2Optimizer' +] \ No newline at end of file diff --git a/mmrazor/models/task_modules/multi_object_optimizer/base_optimizer.py b/mmrazor/models/task_modules/multi_object_optimizer/base_optimizer.py new file mode 100644 index 000000000..796fa37a2 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/base_optimizer.py @@ -0,0 +1,210 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# copied and modified from https://github.com/anyoptimization/pymoo +from abc import abstractmethod + +import numpy as np +from pymoo.util.optimum import filter_optimum + +# from pymoo.core.evaluator import Evaluator +# from pymoo.core.population import Population +from .utils.helper import Evaluator, Population + + +class BaseOptimizer(): + """This class represents the abstract class for any algorithm to be + implemented. The solve method provides a wrapper function which does + validate the input. + + Args: + problem : + Problem to be solved by the algorithm + verbose (bool): + If true information during the algorithm execution are displayed + save_history (bool): + If true, a current snapshot of each generation is saved. + pf (numpy.array): + The Pareto-front for the given problem. If provided performance + metrics are printed during execution. + return_least_infeasible (bool): + Whether the algorithm should return the least infeasible solution, + if no solution was found. + evaluator : :class:`~pymoo.model.evaluator.Evaluator` + The evaluator which can be used to make modifications before + calling the evaluate function of a problem. + """ + + def __init__(self, **kwargs): + # ! + # DEFAULT SETTINGS OF ALGORITHM + # ! + # set the display variable supplied to the algorithm + self.display = kwargs.get('display') + self.logger = kwargs.get('logger') + # ! + # Attributes to be set later on for each problem run + # ! + # the optimization problem as an instance + self.problem = None + + self.return_least_infeasible = None + # whether the history should be saved or not + self.save_history = None + # whether the algorithm should print output in this run or not + self.verbose = None + # the random seed that was used + self.seed = None + # the pareto-front of the problem - if it exist or passed + self.pf = None + # the function evaluator object (can be used to inject code) + self.evaluator = None + # the current number of generation or iteration + self.n_gen = None + # the history object which contains the list + self.history = None + # the current solutions stored - here considered as population + self.pop = None + # the optimum found by the algorithm + self.opt = None + # can be used to store additional data in submodules + self.data = {} + + def initialize( + self, + problem, + pf=True, + evaluator=None, + # START Default minimize + seed=None, + verbose=False, + save_history=False, + return_least_infeasible=False, + # END Default minimize + n_gen=1, + display=None, + # END Overwrite by minimize + **kwargs): + + # set the problem that is optimized for the current run + self.problem = problem + + # set the provided pareto front + self.pf = pf + + # by default make sure an evaluator exists if nothing is passed + if evaluator is None: + evaluator = Evaluator() + self.evaluator = evaluator + + # ! + # START Default minimize + # ! + # if this run should be verbose or not + self.verbose = verbose + # whether the least infeasible should be returned or not + self.return_least_infeasible = return_least_infeasible + # whether the history should be stored or not + self.save_history = save_history + + # set the random seed in the algorithm object + self.seed = seed + if self.seed is None: + self.seed = np.random.randint(0, 10000000) + np.random.seed(self.seed) + # ! + # END Default minimize + # ! + + if display is not None: + self.display = display + + # other run dependent variables that are reset + self.n_gen = n_gen + self.history = [] + self.pop = Population() + self.opt = None + + def solve(self): + + # the result object to be finally returned + res = {} + + # initialize the first population and evaluate it + self._initialize() + self._set_optimum() + + self.current_gen = 0 + # while termination criterion not fulfilled + while self.current_gen < self.n_gen: + self.current_gen += 1 + self.next() + + # store the resulting population + res['pop'] = self.pop + + # get the optimal solution found + opt = self.opt + + # if optimum is not set + if len(opt) == 0: + opt = None + + # if no feasible solution has been found + elif not np.any(opt.get('feasible')): + if self.return_least_infeasible: + opt = filter_optimum(opt, least_infeasible=True) + else: + opt = None + + # set the optimum to the result object + res['opt'] = opt + + # if optimum is set to none to not report anything + if opt is None: + X, F, CV, G = None, None, None, None + + # otherwise get the values from the population + else: + X, F, CV, G = self.opt.get('X', 'F', 'CV', 'G') + + # if single-objective problem and only one solution was found + if self.problem.n_obj == 1 and len(X) == 1: + X, F, CV, G = X[0], F[0], CV[0], G[0] + + # set all the individual values + res['X'], res['F'], res['CV'], res['G'] = X, F, CV, G + + # create the result object + res['problem'], res['pf'] = self.problem, self.pf + res['history'] = self.history + + return res + + def next(self): + # call next of the implementation of the algorithm + self._next() + + # set the optimum - only done if the algorithm did not do it yet + self._set_optimum() + + # do what needs to be done each generation + self._each_iteration() + + # method that is called each iteration to call some algorithms regularly + def _each_iteration(self): + + # display the output if defined by the algorithm + if self.logger: + self.logger.info(f'Generation:[{self.current_gen}/{self.n_gen}] ' + f'evaluate {self.evaluator.n_eval} solutions, ' + f'find {len(self.opt)} optimal solution.') + + def _finalize(self): + pass + + @abstractmethod + def _initialize(self): + pass + + @abstractmethod + def _next(self): + pass diff --git a/mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py b/mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py new file mode 100644 index 000000000..7af88ec79 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# copied and modified from https://github.com/anyoptimization/pymoo +import numpy as np +from pymoo.algorithms.soo.nonconvex.ga import FitnessSurvival +from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting + +from mmrazor.registry import TASK_UTILS +from .nsga2_optimizer import NSGA2Optimizer +from .utils.helper import Individual, Population +from .utils.selection import (BinaryCrossover, IntegerFromFloatMutation, Mating, + MyMutation, MySampling, PointCrossover, + TournamentSelection) + + + +@TASK_UTILS.register_module() +class GeneticOptimizer(NSGA2Optimizer): + """Genetic Algorithm.""" + + def __init__(self, + pop_size=100, + sampling=MySampling(), + selection=TournamentSelection(func_comp='comp_by_cv_and_fitness'), + crossover=BinaryCrossover(), + mutation=MyMutation(), + eliminate_duplicates=True, + n_offsprings=None, + display=None, + **kwargs): + """ + Args: + pop_size : {pop_size} + sampling : {sampling} + selection : {selection} + crossover : {crossover} + mutation : {mutation} + eliminate_duplicates : {eliminate_duplicates} + n_offsprings : {n_offsprings} + + """ + + super().__init__( + pop_size=pop_size, + sampling=sampling, + selection=selection, + crossover=crossover, + mutation=mutation, + survival=FitnessSurvival(), + eliminate_duplicates=eliminate_duplicates, + n_offsprings=n_offsprings, + display=display, + **kwargs) + + def _set_optimum(self, force=False): + pop = self.pop + self.opt = filter_optimum(pop, least_infeasible=True) + + +def filter_optimum(pop, least_infeasible=False): + # first only choose feasible solutions + ret = pop[pop.get('feasible')[:, 0]] + + # if at least one feasible solution was found + if len(ret) > 0: + + # then check the objective values + F = ret.get('F') + + if F.shape[1] > 1: + Index = NonDominatedSorting().do(F, only_non_dominated_front=True) + ret = ret[Index] + + else: + ret = ret[np.argmin(F)] + + # no feasible solution was found + else: + # if flag enable report the least infeasible + if least_infeasible: + ret = pop[np.argmin(pop.get('CV'))] + # otherwise just return none + else: + ret = None + + if isinstance(ret, Individual): + ret = Population().create(ret) + + return ret diff --git a/mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py b/mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py new file mode 100644 index 000000000..9ebf36583 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# copied and modified from https://github.com/anyoptimization/pymoo +import numpy as np + +from mmrazor.registry import TASK_UTILS +from .base_optimizer import BaseOptimizer +from .utils.selection import (IntegerFromFloatMutation, Mating, + PointCrossover, TournamentSelection, + binary_tournament) +from .utils.helper import (DefaultDuplicateElimination, Individual, + Initialization, Survival) + +# from pymoo.algorithms.moo.nsga2 import binary_tournament +# from pymoo.core.mating import Mating +# from pymoo.core.survival import Survival +# from pymoo.core.individual import Individual +# from pymoo.core.initialization import Initialization +# from pymoo.core.duplicate import DefaultDuplicateElimination +# from pymoo.operators.crossover.pntx import PointCrossover +# from pymoo.operators.selection.tournament import TournamentSelection +# from .packages.selection import IntegerFromFloatMutation + + +@TASK_UTILS.register_module() +class NSGA2Optimizer(BaseOptimizer): + """Implementation of NSGA2 search method. + + Args: + pop_size : {pop_size} + sampling : {sampling} + selection : {selection} + crossover : {crossover} + mutation : {mutation} + eliminate_duplicates : {eliminate_duplicates} + n_offsprings : {n_offsprings} + """ + + def __init__(self, + pop_size=100, + sampling=None, + selection=TournamentSelection(func_comp=binary_tournament), + crossover=PointCrossover(n_points=2), + mutation=IntegerFromFloatMutation(eta=1.0), + eliminate_duplicates=True, + n_offsprings=None, + display=None, + survival=Survival(), + repair=None, + **kwargs): + super().__init__( + pop_size=pop_size, + sampling=sampling, + selection=selection, + crossover=crossover, + mutation=mutation, + survival=survival, + eliminate_duplicates=eliminate_duplicates, + n_offsprings=n_offsprings, + display=display, + **kwargs) + + # the population size used + self.pop_size = pop_size + + # the survival for the genetic algorithm + self.survival = Survival() + + # number of offsprings to generate through recombination + self.n_offsprings = n_offsprings + + # if the number of offspring is not set + if self.n_offsprings is None: + self.n_offsprings = pop_size + + # the object to be used to represent an individual + self.individual = Individual(rank=np.inf, crowding=-1) + + # set the duplicate detection class + if isinstance(eliminate_duplicates, bool): + if eliminate_duplicates: + self.eliminate_duplicates = DefaultDuplicateElimination() + else: + self.eliminate_duplicates = None + else: + self.eliminate_duplicates = eliminate_duplicates + + self.initialization = Initialization( + sampling, + individual=self.individual, + repair=repair, + eliminate_duplicates=self.eliminate_duplicates) + + self.mating = Mating( + selection, + crossover, + mutation, + repair=repair, + eliminate_duplicates=self.eliminate_duplicates, + n_max_iterations=100) + + # other run specific data updated whenever solve is called + self.n_gen = None + self.pop = None + self.off = None + + def _initialize(self): + + # create the initial population + pop = self.initialization.do( + self.problem, self.pop_size, algorithm=self) + + # then evaluate using the objective function + self.evaluator.eval(self.problem, pop, algorithm=self) + + # that call is a dummy survival to set attributes + # that are necessary for the mating selection + if self.survival: + pop = self.survival.do(self.problem, pop, len(pop), algorithm=self) + + self.pop, self.off = pop, pop + + def _next(self): + + # do the mating using the current population + self.off = self.mating.do( + self.problem, + self.pop, + n_offsprings=self.n_offsprings, + algorithm=self) + + # if the mating could not generate any new offspring + if len(self.off) == 0: + return + + # evaluate the offspring + self.evaluator.eval(self.problem, self.off, algorithm=self) + + # merge the offsprings with the current population + self.pop = self.pop.merge(self.off) + + # the do survival selection + if self.survival: + self.pop = self.survival.do( + self.problem, self.pop, self.pop_size, algorithm=self) + + def _set_optimum(self, **kwargs): + if not np.any(self.pop.get('feasible')): + self.opt = self.pop[[np.argmin(self.pop.get('CV'))]] + else: + self.opt = self.pop[self.pop.get('rank') == 0] diff --git a/mmrazor/models/task_modules/multi_object_optimizer/problem/__init__.py b/mmrazor/models/task_modules/multi_object_optimizer/problem/__init__.py new file mode 100644 index 000000000..426e6283d --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/problem/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .auxiliary_singlelevel_problem import AuxiliarySingleLevelProblem +from .subset_problem import SubsetProblem + +__all__ = ['AuxiliarySingleLevelProblem', 'SubsetProblem'] diff --git a/mmrazor/models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py b/mmrazor/models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py new file mode 100644 index 000000000..060b60a34 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from pymoo.core.problem import Problem as BaseProblem + + +class AuxiliarySingleLevelProblem(BaseProblem): + """The optimization problem for finding the next N candidate + architectures.""" + + def __init__(self, searcher, dim=15, sec_obj='flops'): + super().__init__(n_var=dim, n_obj=2, vtype=np.int32) + + self.searcher = searcher + self.predictor = self.searcher.predictor + self.sec_obj = sec_obj + + self.xl = np.zeros(self.n_var) + # upper bound for variable, automatically calculate by search space + self.xu = [] + for mutable in self.predictor.search_groups.values(): + if mutable[0].num_choices > 0: + self.xu.append(mutable[0].num_choices - 1) + self.xu = np.array(self.xu) + + def _evaluate(self, x, out, *args, **kwargs): + """Evaluate results.""" + f = np.full((x.shape[0], self.n_obj), np.nan) + # predicted top1 error + top1_err = self.predictor.handler.predict(x)[:, 0] + for i in range(len(x)): + candidate = self.predictor.vector2model(x[i]) + _, resource = self.searcher._check_constraints(candidate) + f[i, 0] = top1_err[i] + f[i, 1] = resource[self.sec_obj] + + out['F'] = f diff --git a/mmrazor/models/task_modules/multi_object_optimizer/problem/base_problem.py b/mmrazor/models/task_modules/multi_object_optimizer/problem/base_problem.py new file mode 100644 index 000000000..b82f0be93 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/problem/base_problem.py @@ -0,0 +1,327 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# copied and modified from https://github.com/anyoptimization/pymoo +from abc import abstractmethod + +import numpy as np + + +def at_least_2d_array(x, extend_as='row'): + if not isinstance(x, np.ndarray): + x = np.array([x]) + + if x.ndim == 1: + if extend_as == 'row': + x = x[None, :] + elif extend_as == 'column': + x = x[:, None] + + return x + + +class BaseProblem(): + """Superclass for each problem that is defined. + + It provides attributes such as the number of variables, number of + objectives or constraints. Also, the lower and upper bounds are stored. If + available the Pareto-front, nadir point and ideal point are stored. + """ + + def __init__(self, + n_var=-1, + n_obj=-1, + n_constr=0, + xl=None, + xu=None, + type_var=np.double, + evaluation_of='auto', + parallelization=None, + elementwise_evaluation=False, + callback=None): + """ + Args: + n_var (int): + number of variables + n_obj (int): + number of objectives + n_constr (int): + number of constraints + xl (np.array or int): + lower bounds for the variables. + xu (np.array or int): + upper bounds for the variables. + type_var (numpy type): + type of the variable to be evaluated. + elementwise_evaluation (bool): + parallelization (str or tuple): + See :ref:`nb_parallelization` for guidance on parallelization. + + """ + + # number of variable for this problem + self.n_var = n_var + + # type of the variable to be evaluated + self.type_var = type_var + + # number of objectives + self.n_obj = n_obj + + # number of constraints + self.n_constr = n_constr + + # allow just an integer for xl and xu if all bounds are equal + if n_var > 0 and not isinstance(xl, np.ndarray) and xl is not None: + self.xl = np.ones(n_var) * xl + else: + self.xl = xl + + if n_var > 0 and not isinstance(xu, np.ndarray) and xu is not None: + self.xu = np.ones(n_var) * xu + else: + self.xu = xu + + # the pareto set and front will be calculated only once. + self._pareto_front = None + self._pareto_set = None + self._ideal_point, self._nadir_point = None, None + + # actually defines what _evaluate is setting during the evaluation + if evaluation_of == 'auto': + # by default F is set, and G if the problem does have constraints + self.evaluation_of = ['F'] + if self.n_constr > 0: + self.evaluation_of.append('G') + else: + self.evaluation_of = evaluation_of + + self.elementwise_evaluation = elementwise_evaluation + + self.parallelization = parallelization + + # store the callback if defined + self.callback = callback + + def nadir_point(self): + """Return nadir_point (np.array): + + The nadir point for a multi-objective problem. + """ + # if the ideal point has not been calculated yet + if self._nadir_point is None: + + # calculate the pareto front if not happened yet + if self._pareto_front is None: + self.pareto_front() + + # if already done or it was successful - calculate the ideal point + if self._pareto_front is not None: + self._ideal_point = np.max(self._pareto_front, axis=0) + + return self._nadir_point + + def ideal_point(self): + """ + Returns + ------- + ideal_point (np.array): + The ideal point for a multi-objective problem. If single-objective + it returns the best possible solution. + """ + + # if the ideal point has not been calculated yet + if self._ideal_point is None: + + # calculate the pareto front if not happened yet + if self._pareto_front is None: + self.pareto_front() + + # if already done or it was successful - calculate the ideal point + if self._pareto_front is not None: + self._ideal_point = np.min(self._pareto_front, axis=0) + + return self._ideal_point + + def pareto_front(self, + *args, + use_cache=True, + exception_if_failing=True, + **kwargs): + """ + Args: + args : Same problem implementation need some more information to + create the Pareto front. + exception_if_failing (bool): + Whether to throw an exception when generating the Pareto front + has failed. + use_cache (bool): + Whether to use the cache if the Pareto front. + + Returns: + P (np.array): + The Pareto front of a given problem. + + """ + if not use_cache or self._pareto_front is None: + try: + pf = self._calc_pareto_front(*args, **kwargs) + if pf is not None: + self._pareto_front = at_least_2d_array(pf) + + except Exception as e: + if exception_if_failing: + raise e + + return self._pareto_front + + def pareto_set(self, *args, use_cache=True, **kwargs): + """ + Returns: + S (np.array): + Returns the pareto set for a problem. + """ + if not use_cache or self._pareto_set is None: + self._pareto_set = at_least_2d_array( + self._calc_pareto_set(*args, **kwargs)) + + return self._pareto_set + + def evaluate(self, + X, + *args, + return_values_of='auto', + return_as_dictionary=False, + **kwargs): + """Evaluate the given problem. + + The function values set as defined in the function. + The constraint values are meant to be positive if infeasible. + + Args: + + X (np.array): + A two dimensional matrix where each row is a point to evaluate + and each column a variable. + + return_as_dictionary (bool): + If this is true than only one object, a dictionary, + is returned. + return_values_of (list of strings): + Allowed is ["F", "CV", "G", "dF", "dG", "dCV", "feasible"] + where the d stands for derivative and h stands for hessian + matrix. + + + Returns: + A dictionary, if return_as_dictionary enabled, or a list of values + as defined in return_values_of. + """ + + # call the callback of the problem + if self.callback is not None: + self.callback(X) + + only_single_value = len(np.shape(X)) == 1 + X = np.atleast_2d(X) + + # check the dimensionality of the problem and the given input + if X.shape[1] != self.n_var: + raise Exception('Input dimension %s are not equal to n_var %s!' % + (X.shape[1], self.n_var)) + + if type(return_values_of) == str and return_values_of == 'auto': + return_values_of = ['F'] + if self.n_constr > 0: + return_values_of.append('CV') + + out = {} + for val in return_values_of: + out[val] = None + + out = self._evaluate_batch(X, False, out, *args, **kwargs) + + # if constraint violation should be returned as well + if self.n_constr == 0: + CV = np.zeros([X.shape[0], 1]) + else: + CV = self.calc_constraint_violation(out['G']) + + if 'CV' in return_values_of: + out['CV'] = CV + + # if an additional boolean flag for feasibility should be returned + if 'feasible' in return_values_of: + out['feasible'] = (CV <= 0) + + # if asked for a value but not set in the evaluation set to None + for val in return_values_of: + if val not in out: + out[val] = None + + if only_single_value: + for key in out.keys(): + if out[key] is not None: + out[key] = out[key][0, :] + + if return_as_dictionary: + return out + else: + + # if just a single value do not return a tuple + if len(return_values_of) == 1: + return out[return_values_of[0]] + else: + return tuple([out[val] for val in return_values_of]) + + def _evaluate_batch(self, X, calc_gradient, out, *args, **kwargs): + self._evaluate(X, out, *args, **kwargs) + for key in out.keys(): + if len(np.shape(out[key])) == 1: + out[key] = out[key][:, None] + + return out + + @abstractmethod + def _evaluate(self, x, out, *args, **kwargs): + pass + + def has_bounds(self): + return self.xl is not None and self.xu is not None + + def bounds(self): + return self.xl, self.xu + + def name(self): + return self.__class__.__name__ + + def _calc_pareto_front(self, *args, **kwargs): + """Method that either loads or calculates the pareto front. This is + only done ones and the pareto front is stored. + + Returns: + pf (np.array): Pareto front as array. + """ + pass + + def _calc_pareto_set(self, *args, **kwargs): + pass + + # some problem information + def __str__(self): + s = '# name: %s\n' % self.name() + s += '# n_var: %s\n' % self.n_var + s += '# n_obj: %s\n' % self.n_obj + s += '# n_constr: %s\n' % self.n_constr + s += '# f(xl): %s\n' % self.evaluate(self.xl)[0] + s += '# f((xl+xu)/2): %s\n' % self.evaluate( + (self.xl + self.xu) / 2.0)[0] + s += '# f(xu): %s\n' % self.evaluate(self.xu)[0] + return s + + @staticmethod + def calc_constraint_violation(G): + if G is None: + return None + elif G.shape[1] == 0: + return np.zeros(G.shape[0])[:, None] + else: + return np.sum(G * (G > 0), axis=1)[:, None] diff --git a/mmrazor/models/task_modules/multi_object_optimizer/problem/subset_problem.py b/mmrazor/models/task_modules/multi_object_optimizer/problem/subset_problem.py new file mode 100644 index 000000000..2a0af040a --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/problem/subset_problem.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +from .base_problem import BaseProblem + + +class SubsetProblem(BaseProblem): + """select a subset to diversify the pareto front.""" + + def __init__(self, candidates, archive, K): + super().__init__( + n_var=len(candidates), + n_obj=1, + n_constr=1, + xl=0, + xu=1, + type_var=bool) + self.archive = archive + self.candidates = candidates + self.n_max = K + + def _evaluate(self, x, out, *args, **kwargs): + f = np.full((x.shape[0], 1), np.nan) + g = np.full((x.shape[0], 1), np.nan) + + for i, _x in enumerate(x): + # append selected candidates to archive then sort + tmp = np.sort(np.concatenate((self.archive, self.candidates[_x]))) + f[i, 0] = np.std(np.diff(tmp)) + # we penalize if the number of selected candidates is not exactly K + g[i, 0] = (self.n_max - np.sum(_x))**2 + + out['F'] = f + out['G'] = g diff --git a/mmrazor/models/task_modules/multi_object_optimizer/utils/__init__.py b/mmrazor/models/task_modules/multi_object_optimizer/utils/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/utils/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmrazor/models/task_modules/multi_object_optimizer/utils/domin_matrix.py b/mmrazor/models/task_modules/multi_object_optimizer/utils/domin_matrix.py new file mode 100644 index 000000000..cf06da663 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/utils/domin_matrix.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + + +def get_relation(a, b, cva=None, cvb=None): + + if cva is not None and cvb is not None: + if cva < cvb: + return 1 + elif cvb < cva: + return -1 + + val = 0 + for i in range(len(a)): + if a[i] < b[i]: + # indifferent because once better and once worse + if val == -1: + return 0 + val = 1 + elif b[i] < a[i]: + # indifferent because once better and once worse + if val == 1: + return 0 + val = -1 + return val + + +def calc_domination_matrix_loop(F, G): + n = F.shape[0] + CV = np.sum(G * (G > 0).astype(np.float32), axis=1) + M = np.zeros((n, n)) + for i in range(n): + for j in range(i + 1, n): + M[i, j] = get_relation(F[i, :], F[j, :], CV[i], CV[j]) + M[j, i] = -M[i, j] + + return M + + +def calc_domination_matrix(F, _F=None, epsilon=0.0): + """ + if G is None or len(G) == 0: + constr = np.zeros((F.shape[0], F.shape[0])) + else: + # consider the constraint violation + # CV = Problem.calc_constraint_violation(G) + # constr = (CV < CV) * 1 + (CV > CV) * -1 + + CV = Problem.calc_constraint_violation(G)[:, 0] + constr = (CV[:, None] < CV) * 1 + (CV[:, None] > CV) * -1 + """ + + if _F is None: + _F = F + + # look at the obj for dom + n = F.shape[0] + m = _F.shape[0] + + L = np.repeat(F, m, axis=0) + R = np.tile(_F, (n, 1)) + + smaller = np.reshape(np.any(L + epsilon < R, axis=1), (n, m)) + larger = np.reshape(np.any(L > R + epsilon, axis=1), (n, m)) + + M = np.logical_and(smaller, np.logical_not(larger)) * 1 \ + + np.logical_and(larger, np.logical_not(smaller)) * -1 + + # if cv equal then look at dom + # M = constr + (constr == 0) * dom + + return M + + +def fast_non_dominated_sort(F, **kwargs): + M = calc_domination_matrix(F) + + # calculate the dominance matrix + n = M.shape[0] + + fronts = [] + + if n == 0: + return fronts + + # final rank that will be returned + n_ranked = 0 + ranked = np.zeros(n, dtype=np.int32) + is_dominating = [[] for _ in range(n)] + + # storage for the number of solutions dominated this one + n_dominated = np.zeros(n) + + current_front = [] + + for i in range(n): + + for j in range(i + 1, n): + rel = M[i, j] + if rel == 1: + is_dominating[i].append(j) + n_dominated[j] += 1 + elif rel == -1: + is_dominating[j].append(i) + n_dominated[i] += 1 + + if n_dominated[i] == 0: + current_front.append(i) + ranked[i] = 1.0 + n_ranked += 1 + + # append the first front to the current front + fronts.append(current_front) + + # while not all solutions are assigned to a pareto front + while n_ranked < n: + + next_front = [] + + # for each individual in the current front + for i in current_front: + + # all solutions that are dominated by this individuals + for j in is_dominating[i]: + n_dominated[j] -= 1 + if n_dominated[j] == 0: + next_front.append(j) + ranked[j] = 1.0 + n_ranked += 1 + + fronts.append(next_front) + current_front = next_front + + return fronts diff --git a/mmrazor/models/task_modules/multi_object_optimizer/utils/helper.py b/mmrazor/models/task_modules/multi_object_optimizer/utils/helper.py new file mode 100644 index 000000000..b0362f333 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/utils/helper.py @@ -0,0 +1,668 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# copied from https://github.com/anyoptimization/pymoo +import copy + +import numpy as np +import scipy +from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting + + +def default_attr(pop): + return pop.get('X') + + +def cdist(x, y): + return scipy.spatial.distance.cdist(x, y) + + +class DuplicateElimination: + """Implementation of Elimination. + + func: function to execute. + """ + + def __init__(self, func=None) -> None: + super().__init__() + self.func = func + + if self.func is None: + self.func = default_attr + + def do(self, pop, *args, return_indices=False, to_itself=True): + original = pop + + if to_itself: + pop = pop[~self._do(pop, None, np.full(len(pop), False))] + + for arg in args: + if len(arg) > 0: + + if len(pop) == 0: + break + elif len(arg) == 0: + continue + else: + pop = pop[~self._do(pop, arg, np.full(len(pop), False))] + + if return_indices: + no_duplicate, is_duplicate = [], [] + H = set(pop) + + for Index, ind in enumerate(original): + if ind in H: + no_duplicate.append(Index) + else: + is_duplicate.append(Index) + + return pop, no_duplicate, is_duplicate + else: + return pop + + +class DefaultDuplicateElimination(DuplicateElimination): + """Implementation of DefaultDuplicate Elimination. + + epsilon(float): smallest dist for judge duplication. + """ + + def __init__(self, epsilon=1e-16, **kwargs) -> None: + super().__init__(**kwargs) + self.epsilon = epsilon + + def calc_dist(self, pop, other=None): + X = self.func(pop) + + if other is None: + D = cdist(X, X) + D[np.triu_indices(len(X))] = np.inf + else: + _X = self.func(other) + D = cdist(X, _X) + + return D + + def _do(self, pop, other, is_duplicate): + D = self.calc_dist(pop, other) + D[np.isnan(D)] = np.inf + + is_duplicate[np.any(D < self.epsilon, axis=1)] = True + return is_duplicate + + +class Individual: + """Class for each individual in search step.""" + + def __init__(self, + X=None, + F=None, + CV=None, + G=None, + feasible=None, + **kwargs) -> None: + self.X = X + self.F = F + self.CV = CV + self.G = G + self.feasible = feasible + self.data = kwargs + self.attr = set(self.__dict__.keys()) + + def has(self, key): + return key in self.attr or key in self.data + + def set(self, key, value): + if key in self.attr: + self.__dict__[key] = value + else: + self.data[key] = value + + def copy(self): + ind = copy.copy(self) + ind.data = self.data.copy() + return ind + + def get(self, keys): + if keys in self.data: + return self.data[keys] + elif keys in self.attr: + return self.__dict__[keys] + else: + return None + + +class Population(np.ndarray): + """Class for all the population in search step.""" + + def __new__(cls, n_individuals=0, individual=Individual()): + obj = super(Population, cls).__new__( + cls, n_individuals, dtype=individual.__class__).view(cls) + for Index in range(n_individuals): + obj[Index] = individual.copy() + obj.individual = individual + return obj + + def merge(self, a, b=None): + if b: + a, b = pop_from_array_or_individual(a), \ + pop_from_array_or_individual(b) + a.merge(b) + else: + other = pop_from_array_or_individual(a) + if len(self) == 0: + return other + else: + obj = np.concatenate([self, other]).view(Population) + obj.individual = self.individual + return obj + + def copy(self): + pop = Population(n_individuals=len(self), individual=self.individual) + for Index in range(len(self)): + pop[Index] = self[Index] + return pop + + def has(self, key): + return all([ind.has(key) for ind in self]) + + def __deepcopy__(self, memo): + return self.copy() + + @classmethod + def create(cls, *args): + pop = np.concatenate([ + pop_from_array_or_individual(arg) for arg in args + ]).view(Population) + pop.individual = Individual() + return pop + + def new(self, *args): + + if len(args) == 1: + return Population( + n_individuals=args[0], individual=self.individual) + else: + n = len(args[1]) if len(args) > 0 else 0 + pop = Population(n_individuals=n, individual=self.individual) + if len(args) > 0: + pop.set(*args) + return pop + + def collect(self, func, to_numpy=True): + val = [] + for Index in range(len(self)): + val.append(func(self[Index])) + if to_numpy: + val = np.array(val) + return val + + def set(self, *args): + + for Index in range(int(len(args) / 2)): + + key, values = args[Index * 2], args[Index * 2 + 1] + is_iterable = hasattr(values, + '__len__') and not isinstance(values, str) + + if is_iterable and len(values) != len(self): + raise Exception( + 'Population Set Attribute Error: ' + 'Number of values and population size do not match!') + + for Index in range(len(self)): + val = values[Index] if is_iterable else values + self[Index].set(key, val) + + return self + + def get(self, *args, to_numpy=True): + + val = {} + for c in args: + val[c] = [] + + for Index in range(len(self)): + + for c in args: + val[c].append(self[Index].get(c)) + + res = [val[c] for c in args] + if to_numpy: + res = [np.array(e) for e in res] + + if len(args) == 1: + return res[0] + else: + return tuple(res) + + def __array_finalize__(self, obj): + if obj is None: + return + self.individual = getattr(obj, 'individual', None) + + +def pop_from_array_or_individual(array, pop=None): + # the population type can be different + if pop is None: + pop = Population() + + # provide a whole population object + if isinstance(array, Population): + pop = array + elif isinstance(array, np.ndarray): + pop = pop.new('X', np.atleast_2d(array)) + elif isinstance(array, Individual): + pop = Population(1) + pop[0] = array + else: + return None + + return pop + + +class Initialization: + """Initiallize step.""" + + def __init__(self, + sampling, + individual=Individual(), + repair=None, + eliminate_duplicates=None) -> None: + + super().__init__() + self.sampling = sampling + self.individual = individual + self.repair = repair + self.eliminate_duplicates = eliminate_duplicates + + def do(self, problem, n_samples, **kwargs): + + # provide a whole population object + if isinstance(self.sampling, Population): + pop = self.sampling + + else: + pop = Population(0, individual=self.individual) + if isinstance(self.sampling, np.ndarray): + pop = pop.new('X', self.sampling) + else: + pop = self.sampling.do(problem, n_samples, pop=pop, **kwargs) + + # repair all solutions that are not already evaluated + if self.repair: + Index = [k for k in range(len(pop)) if pop[k].F is None] + pop = self.repair.do(problem, pop[Index], **kwargs) + + if self.eliminate_duplicates is not None: + pop = self.eliminate_duplicates.do(pop) + + return pop + + +def split_by_feasibility(pop, sort_infeasbible_by_cv=True): + CV = pop.get('CV') + + b = (CV <= 0) + + feasible = np.where(b)[0] + infeasible = np.where(np.logical_not(b))[0] + + if sort_infeasbible_by_cv: + infeasible = infeasible[np.argsort(CV[infeasible, 0])] + + return feasible, infeasible + + +class Survival: + """The survival process is implemented inheriting from this class, which + selects from a population only specific individuals to survive. + + Parameters + ---------- + filter_infeasible : bool + Whether for the survival infeasible solutions should be + filtered first + """ + + def __init__(self, filter_infeasible=True): + self.filter_infeasible = filter_infeasible + + def do(self, problem, pop, n_survive, return_indices=False, **kwargs): + + # if the split should be done beforehand + if self.filter_infeasible and problem.n_constr > 0: + feasible, infeasible = split_by_feasibility( + pop, sort_infeasbible_by_cv=True) + + # if there was no feasible solution was added at all + if len(feasible) == 0: + survivors = pop[infeasible[:n_survive]] + + # if there are feasible solutions in the population + else: + survivors = pop.new() + + # if feasible solution do exist + if len(feasible) > 0: + survivors = self._do(problem, pop[feasible], + min(len(feasible), n_survive), + **kwargs) + + # if infeasible solutions needs to be added + if len(survivors) < n_survive: + least_infeasible = infeasible[:n_survive - len(feasible)] + survivors = survivors.merge(pop[least_infeasible]) + + else: + survivors = self._do(problem, pop, n_survive, **kwargs) + + if return_indices: + H = {} + for k, ind in enumerate(pop): + H[ind] = k + return [H[survivor] for survivor in survivors] + else: + return survivors + + def _do(self, problem, pop, n_survive, D=None, **kwargs): + + # get the objective space values and objects + F = pop.get('F').astype(np.float, copy=False) + + # the final indices of surviving individuals + survivors = [] + + # do the non-dominated sorting until splitting front + fronts = NonDominatedSorting().do(F, n_stop_if_ranked=n_survive) + + for k, front in enumerate(fronts): + + # calculate the crowding distance of the front + crowding_of_front = calc_crowding_distance(F[front, :]) + # save rank and crowding in the individual class + for j, Index in enumerate(front): + pop[Index].set('rank', k) + pop[Index].set('crowding', crowding_of_front[j]) + + # current front sorted by crowding distance if splitting + if len(survivors) + len(front) > n_survive: + Index = randomized_argsort( + crowding_of_front, order='descending', method='numpy') + Index = Index[:(n_survive - len(survivors))] + + # otherwise take the whole front unsorted + else: + Index = np.arange(len(front)) + + # extend the survivors by all or selected individuals + survivors.extend(front[Index]) + + return pop[survivors] + + +class FitnessSurvival(Survival): + """Survival class for Fitness.""" + + def __init__(self) -> None: + super().__init__(True) + + def _do(self, problem, pop, n_survive, out=None, **kwargs): + F = pop.get('F') + + if F.shape[1] != 1: + raise ValueError( + 'FitnessSurvival can only used for single objective single!') + + return pop[np.argsort(F[:, 0])[:n_survive]] + + +def find_duplicates(X, epsilon=1e-16): + # calculate the distance matrix from each point to another + D = cdist(X, X) + + # set the diagonal to infinity + D[np.triu_indices(len(X))] = np.inf + + # set as duplicate if a point is really close to this one + is_duplicate = np.any(D < epsilon, axis=1) + + return is_duplicate + + +def calc_crowding_distance(F, filter_out_duplicates=True): + n_points, n_obj = F.shape + + if n_points <= 2: + return np.full(n_points, np.inf) + + else: + + if filter_out_duplicates: + # filter out solutions which are duplicates + is_unique = np.where( + np.logical_not(find_duplicates(F, epsilon=1e-24)))[0] + else: + # set every point to be unique without checking it + is_unique = np.arange(n_points) + + # index the unique points of the array + _F = F[is_unique] + + # sort each column and get index + Index = np.argsort(_F, axis=0, kind='mergesort') + + # sort the objective space values for the whole matrix + _F = _F[Index, np.arange(n_obj)] + + # calculate the distance from each point to the last and next + dist = np.row_stack([_F, np.full(n_obj, np.inf)]) - np.row_stack( + [np.full(n_obj, -np.inf), _F]) + + # calculate the norm for each objective + norm = np.max(_F, axis=0) - np.min(_F, axis=0) + norm[norm == 0] = np.nan + + # prepare the distance to last and next vectors + dist_to_last, dist_to_next = dist, np.copy(dist) + dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[ + 1:] / norm + + dist_to_last[np.isnan(dist_to_last)] = 0.0 + dist_to_next[np.isnan(dist_to_next)] = 0.0 + + # sum up the distance to next and last and norm by objectives + J = np.argsort(Index, axis=0) + _cd = np.sum( + dist_to_last[J, np.arange(n_obj)] + + dist_to_next[J, np.arange(n_obj)], + axis=1) / n_obj + + # save the final vector which sets the crowding distance for duplicates + crowding = np.zeros(n_points) + crowding[is_unique] = _cd + + # crowding[np.isinf(crowding)] = 1e+14 + return crowding + + +def randomized_argsort(A, method='numpy', order='ascending'): + if method == 'numpy': + P = np.random.permutation(len(A)) + Index = np.argsort(A[P], kind='quicksort') + Index = P[Index] + + elif method == 'quicksort': + Index = quicksort(A) + + else: + raise Exception('Randomized sort method not known.') + + if order == 'ascending': + return Index + elif order == 'descending': + return np.flip(Index, axis=0) + else: + raise Exception('Unknown sorting order: ascending or descending.') + + +def swap(M, a, b): + tmp = M[a] + M[a] = M[b] + M[b] = tmp + + +def quicksort(A): + Index = np.arange(len(A)) + _quicksort(A, Index, 0, len(A) - 1) + return Index + + +def _quicksort(A, Index, left, right): + if left < right: + + index = np.random.randint(left, right + 1) + swap(Index, right, index) + + pivot = A[Index[right]] + + Index = left - 1 + + for j in range(left, right): + + if A[Index[j]] <= pivot: + Index += 1 + swap(Index, Index, j) + + index = Index + 1 + swap(Index, right, index) + + _quicksort(A, Index, left, index - 1) + _quicksort(A, Index, index + 1, right) + + +def random_permuations(n, input): + perms = [] + for _ in range(n): + perms.append(np.random.permutation(input)) + P = np.concatenate(perms) + return P + + +def crossover_mask(X, M): + # convert input to output by flatting along the first axis + _X = np.copy(X) + _X[0][M] = X[1][M] + _X[1][M] = X[0][M] + + return _X + + +def at_least_2d_array(x, extend_as='row'): + if not isinstance(x, np.ndarray): + x = np.array([x]) + + if x.ndim == 1: + if extend_as == 'row': + x = x[None, :] + elif extend_as == 'column': + x = x[:, None] + + return x + + +def repair_out_of_bounds(problem, X): + xl, xu = problem.xl, problem.xu + + only_1d = (X.ndim == 1) + X = at_least_2d_array(X) + + if xl is not None: + xl = np.repeat(xl[None, :], X.shape[0], axis=0) + X[X < xl] = xl[X < xl] + + if xu is not None: + xu = np.repeat(xu[None, :], X.shape[0], axis=0) + X[X > xu] = xu[X > xu] + + if only_1d: + return X[0, :] + else: + return X + + +def denormalize(x, x_min, x_max): + + if x_max is None: + _range = 1 + else: + _range = (x_max - x_min) + + return x * _range + x_min + + +class Evaluator: + """The evaluator class which is used during the algorithm execution to + limit the number of evaluations.""" + + def __init__(self, evaluate_values_of=['F', 'CV', 'G']): + self.n_eval = 0 + self.evaluate_values_of = evaluate_values_of + + def eval(self, problem, pop, **kwargs): + """This function is used to return the result of one valid evaluation. + + Parameters + ---------- + problem : class + The problem which is used to be evaluated + pop : np.array or Population object + kwargs : dict + Additional arguments which might be necessary for the problem to + evaluate. + """ + + is_individual = isinstance(pop, Individual) + is_numpy_array = isinstance( + pop, np.ndarray) and not isinstance(pop, Population) + + # make sure the object is a population + if is_individual or is_numpy_array: + pop = Population().create(pop) + + # find indices to be evaluated + Index = [k for k in range(len(pop)) if pop[k].F is None] + + # update the function evaluation counter + self.n_eval += len(Index) + + # actually evaluate all solutions using the function + if len(Index) > 0: + self._eval(problem, pop[Index], **kwargs) + + # set the feasibility attribute if cv exists + for ind in pop[Index]: + cv = ind.get('CV') + if cv is not None: + ind.set('feasible', cv <= 0) + + if is_individual: + return pop[0] + elif is_numpy_array: + if len(pop) == 1: + pop = pop[0] + return tuple([pop.get(e) for e in self.evaluate_values_of]) + else: + return pop + + def _eval(self, problem, pop, **kwargs): + + out = problem.evaluate( + pop.get('X'), + return_values_of=self.evaluate_values_of, + return_as_dictionary=True, + **kwargs) + + for key, val in out.items(): + if val is None: + continue + else: + pop.set(key, val) diff --git a/mmrazor/models/task_modules/multi_object_optimizer/utils/selection.py b/mmrazor/models/task_modules/multi_object_optimizer/utils/selection.py new file mode 100644 index 000000000..3ec1c644e --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/utils/selection.py @@ -0,0 +1,490 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import numpy as np + +from .domin_matrix import get_relation +# from pymoo.core.population import Population +# from pymoo.util.misc import crossover_mask, random_permuations +# from pymoo.operators.repair.bounce_back import bounce_back_by_problem +from .helper import (Population, crossover_mask, random_permuations, + repair_out_of_bounds) + + +def binary_tournament(pop, P, algorithm, **kwargs): + if P.shape[1] != 2: + raise ValueError('Only implemented for binary tournament!') + + S = np.full(P.shape[0], np.nan) + + for Index in range(P.shape[0]): + + a, b = P[Index, 0], P[Index, 1] + + # if at least one solution is infeasible + if pop[a].CV > 0.0 or pop[b].CV > 0.0: + S[Index] = compare( + a, + pop[a].CV, + b, + pop[b].CV, + method='smaller_is_better', + return_random_if_equal=True) + else: + rel = get_relation(pop[a].F, pop[b].F) + if rel == 1: + S[Index] = a + elif rel == -1: + S[Index] = b + # if rank or domination relation didn't make a decision + if np.isnan(S[Index]): + S[Index] = compare( + a, + pop[a].get('crowding'), + b, + pop[b].get('crowding'), + method='larger_is_better', + return_random_if_equal=True) + + return S[:, None].astype(int, copy=False) + + +def comp_by_cv_and_fitness(pop, P, **kwargs): + S = np.full(P.shape[0], np.nan) + + for Index in range(P.shape[0]): + a, b = P[Index, 0], P[Index, 1] + + # if at least one solution is infeasible + if pop[a].CV > 0.0 or pop[b].CV > 0.0: + S[Index] = compare( + a, + pop[a].CV, + b, + pop[b].CV, + method='smaller_is_better', + return_random_if_equal=True) + + # both solutions are feasible just set random + else: + S[Index] = compare( + a, + pop[a].F, + b, + pop[b].F, + method='smaller_is_better', + return_random_if_equal=True) + + return S[:, None].astype(int) + + +def compare(a, a_val, b, b_val, method, return_random_if_equal=False): + if method == 'larger_is_better': + if a_val > b_val: + return a + elif a_val < b_val: + return b + else: + if return_random_if_equal: + return np.random.choice([a, b]) + else: + return None + elif method == 'smaller_is_better': + if a_val < b_val: + return a + elif a_val > b_val: + return b + else: + if return_random_if_equal: + return np.random.choice([a, b]) + else: + return None + + +class TournamentSelection: + """The Tournament selection is used to simulated a tournament between + individuals. + + The pressure balances greedy the genetic algorithm will be. + """ + + def __init__(self, pressure=2, func_comp='binary_tournament'): + """ + + Parameters + ---------- + func_comp: func + The function to compare two individuals. + It has the shape: comp(pop, indices) and returns the winner. + + pressure: int + The selection pressure to bie applied. + """ + + # selection pressure to be applied + self.pressure = pressure + if func_comp == 'comp_by_cv_and_fitness': + self.f_comp = comp_by_cv_and_fitness + else: + self.f_comp = binary_tournament + + def do(self, pop, n_select, n_parents=2, **kwargs): + # number of random individuals needed + n_random = n_select * n_parents * self.pressure + + # number of permutations needed + n_perms = math.ceil(n_random / len(pop)) + + # get random permutations and reshape them + P = random_permuations(n_perms, len(pop))[:n_random] + P = np.reshape(P, (n_select * n_parents, self.pressure)) + + # compare using tournament function + S = self.f_comp(pop, P, **kwargs) + + return np.reshape(S, (n_select, n_parents)) + + +class PointCrossover: + + def __init__(self, n_points=2, n_parents=2, n_offsprings=2, prob=0.9): + self.n_points = n_points + self.prob = prob + self.n_parents = n_parents + self.n_offsprings = n_offsprings + + def do(self, problem, pop, parents, **kwargs): + """ + + Parameters + ---------- + problem: class + The problem to be solved. + + pop : Population + The population as an object + + parents: numpy.array + The select parents of the population for the crossover + + kwargs : dict + Any additional data that might be necessary. + + Returns + ------- + offsprings : Population + The off as a matrix. n_children rows and the number of columns is + equal to the variable length of the problem. + + """ + + if self.n_parents != parents.shape[1]: + raise ValueError( + 'Exception during crossover: ' + 'Number of parents differs from defined at crossover.') + + # get the design space matrix form the population and parents + X = pop.get('X')[parents.T].copy() + + # now apply the crossover probability + do_crossover = np.random.random(len(parents)) < self.prob + + # execute the crossover + _X = self._do(problem, X, **kwargs) + + X[:, do_crossover, :] = _X[:, do_crossover, :] + + # flatten the array to become a 2d-array + X = X.reshape(-1, X.shape[-1]) + + # create a population object + off = pop.new('X', X) + + return off + + def _do(self, problem, X, **kwargs): + + # get the X of parents and count the matings + _, n_matings, n_var = X.shape + + # start point of crossover + r = np.row_stack([ + np.random.permutation(n_var - 1) + 1 for _ in range(n_matings) + ])[:, :self.n_points] + r.sort(axis=1) + r = np.column_stack([r, np.full(n_matings, n_var)]) + + # the mask do to the crossover + M = np.full((n_matings, n_var), False) + + # create for each individual the crossover range + for Index in range(n_matings): + + j = 0 + while j < r.shape[1] - 1: + a, b = r[Index, j], r[Index, j + 1] + M[Index, a:b] = True + j += 2 + + _X = crossover_mask(X, M) + + return _X + + +class PolynomialMutation: + + def __init__(self, eta=20, prob=None): + super().__init__() + self.eta = float(eta) + + if prob is not None: + self.prob = float(prob) + else: + self.prob = None + + def _do(self, problem, X, **kwargs): + + Y = np.full(X.shape, np.inf) + + if self.prob is None: + self.prob = 1.0 / problem.n_var + + do_mutation = np.random.random(X.shape) < self.prob + + Y[:, :] = X + + xl = np.repeat(problem.xl[None, :], X.shape[0], axis=0)[do_mutation] + xu = np.repeat(problem.xu[None, :], X.shape[0], axis=0)[do_mutation] + + X = X[do_mutation] + + delta1 = (X - xl) / (xu - xl) + delta2 = (xu - X) / (xu - xl) + + mut_pow = 1.0 / (self.eta + 1.0) + + rand = np.random.random(X.shape) + mask = rand <= 0.5 + mask_not = np.logical_not(mask) + + deltaq = np.zeros(X.shape) + + xy = 1.0 - delta1 + val = 2.0 * rand + (1.0 - 2.0 * rand) * ( + np.power(xy, (self.eta + 1.0))) + d = np.power(val, mut_pow) - 1.0 + deltaq[mask] = d[mask] + + xy = 1.0 - delta2 + val = 2.0 * (1.0 - rand) + 2.0 * (rand - 0.5) * ( + np.power(xy, (self.eta + 1.0))) + d = 1.0 - (np.power(val, mut_pow)) + deltaq[mask_not] = d[mask_not] + + # mutated values + _Y = X + deltaq * (xu - xl) + + # back in bounds if necessary (floating point issues) + _Y[_Y < xl] = xl[_Y < xl] + _Y[_Y > xu] = xu[_Y > xu] + + # set the values for output + Y[do_mutation] = _Y + + # in case out of bounds repair (very unlikely) + # Y = bounce_back_by_problem(problem, Y) + Y = repair_out_of_bounds(problem, Y) + + return Y + + def do(self, problem, pop, **kwargs): + Y = self._do(problem, pop.get('X'), **kwargs) + return pop.new('X', Y) + + +class IntegerFromFloatMutation: + + def __init__(self, **kwargs): + + self.mutation = PolynomialMutation(**kwargs) + + def _do(self, problem, X, **kwargs): + + def fun(): + return self.mutation._do(problem, X, **kwargs) + + # save the original bounds of the problem + _xl, _xu = problem.xl, problem.xu + + # copy the arrays of the problem and cast them to float + xl, xu = problem.xl, problem.xu + + # modify the bounds to match the new crossover specifications + problem.xl = xl - (0.5 - 1e-16) + problem.xu = xu + (0.5 - 1e-16) + + # perform the crossover + off = fun() + + # now round to nearest integer for all offsprings + off = np.rint(off) + + # reset the original bounds of the problem and design space values + problem.xl = _xl + problem.xu = _xu + + return off + + def do(self, problem, pop, **kwargs): + """Mutate variables in a genetic way. + + Parameters + ---------- + problem : class + The problem instance + pop : Population + A population object + + Returns + ------- + Y : Population + The mutated population. + """ + + Y = self._do(problem, pop.get('X'), **kwargs) + return pop.new('X', Y) + + +class Mating: + + def __init__(self, + selection, + crossover, + mutation, + repair=None, + eliminate_duplicates=None, + n_max_iterations=100): + + self.selection = selection + self.crossover = crossover + self.mutation = mutation + self.n_max_iterations = n_max_iterations + self.eliminate_duplicates = eliminate_duplicates + self.repair = repair + + def _do(self, problem, pop, n_offsprings, parents=None, **kwargs): + + # if the parents for the mating are not provided directly + if parents is None: + + # how many parents need to be select for the mating + n_select = math.ceil(n_offsprings / self.crossover.n_offsprings) + + # select the parents for the mating - just an index array + parents = self.selection.do(pop, n_select, + self.crossover.n_parents, **kwargs) + + # do the crossover using the parents index and the population + _off = self.crossover.do(problem, pop, parents, **kwargs) + + # do the mutation on the offsprings created through crossover + _off = self.mutation.do(problem, _off, **kwargs) + + return _off + + def do(self, problem, pop, n_offsprings, **kwargs): + + # the population object to be used + off = pop.new() + + # infill counter + # counts how often the mating needs to be done to fill up n_offsprings + n_infills = 0 + # iterate until enough offsprings are created + while len(off) < n_offsprings: + + # how many offsprings are remaining to be created + n_remaining = n_offsprings - len(off) + + # do the mating + _off = self._do(problem, pop, n_remaining, **kwargs) + + # repair the individuals if necessary + if self.repair: + _off = self.repair.do(problem, _off, **kwargs) + + if self.eliminate_duplicates is not None: + _off = self.eliminate_duplicates.do(_off, pop, off) + + # if more offsprings than necessary - truncate them randomly + if len(off) + len(_off) > n_offsprings: + n_remaining = n_offsprings - len(off) + _off = _off[:n_remaining] + + # add to the offsprings and increase the mating counter + off = off.merge(_off) + n_infills += 1 + + if n_infills > self.n_max_iterations: + break + + return off + + +class MySampling: + + def __init__(self): + pass + + def do(self, problem, n_samples, pop=Population(), **kwargs): + X = np.full((n_samples, problem.n_var), False, dtype=bool) + + for k in range(n_samples): + Index = np.random.permutation(problem.n_var)[:problem.n_max] + X[k, Index] = True + + if pop is None: + return X + return pop.new('X', X) + + +class BinaryCrossover(PointCrossover): + + def __init__(self): + super().__init__(n_parents=2, n_offsprings=1) + + def _do(self, problem, X, **kwargs): + n_parents, n_matings, n_var = X.shape + + _X = np.full((self.n_offsprings, n_matings, problem.n_var), False) + + for k in range(n_matings): + p1, p2 = X[0, k], X[1, k] + + both_are_true = np.logical_and(p1, p2) + _X[0, k, both_are_true] = True + + n_remaining = problem.n_max - np.sum(both_are_true) + + Index = np.where(np.logical_xor(p1, p2))[0] + + S = Index[np.random.permutation(len(Index))][:n_remaining] + _X[0, k, S] = True + + return _X + + +class MyMutation(PolynomialMutation): + + def _do(self, problem, X, **kwargs): + for Index in range(X.shape[0]): + X[Index, :] = X[Index, :] + is_false = np.where(np.logical_not(X[Index, :]))[0] + is_true = np.where(X[Index, :])[0] + try: + X[Index, np.random.choice(is_false)] = True + X[Index, np.random.choice(is_true)] = False + except ValueError: + pass + + return X diff --git a/mmrazor/models/task_modules/predictor/metric_predictor.py b/mmrazor/models/task_modules/predictor/metric_predictor.py index ae99ff091..79a915e4d 100644 --- a/mmrazor/models/task_modules/predictor/metric_predictor.py +++ b/mmrazor/models/task_modules/predictor/metric_predictor.py @@ -132,7 +132,7 @@ def vector2model(self, vector: np.array) -> Dict[str, str]: if self.encoding_type == 'onehot': index = np.where(vector[start:start + len(value[0].choices)] == 1)[0][0] - start += len(value) + start += len(value[0].choices) else: index = vector[start] start += 1 From 3d76a7277a5c36ded5b2fb719711c5c6c2c0ed4f Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Fri, 9 Dec 2022 16:15:25 +0800 Subject: [PATCH 02/12] fix bugs --- .../engine/runner/evolution_search_loop.py | 15 ++- .../engine/runner/nsganetv2_search_loop.py | 94 +++++++++++-------- mmrazor/models/task_modules/__init__.py | 2 +- .../multi_object_optimizer/__init__.py | 6 +- .../genetic_optimizer.py | 25 +++-- .../multi_object_optimizer/nsga2_optimizer.py | 5 +- 6 files changed, 85 insertions(+), 62 deletions(-) diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index 12360e653..a0d01b7c7 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -62,8 +62,9 @@ class EvolutionSearchLoop(EpochBasedTrainLoop, CalibrateBNMixin): def __init__(self, runner, - dataloader: Union[DataLoader, Dict], evaluator: Union[Evaluator, Dict, List], + dataloader: Union[DataLoader, Dict], + finetune_dataloader: Union[DataLoader, Dict] = None, max_epochs: int = 20, max_keep_ckpts: int = 3, resume_from: Optional[str] = None, @@ -84,6 +85,13 @@ def __init__(self, self.evaluator = runner.build_evaluator(evaluator) # type: ignore else: self.evaluator = evaluator # type: ignore + + if isinstance(finetune_dataloader, dict): + self.finetune_dataloader = self.runner.build_dataloader( + finetune_dataloader) + else: + self.finetune_dataloader = finetune_dataloader + if hasattr(self.dataloader.dataset, 'metainfo'): self.evaluator.dataset_meta = self.dataloader.dataset.metainfo else: @@ -199,6 +207,8 @@ def sample_candidates(self) -> None: if self.runner.rank == 0: while len(self.candidates) < self.num_candidates: candidate = self.model.sample_subnet() + self.model.set_subnet(candidate) + self.finetune_step() is_pass, result = self._check_constraints( random_subnet=candidate) if is_pass: @@ -459,3 +469,6 @@ def _init_predictor(self): f'Predictor pre-trained, saved in {predictor_dir}.') self.use_predictor = True self.candidates = Candidates() + + def finetune_step(self, subnet): + pass diff --git a/mmrazor/engine/runner/nsganetv2_search_loop.py b/mmrazor/engine/runner/nsganetv2_search_loop.py index 73aadd4f6..f67703b18 100644 --- a/mmrazor/engine/runner/nsganetv2_search_loop.py +++ b/mmrazor/engine/runner/nsganetv2_search_loop.py @@ -1,14 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from copy import deepcopy import numpy as np from mmengine import fileio +from mmengine.optim import build_optim_wrapper from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting -from mmrazor.models.task_modules import (GeneticOptimizer, - NSGA2Optimizer, - AuxiliarySingleLevelProblem, +from mmrazor.models.task_modules import (AuxiliarySingleLevelProblem, + GeneticOptimizer, NSGA2Optimizer, SubsetProblem) from mmrazor.registry import LOOPS from mmrazor.structures import Candidates, export_fix_subnet @@ -56,7 +55,7 @@ def run_epoch(self) -> None: self.candidates.extend(self.top_k_candidates) self.sort_candidates() - self.top_k_candidates = Candidates(self.candidates[:self.top_k]) + self.top_k_candidates = Candidates(self.candidates.data[:self.top_k]) scores_after = self.top_k_candidates.scores self.runner.logger.info(f'top k scores after update: ' @@ -92,15 +91,19 @@ def sample_candidates(self, random: bool = True, archive=None) -> None: if len(candidates_resources) > 0: self.candidates.update_resources( candidates_resources, - start=len(self.candidates.data)-len(candidates_resources)) + start=len(self.candidates.data) - + len(candidates_resources)) - def sample_candidates_with_nsga2(self, archive: Candidates, num_candidates): + def sample_candidates_with_nsga2(self, archive: Candidates, + num_candidates): """Searching for candidates with high-fidelity evaluation.""" F = np.column_stack((archive.scores, archive.resources('flops'))) - front_index = NonDominatedSorting().do(F, only_non_dominated_front=True) + front_index = NonDominatedSorting().do( + F, only_non_dominated_front=True) fronts = np.array(archive.subnets)[front_index] - fronts = np.array([self.predictor.model2vector(cand) for cand in fronts]) + fronts = np.array( + [self.predictor.model2vector(cand) for cand in fronts]) fronts = self.predictor.preprocess(fronts) # initialize the candidate finding optimization problem @@ -120,24 +123,24 @@ def sample_candidates_with_nsga2(self, archive: Candidates, num_candidates): # check for duplicates check_list = [] for x in result['pop'].get('X'): - assert x is not None check_list.append(self.predictor.vector2model(x)) - + not_duplicate = np.logical_not( [x in archive.subnets for x in check_list]) # extra process after nsga2 search - sub_problem = SubsetProblem(result['pop'][not_duplicate].get('F')[:, 1], - F[front_index, 1], - num_candidates) - sub_method = GeneticOptimizer(pop_size=num_candidates, - eliminate_duplicates=True) + sub_problem = SubsetProblem( + result['pop'][not_duplicate].get('F')[:, 1], F[front_index, 1], + num_candidates) + sub_method = GeneticOptimizer( + pop_size=num_candidates, eliminate_duplicates=True) sub_method.initialize(sub_problem, n_gen=4, verbose=False) indices = sub_method.solve()['X'] - - candidates = Candidates() + + candidates = [] pop = result['pop'][not_duplicate][indices] for x in pop.get('X'): + x = x[0] if isinstance(x[0], list) else x candidates.append(self.predictor.vector2model(x)) return candidates @@ -146,20 +149,20 @@ def sort_candidates(self) -> None: """Support sort candidates in single and multiple-obj optimization.""" assert self.trade_off is not None, ( '`self.trade_off` is required when sorting candidates in ' - 'NSGA2SearchLoop. Got self.trade_off is None.') + 'NSGA2SearchLoop. Got `self.trade_off` is None.') ratio = self.trade_off.get('ratio', 1) - multiple_obj_score = [] + max_score_key = self.trade_off.get('max_score_key', 100) + + multi_obj_score = [] for score, flops in zip(self.candidates.scores, self.candidates.resources('flops')): - multiple_obj_score.append((score, flops)) - multiple_obj_score = np.array(multiple_obj_score) - max_score_key = self.trade_off.get('max_score_key', 100) + multi_obj_score.append((score, flops)) + multi_obj_score = np.array(multi_obj_score) if max_score_key != 0: - multiple_obj_score[:, 0] = \ - max_score_key - multiple_obj_score[:, 0] - sort_idx = np.argsort(multiple_obj_score[:, 0]) - F = multiple_obj_score[sort_idx] - dm = HighTradeoffPoints(ratio, n_survive=len(multiple_obj_score)) + multi_obj_score[:, 0] = max_score_key - multi_obj_score[:, 0] + sort_idx = np.argsort(multi_obj_score[:, 0]) + F = multi_obj_score[sort_idx] + dm = HighTradeoffPoints(ratio, n_survive=len(multi_obj_score)) candidate_index = dm.do(F) candidate_index = sort_idx[candidate_index] self.candidates = [self.candidates[idx] for idx in candidate_index] @@ -187,10 +190,9 @@ def _save_searcher_ckpt(self, archive=[]): f'search_epoch_{self._epoch}.pkl')) correlation_str = 'fitting ' - # correlation_str += f'{self.predictor.type}: ' correlation_str += f'RMSE = {rmse:.4f}, ' correlation_str += f'Spearmans Rho = {rho:.4f}, ' - correlation_str += f'num_candidatesendalls Tau = {tau:.4f}' + correlation_str += f'Kendalls Tau = {tau:.4f}' self.pareto_mode = False if self.pareto_mode: @@ -225,21 +227,31 @@ def fit_predictor(self, candidates): metrics[i] = self.max_score_key - metrics[i] return metrics - def finetune_step(self, model): + def finetune_step(self): """fintune before candidates evaluation.""" - # TODO (gaoyang): update with 2.0 version. self.runner.logger.info('start finetuning...') - model.train() - while self._fintune_epoch < self._max_finetune_epochs: + self.model.train() + + self._finetune_epoch = 0 + self._max_finetune_epochs = 1 + + optimizer_cfg = dict( + type='SGD', + lr=0.5, + momentum=0.9, + nesterov=True, + weight_decay=0.0001) + optim_wrapper_cfg = dict(optimizer=optimizer_cfg) + optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg) + + while self._finetune_epoch < self._max_finetune_epochs: self.runner.call_hook('before_train_epoch') - for idx, data_batch in enumerate(self.dataloader): + for idx, data_batch in enumerate(self.finetune_dataloader): self.runner.call_hook( - 'before_train_iter', - batch_idx=idx, - data_batch=data_batch) + 'before_train_iter', batch_idx=idx, data_batch=data_batch) - outputs = model.train_step( - data_batch, optim_wrapper=self.optim_wrapper) + outputs = self.model.train_step( + data_batch, optim_wrapper=optim_wrapper) self.runner.call_hook( 'after_train_iter', @@ -250,4 +262,4 @@ def finetune_step(self, model): self.runner.call_hook('after_train_epoch') self._finetune_epoch += 1 - model.eval() + self.model.eval() diff --git a/mmrazor/models/task_modules/__init__.py b/mmrazor/models/task_modules/__init__.py index 4d152d383..7a691f78d 100644 --- a/mmrazor/models/task_modules/__init__.py +++ b/mmrazor/models/task_modules/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .delivery import * # noqa: F401,F403 from .estimators import ResourceEstimator +from .multi_object_optimizer import * # noqa: F401,F403 from .predictor import * # noqa: F401,F403 from .recorder import * # noqa: F401,F403 from .tracer import * # noqa: F401,F403 -from .multi_object_optimizer import * # noqa: F401,F403 __all__ = ['ResourceEstimator'] diff --git a/mmrazor/models/task_modules/multi_object_optimizer/__init__.py b/mmrazor/models/task_modules/multi_object_optimizer/__init__.py index fc985e92c..710587f59 100644 --- a/mmrazor/models/task_modules/multi_object_optimizer/__init__.py +++ b/mmrazor/models/task_modules/multi_object_optimizer/__init__.py @@ -4,6 +4,6 @@ from .problem import AuxiliarySingleLevelProblem, SubsetProblem __all__ = [ - 'AuxiliarySingleLevelProblem', 'SubsetProblem', - 'GeneticOptimizer', 'NSGA2Optimizer' -] \ No newline at end of file + 'AuxiliarySingleLevelProblem', 'SubsetProblem', 'GeneticOptimizer', + 'NSGA2Optimizer' +] diff --git a/mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py b/mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py index 7af88ec79..5639bc912 100644 --- a/mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py +++ b/mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py @@ -7,26 +7,25 @@ from mmrazor.registry import TASK_UTILS from .nsga2_optimizer import NSGA2Optimizer from .utils.helper import Individual, Population -from .utils.selection import (BinaryCrossover, IntegerFromFloatMutation, Mating, - MyMutation, MySampling, PointCrossover, +from .utils.selection import (BinaryCrossover, MyMutation, MySampling, TournamentSelection) - @TASK_UTILS.register_module() class GeneticOptimizer(NSGA2Optimizer): """Genetic Algorithm.""" - def __init__(self, - pop_size=100, - sampling=MySampling(), - selection=TournamentSelection(func_comp='comp_by_cv_and_fitness'), - crossover=BinaryCrossover(), - mutation=MyMutation(), - eliminate_duplicates=True, - n_offsprings=None, - display=None, - **kwargs): + def __init__( + self, + pop_size=100, + sampling=MySampling(), + selection=TournamentSelection(func_comp='comp_by_cv_and_fitness'), + crossover=BinaryCrossover(), + mutation=MyMutation(), + eliminate_duplicates=True, + n_offsprings=None, + display=None, + **kwargs): """ Args: pop_size : {pop_size} diff --git a/mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py b/mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py index 9ebf36583..2c31dc741 100644 --- a/mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py +++ b/mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py @@ -4,11 +4,10 @@ from mmrazor.registry import TASK_UTILS from .base_optimizer import BaseOptimizer -from .utils.selection import (IntegerFromFloatMutation, Mating, - PointCrossover, TournamentSelection, - binary_tournament) from .utils.helper import (DefaultDuplicateElimination, Individual, Initialization, Survival) +from .utils.selection import (IntegerFromFloatMutation, Mating, PointCrossover, + TournamentSelection, binary_tournament) # from pymoo.algorithms.moo.nsga2 import binary_tournament # from pymoo.core.mating import Mating From 1733affd05970b19294c2fa260c0c2724396670f Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Thu, 5 Jan 2023 13:16:45 +0800 Subject: [PATCH 03/12] unify search_groups by alias --- .../engine/runner/evolution_search_loop.py | 23 +++++++---- .../engine/runner/nsganetv2_search_loop.py | 26 ++++++------- mmrazor/models/algorithms/nas/bignas.py | 23 +++-------- .../mutable_channel/base_mutable_channel.py | 3 +- mmrazor/models/mutators/group_mixin.py | 38 +++++++++---------- .../problem/auxiliary_singlelevel_problem.py | 9 ++++- .../predictor/metric_predictor.py | 20 ++++++++-- mmrazor/structures/subnet/candidate.py | 15 +------- 8 files changed, 78 insertions(+), 79 deletions(-) diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index 30fd73841..9cb581f39 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -138,8 +138,15 @@ def __init__(self, self.predictor_cfg = predictor_cfg if self.predictor_cfg is not None: self.predictor_cfg['score_key'] = self.score_key - self.predictor_cfg['search_groups'] = \ - self.model.mutator.search_groups + if hasattr(self.model, 'mutators'): + self.predictor_cfg['search_groups'] = dict() + for mutator in self.model.mutators.values(): + self.predictor_cfg['search_groups'].update( + mutator.search_groups) + else: + assert hasattr(self.model, 'mutator') + self.predictor_cfg['search_groups'] = \ + self.model.mutator.search_groups self.predictor = TASK_UTILS.build(self.predictor_cfg) def run(self) -> None: @@ -181,7 +188,7 @@ def run_epoch(self) -> None: self.candidates.extend(self.top_k_candidates) self.candidates.sort_by(key_indicator='score', reverse=True) - self.top_k_candidates = Candidates(self.candidates.data[:self.top_k]) + self.top_k_candidates = Candidates(self.candidates[:self.top_k]) scores_after = self.top_k_candidates.scores self.runner.logger.info(f'top k scores after update: ' @@ -213,19 +220,19 @@ def sample_candidates(self) -> None: if is_pass: self.candidates.append(candidate) candidates_resources.append(result) - self.candidates = Candidates(self.candidates.data) + self.candidates = Candidates(self.candidates) else: self.candidates = Candidates([dict(a=0)] * self.num_candidates) if len(candidates_resources) > 0: self.candidates.update_resources( candidates_resources, - start=len(self.candidates.data) - len(candidates_resources)) + start=len(self.candidates) - len(candidates_resources)) assert init_candidates + len( candidates_resources) == self.num_candidates # broadcast candidates to val with multi-GPUs. - broadcast_object_list(self.candidates.data) + broadcast_object_list([self.candidates]) def update_candidates_scores(self) -> None: """Validate candicate one by one from the candicate pool, and update @@ -269,7 +276,7 @@ def gen_mutation_candidates(self): return mutation_candidates def gen_crossover_candidates(self): - """Generate specofied number of crossover candicates.""" + """Generate specified number of crossover candicates.""" crossover_resources = [] crossover_candidates: List = [] crossover_iter = 0 @@ -463,5 +470,5 @@ def _init_predictor(self): self.use_predictor = True self.candidates = Candidates() - def finetune_step(self, subnet): + def finetune_step(self): pass diff --git a/mmrazor/engine/runner/nsganetv2_search_loop.py b/mmrazor/engine/runner/nsganetv2_search_loop.py index f67703b18..56eda84b0 100644 --- a/mmrazor/engine/runner/nsganetv2_search_loop.py +++ b/mmrazor/engine/runner/nsganetv2_search_loop.py @@ -37,14 +37,15 @@ def run_epoch(self) -> None: 4. Implement Mutation and crossover, generate better candidates. """ archive = Candidates() - for subnet, score, flops in zip(self.candidates.subnets, - self.candidates.scores, - self.candidates.resources('flops')): - if self.trade_off['max_score_key'] != 0: - score = self.trade_off['max_score_key'] - score - archive.append(subnet) - archive.set_score(-1, score) - archive.set_resource(-1, flops, 'flops') + if len(self.candidates) > 0: + for subnet, score, flops in zip( + self.candidates.subnets, self.candidates.scores, + self.candidates.resources('flops')): + if self.trade_off['max_score_key'] != 0: + score = self.trade_off['max_score_key'] - score + archive.append(subnet) + archive.set_score(-1, score) + archive.set_resource(-1, flops, 'flops') self.sample_candidates(random=(self._epoch == 0), archive=archive) self.update_candidates_scores() @@ -55,7 +56,7 @@ def run_epoch(self) -> None: self.candidates.extend(self.top_k_candidates) self.sort_candidates() - self.top_k_candidates = Candidates(self.candidates.data[:self.top_k]) + self.top_k_candidates = Candidates(self.candidates[:self.top_k]) scores_after = self.top_k_candidates.scores self.runner.logger.info(f'top k scores after update: ' @@ -91,8 +92,7 @@ def sample_candidates(self, random: bool = True, archive=None) -> None: if len(candidates_resources) > 0: self.candidates.update_resources( candidates_resources, - start=len(self.candidates.data) - - len(candidates_resources)) + start=len(self.candidates) - len(candidates_resources)) def sample_candidates_with_nsga2(self, archive: Candidates, num_candidates): @@ -111,13 +111,13 @@ def sample_candidates_with_nsga2(self, archive: Candidates, # initiate a multi-objective solver to optimize the problem method = NSGA2Optimizer( - pop_size=4, + pop_size=40, sampling=fronts, # initialize with current nd archs eliminate_duplicates=True, logger=self.runner.logger) # # kick-off the search - method.initialize(problem, n_gen=2, verbose=True) + method.initialize(problem, n_gen=20, verbose=True) result = method.solve() # check for duplicates diff --git a/mmrazor/models/algorithms/nas/bignas.py b/mmrazor/models/algorithms/nas/bignas.py index f75c60e57..7114a37c8 100644 --- a/mmrazor/models/algorithms/nas/bignas.py +++ b/mmrazor/models/algorithms/nas/bignas.py @@ -136,27 +136,16 @@ def _build_distiller( def sample_subnet(self, kind='random') -> Dict: """Random sample subnet by mutator.""" - value_subnet = dict() - channel_subnet = dict() - for name, mutator in self.mutators.items(): - if name == 'value_mutator': - value_subnet.update(mutator.sample_choices(kind)) - elif name == 'channel_mutator': - channel_subnet.update(mutator.sample_choices(kind)) - else: - raise NotImplementedError - return dict(value_subnet=value_subnet, channel_subnet=channel_subnet) + subnet = dict() + for mutator in self.mutators.values(): + subnet.update(mutator.sample_choices(kind)) + return subnet def set_subnet(self, subnet: Dict[str, Dict[int, Union[int, list]]]) -> None: """Set the subnet sampled by :meth:sample_subnet.""" - for name, mutator in self.mutators.items(): - if name == 'value_mutator': - mutator.set_choices(subnet['value_subnet']) - elif name == 'channel_mutator': - mutator.set_choices(subnet['channel_subnet']) - else: - raise NotImplementedError + for mutator in self.mutators.values(): + mutator.set_choices(subnet) def set_max_subnet(self) -> None: """Set max subnet.""" diff --git a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py index 65d5a44d6..ed3a9e617 100644 --- a/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/base_mutable_channel.py @@ -62,7 +62,8 @@ def fix_chosen(self, chosen=None): def dump_chosen(self) -> DumpChosen: """Dump chosen.""" - meta = dict(max_channels=self.mask.size(0)) + meta = dict( + max_channels=self.mask.size(0), all_choices=self.candidate_choices) chosen = self.export_chosen() return DumpChosen(chosen=chosen, meta=meta) diff --git a/mmrazor/models/mutators/group_mixin.py b/mmrazor/models/mutators/group_mixin.py index f6b84aea2..401fd8e69 100644 --- a/mmrazor/models/mutators/group_mixin.py +++ b/mmrazor/models/mutators/group_mixin.py @@ -208,7 +208,15 @@ def build_search_groups(self, supernet: Module, support_mutables: Type, f'The duplicate keys are {duplicate_keys}. ' \ 'Please check if there are duplicate keys in the `custom_group`.' - return search_groups + # TODO: update search groups + new_search_groups = dict() + for group_id, module in search_groups.items(): + if module[0].alias: + new_search_groups[module[0].alias] = module + else: + new_search_groups[group_id] = module + + return new_search_groups def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]], name2mutable: Dict[str, BaseMutable], @@ -259,26 +267,7 @@ def search_groups(self) -> Dict: ... -class OneShotSampleMixin: - """Sample mixin for one-shot mutators.""" - - def sample_choices(self: MutatorProtocol) -> Dict: - """Sample choices for each group in search_groups.""" - random_choices = dict() - for group_id, modules in self.search_groups.items(): - random_choices[group_id] = modules[0].sample_choice() - - return random_choices - - def set_choices(self: MutatorProtocol, choices: Dict) -> None: - """Set choices for each group in search_groups.""" - for group_id, modules in self.search_groups.items(): - choice = choices[group_id] - for module in modules: - module.current_choice = choice - - -class DynamicSampleMixin(OneShotSampleMixin): +class DynamicSampleMixin(): def sample_choices(self: MutatorProtocol, kind: str = 'random') -> Dict: """Sample choices for each group in search_groups.""" @@ -292,6 +281,13 @@ def sample_choices(self: MutatorProtocol, kind: str = 'random') -> Dict: random_choices[group_id] = modules[0].sample_choice() return random_choices + def set_choices(self: MutatorProtocol, choices: Dict) -> None: + """Set choices for each group in search_groups.""" + for group_id, modules in self.search_groups.items(): + choice = choices[modules[0].alias] + for module in modules: + module.current_choice = choice + @property def max_choice(self: MutatorProtocol) -> Dict: """Get max choices for each group in search_groups.""" diff --git a/mmrazor/models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py b/mmrazor/models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py index 060b60a34..e1badd52b 100644 --- a/mmrazor/models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py +++ b/mmrazor/models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py @@ -17,9 +17,14 @@ def __init__(self, searcher, dim=15, sec_obj='flops'): self.xl = np.zeros(self.n_var) # upper bound for variable, automatically calculate by search space self.xu = [] + from mmrazor.models.mutables import OneShotMutableChannelUnit for mutable in self.predictor.search_groups.values(): - if mutable[0].num_choices > 0: - self.xu.append(mutable[0].num_choices - 1) + if isinstance(mutable[0], OneShotMutableChannelUnit): + if mutable[0].num_channels > 0: + self.xu.append(mutable[0].num_channels - 1) + else: + if mutable[0].num_choices > 0: + self.xu.append(mutable[0].num_choices - 1) self.xu = np.array(self.xu) def _evaluate(self, x, out, *args, **kwargs): diff --git a/mmrazor/models/task_modules/predictor/metric_predictor.py b/mmrazor/models/task_modules/predictor/metric_predictor.py index d07e17dff..8ad71e522 100644 --- a/mmrazor/models/task_modules/predictor/metric_predictor.py +++ b/mmrazor/models/task_modules/predictor/metric_predictor.py @@ -105,8 +105,16 @@ def model2vector( len(choice.meta['all_choices']), dtype=np.int) _chosen_index = choice.meta['all_choices'].index(choice.chosen) else: - assert len(self.search_groups[index]) == 1 - choices = self.search_groups[index][0].choices + if key is not None: + from mmrazor.models.mutables import MutableChannelUnit + if isinstance(self.search_groups[key][0], + MutableChannelUnit): + choices = self.search_groups[key][0].candidate_choices + else: + choices = self.search_groups[key][0].choices + else: + assert len(self.search_groups[index]) == 1 + choices = self.search_groups[index][0].choices onehot = np.zeros(len(choices), dtype=np.int) _chosen_index = choices.index(choice) onehot[_chosen_index] = 1 @@ -126,6 +134,8 @@ def vector2model(self, vector: np.array) -> Dict[str, str]: Returns: Dict[str, str]: converted model. """ + from mmrazor.models.mutables import OneShotMutableChannelUnit + start = 0 model = {} for key, value in self.search_groups.items(): @@ -136,7 +146,11 @@ def vector2model(self, vector: np.array) -> Dict[str, str]: else: index = vector[start] start += 1 - chosen = value[0].choices[int(index)] + + if isinstance(value[0], OneShotMutableChannelUnit): + chosen = value[0].candidate_choices[int(index)] + else: + chosen = value[0].choices[int(index)] model[key] = chosen return model diff --git a/mmrazor/structures/subnet/candidate.py b/mmrazor/structures/subnet/candidate.py index 9f0ebc344..e4ddd8326 100644 --- a/mmrazor/structures/subnet/candidate.py +++ b/mmrazor/structures/subnet/candidate.py @@ -66,21 +66,8 @@ def resources(self, key_indicator: str = 'flops') -> List[float]: @property def subnets(self) -> List[Dict]: """The subnets of candidates.""" - import copy assert len(self.data) > 0, ('Got empty candidates.') - if 'value_subnet' in self.data[0]: - subnets = [] - for data in self.data: - subnet = dict() - _data = copy.deepcopy(data) - for k1 in ['value_subnet', 'channel_subnet']: - for k2 in self._indicators: - _data[k1].pop(k2) - subnet[k1] = _data[k1] - subnets.append(subnet) - return subnets - else: - return [eval(key) for item in self.data for key, _ in item.items()] + return [eval(key) for item in self.data for key, _ in item.items()] def _format(self, data: _format_input) -> _format_return: """Transform [Dict, ...] to Union[Dict[str, Dict], List[Dict[str, From 83da8677f63cff7d9f7bb4288ee589afc983a605 Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Fri, 6 Jan 2023 00:01:23 +0800 Subject: [PATCH 04/12] update algorithm & imagenet cfgs --- .../nsga_mobilenetv3_supernet.py | 106 ++++++++++++++++++ .../nsganetv2_mobilenet_search_8xb256_in1k.py | 25 +++++ ...sganetv2_mobilenet_supernet_8xb128_in1k.py | 48 ++++++++ mmrazor/models/algorithms/__init__.py | 4 +- mmrazor/models/algorithms/nas/__init__.py | 4 +- mmrazor/models/algorithms/nas/nsganetv2.py | 82 +++++++------- .../backbones/searchable_mobilenet_v3.py | 5 - .../predictor/metric_predictor.py | 15 +-- 8 files changed, 231 insertions(+), 58 deletions(-) create mode 100644 configs/_base_/nas_backbones/nsga_mobilenetv3_supernet.py create mode 100644 configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py create mode 100644 configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_supernet_8xb128_in1k.py diff --git a/configs/_base_/nas_backbones/nsga_mobilenetv3_supernet.py b/configs/_base_/nas_backbones/nsga_mobilenetv3_supernet.py new file mode 100644 index 000000000..4ecdae032 --- /dev/null +++ b/configs/_base_/nas_backbones/nsga_mobilenetv3_supernet.py @@ -0,0 +1,106 @@ +# search space +arch_setting = dict( + kernel_size=[ # [min_kernel_size, max_kernel_size, step] + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + [3, 7, 2], + ], + num_blocks=[ # [min_num_blocks, max_num_blocks, step] + [1, 1, 1], + [1, 1, 1], + [0, 1, 1], + [0, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 1, 1], + [0, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 1, 1], + [0, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 1, 1], + [0, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 1, 1], + [0, 1, 1], + ], + expand_ratio=[ # [min_expand_ratio, max_expand_ratio, step] + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [3, 6, 1], + [6, 6, 1], # last layer + ], + num_out_channels=[ # [min_channel, max_channel, step] + [16, 16, 1], # first layer + [24, 24, 1], + [24, 24, 1], + [24, 24, 1], + [24, 24, 1], + [40, 40, 1], + [40, 40, 1], + [40, 40, 1], + [40, 40, 1], + [80, 80, 1], + [80, 80, 1], + [80, 80, 1], + [80, 80, 1], + [112, 112, 1], + [112, 112, 1], + [112, 112, 1], + [112, 112, 1], + [160, 160, 1], + [160, 160, 1], + [160, 160, 1], + [160, 160, 1], + [1280, 1280, 1], # last layer + ]) + +input_resizer_cfg = dict( + input_sizes=[[192, 192], [208, 208], [224, 224], [256, 256]]) + +nas_backbone = dict( + type='AttentiveMobileNetV3', + arch_setting=arch_setting, + out_indices=(20, ), + stride_list=[2, 1, 1, 1] * 3 + [1] * 4 + [2] + [1] * 3, + with_se_list=[False] * 4 + [True] * 4 + [False] * 4 + [True] * 8, + act_cfg_list=['ReLU'] * 9 + ['HSwish'] * 13, + conv_cfg=dict(type='OFAConv2d'), + norm_cfg=dict(type='DynamicBatchNorm2d', momentum=0.1)) diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py new file mode 100644 index 000000000..072885f13 --- /dev/null +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py @@ -0,0 +1,25 @@ +_base_ = ['./nsganetv2_mobilenet_supernet_8x128_in1k.py'] + +model = dict(norm_training=True) + +train_cfg = dict( + _delete_=True, + type='mmrazor.NSGA2SearchLoop', + dataloader=_base_.val_dataloader, + finetune_dataloader=_base_.train_dataloader, + evaluator=_base_.val_evaluator, + max_epochs=4, + num_candidates=4, + top_k=2, + num_mutation=2, + num_crossover=2, + mutate_prob=0.1, + flops_range=(0., 330.), + score_key='accuracy/top1', + predictor_cfg=dict( + type='mmrazor.MetricPredictor', + encoding_type='normal', + train_samples=2, + handler_cfg=dict(type='mmrazor.GaussProcessHandler')), + trade_off=dict(sec_obj='flops', max_score_key=100), +) diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_supernet_8xb128_in1k.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_supernet_8xb128_in1k.py new file mode 100644 index 000000000..e86e041e2 --- /dev/null +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_supernet_8xb128_in1k.py @@ -0,0 +1,48 @@ +_base_ = [ + 'mmcls::_base_/default_runtime.py', + 'mmrazor::_base_/settings/imagenet_bs2048_bignas.py', + 'mmrazor::_base_/nas_backbones/nsga_mobilenetv3_supernet.py', +] + +supernet = dict( + _scope_='mmrazor', + type='SearchableImageClassifier', + backbone=_base_.nas_backbone, + neck=dict(type='SqueezeMeanPoolingWithDropout', drop_ratio=0.2), + head=dict( + type='DynamicLinearClsHead', + num_classes=1000, + in_channels=1280, + loss=dict( + type='mmcls.LabelSmoothLoss', + num_classes=1000, + label_smooth_val=0.1, + mode='original', + loss_weight=1.0), + topk=(1, 5)), + input_resizer_cfg=_base_.input_resizer_cfg, + connect_head=dict(connect_with_backbone='backbone.last_mutable_channels'), +) + +model = dict( + _scope_='mmrazor', + type='NSGANetV2', + architecture=supernet, + data_preprocessor=_base_.data_preprocessor, + mutators=dict( + channel_mutator=dict( + type='mmrazor.OneShotChannelMutator', + channel_unit_cfg={ + 'type': 'OneShotMutableChannelUnit', + 'default_args': { + 'unit_predefined': True + } + }, + parse_cfg={'type': 'Predefined'}), + value_mutator=dict(type='DynamicValueMutator'))) + +find_unused_parameters = True + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', interval=1, max_keep_ckpts=1, save_best='auto')) diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index 3f649c426..c885e745b 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -4,7 +4,7 @@ FpnTeacherDistill, OverhaulFeatureDistillation, SelfDistill, SingleTeacherDistill) from .nas import (DSNAS, DSNASDDP, SPOS, Autoformer, AutoSlim, AutoSlimDDP, - BigNAS, BigNASDDP, Darts, DartsDDP) + BigNAS, BigNASDDP, Darts, DartsDDP, NSGANetV2) from .pruning import DCFF, SlimmableNetwork, SlimmableNetworkDDP from .pruning.ite_prune_algorithm import ItePruneAlgorithm @@ -14,5 +14,5 @@ 'Darts', 'DartsDDP', 'DCFF', 'SelfDistill', 'DataFreeDistillation', 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'ItePruneAlgorithm', 'DSNAS', 'DSNASDDP', 'Autoformer', 'BigNAS', - 'BigNASDDP' + 'BigNASDDP', 'NSGANetV2' ] diff --git a/mmrazor/models/algorithms/nas/__init__.py b/mmrazor/models/algorithms/nas/__init__.py index 6a9c29161..967156a12 100644 --- a/mmrazor/models/algorithms/nas/__init__.py +++ b/mmrazor/models/algorithms/nas/__init__.py @@ -4,9 +4,11 @@ from .bignas import BigNAS, BigNASDDP from .darts import Darts, DartsDDP from .dsnas import DSNAS, DSNASDDP +from .nsganetv2 import NSGANetV2 from .spos import SPOS __all__ = [ 'SPOS', 'AutoSlim', 'AutoSlimDDP', 'BigNAS', 'BigNASDDP', 'Darts', - 'DartsDDP', 'DSNAS', 'DSNASDDP', 'DSNAS', 'DSNASDDP', 'Autoformer' + 'DartsDDP', 'DSNAS', 'DSNASDDP', 'DSNAS', 'DSNASDDP', 'Autoformer', + 'NSGANetV2' ] diff --git a/mmrazor/models/algorithms/nas/nsganetv2.py b/mmrazor/models/algorithms/nas/nsganetv2.py index 8eaad2f38..24b5402c8 100644 --- a/mmrazor/models/algorithms/nas/nsganetv2.py +++ b/mmrazor/models/algorithms/nas/nsganetv2.py @@ -9,10 +9,8 @@ from mmrazor.models.distillers import ConfigurableDistiller from mmrazor.models.mutators.base_mutator import BaseMutator -from mmrazor.models.mutators import OneShotModuleMutator from mmrazor.registry import MODELS -from mmrazor.structures.subnet.fix_subnet import load_fix_subnet -from mmrazor.utils import SingleMutatorRandomSubnet, ValidFixMutable +from mmrazor.utils import ValidFixMutable from ..base import BaseAlgorithm, LossResults VALID_MUTATOR_TYPE = Union[BaseMutator, Dict] @@ -22,61 +20,60 @@ @MODELS.register_module() class NSGANetV2(BaseAlgorithm): - """ - - """ - - # TODO fix ea's name in doc-string. + """NSGANetV2 algorithm.""" def __init__(self, architecture: Union[BaseModel, Dict], - mutator: VALID_MUTATORS_TYPE, - # distiller: VALID_DISTILLER_TYPE, - # norm_training: bool = False, + mutators: VALID_MUTATORS_TYPE, fix_subnet: Optional[ValidFixMutable] = None, data_preprocessor: Optional[Union[dict, nn.Module]] = None, - init_cfg: Optional[dict] = None, - drop_prob: float = 0.2): + drop_path_rate: float = 0.2, + backbone_dropout_stages: List = [6, 7], + norm_training: bool = False, + init_cfg: Optional[dict] = None): super().__init__(architecture, data_preprocessor, init_cfg) + if isinstance(mutators, dict): + built_mutators: Dict = dict() + for name, mutator_cfg in mutators.items(): + if 'parse_cfg' in mutator_cfg and isinstance( + mutator_cfg['parse_cfg'], dict): + assert mutator_cfg['parse_cfg'][ + 'type'] == 'Predefined', \ + 'BigNAS only support predefined.' + mutator: BaseMutator = MODELS.build(mutator_cfg) + built_mutators[name] = mutator + mutator.prepare_from_supernet(self.architecture) + self.mutators = built_mutators + else: + raise TypeError('mutator should be a `dict` but got ' + f'{type(mutators)}') + + self.drop_path_rate = drop_path_rate + self.backbone_dropout_stages = backbone_dropout_stages + self.norm_training = norm_training + self.is_supernet = True + if fix_subnet: # Avoid circular import from mmrazor.structures import load_fix_subnet # According to fix_subnet, delete the unchosen part of supernet - load_fix_subnet(self.architecture, fix_subnet) + load_fix_subnet(self, fix_subnet) self.is_supernet = False - else: - # Mutator is an essential component of the NAS algorithm. It - # provides some APIs commonly used by NAS. - # Before using it, you must do some preparations according to - # the supernet. - self.mutator.prepare_from_supernet(self.architecture) - self.is_supernet = True - - self.drop_prob = drop_prob - - def _build_mutator(self, mutator: VALID_MUTATOR_TYPE) -> BaseMutator: - """build mutator.""" - assert mutator is not None, \ - 'mutator cannot be None when fix_subnet is None.' - if isinstance(mutator, OneShotModuleMutator): - self.mutator = mutator - elif isinstance(mutator, dict): - self.mutator = MODELS.build(mutator) - else: - raise TypeError('mutator should be a `dict` or ' - f'`OneShotModuleMutator` instance, but got ' - f'{type(mutator)}') - return mutator - def sample_subnet(self) -> SingleMutatorRandomSubnet: + def sample_subnet(self, kind='random') -> Dict: """Random sample subnet by mutator.""" - return self.mutator.sample_choices() + subnet = dict() + for mutator in self.mutators.values(): + subnet.update(mutator.sample_choices(kind)) + return subnet - def set_subnet(self, subnet: SingleMutatorRandomSubnet): + def set_subnet(self, subnet: Dict[str, Dict[int, Union[int, + list]]]) -> None: """Set the subnet sampled by :meth:sample_subnet.""" - self.mutator.set_choices(subnet) + for mutator in self.mutators.values(): + mutator.set_choices(subnet) def loss( self, @@ -85,8 +82,7 @@ def loss( ) -> LossResults: """Calculate losses from a batch of inputs and data samples.""" if self.is_supernet: - random_subnet = self.sample_subnet() - self.set_subnet(random_subnet) + self.set_subnet(self.sample_subnet()) return self.architecture(batch_inputs, data_samples, mode='loss') else: return self.architecture(batch_inputs, data_samples, mode='loss') diff --git a/mmrazor/models/architectures/backbones/searchable_mobilenet_v3.py b/mmrazor/models/architectures/backbones/searchable_mobilenet_v3.py index 5a6e15cbc..6e4b62805 100644 --- a/mmrazor/models/architectures/backbones/searchable_mobilenet_v3.py +++ b/mmrazor/models/architectures/backbones/searchable_mobilenet_v3.py @@ -92,14 +92,9 @@ def __init__(self, self.arch_setting = arch_setting self.widen_factor = widen_factor self.out_indices = out_indices - for index in out_indices: - if index not in range(0, 8): - raise ValueError('the item in out_indices must in ' - f'range(0, 8). But received {index}') if frozen_stages not in range(-1, 8): raise ValueError('frozen_stages must in range(-1, 8). ' f'But received {frozen_stages}') - self.out_indices = out_indices self.frozen_stages = frozen_stages self.conv_cfg = conv_cfg diff --git a/mmrazor/models/task_modules/predictor/metric_predictor.py b/mmrazor/models/task_modules/predictor/metric_predictor.py index 8ad71e522..2a6bbb449 100644 --- a/mmrazor/models/task_modules/predictor/metric_predictor.py +++ b/mmrazor/models/task_modules/predictor/metric_predictor.py @@ -139,18 +139,19 @@ def vector2model(self, vector: np.array) -> Dict[str, str]: start = 0 model = {} for key, value in self.search_groups.items(): + if isinstance(value[0], OneShotMutableChannelUnit): + choices = value[0].candidate_choices + else: + choices = value[0].choices + if self.encoding_type == 'onehot': - index = np.where(vector[start:start + - len(value[0].choices)] == 1)[0][0] - start += len(value[0].choices) + index = np.where(vector[start:start + len(choices)] == 1)[0][0] + start += len(choices) else: index = vector[start] start += 1 - if isinstance(value[0], OneShotMutableChannelUnit): - chosen = value[0].candidate_choices[int(index)] - else: - chosen = value[0].choices[int(index)] + chosen = choices[int(index)] model[key] = chosen return model From 2ae2175be0ff328b0e72cdc595421e0d64501137 Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Wed, 11 Jan 2023 11:09:20 +0800 Subject: [PATCH 05/12] update configs --- configs/_base_/settings/cifar10_bs96_nsga.py | 49 ++++++++++ ...anetv2_mobilenet_supernet_1xb96_cifar10.py | 49 ++++++++++ mmrazor/models/algorithms/nas/nsganetv2.py | 93 ++++++++++++------- 3 files changed, 156 insertions(+), 35 deletions(-) create mode 100644 configs/_base_/settings/cifar10_bs96_nsga.py create mode 100644 configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_supernet_1xb96_cifar10.py diff --git a/configs/_base_/settings/cifar10_bs96_nsga.py b/configs/_base_/settings/cifar10_bs96_nsga.py new file mode 100644 index 000000000..bfd69e45b --- /dev/null +++ b/configs/_base_/settings/cifar10_bs96_nsga.py @@ -0,0 +1,49 @@ +# dataset settings +dataset_type = 'mmcls.CIFAR10' +data_preprocessor = dict( + type='mmcls.ClsDataPreprocessor', + num_classes=10, + # RGB format normalization parameters + mean=[125.307, 122.961, 113.8575], + std=[51.5865, 50.847, 51.255], + # loaded images are already RGB format + to_rgb=False) + +train_pipeline = [ + dict(type='mmcls.RandomCrop', crop_size=32, padding=4), + dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'), + dict(type='mmcls.Cutout', shape=16, pad_val=0, prob=1.0), + dict(type='mmcls.PackClsInputs'), +] + +test_pipeline = [ + dict(type='mmcls.PackClsInputs'), +] + +train_dataloader = dict( + batch_size=96, + num_workers=5, + dataset=dict( + type=dataset_type, + data_prefix='data/cifar10', + test_mode=False, + pipeline=train_pipeline), + sampler=dict(type='mmcls.DefaultSampler', shuffle=True), + persistent_workers=True, +) + +val_dataloader = dict( + batch_size=96, + num_workers=5, + dataset=dict( + type=dataset_type, + data_prefix='data/cifar10/', + test_mode=True, + pipeline=test_pipeline), + sampler=dict(type='mmcls.DefaultSampler', shuffle=False), + persistent_workers=True, +) +val_evaluator = dict(type='mmcls.Accuracy', topk=(1, )) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_supernet_1xb96_cifar10.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_supernet_1xb96_cifar10.py new file mode 100644 index 000000000..aee8c9b08 --- /dev/null +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_supernet_1xb96_cifar10.py @@ -0,0 +1,49 @@ +_base_ = [ + 'mmcls::_base_/default_runtime.py', + 'mmcls::_base_/schedules/imagenet_bs2048.py', + 'mmrazor::_base_/settings/cifar10_bs96_nsga.py', + 'mmrazor::_base_/nas_backbones/nsga_mobilenetv3_supernet.py', +] + +supernet = dict( + _scope_='mmrazor', + type='SearchableImageClassifier', + backbone=_base_.nas_backbone, + neck=dict(type='SqueezeMeanPoolingWithDropout', drop_ratio=0.2), + head=dict( + type='DynamicLinearClsHead', + num_classes=1000, + in_channels=1280, + loss=dict( + type='mmcls.LabelSmoothLoss', + num_classes=1000, + label_smooth_val=0.1, + mode='original', + loss_weight=1.0), + topk=(1, 5)), + input_resizer_cfg=_base_.input_resizer_cfg, + connect_head=dict(connect_with_backbone='backbone.last_mutable_channels'), +) + +model = dict( + _scope_='mmrazor', + type='NSGANetV2', + architecture=supernet, + data_preprocessor=_base_.data_preprocessor, + mutators=dict( + channel_mutator=dict( + type='mmrazor.OneShotChannelMutator', + channel_unit_cfg={ + 'type': 'OneShotMutableChannelUnit', + 'default_args': { + 'unit_predefined': True + } + }, + parse_cfg={'type': 'Predefined'}), + value_mutator=dict(type='DynamicValueMutator'))) + +find_unused_parameters = True + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', interval=1, max_keep_ckpts=1, save_best='auto')) diff --git a/mmrazor/models/algorithms/nas/nsganetv2.py b/mmrazor/models/algorithms/nas/nsganetv2.py index 24b5402c8..4ca9719c1 100644 --- a/mmrazor/models/algorithms/nas/nsganetv2.py +++ b/mmrazor/models/algorithms/nas/nsganetv2.py @@ -12,6 +12,7 @@ from mmrazor.registry import MODELS from mmrazor.utils import ValidFixMutable from ..base import BaseAlgorithm, LossResults +from ..space_mixin import SpaceMixin VALID_MUTATOR_TYPE = Union[BaseMutator, Dict] VALID_MUTATORS_TYPE = Dict[str, Union[BaseMutator, Dict]] @@ -19,8 +20,42 @@ @MODELS.register_module() -class NSGANetV2(BaseAlgorithm): - """NSGANetV2 algorithm.""" +class NSGANetV2(BaseAlgorithm, SpaceMixin): + """Implementation of `NSGANetV2 `_ + + NSGANetV2 generates task-specific models that are competitive under + multiple competing objectives. + + NSGANetV2 comprises of two surrogates, one at the architecture level to + improve sample efficiency and one at the weights level, through a supernet, + to improve gradient descent training efficiency. + + The logic of the search part is implemented in + :class:`mmrazor.engine.NSGA2SearchLoop` + + Args: + architecture (dict|:obj:`BaseModel`): The config of :class:`BaseModel` + or built model. Corresponding to supernet in NAS algorithm. + mutators (VALID_MUTATORS_TYPE): Configs to build different mutators. + fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or + loaded dict or built :obj:`FixSubnet`. Defaults to None. + data_preprocessor (Optional[Union[dict, nn.Module]]): The pre-process + config of :class:`BaseDataPreprocessor`. Defaults to None. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.2. + backbone_dropout_stages (List): Stages to be set dropout. Defaults to + [6, 7]. + norm_training (bool): Whether to set norm layers to training mode, + namely, not freeze running stats (mean and var). Note: Effect on + Batch Norm and its variants only. Defaults to False. + init_cfg (Optional[dict]): Init config for ``BaseModule``. + Defaults to None. + + Note: + NSGANetV2 uses two mutators which are ``DynamicValueMutator`` and + ``ChannelMutator``. `DynamicValueMutator` handle the mutable object + ``OneShotMutableValue`` while ChannelMutator handle the mutable object + ``OneShotMutableChannel``. + """ def __init__(self, architecture: Union[BaseModel, Dict], @@ -33,27 +68,6 @@ def __init__(self, init_cfg: Optional[dict] = None): super().__init__(architecture, data_preprocessor, init_cfg) - if isinstance(mutators, dict): - built_mutators: Dict = dict() - for name, mutator_cfg in mutators.items(): - if 'parse_cfg' in mutator_cfg and isinstance( - mutator_cfg['parse_cfg'], dict): - assert mutator_cfg['parse_cfg'][ - 'type'] == 'Predefined', \ - 'BigNAS only support predefined.' - mutator: BaseMutator = MODELS.build(mutator_cfg) - built_mutators[name] = mutator - mutator.prepare_from_supernet(self.architecture) - self.mutators = built_mutators - else: - raise TypeError('mutator should be a `dict` but got ' - f'{type(mutators)}') - - self.drop_path_rate = drop_path_rate - self.backbone_dropout_stages = backbone_dropout_stages - self.norm_training = norm_training - self.is_supernet = True - if fix_subnet: # Avoid circular import from mmrazor.structures import load_fix_subnet @@ -61,19 +75,28 @@ def __init__(self, # According to fix_subnet, delete the unchosen part of supernet load_fix_subnet(self, fix_subnet) self.is_supernet = False + else: + if isinstance(mutators, dict): + built_mutators: Dict = dict() + for name, mutator_cfg in mutators.items(): + if 'parse_cfg' in mutator_cfg and isinstance( + mutator_cfg['parse_cfg'], dict): + assert mutator_cfg['parse_cfg'][ + 'type'] == 'Predefined', \ + 'NSGANetV2 only support predefined.' + mutator: BaseMutator = MODELS.build(mutator_cfg) + built_mutators[name] = mutator + mutator.prepare_from_supernet(self.architecture) + self.mutators = built_mutators + else: + raise TypeError('mutator should be a `dict` but got ' + f'{type(mutators)}') + self._build_search_space() + self.is_supernet = True - def sample_subnet(self, kind='random') -> Dict: - """Random sample subnet by mutator.""" - subnet = dict() - for mutator in self.mutators.values(): - subnet.update(mutator.sample_choices(kind)) - return subnet - - def set_subnet(self, subnet: Dict[str, Dict[int, Union[int, - list]]]) -> None: - """Set the subnet sampled by :meth:sample_subnet.""" - for mutator in self.mutators.values(): - mutator.set_choices(subnet) + self.drop_path_rate = drop_path_rate + self.backbone_dropout_stages = backbone_dropout_stages + self.norm_training = norm_training def loss( self, From 25e05d88e469bddc8d05b3e71da12e39879e5415 Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Thu, 12 Jan 2023 20:27:47 +0800 Subject: [PATCH 06/12] update finetune step --- .../engine/runner/evolution_search_loop.py | 74 +++++++++++-------- .../engine/runner/nsganetv2_search_loop.py | 52 ++++--------- mmrazor/engine/runner/utils/check.py | 9 +-- 3 files changed, 62 insertions(+), 73 deletions(-) diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index 9cb581f39..bbd8a1a91 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -11,7 +11,7 @@ from mmengine import fileio from mmengine.dist import broadcast_object_list from mmengine.evaluator import Evaluator -from mmengine.runner import EpochBasedTrainLoop +from mmengine.runner import EpochBasedTrainLoop, Runner from mmengine.utils import is_list_of from torch.utils.data import DataLoader @@ -63,7 +63,6 @@ def __init__(self, runner, evaluator: Union[Evaluator, Dict, List], dataloader: Union[DataLoader, Dict], - finetune_dataloader: Union[DataLoader, Dict] = None, max_epochs: int = 20, max_keep_ckpts: int = 3, resume_from: Optional[str] = None, @@ -77,6 +76,7 @@ def __init__(self, constraints_range: Dict[str, Any] = dict(flops=(0., 330.)), estimator_cfg: Optional[Dict] = None, predictor_cfg: Optional[Dict] = None, + finetune_cfg: Optional[Dict] = None, score_key: str = 'accuracy/top1', init_candidates: Optional[str] = None) -> None: super().__init__(runner, dataloader, max_epochs) @@ -85,12 +85,6 @@ def __init__(self, else: self.evaluator = evaluator # type: ignore - if isinstance(finetune_dataloader, dict): - self.finetune_dataloader = self.runner.build_dataloader( - finetune_dataloader) - else: - self.finetune_dataloader = finetune_dataloader - if hasattr(self.dataloader.dataset, 'metainfo'): self.evaluator.dataset_meta = self.dataloader.dataset.metainfo else: @@ -138,17 +132,12 @@ def __init__(self, self.predictor_cfg = predictor_cfg if self.predictor_cfg is not None: self.predictor_cfg['score_key'] = self.score_key - if hasattr(self.model, 'mutators'): - self.predictor_cfg['search_groups'] = dict() - for mutator in self.model.mutators.values(): - self.predictor_cfg['search_groups'].update( - mutator.search_groups) - else: - assert hasattr(self.model, 'mutator') - self.predictor_cfg['search_groups'] = \ - self.model.mutator.search_groups + self.predictor_cfg['search_groups'] = self.model.search_space self.predictor = TASK_UTILS.build(self.predictor_cfg) + if finetune_cfg is not None: + self.finetune_runner = self.build_finetune_runner(finetune_cfg) + def run(self) -> None: """Launch searching.""" self.runner.call_hook('before_train') @@ -214,9 +203,10 @@ def sample_candidates(self) -> None: while len(self.candidates) < self.num_candidates: candidate = self.model.sample_subnet() self.model.set_subnet(candidate) - self.finetune_step() - is_pass, result = self._check_constraints( - random_subnet=candidate) + _, sliced_model = export_fix_subnet( + self.model, slice_weight=True) + self.finetune_step(sliced_model) + is_pass, result = self._check_constraints(sliced_model) if is_pass: self.candidates.append(candidate) candidates_resources.append(result) @@ -263,9 +253,10 @@ def gen_mutation_candidates(self): break mutation_candidate = self._mutation() + self.model.set_subnet(mutation_candidate) + _, sliced_model = export_fix_subnet(self.model, slice_weight=True) - is_pass, result = self._check_constraints( - random_subnet=mutation_candidate) + is_pass, result = self._check_constraints(sliced_model) if is_pass: mutation_candidates.append(mutation_candidate) mutation_resources.append(result) @@ -287,9 +278,10 @@ def gen_crossover_candidates(self): break crossover_candidate = self._crossover() + self.model.set_subnet(crossover_candidate) + _, sliced_model = export_fix_subnet(self.model, slice_weight=True) - is_pass, result = self._check_constraints( - random_subnet=crossover_candidate) + is_pass, result = self._check_constraints(sliced_model) if is_pass: crossover_candidates.append(crossover_candidate) crossover_resources.append(result) @@ -415,16 +407,14 @@ def _save_searcher_ckpt(self) -> None: if osp.isfile(ckpt_path): os.remove(ckpt_path) - def _check_constraints( - self, random_subnet: SupportRandomSubnet) -> Tuple[bool, Dict]: + def _check_constraints(self, model) -> Tuple[bool, Dict]: """Check whether is beyond constraints. Returns: bool, result: The result of checking. """ is_pass, results = check_subnet_resources( - model=self.model, - subnet=random_subnet, + model=model, estimator=self.estimator, constraints_range=self.constraints_range) @@ -470,5 +460,31 @@ def _init_predictor(self): self.use_predictor = True self.candidates = Candidates() - def finetune_step(self): + def build_finetune_runner(self, finetune_cfg: Dict) -> Runner: + """Build a runner for finetuning the sliced_model.""" + finetune_cfg.update(work_dir=self.runner.work_dir) + finetune_cfg.update(env_cfg=dict(dist_cfg=dict(backend='nccl'))) + + runner = Runner.from_cfg(finetune_cfg) + + runner._train_loop = runner.build_train_loop( + runner._train_loop) # type: ignore + + runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper) + runner.scale_lr(runner.optim_wrapper, runner.auto_scale_lr) + + runner.param_schedulers = runner.build_param_scheduler( # type: ignore + runner.param_schedulers) # type: ignore + + from mmengine.hooks import CheckpointHook + + # remove CheckpointHook to avoid extra problems. + for hook in runner._hooks: + if isinstance(hook, CheckpointHook): + runner._hooks.remove(hook) + break + + return runner + + def finetune_step(self, model): pass diff --git a/mmrazor/engine/runner/nsganetv2_search_loop.py b/mmrazor/engine/runner/nsganetv2_search_loop.py index 56eda84b0..145997d06 100644 --- a/mmrazor/engine/runner/nsganetv2_search_loop.py +++ b/mmrazor/engine/runner/nsganetv2_search_loop.py @@ -3,7 +3,6 @@ import numpy as np from mmengine import fileio -from mmengine.optim import build_optim_wrapper from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting from mmrazor.models.task_modules import (AuxiliarySingleLevelProblem, @@ -227,39 +226,18 @@ def fit_predictor(self, candidates): metrics[i] = self.max_score_key - metrics[i] return metrics - def finetune_step(self): - """fintune before candidates evaluation.""" - self.runner.logger.info('start finetuning...') - self.model.train() - - self._finetune_epoch = 0 - self._max_finetune_epochs = 1 - - optimizer_cfg = dict( - type='SGD', - lr=0.5, - momentum=0.9, - nesterov=True, - weight_decay=0.0001) - optim_wrapper_cfg = dict(optimizer=optimizer_cfg) - optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg) - - while self._finetune_epoch < self._max_finetune_epochs: - self.runner.call_hook('before_train_epoch') - for idx, data_batch in enumerate(self.finetune_dataloader): - self.runner.call_hook( - 'before_train_iter', batch_idx=idx, data_batch=data_batch) - - outputs = self.model.train_step( - data_batch, optim_wrapper=optim_wrapper) - - self.runner.call_hook( - 'after_train_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) - - self.runner.call_hook('after_train_epoch') - self._finetune_epoch += 1 - - self.model.eval() + def finetune_step(self, model): + """Fintune before candidates evaluation.""" + self.runner.logger.info('Start finetuning...') + self.finetune_runner.model = model + self.finetune_runner.call_hook('before_run') + + self.finetune_runner.optim_wrapper.initialize_count_status( + self.finetune_runner.model, self.finetune_runner._train_loop.iter, + self.finetune_runner._train_loop.max_iters) + + self.model = self.finetune_runner.train_loop.run() + self.finetune_runner.train_loop._iter = 0 + + self.finetune_runner.call_hook('after_run') + self.runner.logger.info('End finetuning...') diff --git a/mmrazor/engine/runner/utils/check.py b/mmrazor/engine/runner/utils/check.py index eb49ede68..6c567e09d 100644 --- a/mmrazor/engine/runner/utils/check.py +++ b/mmrazor/engine/runner/utils/check.py @@ -4,8 +4,6 @@ import torch from mmrazor.models import ResourceEstimator -from mmrazor.structures import export_fix_subnet -from mmrazor.utils import SupportRandomSubnet try: from mmdet.models.detectors import BaseDetector @@ -17,7 +15,6 @@ @torch.no_grad() def check_subnet_resources( model, - subnet: SupportRandomSubnet, estimator: ResourceEstimator, constraints_range: Dict[str, Any] = dict(flops=(0, 330)) ) -> Tuple[bool, Dict]: @@ -29,11 +26,9 @@ def check_subnet_resources( if constraints_range is None: return True, dict() - assert hasattr(model, 'set_subnet') and hasattr(model, 'architecture') - model.set_subnet(subnet) - _, sliced_model = export_fix_subnet(model, slice_weight=True) + assert hasattr(model, 'architecture') + model_to_check = model.architecture # type: ignore - model_to_check = sliced_model.architecture # type: ignore if isinstance(model_to_check, BaseDetector): results = estimator.estimate(model=model_to_check.backbone) else: From dd6fc3261869d74d75ccabd6e3a9ca49731237a7 Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Thu, 12 Jan 2023 20:28:41 +0800 Subject: [PATCH 07/12] update search configs --- ...ganetv2_mobilenet_search_8xb256_cifar10.py | 42 +++++++++++++++++++ .../nsganetv2_mobilenet_search_8xb256_in1k.py | 16 ++++++- 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py new file mode 100644 index 000000000..619a6e050 --- /dev/null +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py @@ -0,0 +1,42 @@ +_base_ = ['./nsganetv2_mobilenet_supernet_1xb96_cifar10.py'] + +model = dict(norm_training=True) + +train_dataloader = dict(batch_size=256) +val_dataloader = dict(batch_size=256) +test_dataloader = val_dataloader + +train_cfg = dict( + _delete_=True, + type='mmrazor.NSGA2SearchLoop', + dataloader=_base_.val_dataloader, + evaluator=_base_.val_evaluator, + max_epochs=4, + num_candidates=4, + top_k=2, + num_mutation=2, + num_crossover=2, + mutate_prob=0.1, + flops_range=(0., 330.), + score_key='accuracy/top1', + predictor_cfg=dict( + type='mmrazor.MetricPredictor', + encoding_type='normal', + train_samples=2, + handler_cfg=dict(type='mmrazor.GaussProcessHandler')), + finetune_cfg=dict( + model=_base_.model, + train_dataloader=_base_.train_dataloader, + train_cfg=dict(by_epoch=True, max_epochs=1), + optim_wrapper=dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=0.025, + momentum=0.9, + weight_decay=3e-4, + nesterov=True)), + param_scheduler=_base_.param_scheduler, + default_hooks=_base_.default_hooks, + ), +) diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py index 072885f13..3cd9bb6f2 100644 --- a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py @@ -21,5 +21,19 @@ encoding_type='normal', train_samples=2, handler_cfg=dict(type='mmrazor.GaussProcessHandler')), - trade_off=dict(sec_obj='flops', max_score_key=100), + finetune_cfg=dict( + model=_base_.model, + train_dataloader=_base_.train_dataloader, + train_cfg=dict(by_epoch=True, max_epochs=1), + optim_wrapper=dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=0.025, + momentum=0.9, + weight_decay=3e-4, + nesterov=True)), + param_scheduler=_base_.param_scheduler, + default_hooks=_base_.default_hooks, + ), ) From 4cc260312559b5adb801eb3612e6f04486fba901 Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Thu, 2 Feb 2023 15:53:27 +0800 Subject: [PATCH 08/12] update search loops --- .../nsganetv2_mobilenet_search_8xb256_cifar10.py | 2 +- .../nsganetv2_mobilenet_search_8xb256_in1k.py | 2 +- mmrazor/engine/runner/evolution_search_loop.py | 16 ++++++++++------ mmrazor/engine/runner/nsganetv2_search_loop.py | 5 +++-- .../task_modules/predictor/metric_predictor.py | 4 ++++ 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py index 619a6e050..f998f2386 100644 --- a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py @@ -17,7 +17,7 @@ num_mutation=2, num_crossover=2, mutate_prob=0.1, - flops_range=(0., 330.), + constraints_range=dict(flops=(0., 360.)), score_key='accuracy/top1', predictor_cfg=dict( type='mmrazor.MetricPredictor', diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py index 3cd9bb6f2..252e08c88 100644 --- a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py @@ -14,7 +14,7 @@ num_mutation=2, num_crossover=2, mutate_prob=0.1, - flops_range=(0., 330.), + constraints_range=dict(flops=(0., 360.)), score_key='accuracy/top1', predictor_cfg=dict( type='mmrazor.MetricPredictor', diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index a4b3777c6..5fbe9ac36 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -53,6 +53,8 @@ class EvolutionSearchLoop(EpochBasedTrainLoop, CalibrateBNMixin): Defaults to None. predictor_cfg (dict, Optional): Used for building a metric predictor. Defaults to None. + finetune_cfg (dict, Optional): Used for building an extra runner to + finetune the searched model. Defaults to None. score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. init_candidates (str, optional): The candidates file path, which is @@ -133,7 +135,8 @@ def __init__(self, self.predictor_cfg = predictor_cfg if self.predictor_cfg is not None: self.predictor_cfg['score_key'] = self.score_key - self.predictor_cfg['search_groups'] = self.model.search_space + self.predictor_cfg['search_groups'] = \ + self.model.mutator.search_groups self.predictor = TASK_UTILS.build(self.predictor_cfg) if finetune_cfg is not None: @@ -203,8 +206,8 @@ def sample_candidates(self) -> None: init_candidates = len(self.candidates) if self.runner.rank == 0: while len(self.candidates) < self.num_candidates: - candidate = self.model.sample_subnet() - self.model.set_subnet(candidate) + candidate = self.model.mutator.sample_choices() + self.model.mutator.set_choices(candidate) _, sliced_model = export_fix_subnet( self.model, slice_weight=True) self.finetune_step(sliced_model) @@ -255,7 +258,7 @@ def gen_mutation_candidates(self): break mutation_candidate = self._mutation() - self.model.set_subnet(mutation_candidate) + self.model.mutator.set_choices(mutation_candidate) _, sliced_model = export_fix_subnet(self.model, slice_weight=True) is_pass, result = self._check_constraints(sliced_model) @@ -280,7 +283,7 @@ def gen_crossover_candidates(self): break crossover_candidate = self._crossover() - self.model.set_subnet(crossover_candidate) + self.model.mutator.set_choices(crossover_candidate) _, sliced_model = export_fix_subnet(self.model, slice_weight=True) is_pass, result = self._check_constraints(sliced_model) @@ -296,7 +299,7 @@ def gen_crossover_candidates(self): def _mutation(self) -> SupportRandomSubnet: """Mutate with the specified mutate_prob.""" candidate1 = random.choice(self.top_k_candidates.subnets) - candidate2 = self.model.sample_subnet() + candidate2 = self.model.mutator.sample_choices() candidate = crossover(candidate1, candidate2, prob=self.mutate_prob) return candidate @@ -454,6 +457,7 @@ def _init_predictor(self): def build_finetune_runner(self, finetune_cfg: Dict) -> Runner: """Build a runner for finetuning the sliced_model.""" finetune_cfg.update(work_dir=self.runner.work_dir) + finetune_cfg.update(launcher=self.runner.launcher) finetune_cfg.update(env_cfg=dict(dist_cfg=dict(backend='nccl'))) runner = Runner.from_cfg(finetune_cfg) diff --git a/mmrazor/engine/runner/nsganetv2_search_loop.py b/mmrazor/engine/runner/nsganetv2_search_loop.py index 4503bcef0..0a060c384 100644 --- a/mmrazor/engine/runner/nsganetv2_search_loop.py +++ b/mmrazor/engine/runner/nsganetv2_search_loop.py @@ -111,13 +111,13 @@ def sample_candidates_with_nsga2(self, archive: Candidates, # initiate a multi-objective solver to optimize the problem method = NSGA2Optimizer( - pop_size=40, + pop_size=4, sampling=fronts, # initialize with current nd archs eliminate_duplicates=True, logger=self.runner.logger) # # kick-off the search - method.initialize(problem, n_gen=20, verbose=True) + method.initialize(problem, n_gen=2, verbose=True) result = method.solve() # check for duplicates @@ -244,6 +244,7 @@ def finetune_step(self, model): self.model = self.finetune_runner.train_loop.run() self.finetune_runner.train_loop._iter = 0 + self.finetune_runner.train_loop._epoch = 0 self.finetune_runner.call_hook('after_run') self.runner.logger.info('End finetuning...') diff --git a/mmrazor/models/task_modules/predictor/metric_predictor.py b/mmrazor/models/task_modules/predictor/metric_predictor.py index 2a6bbb449..1561f0776 100644 --- a/mmrazor/models/task_modules/predictor/metric_predictor.py +++ b/mmrazor/models/task_modules/predictor/metric_predictor.py @@ -96,6 +96,10 @@ def model2vector( vector_dict: Dict[str, list] = \ dict(normal_vector=[], onehot_vector=[]) + assert len(model.keys()) == len(self.search_groups.keys()), ( + f'Length mismatch for model({len(model.keys())}) and search_groups' + f'({len(self.search_groups.keys())}).') + for key, choice in model.items(): if isinstance(choice, DumpChosen): assert choice.meta is not None, ( From 95d8087b685ef8a94db24d360fd9fd34a5a33025 Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Fri, 3 Feb 2023 10:51:08 +0800 Subject: [PATCH 09/12] fix finetune_step model bug --- .../engine/runner/evolution_search_loop.py | 23 +++++++------------ mmrazor/engine/runner/utils/check.py | 9 ++++++-- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index 5fbe9ac36..49b6af219 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -207,11 +207,8 @@ def sample_candidates(self) -> None: if self.runner.rank == 0: while len(self.candidates) < self.num_candidates: candidate = self.model.mutator.sample_choices() - self.model.mutator.set_choices(candidate) - _, sliced_model = export_fix_subnet( - self.model, slice_weight=True) - self.finetune_step(sliced_model) - is_pass, result = self._check_constraints(sliced_model) + self.finetune_step(self.model) + is_pass, result = self._check_constraints(candidate) if is_pass: self.candidates.append(candidate) candidates_resources.append(result) @@ -258,10 +255,7 @@ def gen_mutation_candidates(self): break mutation_candidate = self._mutation() - self.model.mutator.set_choices(mutation_candidate) - _, sliced_model = export_fix_subnet(self.model, slice_weight=True) - - is_pass, result = self._check_constraints(sliced_model) + is_pass, result = self._check_constraints(mutation_candidate) if is_pass: mutation_candidates.append(mutation_candidate) mutation_resources.append(result) @@ -283,10 +277,7 @@ def gen_crossover_candidates(self): break crossover_candidate = self._crossover() - self.model.mutator.set_choices(crossover_candidate) - _, sliced_model = export_fix_subnet(self.model, slice_weight=True) - - is_pass, result = self._check_constraints(sliced_model) + is_pass, result = self._check_constraints(crossover_candidate) if is_pass: crossover_candidates.append(crossover_candidate) crossover_resources.append(result) @@ -401,14 +392,16 @@ def _save_searcher_ckpt(self) -> None: if osp.isfile(ckpt_path): os.remove(ckpt_path) - def _check_constraints(self, model) -> Tuple[bool, Dict]: + def _check_constraints( + self, random_subnet: SupportRandomSubnet) -> Tuple[bool, Dict]: """Check whether is beyond constraints. Returns: bool, result: The result of checking. """ is_pass, results = check_subnet_resources( - model=model, + model=self.model, + subnet=random_subnet, estimator=self.estimator, constraints_range=self.constraints_range) diff --git a/mmrazor/engine/runner/utils/check.py b/mmrazor/engine/runner/utils/check.py index 6c567e09d..ad774f647 100644 --- a/mmrazor/engine/runner/utils/check.py +++ b/mmrazor/engine/runner/utils/check.py @@ -4,6 +4,8 @@ import torch from mmrazor.models import ResourceEstimator +from mmrazor.structures import export_fix_subnet +from mmrazor.utils import SupportRandomSubnet try: from mmdet.models.detectors import BaseDetector @@ -15,6 +17,7 @@ @torch.no_grad() def check_subnet_resources( model, + subnet: SupportRandomSubnet, estimator: ResourceEstimator, constraints_range: Dict[str, Any] = dict(flops=(0, 330)) ) -> Tuple[bool, Dict]: @@ -26,9 +29,11 @@ def check_subnet_resources( if constraints_range is None: return True, dict() - assert hasattr(model, 'architecture') - model_to_check = model.architecture # type: ignore + assert hasattr(model, 'mutator') and hasattr(model, 'architecture') + model.mutator.set_choices(subnet) + _, sliced_model = export_fix_subnet(model, slice_weight=True) + model_to_check = sliced_model.architecture # type: ignore if isinstance(model_to_check, BaseDetector): results = estimator.estimate(model=model_to_check.backbone) else: From b04f8f25a3d3b09b0ce4c8b719971fa20c0bff9b Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Tue, 7 Feb 2023 00:52:34 +0800 Subject: [PATCH 10/12] update nsga2 search with pymoo_v0.60 --- ...ganetv2_mobilenet_search_8xb256_cifar10.py | 20 +- .../nsganetv2_mobilenet_search_8xb256_in1k.py | 3 +- .../engine/runner/evolution_search_loop.py | 25 +- .../engine/runner/nsganetv2_search_loop.py | 131 ++-- .../runner/utils/pymoo_utils/__init__.py | 7 + .../{ => pymoo_utils}/high_tradeoff_points.py | 0 .../runner/utils/pymoo_utils/problems.py} | 42 +- mmrazor/models/task_modules/__init__.py | 1 - .../multi_object_optimizer/__init__.py | 9 - .../multi_object_optimizer/base_optimizer.py | 210 ------ .../genetic_optimizer.py | 87 --- .../multi_object_optimizer/nsga2_optimizer.py | 149 ---- .../problem/__init__.py | 5 - .../problem/base_problem.py | 327 --------- .../problem/subset_problem.py | 34 - .../multi_object_optimizer/utils/__init__.py | 1 - .../utils/domin_matrix.py | 134 ---- .../multi_object_optimizer/utils/helper.py | 668 ------------------ .../multi_object_optimizer/utils/selection.py | 490 ------------- .../predictor/metric_predictor.py | 13 +- 20 files changed, 143 insertions(+), 2213 deletions(-) create mode 100644 mmrazor/engine/runner/utils/pymoo_utils/__init__.py rename mmrazor/engine/runner/utils/{ => pymoo_utils}/high_tradeoff_points.py (100%) rename mmrazor/{models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py => engine/runner/utils/pymoo_utils/problems.py} (50%) delete mode 100644 mmrazor/models/task_modules/multi_object_optimizer/__init__.py delete mode 100644 mmrazor/models/task_modules/multi_object_optimizer/base_optimizer.py delete mode 100644 mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py delete mode 100644 mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py delete mode 100644 mmrazor/models/task_modules/multi_object_optimizer/problem/__init__.py delete mode 100644 mmrazor/models/task_modules/multi_object_optimizer/problem/base_problem.py delete mode 100644 mmrazor/models/task_modules/multi_object_optimizer/problem/subset_problem.py delete mode 100644 mmrazor/models/task_modules/multi_object_optimizer/utils/__init__.py delete mode 100644 mmrazor/models/task_modules/multi_object_optimizer/utils/domin_matrix.py delete mode 100644 mmrazor/models/task_modules/multi_object_optimizer/utils/helper.py delete mode 100644 mmrazor/models/task_modules/multi_object_optimizer/utils/selection.py diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py index f998f2386..c731c5486 100644 --- a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py @@ -11,13 +11,13 @@ type='mmrazor.NSGA2SearchLoop', dataloader=_base_.val_dataloader, evaluator=_base_.val_evaluator, - max_epochs=4, - num_candidates=4, - top_k=2, - num_mutation=2, - num_crossover=2, - mutate_prob=0.1, - constraints_range=dict(flops=(0., 360.)), + max_epochs=30, + num_candidates=50, + top_k=10, + num_mutation=25, + num_crossover=25, + mutate_prob=0.3, + constraints_range=dict(flops=(0., 330.)), score_key='accuracy/top1', predictor_cfg=dict( type='mmrazor.MetricPredictor', @@ -27,14 +27,14 @@ finetune_cfg=dict( model=_base_.model, train_dataloader=_base_.train_dataloader, - train_cfg=dict(by_epoch=True, max_epochs=1), + train_cfg=dict(by_epoch=True, max_epochs=2), optim_wrapper=dict( type='OptimWrapper', optimizer=dict( type='SGD', - lr=0.025, + lr=0.1, momentum=0.9, - weight_decay=3e-4, + weight_decay=1e-4, nesterov=True)), param_scheduler=_base_.param_scheduler, default_hooks=_base_.default_hooks, diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py index 252e08c88..1a191bea2 100644 --- a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py @@ -1,4 +1,4 @@ -_base_ = ['./nsganetv2_mobilenet_supernet_8x128_in1k.py'] +_base_ = ['./nsganetv2_mobilenet_supernet_8xb128_in1k.py'] model = dict(norm_training=True) @@ -6,7 +6,6 @@ _delete_=True, type='mmrazor.NSGA2SearchLoop', dataloader=_base_.val_dataloader, - finetune_dataloader=_base_.train_dataloader, evaluator=_base_.val_evaluator, max_epochs=4, num_candidates=4, diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index 49b6af219..2c6f6b6ff 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -9,7 +9,6 @@ import numpy as np import torch from mmengine import fileio -from mmengine.dist import broadcast_object_list from mmengine.evaluator import Evaluator from mmengine.runner import EpochBasedTrainLoop, Runner from mmengine.utils import is_list_of @@ -107,7 +106,7 @@ def __init__(self, self.crossover_prob = crossover_prob self.max_keep_ckpts = max_keep_ckpts self.resume_from = resume_from - self.trade_off = dict(max_score_key=40) + self.trade_off = dict(max_score_key=100) self.fp16 = False if init_candidates is None: @@ -204,17 +203,14 @@ def sample_candidates(self) -> None: """Update candidate pool contains specified number of candicates.""" candidates_resources = [] init_candidates = len(self.candidates) - if self.runner.rank == 0: - while len(self.candidates) < self.num_candidates: - candidate = self.model.mutator.sample_choices() - self.finetune_step(self.model) - is_pass, result = self._check_constraints(candidate) - if is_pass: - self.candidates.append(candidate) - candidates_resources.append(result) - self.candidates = Candidates(self.candidates.data) - else: - self.candidates = Candidates([dict(a=0)] * self.num_candidates) + while len(self.candidates) < self.num_candidates: + candidate = self.model.mutator.sample_choices() + self.finetune_step(self.model) + is_pass, result = self._check_constraints(candidate) + if is_pass: + self.candidates.append(candidate) + candidates_resources.append(result) + self.candidates = Candidates(self.candidates.data) if len(candidates_resources) > 0: self.candidates.update_resources( @@ -223,9 +219,6 @@ def sample_candidates(self) -> None: assert init_candidates + len( candidates_resources) == self.num_candidates - # broadcast candidates to val with multi-GPUs. - broadcast_object_list([self.candidates]) - def update_candidates_scores(self) -> None: """Validate candicate one by one from the candicate pool, and update top-k candicates.""" diff --git a/mmrazor/engine/runner/nsganetv2_search_loop.py b/mmrazor/engine/runner/nsganetv2_search_loop.py index 0a060c384..9eaa1f3d3 100644 --- a/mmrazor/engine/runner/nsganetv2_search_loop.py +++ b/mmrazor/engine/runner/nsganetv2_search_loop.py @@ -1,21 +1,27 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os import os.path as osp import numpy as np from mmengine import fileio -from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting -from mmrazor.models.task_modules import (AuxiliarySingleLevelProblem, - GeneticOptimizer, NSGA2Optimizer, - SubsetProblem) +try: + from pymoo.algorithms.moo.nsga2 import NSGA2 + from pymoo.algorithms.soo.nonconvex.ga import GA + from pymoo.optimize import minimize + from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting +except ImportError: + from mmrazor.utils import get_placeholder + NSGA2 = get_placeholder('pymoo') + GA = get_placeholder('pymoo') + minimize = get_placeholder('pymoo') + NonDominatedSorting = get_placeholder('pymoo') + from mmrazor.registry import LOOPS -from mmrazor.structures import Candidates, export_fix_subnet +from mmrazor.structures import Candidates from .attentive_search_loop import AttentiveSearchLoop -from .utils.high_tradeoff_points import HighTradeoffPoints - -# from pymoo.algorithms.moo.nsga2 import NSGA2 as NSGA2Optimizer -# from pymoo.algorithms.soo.nonconvex.ga import GA as GeneticOptimizer -# from pymoo.optimize import minimize +from .utils.pymoo_utils import (AuxiliarySingleLevelProblem, + HighTradeoffPoints, SubsetProblem) @LOOPS.register_module() @@ -35,18 +41,18 @@ def run_epoch(self) -> None: will be used in mutation and crossover. 4. Implement Mutation and crossover, generate better candidates. """ - archive = Candidates() + self.archive = Candidates() if len(self.candidates) > 0: for subnet, score, flops in zip( self.candidates.subnets, self.candidates.scores, self.candidates.resources('flops')): if self.trade_off['max_score_key'] != 0: score = self.trade_off['max_score_key'] - score - archive.append(subnet) - archive.set_score(-1, score) - archive.set_resource(-1, flops, 'flops') + self.archive.append(subnet) + self.archive.set_score(-1, score) + self.archive.set_resource(-1, flops, 'flops') - self.sample_candidates(random=(self._epoch == 0), archive=archive) + self.sample_candidates(random=(self._epoch == 0), archive=self.archive) self.update_candidates_scores() scores_before = self.top_k_candidates.scores @@ -110,36 +116,34 @@ def sample_candidates_with_nsga2(self, archive: Candidates, problem = AuxiliarySingleLevelProblem(self, len(fronts[0])) # initiate a multi-objective solver to optimize the problem - method = NSGA2Optimizer( + method = NSGA2( pop_size=4, sampling=fronts, # initialize with current nd archs eliminate_duplicates=True, logger=self.runner.logger) - # # kick-off the search - method.initialize(problem, n_gen=2, verbose=True) - result = method.solve() + result = minimize(problem, method, ('n_gen', 4), seed=1, verbose=True) - # check for duplicates check_list = [] - for x in result['pop'].get('X'): + for x in result.pop.get('X'): check_list.append(self.predictor.vector2model(x)) not_duplicate = np.logical_not( [x in archive.subnets for x in check_list]) - # extra process after nsga2 search - sub_problem = SubsetProblem( - result['pop'][not_duplicate].get('F')[:, 1], F[front_index, 1], - num_candidates) - sub_method = GeneticOptimizer( - pop_size=num_candidates, eliminate_duplicates=True) - sub_method.initialize(sub_problem, n_gen=4, verbose=False) - indices = sub_method.solve()['X'] + sub_problem = SubsetProblem(result.pop[not_duplicate].get('F')[:, 1], + F[front_index, 1], num_candidates) + + sub_method = GA(pop_size=num_candidates, eliminate_duplicates=True) + + sub_result = minimize( + sub_problem, sub_method, ('n_gen', 2), seed=1, verbose=False) + indices = sub_result.pop.get('X') + + _X = result.pop.get('X')[not_duplicate][indices] candidates = [] - pop = result['pop'][not_duplicate][indices] - for x in pop.get('X'): + for x in _X: x = x[0] if isinstance(x[0], list) else x candidates.append(self.predictor.vector2model(x)) @@ -150,29 +154,31 @@ def sort_candidates(self) -> None: assert self.trade_off is not None, ( '`self.trade_off` is required when sorting candidates in ' 'NSGA2SearchLoop. Got `self.trade_off` is None.') + ratio = self.trade_off.get('ratio', 1) max_score_key = self.trade_off.get('max_score_key', 100) max_score_key = np.array(max_score_key) - multi_obj_score = [] + patches = [] for score, flops in zip(self.candidates.scores, self.candidates.resources('flops')): - multi_obj_score.append((score, flops)) - multi_obj_score = np.array(multi_obj_score) + patches.append((score, flops)) + patches = np.array(patches) if max_score_key != 0: - multi_obj_score[:, 0] = max_score_key - multi_obj_score[:, 0] + patches[:, 0] = max_score_key - patches[:, 0] # type: ignore - sort_idx = np.argsort(multi_obj_score[:, 0]) - F = multi_obj_score[sort_idx] + sort_idx = np.argsort(patches[:, 0]) # type: ignore + F = patches[sort_idx] - dm = HighTradeoffPoints(ratio, n_survive=len(multi_obj_score)) + dm = HighTradeoffPoints(ratio, n_survive=len(patches)) candidate_index = dm.do(F) candidate_index = sort_idx[candidate_index] - self.candidates = [self.candidates[idx] for idx in candidate_index] + self.candidates = \ + [self.candidates[idx] for idx in candidate_index] # type: ignore - def _save_searcher_ckpt(self, archive=[]): + def _save_searcher_ckpt(self): """Save searcher ckpt, which is different from common ckpt. It mainly contains the candicate pool, the top-k candicates with scores @@ -180,10 +186,10 @@ def _save_searcher_ckpt(self, archive=[]): """ if self.runner.rank == 0: rmse, rho, tau = 0, 0, 0 - if len(archive) > 0: - top1_err_pred = self.fit_predictor(archive) + if len(self.archive) > 0: + top1_err_pred = self.fit_predictor(self.archive) rmse, rho, tau = self.predictor.get_correlation( - top1_err_pred, np.array([x[1] for x in archive])) + top1_err_pred, np.array(self.archive.scores)) save_for_resume = dict() save_for_resume['_epoch'] = self._epoch @@ -207,29 +213,38 @@ def _save_searcher_ckpt(self, archive=[]): step_str += f'step: {step}: ' step_str += f'{candidates[0][self.score_key]}\n' self.runner.logger.info( - f'Epoch:[{self._epoch + 1}/{self._max_epochs}], ' - f'top1_score: {step_str} ' + f'Epoch:[{self._epoch}/{self._max_epochs}] ' + f'Top1_score: {step_str} ' f'{correlation_str}') else: self.runner.logger.info( - f'Epoch:[{self._epoch + 1}/{self._max_epochs}], ' - f'top1_score: {self.top_k_candidates.scores[0]} ' + f'Epoch:[{self._epoch}/{self._max_epochs}] ' + f'Top1_score: {self.top_k_candidates.scores[0]} ' f'{correlation_str}') - def fit_predictor(self, candidates): - """anticipate testfn training(err rate).""" - inputs = [export_fix_subnet(x) for x in candidates.subnets] - inputs = np.array([self.predictor.model2vector(x) for x in inputs]) + if self.max_keep_ckpts > 0: + cur_ckpt = self._epoch + 1 + redundant_ckpts = range(1, cur_ckpt - self.max_keep_ckpts) + for _step in redundant_ckpts: + ckpt_path = osp.join(self.runner.work_dir, + f'search_epoch_{_step}.pkl') + if osp.isfile(ckpt_path): + os.remove(ckpt_path) - targets = np.array([x[1] for x in candidates]) + def fit_predictor(self, candidates): + """Predict performance using predictor.""" + assert self.predictor.initialize is True - if not self.predictor.pretrained: - self.predictor.fit(inputs, targets) + metrics = [] + for i, candidate in enumerate(candidates.subnets): + self.model.mutator.set_choices(candidate) + metric = self._val_candidate(use_predictor=True) + metrics.append(metric[self.score_key]) - metrics = self.predictor.predict(inputs) - if self.max_score_key != 0: - for i in range(len(metrics)): - metrics[i] = self.max_score_key - metrics[i] + max_score_key = self.trade_off.get('max_score_key', 0.) + if max_score_key != 0: + for m in metrics: + m = max_score_key - m return metrics def finetune_step(self, model): diff --git a/mmrazor/engine/runner/utils/pymoo_utils/__init__.py b/mmrazor/engine/runner/utils/pymoo_utils/__init__.py new file mode 100644 index 000000000..9fa1ac34b --- /dev/null +++ b/mmrazor/engine/runner/utils/pymoo_utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .high_tradeoff_points import HighTradeoffPoints +from .problems import AuxiliarySingleLevelProblem, SubsetProblem + +__all__ = [ + 'AuxiliarySingleLevelProblem', 'SubsetProblem', 'HighTradeoffPoints' +] diff --git a/mmrazor/engine/runner/utils/high_tradeoff_points.py b/mmrazor/engine/runner/utils/pymoo_utils/high_tradeoff_points.py similarity index 100% rename from mmrazor/engine/runner/utils/high_tradeoff_points.py rename to mmrazor/engine/runner/utils/pymoo_utils/high_tradeoff_points.py diff --git a/mmrazor/models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py b/mmrazor/engine/runner/utils/pymoo_utils/problems.py similarity index 50% rename from mmrazor/models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py rename to mmrazor/engine/runner/utils/pymoo_utils/problems.py index e1badd52b..afe831853 100644 --- a/mmrazor/models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py +++ b/mmrazor/engine/runner/utils/pymoo_utils/problems.py @@ -1,17 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np -from pymoo.core.problem import Problem as BaseProblem +from pymoo.core.problem import Problem -class AuxiliarySingleLevelProblem(BaseProblem): +class AuxiliarySingleLevelProblem(Problem): """The optimization problem for finding the next N candidate architectures.""" - def __init__(self, searcher, dim=15, sec_obj='flops'): + def __init__(self, search_loop, dim=15, sec_obj='flops'): super().__init__(n_var=dim, n_obj=2, vtype=np.int32) - self.searcher = searcher - self.predictor = self.searcher.predictor + self.search_loop = search_loop + self.predictor = self.search_loop.predictor self.sec_obj = sec_obj self.xl = np.zeros(self.n_var) @@ -34,8 +34,38 @@ def _evaluate(self, x, out, *args, **kwargs): top1_err = self.predictor.handler.predict(x)[:, 0] for i in range(len(x)): candidate = self.predictor.vector2model(x[i]) - _, resource = self.searcher._check_constraints(candidate) + _, resource = self.search_loop._check_constraints(candidate) f[i, 0] = top1_err[i] f[i, 1] = resource[self.sec_obj] out['F'] = f + + +class SubsetProblem(Problem): + """select a subset to diversify the pareto front.""" + + def __init__(self, candidates, archive, K): + super().__init__( + n_var=len(candidates), + n_obj=1, + n_constr=1, + xl=0, + xu=1, + type_var=bool) + self.archive = archive + self.candidates = candidates + self.n_max = K + + def _evaluate(self, x, out, *args, **kwargs): + f = np.full((x.shape[0], 1), np.nan) + g = np.full((x.shape[0], 1), np.nan) + + for i, _x in enumerate(x): + # append selected candidates to archive then sort + tmp = np.sort(np.concatenate((self.archive, self.candidates[_x]))) + f[i, 0] = np.std(np.diff(tmp)) + # we penalize if the number of selected candidates is not exactly K + g[i, 0] = (self.n_max - np.sum(_x))**2 + + out['F'] = f + out['G'] = g diff --git a/mmrazor/models/task_modules/__init__.py b/mmrazor/models/task_modules/__init__.py index a4690a683..931278b8a 100644 --- a/mmrazor/models/task_modules/__init__.py +++ b/mmrazor/models/task_modules/__init__.py @@ -2,7 +2,6 @@ from .delivery import * # noqa: F401,F403 from .demo_inputs import * # noqa: F401,F403 from .estimators import ResourceEstimator -from .multi_object_optimizer import * # noqa: F401,F403 from .predictor import * # noqa: F401,F403 from .recorder import * # noqa: F401,F403 from .tracer import * # noqa: F401,F403 diff --git a/mmrazor/models/task_modules/multi_object_optimizer/__init__.py b/mmrazor/models/task_modules/multi_object_optimizer/__init__.py deleted file mode 100644 index 710587f59..000000000 --- a/mmrazor/models/task_modules/multi_object_optimizer/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .genetic_optimizer import GeneticOptimizer -from .nsga2_optimizer import NSGA2Optimizer -from .problem import AuxiliarySingleLevelProblem, SubsetProblem - -__all__ = [ - 'AuxiliarySingleLevelProblem', 'SubsetProblem', 'GeneticOptimizer', - 'NSGA2Optimizer' -] diff --git a/mmrazor/models/task_modules/multi_object_optimizer/base_optimizer.py b/mmrazor/models/task_modules/multi_object_optimizer/base_optimizer.py deleted file mode 100644 index 796fa37a2..000000000 --- a/mmrazor/models/task_modules/multi_object_optimizer/base_optimizer.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# copied and modified from https://github.com/anyoptimization/pymoo -from abc import abstractmethod - -import numpy as np -from pymoo.util.optimum import filter_optimum - -# from pymoo.core.evaluator import Evaluator -# from pymoo.core.population import Population -from .utils.helper import Evaluator, Population - - -class BaseOptimizer(): - """This class represents the abstract class for any algorithm to be - implemented. The solve method provides a wrapper function which does - validate the input. - - Args: - problem : - Problem to be solved by the algorithm - verbose (bool): - If true information during the algorithm execution are displayed - save_history (bool): - If true, a current snapshot of each generation is saved. - pf (numpy.array): - The Pareto-front for the given problem. If provided performance - metrics are printed during execution. - return_least_infeasible (bool): - Whether the algorithm should return the least infeasible solution, - if no solution was found. - evaluator : :class:`~pymoo.model.evaluator.Evaluator` - The evaluator which can be used to make modifications before - calling the evaluate function of a problem. - """ - - def __init__(self, **kwargs): - # ! - # DEFAULT SETTINGS OF ALGORITHM - # ! - # set the display variable supplied to the algorithm - self.display = kwargs.get('display') - self.logger = kwargs.get('logger') - # ! - # Attributes to be set later on for each problem run - # ! - # the optimization problem as an instance - self.problem = None - - self.return_least_infeasible = None - # whether the history should be saved or not - self.save_history = None - # whether the algorithm should print output in this run or not - self.verbose = None - # the random seed that was used - self.seed = None - # the pareto-front of the problem - if it exist or passed - self.pf = None - # the function evaluator object (can be used to inject code) - self.evaluator = None - # the current number of generation or iteration - self.n_gen = None - # the history object which contains the list - self.history = None - # the current solutions stored - here considered as population - self.pop = None - # the optimum found by the algorithm - self.opt = None - # can be used to store additional data in submodules - self.data = {} - - def initialize( - self, - problem, - pf=True, - evaluator=None, - # START Default minimize - seed=None, - verbose=False, - save_history=False, - return_least_infeasible=False, - # END Default minimize - n_gen=1, - display=None, - # END Overwrite by minimize - **kwargs): - - # set the problem that is optimized for the current run - self.problem = problem - - # set the provided pareto front - self.pf = pf - - # by default make sure an evaluator exists if nothing is passed - if evaluator is None: - evaluator = Evaluator() - self.evaluator = evaluator - - # ! - # START Default minimize - # ! - # if this run should be verbose or not - self.verbose = verbose - # whether the least infeasible should be returned or not - self.return_least_infeasible = return_least_infeasible - # whether the history should be stored or not - self.save_history = save_history - - # set the random seed in the algorithm object - self.seed = seed - if self.seed is None: - self.seed = np.random.randint(0, 10000000) - np.random.seed(self.seed) - # ! - # END Default minimize - # ! - - if display is not None: - self.display = display - - # other run dependent variables that are reset - self.n_gen = n_gen - self.history = [] - self.pop = Population() - self.opt = None - - def solve(self): - - # the result object to be finally returned - res = {} - - # initialize the first population and evaluate it - self._initialize() - self._set_optimum() - - self.current_gen = 0 - # while termination criterion not fulfilled - while self.current_gen < self.n_gen: - self.current_gen += 1 - self.next() - - # store the resulting population - res['pop'] = self.pop - - # get the optimal solution found - opt = self.opt - - # if optimum is not set - if len(opt) == 0: - opt = None - - # if no feasible solution has been found - elif not np.any(opt.get('feasible')): - if self.return_least_infeasible: - opt = filter_optimum(opt, least_infeasible=True) - else: - opt = None - - # set the optimum to the result object - res['opt'] = opt - - # if optimum is set to none to not report anything - if opt is None: - X, F, CV, G = None, None, None, None - - # otherwise get the values from the population - else: - X, F, CV, G = self.opt.get('X', 'F', 'CV', 'G') - - # if single-objective problem and only one solution was found - if self.problem.n_obj == 1 and len(X) == 1: - X, F, CV, G = X[0], F[0], CV[0], G[0] - - # set all the individual values - res['X'], res['F'], res['CV'], res['G'] = X, F, CV, G - - # create the result object - res['problem'], res['pf'] = self.problem, self.pf - res['history'] = self.history - - return res - - def next(self): - # call next of the implementation of the algorithm - self._next() - - # set the optimum - only done if the algorithm did not do it yet - self._set_optimum() - - # do what needs to be done each generation - self._each_iteration() - - # method that is called each iteration to call some algorithms regularly - def _each_iteration(self): - - # display the output if defined by the algorithm - if self.logger: - self.logger.info(f'Generation:[{self.current_gen}/{self.n_gen}] ' - f'evaluate {self.evaluator.n_eval} solutions, ' - f'find {len(self.opt)} optimal solution.') - - def _finalize(self): - pass - - @abstractmethod - def _initialize(self): - pass - - @abstractmethod - def _next(self): - pass diff --git a/mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py b/mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py deleted file mode 100644 index 5639bc912..000000000 --- a/mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# copied and modified from https://github.com/anyoptimization/pymoo -import numpy as np -from pymoo.algorithms.soo.nonconvex.ga import FitnessSurvival -from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting - -from mmrazor.registry import TASK_UTILS -from .nsga2_optimizer import NSGA2Optimizer -from .utils.helper import Individual, Population -from .utils.selection import (BinaryCrossover, MyMutation, MySampling, - TournamentSelection) - - -@TASK_UTILS.register_module() -class GeneticOptimizer(NSGA2Optimizer): - """Genetic Algorithm.""" - - def __init__( - self, - pop_size=100, - sampling=MySampling(), - selection=TournamentSelection(func_comp='comp_by_cv_and_fitness'), - crossover=BinaryCrossover(), - mutation=MyMutation(), - eliminate_duplicates=True, - n_offsprings=None, - display=None, - **kwargs): - """ - Args: - pop_size : {pop_size} - sampling : {sampling} - selection : {selection} - crossover : {crossover} - mutation : {mutation} - eliminate_duplicates : {eliminate_duplicates} - n_offsprings : {n_offsprings} - - """ - - super().__init__( - pop_size=pop_size, - sampling=sampling, - selection=selection, - crossover=crossover, - mutation=mutation, - survival=FitnessSurvival(), - eliminate_duplicates=eliminate_duplicates, - n_offsprings=n_offsprings, - display=display, - **kwargs) - - def _set_optimum(self, force=False): - pop = self.pop - self.opt = filter_optimum(pop, least_infeasible=True) - - -def filter_optimum(pop, least_infeasible=False): - # first only choose feasible solutions - ret = pop[pop.get('feasible')[:, 0]] - - # if at least one feasible solution was found - if len(ret) > 0: - - # then check the objective values - F = ret.get('F') - - if F.shape[1] > 1: - Index = NonDominatedSorting().do(F, only_non_dominated_front=True) - ret = ret[Index] - - else: - ret = ret[np.argmin(F)] - - # no feasible solution was found - else: - # if flag enable report the least infeasible - if least_infeasible: - ret = pop[np.argmin(pop.get('CV'))] - # otherwise just return none - else: - ret = None - - if isinstance(ret, Individual): - ret = Population().create(ret) - - return ret diff --git a/mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py b/mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py deleted file mode 100644 index 2c31dc741..000000000 --- a/mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# copied and modified from https://github.com/anyoptimization/pymoo -import numpy as np - -from mmrazor.registry import TASK_UTILS -from .base_optimizer import BaseOptimizer -from .utils.helper import (DefaultDuplicateElimination, Individual, - Initialization, Survival) -from .utils.selection import (IntegerFromFloatMutation, Mating, PointCrossover, - TournamentSelection, binary_tournament) - -# from pymoo.algorithms.moo.nsga2 import binary_tournament -# from pymoo.core.mating import Mating -# from pymoo.core.survival import Survival -# from pymoo.core.individual import Individual -# from pymoo.core.initialization import Initialization -# from pymoo.core.duplicate import DefaultDuplicateElimination -# from pymoo.operators.crossover.pntx import PointCrossover -# from pymoo.operators.selection.tournament import TournamentSelection -# from .packages.selection import IntegerFromFloatMutation - - -@TASK_UTILS.register_module() -class NSGA2Optimizer(BaseOptimizer): - """Implementation of NSGA2 search method. - - Args: - pop_size : {pop_size} - sampling : {sampling} - selection : {selection} - crossover : {crossover} - mutation : {mutation} - eliminate_duplicates : {eliminate_duplicates} - n_offsprings : {n_offsprings} - """ - - def __init__(self, - pop_size=100, - sampling=None, - selection=TournamentSelection(func_comp=binary_tournament), - crossover=PointCrossover(n_points=2), - mutation=IntegerFromFloatMutation(eta=1.0), - eliminate_duplicates=True, - n_offsprings=None, - display=None, - survival=Survival(), - repair=None, - **kwargs): - super().__init__( - pop_size=pop_size, - sampling=sampling, - selection=selection, - crossover=crossover, - mutation=mutation, - survival=survival, - eliminate_duplicates=eliminate_duplicates, - n_offsprings=n_offsprings, - display=display, - **kwargs) - - # the population size used - self.pop_size = pop_size - - # the survival for the genetic algorithm - self.survival = Survival() - - # number of offsprings to generate through recombination - self.n_offsprings = n_offsprings - - # if the number of offspring is not set - if self.n_offsprings is None: - self.n_offsprings = pop_size - - # the object to be used to represent an individual - self.individual = Individual(rank=np.inf, crowding=-1) - - # set the duplicate detection class - if isinstance(eliminate_duplicates, bool): - if eliminate_duplicates: - self.eliminate_duplicates = DefaultDuplicateElimination() - else: - self.eliminate_duplicates = None - else: - self.eliminate_duplicates = eliminate_duplicates - - self.initialization = Initialization( - sampling, - individual=self.individual, - repair=repair, - eliminate_duplicates=self.eliminate_duplicates) - - self.mating = Mating( - selection, - crossover, - mutation, - repair=repair, - eliminate_duplicates=self.eliminate_duplicates, - n_max_iterations=100) - - # other run specific data updated whenever solve is called - self.n_gen = None - self.pop = None - self.off = None - - def _initialize(self): - - # create the initial population - pop = self.initialization.do( - self.problem, self.pop_size, algorithm=self) - - # then evaluate using the objective function - self.evaluator.eval(self.problem, pop, algorithm=self) - - # that call is a dummy survival to set attributes - # that are necessary for the mating selection - if self.survival: - pop = self.survival.do(self.problem, pop, len(pop), algorithm=self) - - self.pop, self.off = pop, pop - - def _next(self): - - # do the mating using the current population - self.off = self.mating.do( - self.problem, - self.pop, - n_offsprings=self.n_offsprings, - algorithm=self) - - # if the mating could not generate any new offspring - if len(self.off) == 0: - return - - # evaluate the offspring - self.evaluator.eval(self.problem, self.off, algorithm=self) - - # merge the offsprings with the current population - self.pop = self.pop.merge(self.off) - - # the do survival selection - if self.survival: - self.pop = self.survival.do( - self.problem, self.pop, self.pop_size, algorithm=self) - - def _set_optimum(self, **kwargs): - if not np.any(self.pop.get('feasible')): - self.opt = self.pop[[np.argmin(self.pop.get('CV'))]] - else: - self.opt = self.pop[self.pop.get('rank') == 0] diff --git a/mmrazor/models/task_modules/multi_object_optimizer/problem/__init__.py b/mmrazor/models/task_modules/multi_object_optimizer/problem/__init__.py deleted file mode 100644 index 426e6283d..000000000 --- a/mmrazor/models/task_modules/multi_object_optimizer/problem/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .auxiliary_singlelevel_problem import AuxiliarySingleLevelProblem -from .subset_problem import SubsetProblem - -__all__ = ['AuxiliarySingleLevelProblem', 'SubsetProblem'] diff --git a/mmrazor/models/task_modules/multi_object_optimizer/problem/base_problem.py b/mmrazor/models/task_modules/multi_object_optimizer/problem/base_problem.py deleted file mode 100644 index b82f0be93..000000000 --- a/mmrazor/models/task_modules/multi_object_optimizer/problem/base_problem.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# copied and modified from https://github.com/anyoptimization/pymoo -from abc import abstractmethod - -import numpy as np - - -def at_least_2d_array(x, extend_as='row'): - if not isinstance(x, np.ndarray): - x = np.array([x]) - - if x.ndim == 1: - if extend_as == 'row': - x = x[None, :] - elif extend_as == 'column': - x = x[:, None] - - return x - - -class BaseProblem(): - """Superclass for each problem that is defined. - - It provides attributes such as the number of variables, number of - objectives or constraints. Also, the lower and upper bounds are stored. If - available the Pareto-front, nadir point and ideal point are stored. - """ - - def __init__(self, - n_var=-1, - n_obj=-1, - n_constr=0, - xl=None, - xu=None, - type_var=np.double, - evaluation_of='auto', - parallelization=None, - elementwise_evaluation=False, - callback=None): - """ - Args: - n_var (int): - number of variables - n_obj (int): - number of objectives - n_constr (int): - number of constraints - xl (np.array or int): - lower bounds for the variables. - xu (np.array or int): - upper bounds for the variables. - type_var (numpy type): - type of the variable to be evaluated. - elementwise_evaluation (bool): - parallelization (str or tuple): - See :ref:`nb_parallelization` for guidance on parallelization. - - """ - - # number of variable for this problem - self.n_var = n_var - - # type of the variable to be evaluated - self.type_var = type_var - - # number of objectives - self.n_obj = n_obj - - # number of constraints - self.n_constr = n_constr - - # allow just an integer for xl and xu if all bounds are equal - if n_var > 0 and not isinstance(xl, np.ndarray) and xl is not None: - self.xl = np.ones(n_var) * xl - else: - self.xl = xl - - if n_var > 0 and not isinstance(xu, np.ndarray) and xu is not None: - self.xu = np.ones(n_var) * xu - else: - self.xu = xu - - # the pareto set and front will be calculated only once. - self._pareto_front = None - self._pareto_set = None - self._ideal_point, self._nadir_point = None, None - - # actually defines what _evaluate is setting during the evaluation - if evaluation_of == 'auto': - # by default F is set, and G if the problem does have constraints - self.evaluation_of = ['F'] - if self.n_constr > 0: - self.evaluation_of.append('G') - else: - self.evaluation_of = evaluation_of - - self.elementwise_evaluation = elementwise_evaluation - - self.parallelization = parallelization - - # store the callback if defined - self.callback = callback - - def nadir_point(self): - """Return nadir_point (np.array): - - The nadir point for a multi-objective problem. - """ - # if the ideal point has not been calculated yet - if self._nadir_point is None: - - # calculate the pareto front if not happened yet - if self._pareto_front is None: - self.pareto_front() - - # if already done or it was successful - calculate the ideal point - if self._pareto_front is not None: - self._ideal_point = np.max(self._pareto_front, axis=0) - - return self._nadir_point - - def ideal_point(self): - """ - Returns - ------- - ideal_point (np.array): - The ideal point for a multi-objective problem. If single-objective - it returns the best possible solution. - """ - - # if the ideal point has not been calculated yet - if self._ideal_point is None: - - # calculate the pareto front if not happened yet - if self._pareto_front is None: - self.pareto_front() - - # if already done or it was successful - calculate the ideal point - if self._pareto_front is not None: - self._ideal_point = np.min(self._pareto_front, axis=0) - - return self._ideal_point - - def pareto_front(self, - *args, - use_cache=True, - exception_if_failing=True, - **kwargs): - """ - Args: - args : Same problem implementation need some more information to - create the Pareto front. - exception_if_failing (bool): - Whether to throw an exception when generating the Pareto front - has failed. - use_cache (bool): - Whether to use the cache if the Pareto front. - - Returns: - P (np.array): - The Pareto front of a given problem. - - """ - if not use_cache or self._pareto_front is None: - try: - pf = self._calc_pareto_front(*args, **kwargs) - if pf is not None: - self._pareto_front = at_least_2d_array(pf) - - except Exception as e: - if exception_if_failing: - raise e - - return self._pareto_front - - def pareto_set(self, *args, use_cache=True, **kwargs): - """ - Returns: - S (np.array): - Returns the pareto set for a problem. - """ - if not use_cache or self._pareto_set is None: - self._pareto_set = at_least_2d_array( - self._calc_pareto_set(*args, **kwargs)) - - return self._pareto_set - - def evaluate(self, - X, - *args, - return_values_of='auto', - return_as_dictionary=False, - **kwargs): - """Evaluate the given problem. - - The function values set as defined in the function. - The constraint values are meant to be positive if infeasible. - - Args: - - X (np.array): - A two dimensional matrix where each row is a point to evaluate - and each column a variable. - - return_as_dictionary (bool): - If this is true than only one object, a dictionary, - is returned. - return_values_of (list of strings): - Allowed is ["F", "CV", "G", "dF", "dG", "dCV", "feasible"] - where the d stands for derivative and h stands for hessian - matrix. - - - Returns: - A dictionary, if return_as_dictionary enabled, or a list of values - as defined in return_values_of. - """ - - # call the callback of the problem - if self.callback is not None: - self.callback(X) - - only_single_value = len(np.shape(X)) == 1 - X = np.atleast_2d(X) - - # check the dimensionality of the problem and the given input - if X.shape[1] != self.n_var: - raise Exception('Input dimension %s are not equal to n_var %s!' % - (X.shape[1], self.n_var)) - - if type(return_values_of) == str and return_values_of == 'auto': - return_values_of = ['F'] - if self.n_constr > 0: - return_values_of.append('CV') - - out = {} - for val in return_values_of: - out[val] = None - - out = self._evaluate_batch(X, False, out, *args, **kwargs) - - # if constraint violation should be returned as well - if self.n_constr == 0: - CV = np.zeros([X.shape[0], 1]) - else: - CV = self.calc_constraint_violation(out['G']) - - if 'CV' in return_values_of: - out['CV'] = CV - - # if an additional boolean flag for feasibility should be returned - if 'feasible' in return_values_of: - out['feasible'] = (CV <= 0) - - # if asked for a value but not set in the evaluation set to None - for val in return_values_of: - if val not in out: - out[val] = None - - if only_single_value: - for key in out.keys(): - if out[key] is not None: - out[key] = out[key][0, :] - - if return_as_dictionary: - return out - else: - - # if just a single value do not return a tuple - if len(return_values_of) == 1: - return out[return_values_of[0]] - else: - return tuple([out[val] for val in return_values_of]) - - def _evaluate_batch(self, X, calc_gradient, out, *args, **kwargs): - self._evaluate(X, out, *args, **kwargs) - for key in out.keys(): - if len(np.shape(out[key])) == 1: - out[key] = out[key][:, None] - - return out - - @abstractmethod - def _evaluate(self, x, out, *args, **kwargs): - pass - - def has_bounds(self): - return self.xl is not None and self.xu is not None - - def bounds(self): - return self.xl, self.xu - - def name(self): - return self.__class__.__name__ - - def _calc_pareto_front(self, *args, **kwargs): - """Method that either loads or calculates the pareto front. This is - only done ones and the pareto front is stored. - - Returns: - pf (np.array): Pareto front as array. - """ - pass - - def _calc_pareto_set(self, *args, **kwargs): - pass - - # some problem information - def __str__(self): - s = '# name: %s\n' % self.name() - s += '# n_var: %s\n' % self.n_var - s += '# n_obj: %s\n' % self.n_obj - s += '# n_constr: %s\n' % self.n_constr - s += '# f(xl): %s\n' % self.evaluate(self.xl)[0] - s += '# f((xl+xu)/2): %s\n' % self.evaluate( - (self.xl + self.xu) / 2.0)[0] - s += '# f(xu): %s\n' % self.evaluate(self.xu)[0] - return s - - @staticmethod - def calc_constraint_violation(G): - if G is None: - return None - elif G.shape[1] == 0: - return np.zeros(G.shape[0])[:, None] - else: - return np.sum(G * (G > 0), axis=1)[:, None] diff --git a/mmrazor/models/task_modules/multi_object_optimizer/problem/subset_problem.py b/mmrazor/models/task_modules/multi_object_optimizer/problem/subset_problem.py deleted file mode 100644 index 2a0af040a..000000000 --- a/mmrazor/models/task_modules/multi_object_optimizer/problem/subset_problem.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import numpy as np - -from .base_problem import BaseProblem - - -class SubsetProblem(BaseProblem): - """select a subset to diversify the pareto front.""" - - def __init__(self, candidates, archive, K): - super().__init__( - n_var=len(candidates), - n_obj=1, - n_constr=1, - xl=0, - xu=1, - type_var=bool) - self.archive = archive - self.candidates = candidates - self.n_max = K - - def _evaluate(self, x, out, *args, **kwargs): - f = np.full((x.shape[0], 1), np.nan) - g = np.full((x.shape[0], 1), np.nan) - - for i, _x in enumerate(x): - # append selected candidates to archive then sort - tmp = np.sort(np.concatenate((self.archive, self.candidates[_x]))) - f[i, 0] = np.std(np.diff(tmp)) - # we penalize if the number of selected candidates is not exactly K - g[i, 0] = (self.n_max - np.sum(_x))**2 - - out['F'] = f - out['G'] = g diff --git a/mmrazor/models/task_modules/multi_object_optimizer/utils/__init__.py b/mmrazor/models/task_modules/multi_object_optimizer/utils/__init__.py deleted file mode 100644 index ef101fec6..000000000 --- a/mmrazor/models/task_modules/multi_object_optimizer/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmrazor/models/task_modules/multi_object_optimizer/utils/domin_matrix.py b/mmrazor/models/task_modules/multi_object_optimizer/utils/domin_matrix.py deleted file mode 100644 index cf06da663..000000000 --- a/mmrazor/models/task_modules/multi_object_optimizer/utils/domin_matrix.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import numpy as np - - -def get_relation(a, b, cva=None, cvb=None): - - if cva is not None and cvb is not None: - if cva < cvb: - return 1 - elif cvb < cva: - return -1 - - val = 0 - for i in range(len(a)): - if a[i] < b[i]: - # indifferent because once better and once worse - if val == -1: - return 0 - val = 1 - elif b[i] < a[i]: - # indifferent because once better and once worse - if val == 1: - return 0 - val = -1 - return val - - -def calc_domination_matrix_loop(F, G): - n = F.shape[0] - CV = np.sum(G * (G > 0).astype(np.float32), axis=1) - M = np.zeros((n, n)) - for i in range(n): - for j in range(i + 1, n): - M[i, j] = get_relation(F[i, :], F[j, :], CV[i], CV[j]) - M[j, i] = -M[i, j] - - return M - - -def calc_domination_matrix(F, _F=None, epsilon=0.0): - """ - if G is None or len(G) == 0: - constr = np.zeros((F.shape[0], F.shape[0])) - else: - # consider the constraint violation - # CV = Problem.calc_constraint_violation(G) - # constr = (CV < CV) * 1 + (CV > CV) * -1 - - CV = Problem.calc_constraint_violation(G)[:, 0] - constr = (CV[:, None] < CV) * 1 + (CV[:, None] > CV) * -1 - """ - - if _F is None: - _F = F - - # look at the obj for dom - n = F.shape[0] - m = _F.shape[0] - - L = np.repeat(F, m, axis=0) - R = np.tile(_F, (n, 1)) - - smaller = np.reshape(np.any(L + epsilon < R, axis=1), (n, m)) - larger = np.reshape(np.any(L > R + epsilon, axis=1), (n, m)) - - M = np.logical_and(smaller, np.logical_not(larger)) * 1 \ - + np.logical_and(larger, np.logical_not(smaller)) * -1 - - # if cv equal then look at dom - # M = constr + (constr == 0) * dom - - return M - - -def fast_non_dominated_sort(F, **kwargs): - M = calc_domination_matrix(F) - - # calculate the dominance matrix - n = M.shape[0] - - fronts = [] - - if n == 0: - return fronts - - # final rank that will be returned - n_ranked = 0 - ranked = np.zeros(n, dtype=np.int32) - is_dominating = [[] for _ in range(n)] - - # storage for the number of solutions dominated this one - n_dominated = np.zeros(n) - - current_front = [] - - for i in range(n): - - for j in range(i + 1, n): - rel = M[i, j] - if rel == 1: - is_dominating[i].append(j) - n_dominated[j] += 1 - elif rel == -1: - is_dominating[j].append(i) - n_dominated[i] += 1 - - if n_dominated[i] == 0: - current_front.append(i) - ranked[i] = 1.0 - n_ranked += 1 - - # append the first front to the current front - fronts.append(current_front) - - # while not all solutions are assigned to a pareto front - while n_ranked < n: - - next_front = [] - - # for each individual in the current front - for i in current_front: - - # all solutions that are dominated by this individuals - for j in is_dominating[i]: - n_dominated[j] -= 1 - if n_dominated[j] == 0: - next_front.append(j) - ranked[j] = 1.0 - n_ranked += 1 - - fronts.append(next_front) - current_front = next_front - - return fronts diff --git a/mmrazor/models/task_modules/multi_object_optimizer/utils/helper.py b/mmrazor/models/task_modules/multi_object_optimizer/utils/helper.py deleted file mode 100644 index b0362f333..000000000 --- a/mmrazor/models/task_modules/multi_object_optimizer/utils/helper.py +++ /dev/null @@ -1,668 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# copied from https://github.com/anyoptimization/pymoo -import copy - -import numpy as np -import scipy -from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting - - -def default_attr(pop): - return pop.get('X') - - -def cdist(x, y): - return scipy.spatial.distance.cdist(x, y) - - -class DuplicateElimination: - """Implementation of Elimination. - - func: function to execute. - """ - - def __init__(self, func=None) -> None: - super().__init__() - self.func = func - - if self.func is None: - self.func = default_attr - - def do(self, pop, *args, return_indices=False, to_itself=True): - original = pop - - if to_itself: - pop = pop[~self._do(pop, None, np.full(len(pop), False))] - - for arg in args: - if len(arg) > 0: - - if len(pop) == 0: - break - elif len(arg) == 0: - continue - else: - pop = pop[~self._do(pop, arg, np.full(len(pop), False))] - - if return_indices: - no_duplicate, is_duplicate = [], [] - H = set(pop) - - for Index, ind in enumerate(original): - if ind in H: - no_duplicate.append(Index) - else: - is_duplicate.append(Index) - - return pop, no_duplicate, is_duplicate - else: - return pop - - -class DefaultDuplicateElimination(DuplicateElimination): - """Implementation of DefaultDuplicate Elimination. - - epsilon(float): smallest dist for judge duplication. - """ - - def __init__(self, epsilon=1e-16, **kwargs) -> None: - super().__init__(**kwargs) - self.epsilon = epsilon - - def calc_dist(self, pop, other=None): - X = self.func(pop) - - if other is None: - D = cdist(X, X) - D[np.triu_indices(len(X))] = np.inf - else: - _X = self.func(other) - D = cdist(X, _X) - - return D - - def _do(self, pop, other, is_duplicate): - D = self.calc_dist(pop, other) - D[np.isnan(D)] = np.inf - - is_duplicate[np.any(D < self.epsilon, axis=1)] = True - return is_duplicate - - -class Individual: - """Class for each individual in search step.""" - - def __init__(self, - X=None, - F=None, - CV=None, - G=None, - feasible=None, - **kwargs) -> None: - self.X = X - self.F = F - self.CV = CV - self.G = G - self.feasible = feasible - self.data = kwargs - self.attr = set(self.__dict__.keys()) - - def has(self, key): - return key in self.attr or key in self.data - - def set(self, key, value): - if key in self.attr: - self.__dict__[key] = value - else: - self.data[key] = value - - def copy(self): - ind = copy.copy(self) - ind.data = self.data.copy() - return ind - - def get(self, keys): - if keys in self.data: - return self.data[keys] - elif keys in self.attr: - return self.__dict__[keys] - else: - return None - - -class Population(np.ndarray): - """Class for all the population in search step.""" - - def __new__(cls, n_individuals=0, individual=Individual()): - obj = super(Population, cls).__new__( - cls, n_individuals, dtype=individual.__class__).view(cls) - for Index in range(n_individuals): - obj[Index] = individual.copy() - obj.individual = individual - return obj - - def merge(self, a, b=None): - if b: - a, b = pop_from_array_or_individual(a), \ - pop_from_array_or_individual(b) - a.merge(b) - else: - other = pop_from_array_or_individual(a) - if len(self) == 0: - return other - else: - obj = np.concatenate([self, other]).view(Population) - obj.individual = self.individual - return obj - - def copy(self): - pop = Population(n_individuals=len(self), individual=self.individual) - for Index in range(len(self)): - pop[Index] = self[Index] - return pop - - def has(self, key): - return all([ind.has(key) for ind in self]) - - def __deepcopy__(self, memo): - return self.copy() - - @classmethod - def create(cls, *args): - pop = np.concatenate([ - pop_from_array_or_individual(arg) for arg in args - ]).view(Population) - pop.individual = Individual() - return pop - - def new(self, *args): - - if len(args) == 1: - return Population( - n_individuals=args[0], individual=self.individual) - else: - n = len(args[1]) if len(args) > 0 else 0 - pop = Population(n_individuals=n, individual=self.individual) - if len(args) > 0: - pop.set(*args) - return pop - - def collect(self, func, to_numpy=True): - val = [] - for Index in range(len(self)): - val.append(func(self[Index])) - if to_numpy: - val = np.array(val) - return val - - def set(self, *args): - - for Index in range(int(len(args) / 2)): - - key, values = args[Index * 2], args[Index * 2 + 1] - is_iterable = hasattr(values, - '__len__') and not isinstance(values, str) - - if is_iterable and len(values) != len(self): - raise Exception( - 'Population Set Attribute Error: ' - 'Number of values and population size do not match!') - - for Index in range(len(self)): - val = values[Index] if is_iterable else values - self[Index].set(key, val) - - return self - - def get(self, *args, to_numpy=True): - - val = {} - for c in args: - val[c] = [] - - for Index in range(len(self)): - - for c in args: - val[c].append(self[Index].get(c)) - - res = [val[c] for c in args] - if to_numpy: - res = [np.array(e) for e in res] - - if len(args) == 1: - return res[0] - else: - return tuple(res) - - def __array_finalize__(self, obj): - if obj is None: - return - self.individual = getattr(obj, 'individual', None) - - -def pop_from_array_or_individual(array, pop=None): - # the population type can be different - if pop is None: - pop = Population() - - # provide a whole population object - if isinstance(array, Population): - pop = array - elif isinstance(array, np.ndarray): - pop = pop.new('X', np.atleast_2d(array)) - elif isinstance(array, Individual): - pop = Population(1) - pop[0] = array - else: - return None - - return pop - - -class Initialization: - """Initiallize step.""" - - def __init__(self, - sampling, - individual=Individual(), - repair=None, - eliminate_duplicates=None) -> None: - - super().__init__() - self.sampling = sampling - self.individual = individual - self.repair = repair - self.eliminate_duplicates = eliminate_duplicates - - def do(self, problem, n_samples, **kwargs): - - # provide a whole population object - if isinstance(self.sampling, Population): - pop = self.sampling - - else: - pop = Population(0, individual=self.individual) - if isinstance(self.sampling, np.ndarray): - pop = pop.new('X', self.sampling) - else: - pop = self.sampling.do(problem, n_samples, pop=pop, **kwargs) - - # repair all solutions that are not already evaluated - if self.repair: - Index = [k for k in range(len(pop)) if pop[k].F is None] - pop = self.repair.do(problem, pop[Index], **kwargs) - - if self.eliminate_duplicates is not None: - pop = self.eliminate_duplicates.do(pop) - - return pop - - -def split_by_feasibility(pop, sort_infeasbible_by_cv=True): - CV = pop.get('CV') - - b = (CV <= 0) - - feasible = np.where(b)[0] - infeasible = np.where(np.logical_not(b))[0] - - if sort_infeasbible_by_cv: - infeasible = infeasible[np.argsort(CV[infeasible, 0])] - - return feasible, infeasible - - -class Survival: - """The survival process is implemented inheriting from this class, which - selects from a population only specific individuals to survive. - - Parameters - ---------- - filter_infeasible : bool - Whether for the survival infeasible solutions should be - filtered first - """ - - def __init__(self, filter_infeasible=True): - self.filter_infeasible = filter_infeasible - - def do(self, problem, pop, n_survive, return_indices=False, **kwargs): - - # if the split should be done beforehand - if self.filter_infeasible and problem.n_constr > 0: - feasible, infeasible = split_by_feasibility( - pop, sort_infeasbible_by_cv=True) - - # if there was no feasible solution was added at all - if len(feasible) == 0: - survivors = pop[infeasible[:n_survive]] - - # if there are feasible solutions in the population - else: - survivors = pop.new() - - # if feasible solution do exist - if len(feasible) > 0: - survivors = self._do(problem, pop[feasible], - min(len(feasible), n_survive), - **kwargs) - - # if infeasible solutions needs to be added - if len(survivors) < n_survive: - least_infeasible = infeasible[:n_survive - len(feasible)] - survivors = survivors.merge(pop[least_infeasible]) - - else: - survivors = self._do(problem, pop, n_survive, **kwargs) - - if return_indices: - H = {} - for k, ind in enumerate(pop): - H[ind] = k - return [H[survivor] for survivor in survivors] - else: - return survivors - - def _do(self, problem, pop, n_survive, D=None, **kwargs): - - # get the objective space values and objects - F = pop.get('F').astype(np.float, copy=False) - - # the final indices of surviving individuals - survivors = [] - - # do the non-dominated sorting until splitting front - fronts = NonDominatedSorting().do(F, n_stop_if_ranked=n_survive) - - for k, front in enumerate(fronts): - - # calculate the crowding distance of the front - crowding_of_front = calc_crowding_distance(F[front, :]) - # save rank and crowding in the individual class - for j, Index in enumerate(front): - pop[Index].set('rank', k) - pop[Index].set('crowding', crowding_of_front[j]) - - # current front sorted by crowding distance if splitting - if len(survivors) + len(front) > n_survive: - Index = randomized_argsort( - crowding_of_front, order='descending', method='numpy') - Index = Index[:(n_survive - len(survivors))] - - # otherwise take the whole front unsorted - else: - Index = np.arange(len(front)) - - # extend the survivors by all or selected individuals - survivors.extend(front[Index]) - - return pop[survivors] - - -class FitnessSurvival(Survival): - """Survival class for Fitness.""" - - def __init__(self) -> None: - super().__init__(True) - - def _do(self, problem, pop, n_survive, out=None, **kwargs): - F = pop.get('F') - - if F.shape[1] != 1: - raise ValueError( - 'FitnessSurvival can only used for single objective single!') - - return pop[np.argsort(F[:, 0])[:n_survive]] - - -def find_duplicates(X, epsilon=1e-16): - # calculate the distance matrix from each point to another - D = cdist(X, X) - - # set the diagonal to infinity - D[np.triu_indices(len(X))] = np.inf - - # set as duplicate if a point is really close to this one - is_duplicate = np.any(D < epsilon, axis=1) - - return is_duplicate - - -def calc_crowding_distance(F, filter_out_duplicates=True): - n_points, n_obj = F.shape - - if n_points <= 2: - return np.full(n_points, np.inf) - - else: - - if filter_out_duplicates: - # filter out solutions which are duplicates - is_unique = np.where( - np.logical_not(find_duplicates(F, epsilon=1e-24)))[0] - else: - # set every point to be unique without checking it - is_unique = np.arange(n_points) - - # index the unique points of the array - _F = F[is_unique] - - # sort each column and get index - Index = np.argsort(_F, axis=0, kind='mergesort') - - # sort the objective space values for the whole matrix - _F = _F[Index, np.arange(n_obj)] - - # calculate the distance from each point to the last and next - dist = np.row_stack([_F, np.full(n_obj, np.inf)]) - np.row_stack( - [np.full(n_obj, -np.inf), _F]) - - # calculate the norm for each objective - norm = np.max(_F, axis=0) - np.min(_F, axis=0) - norm[norm == 0] = np.nan - - # prepare the distance to last and next vectors - dist_to_last, dist_to_next = dist, np.copy(dist) - dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[ - 1:] / norm - - dist_to_last[np.isnan(dist_to_last)] = 0.0 - dist_to_next[np.isnan(dist_to_next)] = 0.0 - - # sum up the distance to next and last and norm by objectives - J = np.argsort(Index, axis=0) - _cd = np.sum( - dist_to_last[J, np.arange(n_obj)] + - dist_to_next[J, np.arange(n_obj)], - axis=1) / n_obj - - # save the final vector which sets the crowding distance for duplicates - crowding = np.zeros(n_points) - crowding[is_unique] = _cd - - # crowding[np.isinf(crowding)] = 1e+14 - return crowding - - -def randomized_argsort(A, method='numpy', order='ascending'): - if method == 'numpy': - P = np.random.permutation(len(A)) - Index = np.argsort(A[P], kind='quicksort') - Index = P[Index] - - elif method == 'quicksort': - Index = quicksort(A) - - else: - raise Exception('Randomized sort method not known.') - - if order == 'ascending': - return Index - elif order == 'descending': - return np.flip(Index, axis=0) - else: - raise Exception('Unknown sorting order: ascending or descending.') - - -def swap(M, a, b): - tmp = M[a] - M[a] = M[b] - M[b] = tmp - - -def quicksort(A): - Index = np.arange(len(A)) - _quicksort(A, Index, 0, len(A) - 1) - return Index - - -def _quicksort(A, Index, left, right): - if left < right: - - index = np.random.randint(left, right + 1) - swap(Index, right, index) - - pivot = A[Index[right]] - - Index = left - 1 - - for j in range(left, right): - - if A[Index[j]] <= pivot: - Index += 1 - swap(Index, Index, j) - - index = Index + 1 - swap(Index, right, index) - - _quicksort(A, Index, left, index - 1) - _quicksort(A, Index, index + 1, right) - - -def random_permuations(n, input): - perms = [] - for _ in range(n): - perms.append(np.random.permutation(input)) - P = np.concatenate(perms) - return P - - -def crossover_mask(X, M): - # convert input to output by flatting along the first axis - _X = np.copy(X) - _X[0][M] = X[1][M] - _X[1][M] = X[0][M] - - return _X - - -def at_least_2d_array(x, extend_as='row'): - if not isinstance(x, np.ndarray): - x = np.array([x]) - - if x.ndim == 1: - if extend_as == 'row': - x = x[None, :] - elif extend_as == 'column': - x = x[:, None] - - return x - - -def repair_out_of_bounds(problem, X): - xl, xu = problem.xl, problem.xu - - only_1d = (X.ndim == 1) - X = at_least_2d_array(X) - - if xl is not None: - xl = np.repeat(xl[None, :], X.shape[0], axis=0) - X[X < xl] = xl[X < xl] - - if xu is not None: - xu = np.repeat(xu[None, :], X.shape[0], axis=0) - X[X > xu] = xu[X > xu] - - if only_1d: - return X[0, :] - else: - return X - - -def denormalize(x, x_min, x_max): - - if x_max is None: - _range = 1 - else: - _range = (x_max - x_min) - - return x * _range + x_min - - -class Evaluator: - """The evaluator class which is used during the algorithm execution to - limit the number of evaluations.""" - - def __init__(self, evaluate_values_of=['F', 'CV', 'G']): - self.n_eval = 0 - self.evaluate_values_of = evaluate_values_of - - def eval(self, problem, pop, **kwargs): - """This function is used to return the result of one valid evaluation. - - Parameters - ---------- - problem : class - The problem which is used to be evaluated - pop : np.array or Population object - kwargs : dict - Additional arguments which might be necessary for the problem to - evaluate. - """ - - is_individual = isinstance(pop, Individual) - is_numpy_array = isinstance( - pop, np.ndarray) and not isinstance(pop, Population) - - # make sure the object is a population - if is_individual or is_numpy_array: - pop = Population().create(pop) - - # find indices to be evaluated - Index = [k for k in range(len(pop)) if pop[k].F is None] - - # update the function evaluation counter - self.n_eval += len(Index) - - # actually evaluate all solutions using the function - if len(Index) > 0: - self._eval(problem, pop[Index], **kwargs) - - # set the feasibility attribute if cv exists - for ind in pop[Index]: - cv = ind.get('CV') - if cv is not None: - ind.set('feasible', cv <= 0) - - if is_individual: - return pop[0] - elif is_numpy_array: - if len(pop) == 1: - pop = pop[0] - return tuple([pop.get(e) for e in self.evaluate_values_of]) - else: - return pop - - def _eval(self, problem, pop, **kwargs): - - out = problem.evaluate( - pop.get('X'), - return_values_of=self.evaluate_values_of, - return_as_dictionary=True, - **kwargs) - - for key, val in out.items(): - if val is None: - continue - else: - pop.set(key, val) diff --git a/mmrazor/models/task_modules/multi_object_optimizer/utils/selection.py b/mmrazor/models/task_modules/multi_object_optimizer/utils/selection.py deleted file mode 100644 index 3ec1c644e..000000000 --- a/mmrazor/models/task_modules/multi_object_optimizer/utils/selection.py +++ /dev/null @@ -1,490 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import math - -import numpy as np - -from .domin_matrix import get_relation -# from pymoo.core.population import Population -# from pymoo.util.misc import crossover_mask, random_permuations -# from pymoo.operators.repair.bounce_back import bounce_back_by_problem -from .helper import (Population, crossover_mask, random_permuations, - repair_out_of_bounds) - - -def binary_tournament(pop, P, algorithm, **kwargs): - if P.shape[1] != 2: - raise ValueError('Only implemented for binary tournament!') - - S = np.full(P.shape[0], np.nan) - - for Index in range(P.shape[0]): - - a, b = P[Index, 0], P[Index, 1] - - # if at least one solution is infeasible - if pop[a].CV > 0.0 or pop[b].CV > 0.0: - S[Index] = compare( - a, - pop[a].CV, - b, - pop[b].CV, - method='smaller_is_better', - return_random_if_equal=True) - else: - rel = get_relation(pop[a].F, pop[b].F) - if rel == 1: - S[Index] = a - elif rel == -1: - S[Index] = b - # if rank or domination relation didn't make a decision - if np.isnan(S[Index]): - S[Index] = compare( - a, - pop[a].get('crowding'), - b, - pop[b].get('crowding'), - method='larger_is_better', - return_random_if_equal=True) - - return S[:, None].astype(int, copy=False) - - -def comp_by_cv_and_fitness(pop, P, **kwargs): - S = np.full(P.shape[0], np.nan) - - for Index in range(P.shape[0]): - a, b = P[Index, 0], P[Index, 1] - - # if at least one solution is infeasible - if pop[a].CV > 0.0 or pop[b].CV > 0.0: - S[Index] = compare( - a, - pop[a].CV, - b, - pop[b].CV, - method='smaller_is_better', - return_random_if_equal=True) - - # both solutions are feasible just set random - else: - S[Index] = compare( - a, - pop[a].F, - b, - pop[b].F, - method='smaller_is_better', - return_random_if_equal=True) - - return S[:, None].astype(int) - - -def compare(a, a_val, b, b_val, method, return_random_if_equal=False): - if method == 'larger_is_better': - if a_val > b_val: - return a - elif a_val < b_val: - return b - else: - if return_random_if_equal: - return np.random.choice([a, b]) - else: - return None - elif method == 'smaller_is_better': - if a_val < b_val: - return a - elif a_val > b_val: - return b - else: - if return_random_if_equal: - return np.random.choice([a, b]) - else: - return None - - -class TournamentSelection: - """The Tournament selection is used to simulated a tournament between - individuals. - - The pressure balances greedy the genetic algorithm will be. - """ - - def __init__(self, pressure=2, func_comp='binary_tournament'): - """ - - Parameters - ---------- - func_comp: func - The function to compare two individuals. - It has the shape: comp(pop, indices) and returns the winner. - - pressure: int - The selection pressure to bie applied. - """ - - # selection pressure to be applied - self.pressure = pressure - if func_comp == 'comp_by_cv_and_fitness': - self.f_comp = comp_by_cv_and_fitness - else: - self.f_comp = binary_tournament - - def do(self, pop, n_select, n_parents=2, **kwargs): - # number of random individuals needed - n_random = n_select * n_parents * self.pressure - - # number of permutations needed - n_perms = math.ceil(n_random / len(pop)) - - # get random permutations and reshape them - P = random_permuations(n_perms, len(pop))[:n_random] - P = np.reshape(P, (n_select * n_parents, self.pressure)) - - # compare using tournament function - S = self.f_comp(pop, P, **kwargs) - - return np.reshape(S, (n_select, n_parents)) - - -class PointCrossover: - - def __init__(self, n_points=2, n_parents=2, n_offsprings=2, prob=0.9): - self.n_points = n_points - self.prob = prob - self.n_parents = n_parents - self.n_offsprings = n_offsprings - - def do(self, problem, pop, parents, **kwargs): - """ - - Parameters - ---------- - problem: class - The problem to be solved. - - pop : Population - The population as an object - - parents: numpy.array - The select parents of the population for the crossover - - kwargs : dict - Any additional data that might be necessary. - - Returns - ------- - offsprings : Population - The off as a matrix. n_children rows and the number of columns is - equal to the variable length of the problem. - - """ - - if self.n_parents != parents.shape[1]: - raise ValueError( - 'Exception during crossover: ' - 'Number of parents differs from defined at crossover.') - - # get the design space matrix form the population and parents - X = pop.get('X')[parents.T].copy() - - # now apply the crossover probability - do_crossover = np.random.random(len(parents)) < self.prob - - # execute the crossover - _X = self._do(problem, X, **kwargs) - - X[:, do_crossover, :] = _X[:, do_crossover, :] - - # flatten the array to become a 2d-array - X = X.reshape(-1, X.shape[-1]) - - # create a population object - off = pop.new('X', X) - - return off - - def _do(self, problem, X, **kwargs): - - # get the X of parents and count the matings - _, n_matings, n_var = X.shape - - # start point of crossover - r = np.row_stack([ - np.random.permutation(n_var - 1) + 1 for _ in range(n_matings) - ])[:, :self.n_points] - r.sort(axis=1) - r = np.column_stack([r, np.full(n_matings, n_var)]) - - # the mask do to the crossover - M = np.full((n_matings, n_var), False) - - # create for each individual the crossover range - for Index in range(n_matings): - - j = 0 - while j < r.shape[1] - 1: - a, b = r[Index, j], r[Index, j + 1] - M[Index, a:b] = True - j += 2 - - _X = crossover_mask(X, M) - - return _X - - -class PolynomialMutation: - - def __init__(self, eta=20, prob=None): - super().__init__() - self.eta = float(eta) - - if prob is not None: - self.prob = float(prob) - else: - self.prob = None - - def _do(self, problem, X, **kwargs): - - Y = np.full(X.shape, np.inf) - - if self.prob is None: - self.prob = 1.0 / problem.n_var - - do_mutation = np.random.random(X.shape) < self.prob - - Y[:, :] = X - - xl = np.repeat(problem.xl[None, :], X.shape[0], axis=0)[do_mutation] - xu = np.repeat(problem.xu[None, :], X.shape[0], axis=0)[do_mutation] - - X = X[do_mutation] - - delta1 = (X - xl) / (xu - xl) - delta2 = (xu - X) / (xu - xl) - - mut_pow = 1.0 / (self.eta + 1.0) - - rand = np.random.random(X.shape) - mask = rand <= 0.5 - mask_not = np.logical_not(mask) - - deltaq = np.zeros(X.shape) - - xy = 1.0 - delta1 - val = 2.0 * rand + (1.0 - 2.0 * rand) * ( - np.power(xy, (self.eta + 1.0))) - d = np.power(val, mut_pow) - 1.0 - deltaq[mask] = d[mask] - - xy = 1.0 - delta2 - val = 2.0 * (1.0 - rand) + 2.0 * (rand - 0.5) * ( - np.power(xy, (self.eta + 1.0))) - d = 1.0 - (np.power(val, mut_pow)) - deltaq[mask_not] = d[mask_not] - - # mutated values - _Y = X + deltaq * (xu - xl) - - # back in bounds if necessary (floating point issues) - _Y[_Y < xl] = xl[_Y < xl] - _Y[_Y > xu] = xu[_Y > xu] - - # set the values for output - Y[do_mutation] = _Y - - # in case out of bounds repair (very unlikely) - # Y = bounce_back_by_problem(problem, Y) - Y = repair_out_of_bounds(problem, Y) - - return Y - - def do(self, problem, pop, **kwargs): - Y = self._do(problem, pop.get('X'), **kwargs) - return pop.new('X', Y) - - -class IntegerFromFloatMutation: - - def __init__(self, **kwargs): - - self.mutation = PolynomialMutation(**kwargs) - - def _do(self, problem, X, **kwargs): - - def fun(): - return self.mutation._do(problem, X, **kwargs) - - # save the original bounds of the problem - _xl, _xu = problem.xl, problem.xu - - # copy the arrays of the problem and cast them to float - xl, xu = problem.xl, problem.xu - - # modify the bounds to match the new crossover specifications - problem.xl = xl - (0.5 - 1e-16) - problem.xu = xu + (0.5 - 1e-16) - - # perform the crossover - off = fun() - - # now round to nearest integer for all offsprings - off = np.rint(off) - - # reset the original bounds of the problem and design space values - problem.xl = _xl - problem.xu = _xu - - return off - - def do(self, problem, pop, **kwargs): - """Mutate variables in a genetic way. - - Parameters - ---------- - problem : class - The problem instance - pop : Population - A population object - - Returns - ------- - Y : Population - The mutated population. - """ - - Y = self._do(problem, pop.get('X'), **kwargs) - return pop.new('X', Y) - - -class Mating: - - def __init__(self, - selection, - crossover, - mutation, - repair=None, - eliminate_duplicates=None, - n_max_iterations=100): - - self.selection = selection - self.crossover = crossover - self.mutation = mutation - self.n_max_iterations = n_max_iterations - self.eliminate_duplicates = eliminate_duplicates - self.repair = repair - - def _do(self, problem, pop, n_offsprings, parents=None, **kwargs): - - # if the parents for the mating are not provided directly - if parents is None: - - # how many parents need to be select for the mating - n_select = math.ceil(n_offsprings / self.crossover.n_offsprings) - - # select the parents for the mating - just an index array - parents = self.selection.do(pop, n_select, - self.crossover.n_parents, **kwargs) - - # do the crossover using the parents index and the population - _off = self.crossover.do(problem, pop, parents, **kwargs) - - # do the mutation on the offsprings created through crossover - _off = self.mutation.do(problem, _off, **kwargs) - - return _off - - def do(self, problem, pop, n_offsprings, **kwargs): - - # the population object to be used - off = pop.new() - - # infill counter - # counts how often the mating needs to be done to fill up n_offsprings - n_infills = 0 - # iterate until enough offsprings are created - while len(off) < n_offsprings: - - # how many offsprings are remaining to be created - n_remaining = n_offsprings - len(off) - - # do the mating - _off = self._do(problem, pop, n_remaining, **kwargs) - - # repair the individuals if necessary - if self.repair: - _off = self.repair.do(problem, _off, **kwargs) - - if self.eliminate_duplicates is not None: - _off = self.eliminate_duplicates.do(_off, pop, off) - - # if more offsprings than necessary - truncate them randomly - if len(off) + len(_off) > n_offsprings: - n_remaining = n_offsprings - len(off) - _off = _off[:n_remaining] - - # add to the offsprings and increase the mating counter - off = off.merge(_off) - n_infills += 1 - - if n_infills > self.n_max_iterations: - break - - return off - - -class MySampling: - - def __init__(self): - pass - - def do(self, problem, n_samples, pop=Population(), **kwargs): - X = np.full((n_samples, problem.n_var), False, dtype=bool) - - for k in range(n_samples): - Index = np.random.permutation(problem.n_var)[:problem.n_max] - X[k, Index] = True - - if pop is None: - return X - return pop.new('X', X) - - -class BinaryCrossover(PointCrossover): - - def __init__(self): - super().__init__(n_parents=2, n_offsprings=1) - - def _do(self, problem, X, **kwargs): - n_parents, n_matings, n_var = X.shape - - _X = np.full((self.n_offsprings, n_matings, problem.n_var), False) - - for k in range(n_matings): - p1, p2 = X[0, k], X[1, k] - - both_are_true = np.logical_and(p1, p2) - _X[0, k, both_are_true] = True - - n_remaining = problem.n_max - np.sum(both_are_true) - - Index = np.where(np.logical_xor(p1, p2))[0] - - S = Index[np.random.permutation(len(Index))][:n_remaining] - _X[0, k, S] = True - - return _X - - -class MyMutation(PolynomialMutation): - - def _do(self, problem, X, **kwargs): - for Index in range(X.shape[0]): - X[Index, :] = X[Index, :] - is_false = np.where(np.logical_not(X[Index, :]))[0] - is_true = np.where(X[Index, :])[0] - try: - X[Index, np.random.choice(is_false)] = True - X[Index, np.random.choice(is_true)] = False - except ValueError: - pass - - return X diff --git a/mmrazor/models/task_modules/predictor/metric_predictor.py b/mmrazor/models/task_modules/predictor/metric_predictor.py index 1561f0776..796c3ac5b 100644 --- a/mmrazor/models/task_modules/predictor/metric_predictor.py +++ b/mmrazor/models/task_modules/predictor/metric_predictor.py @@ -142,11 +142,12 @@ def vector2model(self, vector: np.array) -> Dict[str, str]: start = 0 model = {} - for key, value in self.search_groups.items(): - if isinstance(value[0], OneShotMutableChannelUnit): - choices = value[0].candidate_choices + vector = np.squeeze(vector) + for name, mutables in self.search_groups.items(): + if isinstance(mutables[0], OneShotMutableChannelUnit): + choices = mutables[0].candidate_choices else: - choices = value[0].choices + choices = mutables[0].choices if self.encoding_type == 'onehot': index = np.where(vector[start:start + len(choices)] == 1)[0][0] @@ -155,8 +156,8 @@ def vector2model(self, vector: np.array) -> Dict[str, str]: index = vector[start] start += 1 - chosen = choices[int(index)] - model[key] = chosen + chosen = choices[int(index)] if len(choices) > 1 else choices[0] + model[name] = chosen return model From 2e7ce1e4f6faff3e1634a10a6a9bfffb785643b7 Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Fri, 17 Feb 2023 09:31:05 +0800 Subject: [PATCH 11/12] bump nsga2 search into pymoo_v0.41 due to stability --- .../engine/runner/evolution_search_loop.py | 46 +++++++++--- .../engine/runner/nsganetv2_search_loop.py | 65 +++++++++-------- .../utils/pymoo_utils/high_tradeoff_points.py | 60 +++++++--------- .../runner/utils/pymoo_utils/problems.py | 71 ++++++++++++++++++- mmrazor/models/algorithms/nas/nsganetv2.py | 1 + 5 files changed, 161 insertions(+), 82 deletions(-) diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index 2c6f6b6ff..fbc9cfe31 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -138,6 +138,7 @@ def __init__(self, self.model.mutator.search_groups self.predictor = TASK_UTILS.build(self.predictor_cfg) + self.finetune_cfg = finetune_cfg if finetune_cfg is not None: self.finetune_runner = self.build_finetune_runner(finetune_cfg) @@ -205,7 +206,6 @@ def sample_candidates(self) -> None: init_candidates = len(self.candidates) while len(self.candidates) < self.num_candidates: candidate = self.model.mutator.sample_choices() - self.finetune_step(self.model) is_pass, result = self._check_constraints(candidate) if is_pass: self.candidates.append(candidate) @@ -224,6 +224,8 @@ def update_candidates_scores(self) -> None: top-k candicates.""" for i, candidate in enumerate(self.candidates.subnets): self.model.mutator.set_choices(candidate) + if self.finetune_cfg: + self.finetune_step(self.model) metrics = self._val_candidate(use_predictor=self.use_predictor) score = round(metrics[self.score_key], 2) \ if len(metrics) != 0 else 0. @@ -448,15 +450,6 @@ def build_finetune_runner(self, finetune_cfg: Dict) -> Runner: runner = Runner.from_cfg(finetune_cfg) - runner._train_loop = runner.build_train_loop( - runner._train_loop) # type: ignore - - runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper) - runner.scale_lr(runner.optim_wrapper, runner.auto_scale_lr) - - runner.param_schedulers = runner.build_param_scheduler( # type: ignore - runner.param_schedulers) # type: ignore - from mmengine.hooks import CheckpointHook # remove CheckpointHook to avoid extra problems. @@ -468,4 +461,35 @@ def build_finetune_runner(self, finetune_cfg: Dict) -> Runner: return runner def finetune_step(self, model): - pass + """Fintune before candidates evaluation.""" + self.finetune_runner._train_loop = \ + self.finetune_runner.build_train_loop( + self.finetune_cfg.train_cfg) # type: ignore + + self.finetune_runner.optim_wrapper = \ + self.finetune_runner.build_optim_wrapper( + self.finetune_cfg.optim_wrapper) + + self.finetune_runner.scale_lr(self.finetune_runner.optim_wrapper, + self.finetune_runner.auto_scale_lr) + + self.finetune_runner.param_schedulers = \ + self.finetune_runner.build_param_scheduler( + self.finetune_cfg.param_scheduler) # type: ignore + + model.train() + self.runner.logger.info('Start finetuning...') + self.finetune_runner.model = model + self.finetune_runner.call_hook('before_run') + + self.finetune_runner.optim_wrapper.initialize_count_status( + self.finetune_runner.model, self.finetune_runner._train_loop.iter, + self.finetune_runner._train_loop.max_iters) + + self.model = self.finetune_runner.train_loop.run() + self.finetune_runner.train_loop._iter = 0 + self.finetune_runner.train_loop._epoch = 0 + + self.finetune_runner.call_hook('after_run') + self.runner.logger.info('End finetuning...') + model.eval() diff --git a/mmrazor/engine/runner/nsganetv2_search_loop.py b/mmrazor/engine/runner/nsganetv2_search_loop.py index 9eaa1f3d3..62f418abd 100644 --- a/mmrazor/engine/runner/nsganetv2_search_loop.py +++ b/mmrazor/engine/runner/nsganetv2_search_loop.py @@ -6,14 +6,16 @@ from mmengine import fileio try: - from pymoo.algorithms.moo.nsga2 import NSGA2 - from pymoo.algorithms.soo.nonconvex.ga import GA + from pymoo.algorithms.so_genetic_algorithm import GA + from pymoo.factory import get_algorithm, get_crossover, get_mutation from pymoo.optimize import minimize from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting except ImportError: from mmrazor.utils import get_placeholder - NSGA2 = get_placeholder('pymoo') GA = get_placeholder('pymoo') + get_algorithm = get_placeholder('pymoo') + get_crossover = get_placeholder('pymoo') + get_mutation = get_placeholder('pymoo') minimize = get_placeholder('pymoo') NonDominatedSorting = get_placeholder('pymoo') @@ -22,6 +24,8 @@ from .attentive_search_loop import AttentiveSearchLoop from .utils.pymoo_utils import (AuxiliarySingleLevelProblem, HighTradeoffPoints, SubsetProblem) +from .utils.pymoo_utils.problems import (BinaryCrossover, RandomMutation, + RandomSampling) @LOOPS.register_module() @@ -46,6 +50,7 @@ def run_epoch(self) -> None: for subnet, score, flops in zip( self.candidates.subnets, self.candidates.scores, self.candidates.resources('flops')): + self.update_candidates_scores() if self.trade_off['max_score_key'] != 0: score = self.trade_off['max_score_key'] - score self.archive.append(subnet) @@ -116,11 +121,13 @@ def sample_candidates_with_nsga2(self, archive: Candidates, problem = AuxiliarySingleLevelProblem(self, len(fronts[0])) # initiate a multi-objective solver to optimize the problem - method = NSGA2( - pop_size=4, - sampling=fronts, # initialize with current nd archs - eliminate_duplicates=True, - logger=self.runner.logger) + method = get_algorithm( + 'nsga2', + pop_size=40, + sampling=fronts, + crossover=get_crossover('int_two_point', prob=0.9), + mutation=get_mutation('int_pm', eta=1.0), + eliminate_duplicates=True) result = minimize(problem, method, ('n_gen', 4), seed=1, verbose=True) @@ -134,11 +141,16 @@ def sample_candidates_with_nsga2(self, archive: Candidates, sub_problem = SubsetProblem(result.pop[not_duplicate].get('F')[:, 1], F[front_index, 1], num_candidates) - sub_method = GA(pop_size=num_candidates, eliminate_duplicates=True) + sub_method = GA( + pop_size=100, + sampling=RandomSampling(), + crossover=BinaryCrossover(), + mutation=RandomMutation(), + eliminate_duplicates=True) sub_result = minimize( sub_problem, sub_method, ('n_gen', 2), seed=1, verbose=False) - indices = sub_result.pop.get('X') + indices = sub_result.X _X = result.pop.get('X')[not_duplicate][indices] @@ -155,7 +167,6 @@ def sort_candidates(self) -> None: '`self.trade_off` is required when sorting candidates in ' 'NSGA2SearchLoop. Got `self.trade_off` is None.') - ratio = self.trade_off.get('ratio', 1) max_score_key = self.trade_off.get('max_score_key', 100) max_score_key = np.array(max_score_key) @@ -171,7 +182,7 @@ def sort_candidates(self) -> None: sort_idx = np.argsort(patches[:, 0]) # type: ignore F = patches[sort_idx] - dm = HighTradeoffPoints(ratio, n_survive=len(patches)) + dm = HighTradeoffPoints(n_survive=len(patches)) candidate_index = dm.do(F) candidate_index = sort_idx[candidate_index] @@ -188,8 +199,13 @@ def _save_searcher_ckpt(self): rmse, rho, tau = 0, 0, 0 if len(self.archive) > 0: top1_err_pred = self.fit_predictor(self.archive) + + self.candidates = self.archive + self.use_predictor = False + self.update_candidates_scores() + rmse, rho, tau = self.predictor.get_correlation( - top1_err_pred, np.array(self.archive.scores)) + top1_err_pred, np.array(self.candidates.scores)) save_for_resume = dict() save_for_resume['_epoch'] = self._epoch @@ -238,28 +254,11 @@ def fit_predictor(self, candidates): metrics = [] for i, candidate in enumerate(candidates.subnets): self.model.mutator.set_choices(candidate) + self.finetune_step(self.model) metric = self._val_candidate(use_predictor=True) metrics.append(metric[self.score_key]) max_score_key = self.trade_off.get('max_score_key', 0.) - if max_score_key != 0: - for m in metrics: - m = max_score_key - m + assert max_score_key > 0. + metrics = max_score_key - np.array(metrics) return metrics - - def finetune_step(self, model): - """Fintune before candidates evaluation.""" - self.runner.logger.info('Start finetuning...') - self.finetune_runner.model = model - self.finetune_runner.call_hook('before_run') - - self.finetune_runner.optim_wrapper.initialize_count_status( - self.finetune_runner.model, self.finetune_runner._train_loop.iter, - self.finetune_runner._train_loop.max_iters) - - self.model = self.finetune_runner.train_loop.run() - self.finetune_runner.train_loop._iter = 0 - self.finetune_runner.train_loop._epoch = 0 - - self.finetune_runner.call_hook('after_run') - self.runner.logger.info('End finetuning...') diff --git a/mmrazor/engine/runner/utils/pymoo_utils/high_tradeoff_points.py b/mmrazor/engine/runner/utils/pymoo_utils/high_tradeoff_points.py index 38a627c18..021fa3e74 100644 --- a/mmrazor/engine/runner/utils/pymoo_utils/high_tradeoff_points.py +++ b/mmrazor/engine/runner/utils/pymoo_utils/high_tradeoff_points.py @@ -1,12 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np -from pymoo.config import Config -from pymoo.core.decision_making import (DecisionMaking, NeighborFinder, - find_outliers_upper_tail) -from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting -from pymoo.util.normalization import normalize -Config.warnings['not_compiled'] = False +try: + from pymoo.configuration import Configuration + from pymoo.model.decision_making import (DecisionMaking, NeighborFinder, + find_outliers_upper_tail, + normalize) + Configuration.show_compile_hint = False +except ImportError: + from mmrazor.utils import get_placeholder + DecisionMaking = get_placeholder('pymoo') + NeighborFinder = get_placeholder('pymoo') + normalize = get_placeholder('pymoo') + find_outliers_upper_tail = get_placeholder('pymoo') + Configuration = get_placeholder('pymoo') class HighTradeoffPoints(DecisionMaking): @@ -19,23 +26,20 @@ class HighTradeoffPoints(DecisionMaking): n_survive(int): how many high-tradeoff points will return finally. """ - def __init__(self, - ratio=1, - epsilon=0.125, - n_survive=None, - **kwargs) -> None: + def __init__(self, epsilon=0.125, n_survive=None, **kwargs) -> None: super().__init__(**kwargs) self.epsilon = epsilon - self.n_survive = n_survive - self.ratio = ratio - - def _do(self, data, **kwargs): - front = NonDominatedSorting().do(data, only_non_dominated_front=True) - F = data[front, :] + self.n_survive = n_survive # number of points to be selected + def _do(self, F, **kwargs): n, m = F.shape - F = normalize(F, self.ideal, self.nadir) - F[:, 1] = F[:, 1] * self.ratio + + if self.normalize: + F = normalize( + F, + self.ideal_point, + self.nadir_point, + estimate_bounds_if_none=True) neighbors_finder = NeighborFinder( F, epsilon=0.125, n_min_neigbors='auto', consider_2d=False) @@ -60,22 +64,8 @@ def _do(self, data, **kwargs): # otherwise find the one with the smalled one mu[i] = np.nanmin(tradeoff) - # if given topk if self.n_survive is not None: - n_survive = min(self.n_survive, len(mu)) - index = np.argsort(mu)[-n_survive:][::-1] - front_survive = front[index] - - self.n_survive -= n_survive - if self.n_survive == 0: - return front_survive - # in case the survived in front is not enough for topk - index = np.array(list(set(np.arange(len(data))) - set(front))) - unused_data = data[index] - no_front_survive = index[self._do(unused_data)] - - return np.concatenate([front_survive, no_front_survive]) + return np.argsort(mu)[-self.n_survive:] else: # return points with trade-off > 2*sigma - mu = find_outliers_upper_tail(mu) - return mu if len(mu) else [] + return find_outliers_upper_tail(mu) diff --git a/mmrazor/engine/runner/utils/pymoo_utils/problems.py b/mmrazor/engine/runner/utils/pymoo_utils/problems.py index afe831853..38fcf5fa1 100644 --- a/mmrazor/engine/runner/utils/pymoo_utils/problems.py +++ b/mmrazor/engine/runner/utils/pymoo_utils/problems.py @@ -1,14 +1,79 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np -from pymoo.core.problem import Problem + +try: + from pymoo.model.crossover import Crossover + from pymoo.model.mutation import Mutation + from pymoo.model.problem import Problem + from pymoo.model.sampling import Sampling +except ImportError: + from mmrazor.utils import get_placeholder + Crossover = get_placeholder('pymoo') + Mutation = get_placeholder('pymoo') + Problem = get_placeholder('pymoo') + Sampling = get_placeholder('pymoo') + + +class RandomSampling(Sampling): + + def _do(self, problem, n_samples, **kwargs): + X = np.full((n_samples, problem.n_var), False, dtype=np.bool) + + for k in range(n_samples): + index = np.random.permutation(problem.n_var)[:problem.n_max] + X[k, index] = True + + return X + + +class BinaryCrossover(Crossover): + + def __init__(self): + super().__init__(2, 1) + + def _do(self, problem, X, **kwargs): + n_parents, n_matings, n_var = X.shape + + _X = np.full((self.n_offsprings, n_matings, problem.n_var), False) + + for k in range(n_matings): + p1, p2 = X[0, k], X[1, k] + + both_are_true = np.logical_and(p1, p2) + _X[0, k, both_are_true] = True + + n_remaining = problem.n_max - np.sum(both_are_true) + + index = np.where(np.logical_xor(p1, p2))[0] + + S = index[np.random.permutation(len(index))][:n_remaining] + _X[0, k, S] = True + + return _X + + +class RandomMutation(Mutation): + + def _do(self, problem, X, **kwargs): + for i in range(X.shape[0]): + X[i, :] = X[i, :] + is_false = np.where(np.logical_not(X[i, :]))[0] + is_true = np.where(X[i, :])[0] + try: + X[i, np.random.choice(is_false)] = True + X[i, np.random.choice(is_true)] = False + except ValueError: + pass + + return X class AuxiliarySingleLevelProblem(Problem): """The optimization problem for finding the next N candidate architectures.""" - def __init__(self, search_loop, dim=15, sec_obj='flops'): - super().__init__(n_var=dim, n_obj=2, vtype=np.int32) + def __init__(self, search_loop, n_var=15, sec_obj='flops'): + super().__init__(n_var=n_var, n_obj=2, n_constr=0, type_var=np.int) self.search_loop = search_loop self.predictor = self.search_loop.predictor diff --git a/mmrazor/models/algorithms/nas/nsganetv2.py b/mmrazor/models/algorithms/nas/nsganetv2.py index 8e191e5b7..fad3d99d0 100644 --- a/mmrazor/models/algorithms/nas/nsganetv2.py +++ b/mmrazor/models/algorithms/nas/nsganetv2.py @@ -72,6 +72,7 @@ def __init__(self, self.mutator = self._build_mutator(mutator) self.mutator.prepare_from_supernet(self.architecture) + self.sample_kinds = ['max', 'min'] self.is_supernet = True self.drop_path_rate = drop_path_rate From bd9384ccd688ad74d8190123295198ea7a6efed8 Mon Sep 17 00:00:00 2001 From: gaoyang07 <1546308416@qq.com> Date: Fri, 17 Feb 2023 09:35:04 +0800 Subject: [PATCH 12/12] fix lr during finetune step --- .../nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py | 4 +++- .../mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py index c731c5486..9d0991195 100644 --- a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_cifar10.py @@ -6,6 +6,8 @@ val_dataloader = dict(batch_size=256) test_dataloader = val_dataloader +param_scheduler = [dict(type='ConstantLR', factor=1.0)] + train_cfg = dict( _delete_=True, type='mmrazor.NSGA2SearchLoop', @@ -36,7 +38,7 @@ momentum=0.9, weight_decay=1e-4, nesterov=True)), - param_scheduler=_base_.param_scheduler, + param_scheduler=param_scheduler, default_hooks=_base_.default_hooks, ), ) diff --git a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py index 1a191bea2..969b9a663 100644 --- a/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py +++ b/configs/nas/mmcls/nsganetv2/nsganetv2_mobilenet_search_8xb256_in1k.py @@ -2,6 +2,8 @@ model = dict(norm_training=True) +param_scheduler = [dict(type='ConstantLR', factor=1.0)] + train_cfg = dict( _delete_=True, type='mmrazor.NSGA2SearchLoop', @@ -32,7 +34,7 @@ momentum=0.9, weight_decay=3e-4, nesterov=True)), - param_scheduler=_base_.param_scheduler, + param_scheduler=param_scheduler, default_hooks=_base_.default_hooks, ), )