diff --git a/configs/nas/mmcls/spos/spos_shufflenet_search_nsga2_predictor_8xb128_in1k.py b/configs/nas/mmcls/spos/spos_shufflenet_search_nsga2_predictor_8xb128_in1k.py new file mode 100644 index 000000000..a2ab4cf07 --- /dev/null +++ b/configs/nas/mmcls/spos/spos_shufflenet_search_nsga2_predictor_8xb128_in1k.py @@ -0,0 +1,22 @@ +_base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py'] + +model = dict(norm_training=True) + +train_cfg = dict( + _delete_=True, + type='mmrazor.NSGA2SearchLoop', + dataloader=_base_.val_dataloader, + evaluator=_base_.val_evaluator, + max_epochs=20, + num_candidates=50, + top_k=10, + num_mutation=25, + num_crossover=25, + mutate_prob=0.1, + constraints_range=dict(flops=(0., 360.)), + predictor_cfg=dict( + type='mmrazor.MetricPredictor', + encoding_type='normal', + train_samples=2, + handler_cfg=dict(type='mmrazor.GaussProcessHandler')), +) diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index f2df86a83..973b9c761 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -4,12 +4,13 @@ from .runner import (AutoSlimValLoop, DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop, EvolutionSearchLoop, GreedySamplerTrainLoop, SelfDistillValLoop, - SingleTeacherDistillValLoop, SlimmableValLoop) + SingleTeacherDistillValLoop, SlimmableValLoop, + NSGA2SearchLoop) __all__ = [ 'SeparateOptimWrapperConstructor', 'DumpSubnetHook', 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'EstimateResourcesHook', - 'SelfDistillValLoop' + 'SelfDistillValLoop', 'NSGA2SearchLoop' ] diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 9715a4e6b..6f1befa3b 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -3,11 +3,13 @@ from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop from .evolution_search_loop import EvolutionSearchLoop +from .nsganetv2_search_loop import NSGA2SearchLoop from .slimmable_val_loop import SlimmableValLoop from .subnet_sampler_loop import GreedySamplerTrainLoop __all__ = [ 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', - 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop' + 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop', + 'NSGA2SearchLoop' ] diff --git a/mmrazor/engine/runner/attentive_search_loop.py b/mmrazor/engine/runner/attentive_search_loop.py new file mode 100644 index 000000000..ceb830a60 --- /dev/null +++ b/mmrazor/engine/runner/attentive_search_loop.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.registry import LOOPS +from .evolution_search_loop import EvolutionSearchLoop + + +@LOOPS.register_module() +class AttentiveSearchLoop(EvolutionSearchLoop): + """Loop for evolution searching with attentive tricks from AttentiveNAS. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + max_epochs (int): Total searching epochs. Defaults to 20. + max_keep_ckpts (int): The maximum checkpoints of searcher to keep. + Defaults to 3. + resume_from (str, optional): Specify the path of saved .pkl file for + resuming searching. + num_candidates (int): The length of candidate pool. Defaults to 50. + top_k (int): Specify top k candidates based on scores. Defaults to 10. + num_mutation (int): The number of candidates got by mutation. + Defaults to 25. + num_crossover (int): The number of candidates got by crossover. + Defaults to 25. + mutate_prob (float): The probability of mutation. Defaults to 0.1. + flops_range (tuple, optional): It is used for screening candidates. + resource_estimator_cfg (dict): The config for building estimator, which + is be used to estimate the flops of sampled subnet. Defaults to + None, which means default config is used. + score_key (str): Specify one metric in evaluation results to score + candidates. Defaults to 'accuracy_top-1'. + init_candidates (str, optional): The candidates file path, which is + used to init `self.candidates`. Its format is usually in .yaml + format. Defaults to None. + """ + + def _init_pareto(self): + # TODO (gaoyang): Fix apis with mmrazor2.0 + for k, v in self.constraints.items(): + if not isinstance(v, (list, tuple)): + self.constraints[k] = (0, v) + + assert len(self.constraints) == 1, 'Only accept one kind constrain.' + self.pareto_candidates = dict() + constraints = list(self.constraints.items())[0] + discretize_step = self.pareto_mode['discretize_step'] + ds = discretize_step + # find the left bound + while ds + 0.5 * discretize_step < constraints[1][0]: + ds += discretize_step + self.pareto_candidates[ds] = [] + # find the right bound + while ds - 0.5 * discretize_step < constraints[1][1]: + self.pareto_candidates[ds] = [] + ds += discretize_step diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index fc907f3aa..6d62d99fa 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -96,6 +96,7 @@ def __init__(self, self.crossover_prob = crossover_prob self.max_keep_ckpts = max_keep_ckpts self.resume_from = resume_from + self.trade_off = dict(max_score_key=40) if init_candidates is None: self.candidates = Candidates() diff --git a/mmrazor/engine/runner/nsganetv2_search_loop.py b/mmrazor/engine/runner/nsganetv2_search_loop.py new file mode 100644 index 000000000..73aadd4f6 --- /dev/null +++ b/mmrazor/engine/runner/nsganetv2_search_loop.py @@ -0,0 +1,253 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from copy import deepcopy + +import numpy as np +from mmengine import fileio +from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting + +from mmrazor.models.task_modules import (GeneticOptimizer, + NSGA2Optimizer, + AuxiliarySingleLevelProblem, + SubsetProblem) +from mmrazor.registry import LOOPS +from mmrazor.structures import Candidates, export_fix_subnet +from .attentive_search_loop import AttentiveSearchLoop +from .utils.high_tradeoff_points import HighTradeoffPoints + +# from pymoo.algorithms.moo.nsga2 import NSGA2 as NSGA2Optimizer +# from pymoo.algorithms.soo.nonconvex.ga import GA as GeneticOptimizer +# from pymoo.optimize import minimize + + +@LOOPS.register_module() +class NSGA2SearchLoop(AttentiveSearchLoop): + """Evolution search loop with NSGA2 optimizer.""" + + def run_epoch(self) -> None: + """Iterate one epoch. + + Steps: + 0. Collect archives and predictor. + 1. Sample some new candidates from the supernet.Then Append them + to the candidates, Thus make its number equal to the specified + number. + 2. Validate these candidates(step 1) and update their scores. + 3. Pick the top k candidates based on the scores(step 2), which + will be used in mutation and crossover. + 4. Implement Mutation and crossover, generate better candidates. + """ + archive = Candidates() + for subnet, score, flops in zip(self.candidates.subnets, + self.candidates.scores, + self.candidates.resources('flops')): + if self.trade_off['max_score_key'] != 0: + score = self.trade_off['max_score_key'] - score + archive.append(subnet) + archive.set_score(-1, score) + archive.set_resource(-1, flops, 'flops') + + self.sample_candidates(random=(self._epoch == 0), archive=archive) + self.update_candidates_scores() + + scores_before = self.top_k_candidates.scores + self.runner.logger.info(f'top k scores before update: ' + f'{scores_before}') + + self.candidates.extend(self.top_k_candidates) + self.sort_candidates() + self.top_k_candidates = Candidates(self.candidates[:self.top_k]) + + scores_after = self.top_k_candidates.scores + self.runner.logger.info(f'top k scores after update: ' + f'{scores_after}') + + mutation_candidates = self.gen_mutation_candidates() + self.candidates_mutator_crossover = Candidates(mutation_candidates) + crossover_candidates = self.gen_crossover_candidates() + self.candidates_mutator_crossover.extend(crossover_candidates) + + assert len(self.candidates_mutator_crossover + ) <= self.num_candidates, 'Total of mutation and \ + crossover should be less than the number of candidates.' + + self.candidates = self.candidates_mutator_crossover + self._epoch += 1 + + def sample_candidates(self, random: bool = True, archive=None) -> None: + if random: + super().sample_candidates() + else: + candidates = self.sample_candidates_with_nsga2( + archive, self.num_candidates) + new_candidates = [] + candidates_resources = [] + for candidate in candidates: + is_pass, result = self._check_constraints(candidate) + if is_pass: + new_candidates.append(candidate) + candidates_resources.append(result) + self.candidates = Candidates(new_candidates) + + if len(candidates_resources) > 0: + self.candidates.update_resources( + candidates_resources, + start=len(self.candidates.data)-len(candidates_resources)) + + def sample_candidates_with_nsga2(self, archive: Candidates, num_candidates): + """Searching for candidates with high-fidelity evaluation.""" + F = np.column_stack((archive.scores, archive.resources('flops'))) + front_index = NonDominatedSorting().do(F, only_non_dominated_front=True) + + fronts = np.array(archive.subnets)[front_index] + fronts = np.array([self.predictor.model2vector(cand) for cand in fronts]) + fronts = self.predictor.preprocess(fronts) + + # initialize the candidate finding optimization problem + problem = AuxiliarySingleLevelProblem(self, len(fronts[0])) + + # initiate a multi-objective solver to optimize the problem + method = NSGA2Optimizer( + pop_size=4, + sampling=fronts, # initialize with current nd archs + eliminate_duplicates=True, + logger=self.runner.logger) + + # # kick-off the search + method.initialize(problem, n_gen=2, verbose=True) + result = method.solve() + + # check for duplicates + check_list = [] + for x in result['pop'].get('X'): + assert x is not None + check_list.append(self.predictor.vector2model(x)) + + not_duplicate = np.logical_not( + [x in archive.subnets for x in check_list]) + + # extra process after nsga2 search + sub_problem = SubsetProblem(result['pop'][not_duplicate].get('F')[:, 1], + F[front_index, 1], + num_candidates) + sub_method = GeneticOptimizer(pop_size=num_candidates, + eliminate_duplicates=True) + sub_method.initialize(sub_problem, n_gen=4, verbose=False) + indices = sub_method.solve()['X'] + + candidates = Candidates() + pop = result['pop'][not_duplicate][indices] + for x in pop.get('X'): + candidates.append(self.predictor.vector2model(x)) + + return candidates + + def sort_candidates(self) -> None: + """Support sort candidates in single and multiple-obj optimization.""" + assert self.trade_off is not None, ( + '`self.trade_off` is required when sorting candidates in ' + 'NSGA2SearchLoop. Got self.trade_off is None.') + ratio = self.trade_off.get('ratio', 1) + multiple_obj_score = [] + for score, flops in zip(self.candidates.scores, + self.candidates.resources('flops')): + multiple_obj_score.append((score, flops)) + multiple_obj_score = np.array(multiple_obj_score) + max_score_key = self.trade_off.get('max_score_key', 100) + if max_score_key != 0: + multiple_obj_score[:, 0] = \ + max_score_key - multiple_obj_score[:, 0] + sort_idx = np.argsort(multiple_obj_score[:, 0]) + F = multiple_obj_score[sort_idx] + dm = HighTradeoffPoints(ratio, n_survive=len(multiple_obj_score)) + candidate_index = dm.do(F) + candidate_index = sort_idx[candidate_index] + self.candidates = [self.candidates[idx] for idx in candidate_index] + + def _save_searcher_ckpt(self, archive=[]): + """Save searcher ckpt, which is different from common ckpt. + + It mainly contains the candicate pool, the top-k candicates with scores + and the current epoch. + """ + if self.runner.rank == 0: + rmse, rho, tau = 0, 0, 0 + if len(archive) > 0: + top1_err_pred = self.fit_predictor(archive) + rmse, rho, tau = self.predictor.get_correlation( + top1_err_pred, np.array([x[1] for x in archive])) + + save_for_resume = dict() + save_for_resume['_epoch'] = self._epoch + for k in ['candidates', 'top_k_candidates']: + save_for_resume[k] = getattr(self, k) + fileio.dump( + save_for_resume, + osp.join(self.runner.work_dir, + f'search_epoch_{self._epoch}.pkl')) + + correlation_str = 'fitting ' + # correlation_str += f'{self.predictor.type}: ' + correlation_str += f'RMSE = {rmse:.4f}, ' + correlation_str += f'Spearmans Rho = {rho:.4f}, ' + correlation_str += f'num_candidatesendalls Tau = {tau:.4f}' + + self.pareto_mode = False + if self.pareto_mode: + step_str = '\n' + for step, candidates in self.pareto_candidates.items(): + if len(candidates) > 0: + step_str += f'step: {step}: ' + step_str += f'{candidates[0][self.score_key]}\n' + self.runner.logger.info( + f'Epoch:[{self._epoch + 1}/{self._max_epochs}], ' + f'top1_score: {step_str} ' + f'{correlation_str}') + else: + self.runner.logger.info( + f'Epoch:[{self._epoch + 1}/{self._max_epochs}], ' + f'top1_score: {self.top_k_candidates.scores[0]} ' + f'{correlation_str}') + + def fit_predictor(self, candidates): + """anticipate testfn training(err rate).""" + inputs = [export_fix_subnet(x) for x in candidates.subnets] + inputs = np.array([self.predictor.model2vector(x) for x in inputs]) + + targets = np.array([x[1] for x in candidates]) + + if not self.predictor.pretrained: + self.predictor.fit(inputs, targets) + + metrics = self.predictor.predict(inputs) + if self.max_score_key != 0: + for i in range(len(metrics)): + metrics[i] = self.max_score_key - metrics[i] + return metrics + + def finetune_step(self, model): + """fintune before candidates evaluation.""" + # TODO (gaoyang): update with 2.0 version. + self.runner.logger.info('start finetuning...') + model.train() + while self._fintune_epoch < self._max_finetune_epochs: + self.runner.call_hook('before_train_epoch') + for idx, data_batch in enumerate(self.dataloader): + self.runner.call_hook( + 'before_train_iter', + batch_idx=idx, + data_batch=data_batch) + + outputs = model.train_step( + data_batch, optim_wrapper=self.optim_wrapper) + + self.runner.call_hook( + 'after_train_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) + + self.runner.call_hook('after_train_epoch') + self._finetune_epoch += 1 + + model.eval() diff --git a/mmrazor/engine/runner/utils/high_tradeoff_points.py b/mmrazor/engine/runner/utils/high_tradeoff_points.py new file mode 100644 index 000000000..38a627c18 --- /dev/null +++ b/mmrazor/engine/runner/utils/high_tradeoff_points.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from pymoo.config import Config +from pymoo.core.decision_making import (DecisionMaking, NeighborFinder, + find_outliers_upper_tail) +from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting +from pymoo.util.normalization import normalize + +Config.warnings['not_compiled'] = False + + +class HighTradeoffPoints(DecisionMaking): + """Method for multi-object optimization. + + Args: + ratio(float): weight between score_key and sec_obj, details in + demo/nas/demo.ipynb. + epsilon(float): specific a radius for each neighbour. + n_survive(int): how many high-tradeoff points will return finally. + """ + + def __init__(self, + ratio=1, + epsilon=0.125, + n_survive=None, + **kwargs) -> None: + super().__init__(**kwargs) + self.epsilon = epsilon + self.n_survive = n_survive + self.ratio = ratio + + def _do(self, data, **kwargs): + front = NonDominatedSorting().do(data, only_non_dominated_front=True) + F = data[front, :] + + n, m = F.shape + F = normalize(F, self.ideal, self.nadir) + F[:, 1] = F[:, 1] * self.ratio + + neighbors_finder = NeighborFinder( + F, epsilon=0.125, n_min_neigbors='auto', consider_2d=False) + + mu = np.full(n, -np.inf) + + for i in range(n): + + # for each neighbour in a specific radius of that solution + neighbors = neighbors_finder.find(i) + + # calculate the trade-off to all neighbours + diff = F[neighbors] - F[i] + + # calculate sacrifice and gain + sacrifice = np.maximum(0, diff).sum(axis=1) + gain = np.maximum(0, -diff).sum(axis=1) + + np.warnings.filterwarnings('ignore') + tradeoff = sacrifice / gain + + # otherwise find the one with the smalled one + mu[i] = np.nanmin(tradeoff) + + # if given topk + if self.n_survive is not None: + n_survive = min(self.n_survive, len(mu)) + index = np.argsort(mu)[-n_survive:][::-1] + front_survive = front[index] + + self.n_survive -= n_survive + if self.n_survive == 0: + return front_survive + # in case the survived in front is not enough for topk + index = np.array(list(set(np.arange(len(data))) - set(front))) + unused_data = data[index] + no_front_survive = index[self._do(unused_data)] + + return np.concatenate([front_survive, no_front_survive]) + else: + # return points with trade-off > 2*sigma + mu = find_outliers_upper_tail(mu) + return mu if len(mu) else [] diff --git a/mmrazor/models/algorithms/nas/nsganetv2.py b/mmrazor/models/algorithms/nas/nsganetv2.py new file mode 100644 index 000000000..8eaad2f38 --- /dev/null +++ b/mmrazor/models/algorithms/nas/nsganetv2.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement +from torch import nn +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.models.distillers import ConfigurableDistiller +from mmrazor.models.mutators.base_mutator import BaseMutator +from mmrazor.models.mutators import OneShotModuleMutator +from mmrazor.registry import MODELS +from mmrazor.structures.subnet.fix_subnet import load_fix_subnet +from mmrazor.utils import SingleMutatorRandomSubnet, ValidFixMutable +from ..base import BaseAlgorithm, LossResults + +VALID_MUTATOR_TYPE = Union[BaseMutator, Dict] +VALID_MUTATORS_TYPE = Dict[str, Union[BaseMutator, Dict]] +VALID_DISTILLER_TYPE = Union[ConfigurableDistiller, Dict] + + +@MODELS.register_module() +class NSGANetV2(BaseAlgorithm): + """ + + """ + + # TODO fix ea's name in doc-string. + + def __init__(self, + architecture: Union[BaseModel, Dict], + mutator: VALID_MUTATORS_TYPE, + # distiller: VALID_DISTILLER_TYPE, + # norm_training: bool = False, + fix_subnet: Optional[ValidFixMutable] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None, + drop_prob: float = 0.2): + super().__init__(architecture, data_preprocessor, init_cfg) + + if fix_subnet: + # Avoid circular import + from mmrazor.structures import load_fix_subnet + + # According to fix_subnet, delete the unchosen part of supernet + load_fix_subnet(self.architecture, fix_subnet) + self.is_supernet = False + else: + # Mutator is an essential component of the NAS algorithm. It + # provides some APIs commonly used by NAS. + # Before using it, you must do some preparations according to + # the supernet. + self.mutator.prepare_from_supernet(self.architecture) + self.is_supernet = True + + self.drop_prob = drop_prob + + def _build_mutator(self, mutator: VALID_MUTATOR_TYPE) -> BaseMutator: + """build mutator.""" + assert mutator is not None, \ + 'mutator cannot be None when fix_subnet is None.' + if isinstance(mutator, OneShotModuleMutator): + self.mutator = mutator + elif isinstance(mutator, dict): + self.mutator = MODELS.build(mutator) + else: + raise TypeError('mutator should be a `dict` or ' + f'`OneShotModuleMutator` instance, but got ' + f'{type(mutator)}') + return mutator + + def sample_subnet(self) -> SingleMutatorRandomSubnet: + """Random sample subnet by mutator.""" + return self.mutator.sample_choices() + + def set_subnet(self, subnet: SingleMutatorRandomSubnet): + """Set the subnet sampled by :meth:sample_subnet.""" + self.mutator.set_choices(subnet) + + def loss( + self, + batch_inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + ) -> LossResults: + """Calculate losses from a batch of inputs and data samples.""" + if self.is_supernet: + random_subnet = self.sample_subnet() + self.set_subnet(random_subnet) + return self.architecture(batch_inputs, data_samples, mode='loss') + else: + return self.architecture(batch_inputs, data_samples, mode='loss') + + def train(self, mode=True): + """Convert the model into eval mode while keep normalization layer + unfreezed.""" + + super().train(mode) + if self.norm_training and not mode: + for module in self.architecture.modules(): + if isinstance(module, _BatchNorm): + module.training = True diff --git a/mmrazor/models/task_modules/__init__.py b/mmrazor/models/task_modules/__init__.py index b86bebbb9..4d152d383 100644 --- a/mmrazor/models/task_modules/__init__.py +++ b/mmrazor/models/task_modules/__init__.py @@ -4,5 +4,6 @@ from .predictor import * # noqa: F401,F403 from .recorder import * # noqa: F401,F403 from .tracer import * # noqa: F401,F403 +from .multi_object_optimizer import * # noqa: F401,F403 __all__ = ['ResourceEstimator'] diff --git a/mmrazor/models/task_modules/multi_object_optimizer/__init__.py b/mmrazor/models/task_modules/multi_object_optimizer/__init__.py new file mode 100644 index 000000000..fc985e92c --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .genetic_optimizer import GeneticOptimizer +from .nsga2_optimizer import NSGA2Optimizer +from .problem import AuxiliarySingleLevelProblem, SubsetProblem + +__all__ = [ + 'AuxiliarySingleLevelProblem', 'SubsetProblem', + 'GeneticOptimizer', 'NSGA2Optimizer' +] \ No newline at end of file diff --git a/mmrazor/models/task_modules/multi_object_optimizer/base_optimizer.py b/mmrazor/models/task_modules/multi_object_optimizer/base_optimizer.py new file mode 100644 index 000000000..796fa37a2 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/base_optimizer.py @@ -0,0 +1,210 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# copied and modified from https://github.com/anyoptimization/pymoo +from abc import abstractmethod + +import numpy as np +from pymoo.util.optimum import filter_optimum + +# from pymoo.core.evaluator import Evaluator +# from pymoo.core.population import Population +from .utils.helper import Evaluator, Population + + +class BaseOptimizer(): + """This class represents the abstract class for any algorithm to be + implemented. The solve method provides a wrapper function which does + validate the input. + + Args: + problem : + Problem to be solved by the algorithm + verbose (bool): + If true information during the algorithm execution are displayed + save_history (bool): + If true, a current snapshot of each generation is saved. + pf (numpy.array): + The Pareto-front for the given problem. If provided performance + metrics are printed during execution. + return_least_infeasible (bool): + Whether the algorithm should return the least infeasible solution, + if no solution was found. + evaluator : :class:`~pymoo.model.evaluator.Evaluator` + The evaluator which can be used to make modifications before + calling the evaluate function of a problem. + """ + + def __init__(self, **kwargs): + # ! + # DEFAULT SETTINGS OF ALGORITHM + # ! + # set the display variable supplied to the algorithm + self.display = kwargs.get('display') + self.logger = kwargs.get('logger') + # ! + # Attributes to be set later on for each problem run + # ! + # the optimization problem as an instance + self.problem = None + + self.return_least_infeasible = None + # whether the history should be saved or not + self.save_history = None + # whether the algorithm should print output in this run or not + self.verbose = None + # the random seed that was used + self.seed = None + # the pareto-front of the problem - if it exist or passed + self.pf = None + # the function evaluator object (can be used to inject code) + self.evaluator = None + # the current number of generation or iteration + self.n_gen = None + # the history object which contains the list + self.history = None + # the current solutions stored - here considered as population + self.pop = None + # the optimum found by the algorithm + self.opt = None + # can be used to store additional data in submodules + self.data = {} + + def initialize( + self, + problem, + pf=True, + evaluator=None, + # START Default minimize + seed=None, + verbose=False, + save_history=False, + return_least_infeasible=False, + # END Default minimize + n_gen=1, + display=None, + # END Overwrite by minimize + **kwargs): + + # set the problem that is optimized for the current run + self.problem = problem + + # set the provided pareto front + self.pf = pf + + # by default make sure an evaluator exists if nothing is passed + if evaluator is None: + evaluator = Evaluator() + self.evaluator = evaluator + + # ! + # START Default minimize + # ! + # if this run should be verbose or not + self.verbose = verbose + # whether the least infeasible should be returned or not + self.return_least_infeasible = return_least_infeasible + # whether the history should be stored or not + self.save_history = save_history + + # set the random seed in the algorithm object + self.seed = seed + if self.seed is None: + self.seed = np.random.randint(0, 10000000) + np.random.seed(self.seed) + # ! + # END Default minimize + # ! + + if display is not None: + self.display = display + + # other run dependent variables that are reset + self.n_gen = n_gen + self.history = [] + self.pop = Population() + self.opt = None + + def solve(self): + + # the result object to be finally returned + res = {} + + # initialize the first population and evaluate it + self._initialize() + self._set_optimum() + + self.current_gen = 0 + # while termination criterion not fulfilled + while self.current_gen < self.n_gen: + self.current_gen += 1 + self.next() + + # store the resulting population + res['pop'] = self.pop + + # get the optimal solution found + opt = self.opt + + # if optimum is not set + if len(opt) == 0: + opt = None + + # if no feasible solution has been found + elif not np.any(opt.get('feasible')): + if self.return_least_infeasible: + opt = filter_optimum(opt, least_infeasible=True) + else: + opt = None + + # set the optimum to the result object + res['opt'] = opt + + # if optimum is set to none to not report anything + if opt is None: + X, F, CV, G = None, None, None, None + + # otherwise get the values from the population + else: + X, F, CV, G = self.opt.get('X', 'F', 'CV', 'G') + + # if single-objective problem and only one solution was found + if self.problem.n_obj == 1 and len(X) == 1: + X, F, CV, G = X[0], F[0], CV[0], G[0] + + # set all the individual values + res['X'], res['F'], res['CV'], res['G'] = X, F, CV, G + + # create the result object + res['problem'], res['pf'] = self.problem, self.pf + res['history'] = self.history + + return res + + def next(self): + # call next of the implementation of the algorithm + self._next() + + # set the optimum - only done if the algorithm did not do it yet + self._set_optimum() + + # do what needs to be done each generation + self._each_iteration() + + # method that is called each iteration to call some algorithms regularly + def _each_iteration(self): + + # display the output if defined by the algorithm + if self.logger: + self.logger.info(f'Generation:[{self.current_gen}/{self.n_gen}] ' + f'evaluate {self.evaluator.n_eval} solutions, ' + f'find {len(self.opt)} optimal solution.') + + def _finalize(self): + pass + + @abstractmethod + def _initialize(self): + pass + + @abstractmethod + def _next(self): + pass diff --git a/mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py b/mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py new file mode 100644 index 000000000..7af88ec79 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/genetic_optimizer.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# copied and modified from https://github.com/anyoptimization/pymoo +import numpy as np +from pymoo.algorithms.soo.nonconvex.ga import FitnessSurvival +from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting + +from mmrazor.registry import TASK_UTILS +from .nsga2_optimizer import NSGA2Optimizer +from .utils.helper import Individual, Population +from .utils.selection import (BinaryCrossover, IntegerFromFloatMutation, Mating, + MyMutation, MySampling, PointCrossover, + TournamentSelection) + + + +@TASK_UTILS.register_module() +class GeneticOptimizer(NSGA2Optimizer): + """Genetic Algorithm.""" + + def __init__(self, + pop_size=100, + sampling=MySampling(), + selection=TournamentSelection(func_comp='comp_by_cv_and_fitness'), + crossover=BinaryCrossover(), + mutation=MyMutation(), + eliminate_duplicates=True, + n_offsprings=None, + display=None, + **kwargs): + """ + Args: + pop_size : {pop_size} + sampling : {sampling} + selection : {selection} + crossover : {crossover} + mutation : {mutation} + eliminate_duplicates : {eliminate_duplicates} + n_offsprings : {n_offsprings} + + """ + + super().__init__( + pop_size=pop_size, + sampling=sampling, + selection=selection, + crossover=crossover, + mutation=mutation, + survival=FitnessSurvival(), + eliminate_duplicates=eliminate_duplicates, + n_offsprings=n_offsprings, + display=display, + **kwargs) + + def _set_optimum(self, force=False): + pop = self.pop + self.opt = filter_optimum(pop, least_infeasible=True) + + +def filter_optimum(pop, least_infeasible=False): + # first only choose feasible solutions + ret = pop[pop.get('feasible')[:, 0]] + + # if at least one feasible solution was found + if len(ret) > 0: + + # then check the objective values + F = ret.get('F') + + if F.shape[1] > 1: + Index = NonDominatedSorting().do(F, only_non_dominated_front=True) + ret = ret[Index] + + else: + ret = ret[np.argmin(F)] + + # no feasible solution was found + else: + # if flag enable report the least infeasible + if least_infeasible: + ret = pop[np.argmin(pop.get('CV'))] + # otherwise just return none + else: + ret = None + + if isinstance(ret, Individual): + ret = Population().create(ret) + + return ret diff --git a/mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py b/mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py new file mode 100644 index 000000000..9ebf36583 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/nsga2_optimizer.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# copied and modified from https://github.com/anyoptimization/pymoo +import numpy as np + +from mmrazor.registry import TASK_UTILS +from .base_optimizer import BaseOptimizer +from .utils.selection import (IntegerFromFloatMutation, Mating, + PointCrossover, TournamentSelection, + binary_tournament) +from .utils.helper import (DefaultDuplicateElimination, Individual, + Initialization, Survival) + +# from pymoo.algorithms.moo.nsga2 import binary_tournament +# from pymoo.core.mating import Mating +# from pymoo.core.survival import Survival +# from pymoo.core.individual import Individual +# from pymoo.core.initialization import Initialization +# from pymoo.core.duplicate import DefaultDuplicateElimination +# from pymoo.operators.crossover.pntx import PointCrossover +# from pymoo.operators.selection.tournament import TournamentSelection +# from .packages.selection import IntegerFromFloatMutation + + +@TASK_UTILS.register_module() +class NSGA2Optimizer(BaseOptimizer): + """Implementation of NSGA2 search method. + + Args: + pop_size : {pop_size} + sampling : {sampling} + selection : {selection} + crossover : {crossover} + mutation : {mutation} + eliminate_duplicates : {eliminate_duplicates} + n_offsprings : {n_offsprings} + """ + + def __init__(self, + pop_size=100, + sampling=None, + selection=TournamentSelection(func_comp=binary_tournament), + crossover=PointCrossover(n_points=2), + mutation=IntegerFromFloatMutation(eta=1.0), + eliminate_duplicates=True, + n_offsprings=None, + display=None, + survival=Survival(), + repair=None, + **kwargs): + super().__init__( + pop_size=pop_size, + sampling=sampling, + selection=selection, + crossover=crossover, + mutation=mutation, + survival=survival, + eliminate_duplicates=eliminate_duplicates, + n_offsprings=n_offsprings, + display=display, + **kwargs) + + # the population size used + self.pop_size = pop_size + + # the survival for the genetic algorithm + self.survival = Survival() + + # number of offsprings to generate through recombination + self.n_offsprings = n_offsprings + + # if the number of offspring is not set + if self.n_offsprings is None: + self.n_offsprings = pop_size + + # the object to be used to represent an individual + self.individual = Individual(rank=np.inf, crowding=-1) + + # set the duplicate detection class + if isinstance(eliminate_duplicates, bool): + if eliminate_duplicates: + self.eliminate_duplicates = DefaultDuplicateElimination() + else: + self.eliminate_duplicates = None + else: + self.eliminate_duplicates = eliminate_duplicates + + self.initialization = Initialization( + sampling, + individual=self.individual, + repair=repair, + eliminate_duplicates=self.eliminate_duplicates) + + self.mating = Mating( + selection, + crossover, + mutation, + repair=repair, + eliminate_duplicates=self.eliminate_duplicates, + n_max_iterations=100) + + # other run specific data updated whenever solve is called + self.n_gen = None + self.pop = None + self.off = None + + def _initialize(self): + + # create the initial population + pop = self.initialization.do( + self.problem, self.pop_size, algorithm=self) + + # then evaluate using the objective function + self.evaluator.eval(self.problem, pop, algorithm=self) + + # that call is a dummy survival to set attributes + # that are necessary for the mating selection + if self.survival: + pop = self.survival.do(self.problem, pop, len(pop), algorithm=self) + + self.pop, self.off = pop, pop + + def _next(self): + + # do the mating using the current population + self.off = self.mating.do( + self.problem, + self.pop, + n_offsprings=self.n_offsprings, + algorithm=self) + + # if the mating could not generate any new offspring + if len(self.off) == 0: + return + + # evaluate the offspring + self.evaluator.eval(self.problem, self.off, algorithm=self) + + # merge the offsprings with the current population + self.pop = self.pop.merge(self.off) + + # the do survival selection + if self.survival: + self.pop = self.survival.do( + self.problem, self.pop, self.pop_size, algorithm=self) + + def _set_optimum(self, **kwargs): + if not np.any(self.pop.get('feasible')): + self.opt = self.pop[[np.argmin(self.pop.get('CV'))]] + else: + self.opt = self.pop[self.pop.get('rank') == 0] diff --git a/mmrazor/models/task_modules/multi_object_optimizer/problem/__init__.py b/mmrazor/models/task_modules/multi_object_optimizer/problem/__init__.py new file mode 100644 index 000000000..426e6283d --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/problem/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .auxiliary_singlelevel_problem import AuxiliarySingleLevelProblem +from .subset_problem import SubsetProblem + +__all__ = ['AuxiliarySingleLevelProblem', 'SubsetProblem'] diff --git a/mmrazor/models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py b/mmrazor/models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py new file mode 100644 index 000000000..060b60a34 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/problem/auxiliary_singlelevel_problem.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from pymoo.core.problem import Problem as BaseProblem + + +class AuxiliarySingleLevelProblem(BaseProblem): + """The optimization problem for finding the next N candidate + architectures.""" + + def __init__(self, searcher, dim=15, sec_obj='flops'): + super().__init__(n_var=dim, n_obj=2, vtype=np.int32) + + self.searcher = searcher + self.predictor = self.searcher.predictor + self.sec_obj = sec_obj + + self.xl = np.zeros(self.n_var) + # upper bound for variable, automatically calculate by search space + self.xu = [] + for mutable in self.predictor.search_groups.values(): + if mutable[0].num_choices > 0: + self.xu.append(mutable[0].num_choices - 1) + self.xu = np.array(self.xu) + + def _evaluate(self, x, out, *args, **kwargs): + """Evaluate results.""" + f = np.full((x.shape[0], self.n_obj), np.nan) + # predicted top1 error + top1_err = self.predictor.handler.predict(x)[:, 0] + for i in range(len(x)): + candidate = self.predictor.vector2model(x[i]) + _, resource = self.searcher._check_constraints(candidate) + f[i, 0] = top1_err[i] + f[i, 1] = resource[self.sec_obj] + + out['F'] = f diff --git a/mmrazor/models/task_modules/multi_object_optimizer/problem/base_problem.py b/mmrazor/models/task_modules/multi_object_optimizer/problem/base_problem.py new file mode 100644 index 000000000..b82f0be93 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/problem/base_problem.py @@ -0,0 +1,327 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# copied and modified from https://github.com/anyoptimization/pymoo +from abc import abstractmethod + +import numpy as np + + +def at_least_2d_array(x, extend_as='row'): + if not isinstance(x, np.ndarray): + x = np.array([x]) + + if x.ndim == 1: + if extend_as == 'row': + x = x[None, :] + elif extend_as == 'column': + x = x[:, None] + + return x + + +class BaseProblem(): + """Superclass for each problem that is defined. + + It provides attributes such as the number of variables, number of + objectives or constraints. Also, the lower and upper bounds are stored. If + available the Pareto-front, nadir point and ideal point are stored. + """ + + def __init__(self, + n_var=-1, + n_obj=-1, + n_constr=0, + xl=None, + xu=None, + type_var=np.double, + evaluation_of='auto', + parallelization=None, + elementwise_evaluation=False, + callback=None): + """ + Args: + n_var (int): + number of variables + n_obj (int): + number of objectives + n_constr (int): + number of constraints + xl (np.array or int): + lower bounds for the variables. + xu (np.array or int): + upper bounds for the variables. + type_var (numpy type): + type of the variable to be evaluated. + elementwise_evaluation (bool): + parallelization (str or tuple): + See :ref:`nb_parallelization` for guidance on parallelization. + + """ + + # number of variable for this problem + self.n_var = n_var + + # type of the variable to be evaluated + self.type_var = type_var + + # number of objectives + self.n_obj = n_obj + + # number of constraints + self.n_constr = n_constr + + # allow just an integer for xl and xu if all bounds are equal + if n_var > 0 and not isinstance(xl, np.ndarray) and xl is not None: + self.xl = np.ones(n_var) * xl + else: + self.xl = xl + + if n_var > 0 and not isinstance(xu, np.ndarray) and xu is not None: + self.xu = np.ones(n_var) * xu + else: + self.xu = xu + + # the pareto set and front will be calculated only once. + self._pareto_front = None + self._pareto_set = None + self._ideal_point, self._nadir_point = None, None + + # actually defines what _evaluate is setting during the evaluation + if evaluation_of == 'auto': + # by default F is set, and G if the problem does have constraints + self.evaluation_of = ['F'] + if self.n_constr > 0: + self.evaluation_of.append('G') + else: + self.evaluation_of = evaluation_of + + self.elementwise_evaluation = elementwise_evaluation + + self.parallelization = parallelization + + # store the callback if defined + self.callback = callback + + def nadir_point(self): + """Return nadir_point (np.array): + + The nadir point for a multi-objective problem. + """ + # if the ideal point has not been calculated yet + if self._nadir_point is None: + + # calculate the pareto front if not happened yet + if self._pareto_front is None: + self.pareto_front() + + # if already done or it was successful - calculate the ideal point + if self._pareto_front is not None: + self._ideal_point = np.max(self._pareto_front, axis=0) + + return self._nadir_point + + def ideal_point(self): + """ + Returns + ------- + ideal_point (np.array): + The ideal point for a multi-objective problem. If single-objective + it returns the best possible solution. + """ + + # if the ideal point has not been calculated yet + if self._ideal_point is None: + + # calculate the pareto front if not happened yet + if self._pareto_front is None: + self.pareto_front() + + # if already done or it was successful - calculate the ideal point + if self._pareto_front is not None: + self._ideal_point = np.min(self._pareto_front, axis=0) + + return self._ideal_point + + def pareto_front(self, + *args, + use_cache=True, + exception_if_failing=True, + **kwargs): + """ + Args: + args : Same problem implementation need some more information to + create the Pareto front. + exception_if_failing (bool): + Whether to throw an exception when generating the Pareto front + has failed. + use_cache (bool): + Whether to use the cache if the Pareto front. + + Returns: + P (np.array): + The Pareto front of a given problem. + + """ + if not use_cache or self._pareto_front is None: + try: + pf = self._calc_pareto_front(*args, **kwargs) + if pf is not None: + self._pareto_front = at_least_2d_array(pf) + + except Exception as e: + if exception_if_failing: + raise e + + return self._pareto_front + + def pareto_set(self, *args, use_cache=True, **kwargs): + """ + Returns: + S (np.array): + Returns the pareto set for a problem. + """ + if not use_cache or self._pareto_set is None: + self._pareto_set = at_least_2d_array( + self._calc_pareto_set(*args, **kwargs)) + + return self._pareto_set + + def evaluate(self, + X, + *args, + return_values_of='auto', + return_as_dictionary=False, + **kwargs): + """Evaluate the given problem. + + The function values set as defined in the function. + The constraint values are meant to be positive if infeasible. + + Args: + + X (np.array): + A two dimensional matrix where each row is a point to evaluate + and each column a variable. + + return_as_dictionary (bool): + If this is true than only one object, a dictionary, + is returned. + return_values_of (list of strings): + Allowed is ["F", "CV", "G", "dF", "dG", "dCV", "feasible"] + where the d stands for derivative and h stands for hessian + matrix. + + + Returns: + A dictionary, if return_as_dictionary enabled, or a list of values + as defined in return_values_of. + """ + + # call the callback of the problem + if self.callback is not None: + self.callback(X) + + only_single_value = len(np.shape(X)) == 1 + X = np.atleast_2d(X) + + # check the dimensionality of the problem and the given input + if X.shape[1] != self.n_var: + raise Exception('Input dimension %s are not equal to n_var %s!' % + (X.shape[1], self.n_var)) + + if type(return_values_of) == str and return_values_of == 'auto': + return_values_of = ['F'] + if self.n_constr > 0: + return_values_of.append('CV') + + out = {} + for val in return_values_of: + out[val] = None + + out = self._evaluate_batch(X, False, out, *args, **kwargs) + + # if constraint violation should be returned as well + if self.n_constr == 0: + CV = np.zeros([X.shape[0], 1]) + else: + CV = self.calc_constraint_violation(out['G']) + + if 'CV' in return_values_of: + out['CV'] = CV + + # if an additional boolean flag for feasibility should be returned + if 'feasible' in return_values_of: + out['feasible'] = (CV <= 0) + + # if asked for a value but not set in the evaluation set to None + for val in return_values_of: + if val not in out: + out[val] = None + + if only_single_value: + for key in out.keys(): + if out[key] is not None: + out[key] = out[key][0, :] + + if return_as_dictionary: + return out + else: + + # if just a single value do not return a tuple + if len(return_values_of) == 1: + return out[return_values_of[0]] + else: + return tuple([out[val] for val in return_values_of]) + + def _evaluate_batch(self, X, calc_gradient, out, *args, **kwargs): + self._evaluate(X, out, *args, **kwargs) + for key in out.keys(): + if len(np.shape(out[key])) == 1: + out[key] = out[key][:, None] + + return out + + @abstractmethod + def _evaluate(self, x, out, *args, **kwargs): + pass + + def has_bounds(self): + return self.xl is not None and self.xu is not None + + def bounds(self): + return self.xl, self.xu + + def name(self): + return self.__class__.__name__ + + def _calc_pareto_front(self, *args, **kwargs): + """Method that either loads or calculates the pareto front. This is + only done ones and the pareto front is stored. + + Returns: + pf (np.array): Pareto front as array. + """ + pass + + def _calc_pareto_set(self, *args, **kwargs): + pass + + # some problem information + def __str__(self): + s = '# name: %s\n' % self.name() + s += '# n_var: %s\n' % self.n_var + s += '# n_obj: %s\n' % self.n_obj + s += '# n_constr: %s\n' % self.n_constr + s += '# f(xl): %s\n' % self.evaluate(self.xl)[0] + s += '# f((xl+xu)/2): %s\n' % self.evaluate( + (self.xl + self.xu) / 2.0)[0] + s += '# f(xu): %s\n' % self.evaluate(self.xu)[0] + return s + + @staticmethod + def calc_constraint_violation(G): + if G is None: + return None + elif G.shape[1] == 0: + return np.zeros(G.shape[0])[:, None] + else: + return np.sum(G * (G > 0), axis=1)[:, None] diff --git a/mmrazor/models/task_modules/multi_object_optimizer/problem/subset_problem.py b/mmrazor/models/task_modules/multi_object_optimizer/problem/subset_problem.py new file mode 100644 index 000000000..2a0af040a --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/problem/subset_problem.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +from .base_problem import BaseProblem + + +class SubsetProblem(BaseProblem): + """select a subset to diversify the pareto front.""" + + def __init__(self, candidates, archive, K): + super().__init__( + n_var=len(candidates), + n_obj=1, + n_constr=1, + xl=0, + xu=1, + type_var=bool) + self.archive = archive + self.candidates = candidates + self.n_max = K + + def _evaluate(self, x, out, *args, **kwargs): + f = np.full((x.shape[0], 1), np.nan) + g = np.full((x.shape[0], 1), np.nan) + + for i, _x in enumerate(x): + # append selected candidates to archive then sort + tmp = np.sort(np.concatenate((self.archive, self.candidates[_x]))) + f[i, 0] = np.std(np.diff(tmp)) + # we penalize if the number of selected candidates is not exactly K + g[i, 0] = (self.n_max - np.sum(_x))**2 + + out['F'] = f + out['G'] = g diff --git a/mmrazor/models/task_modules/multi_object_optimizer/utils/__init__.py b/mmrazor/models/task_modules/multi_object_optimizer/utils/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/utils/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmrazor/models/task_modules/multi_object_optimizer/utils/domin_matrix.py b/mmrazor/models/task_modules/multi_object_optimizer/utils/domin_matrix.py new file mode 100644 index 000000000..cf06da663 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/utils/domin_matrix.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + + +def get_relation(a, b, cva=None, cvb=None): + + if cva is not None and cvb is not None: + if cva < cvb: + return 1 + elif cvb < cva: + return -1 + + val = 0 + for i in range(len(a)): + if a[i] < b[i]: + # indifferent because once better and once worse + if val == -1: + return 0 + val = 1 + elif b[i] < a[i]: + # indifferent because once better and once worse + if val == 1: + return 0 + val = -1 + return val + + +def calc_domination_matrix_loop(F, G): + n = F.shape[0] + CV = np.sum(G * (G > 0).astype(np.float32), axis=1) + M = np.zeros((n, n)) + for i in range(n): + for j in range(i + 1, n): + M[i, j] = get_relation(F[i, :], F[j, :], CV[i], CV[j]) + M[j, i] = -M[i, j] + + return M + + +def calc_domination_matrix(F, _F=None, epsilon=0.0): + """ + if G is None or len(G) == 0: + constr = np.zeros((F.shape[0], F.shape[0])) + else: + # consider the constraint violation + # CV = Problem.calc_constraint_violation(G) + # constr = (CV < CV) * 1 + (CV > CV) * -1 + + CV = Problem.calc_constraint_violation(G)[:, 0] + constr = (CV[:, None] < CV) * 1 + (CV[:, None] > CV) * -1 + """ + + if _F is None: + _F = F + + # look at the obj for dom + n = F.shape[0] + m = _F.shape[0] + + L = np.repeat(F, m, axis=0) + R = np.tile(_F, (n, 1)) + + smaller = np.reshape(np.any(L + epsilon < R, axis=1), (n, m)) + larger = np.reshape(np.any(L > R + epsilon, axis=1), (n, m)) + + M = np.logical_and(smaller, np.logical_not(larger)) * 1 \ + + np.logical_and(larger, np.logical_not(smaller)) * -1 + + # if cv equal then look at dom + # M = constr + (constr == 0) * dom + + return M + + +def fast_non_dominated_sort(F, **kwargs): + M = calc_domination_matrix(F) + + # calculate the dominance matrix + n = M.shape[0] + + fronts = [] + + if n == 0: + return fronts + + # final rank that will be returned + n_ranked = 0 + ranked = np.zeros(n, dtype=np.int32) + is_dominating = [[] for _ in range(n)] + + # storage for the number of solutions dominated this one + n_dominated = np.zeros(n) + + current_front = [] + + for i in range(n): + + for j in range(i + 1, n): + rel = M[i, j] + if rel == 1: + is_dominating[i].append(j) + n_dominated[j] += 1 + elif rel == -1: + is_dominating[j].append(i) + n_dominated[i] += 1 + + if n_dominated[i] == 0: + current_front.append(i) + ranked[i] = 1.0 + n_ranked += 1 + + # append the first front to the current front + fronts.append(current_front) + + # while not all solutions are assigned to a pareto front + while n_ranked < n: + + next_front = [] + + # for each individual in the current front + for i in current_front: + + # all solutions that are dominated by this individuals + for j in is_dominating[i]: + n_dominated[j] -= 1 + if n_dominated[j] == 0: + next_front.append(j) + ranked[j] = 1.0 + n_ranked += 1 + + fronts.append(next_front) + current_front = next_front + + return fronts diff --git a/mmrazor/models/task_modules/multi_object_optimizer/utils/helper.py b/mmrazor/models/task_modules/multi_object_optimizer/utils/helper.py new file mode 100644 index 000000000..b0362f333 --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/utils/helper.py @@ -0,0 +1,668 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# copied from https://github.com/anyoptimization/pymoo +import copy + +import numpy as np +import scipy +from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting + + +def default_attr(pop): + return pop.get('X') + + +def cdist(x, y): + return scipy.spatial.distance.cdist(x, y) + + +class DuplicateElimination: + """Implementation of Elimination. + + func: function to execute. + """ + + def __init__(self, func=None) -> None: + super().__init__() + self.func = func + + if self.func is None: + self.func = default_attr + + def do(self, pop, *args, return_indices=False, to_itself=True): + original = pop + + if to_itself: + pop = pop[~self._do(pop, None, np.full(len(pop), False))] + + for arg in args: + if len(arg) > 0: + + if len(pop) == 0: + break + elif len(arg) == 0: + continue + else: + pop = pop[~self._do(pop, arg, np.full(len(pop), False))] + + if return_indices: + no_duplicate, is_duplicate = [], [] + H = set(pop) + + for Index, ind in enumerate(original): + if ind in H: + no_duplicate.append(Index) + else: + is_duplicate.append(Index) + + return pop, no_duplicate, is_duplicate + else: + return pop + + +class DefaultDuplicateElimination(DuplicateElimination): + """Implementation of DefaultDuplicate Elimination. + + epsilon(float): smallest dist for judge duplication. + """ + + def __init__(self, epsilon=1e-16, **kwargs) -> None: + super().__init__(**kwargs) + self.epsilon = epsilon + + def calc_dist(self, pop, other=None): + X = self.func(pop) + + if other is None: + D = cdist(X, X) + D[np.triu_indices(len(X))] = np.inf + else: + _X = self.func(other) + D = cdist(X, _X) + + return D + + def _do(self, pop, other, is_duplicate): + D = self.calc_dist(pop, other) + D[np.isnan(D)] = np.inf + + is_duplicate[np.any(D < self.epsilon, axis=1)] = True + return is_duplicate + + +class Individual: + """Class for each individual in search step.""" + + def __init__(self, + X=None, + F=None, + CV=None, + G=None, + feasible=None, + **kwargs) -> None: + self.X = X + self.F = F + self.CV = CV + self.G = G + self.feasible = feasible + self.data = kwargs + self.attr = set(self.__dict__.keys()) + + def has(self, key): + return key in self.attr or key in self.data + + def set(self, key, value): + if key in self.attr: + self.__dict__[key] = value + else: + self.data[key] = value + + def copy(self): + ind = copy.copy(self) + ind.data = self.data.copy() + return ind + + def get(self, keys): + if keys in self.data: + return self.data[keys] + elif keys in self.attr: + return self.__dict__[keys] + else: + return None + + +class Population(np.ndarray): + """Class for all the population in search step.""" + + def __new__(cls, n_individuals=0, individual=Individual()): + obj = super(Population, cls).__new__( + cls, n_individuals, dtype=individual.__class__).view(cls) + for Index in range(n_individuals): + obj[Index] = individual.copy() + obj.individual = individual + return obj + + def merge(self, a, b=None): + if b: + a, b = pop_from_array_or_individual(a), \ + pop_from_array_or_individual(b) + a.merge(b) + else: + other = pop_from_array_or_individual(a) + if len(self) == 0: + return other + else: + obj = np.concatenate([self, other]).view(Population) + obj.individual = self.individual + return obj + + def copy(self): + pop = Population(n_individuals=len(self), individual=self.individual) + for Index in range(len(self)): + pop[Index] = self[Index] + return pop + + def has(self, key): + return all([ind.has(key) for ind in self]) + + def __deepcopy__(self, memo): + return self.copy() + + @classmethod + def create(cls, *args): + pop = np.concatenate([ + pop_from_array_or_individual(arg) for arg in args + ]).view(Population) + pop.individual = Individual() + return pop + + def new(self, *args): + + if len(args) == 1: + return Population( + n_individuals=args[0], individual=self.individual) + else: + n = len(args[1]) if len(args) > 0 else 0 + pop = Population(n_individuals=n, individual=self.individual) + if len(args) > 0: + pop.set(*args) + return pop + + def collect(self, func, to_numpy=True): + val = [] + for Index in range(len(self)): + val.append(func(self[Index])) + if to_numpy: + val = np.array(val) + return val + + def set(self, *args): + + for Index in range(int(len(args) / 2)): + + key, values = args[Index * 2], args[Index * 2 + 1] + is_iterable = hasattr(values, + '__len__') and not isinstance(values, str) + + if is_iterable and len(values) != len(self): + raise Exception( + 'Population Set Attribute Error: ' + 'Number of values and population size do not match!') + + for Index in range(len(self)): + val = values[Index] if is_iterable else values + self[Index].set(key, val) + + return self + + def get(self, *args, to_numpy=True): + + val = {} + for c in args: + val[c] = [] + + for Index in range(len(self)): + + for c in args: + val[c].append(self[Index].get(c)) + + res = [val[c] for c in args] + if to_numpy: + res = [np.array(e) for e in res] + + if len(args) == 1: + return res[0] + else: + return tuple(res) + + def __array_finalize__(self, obj): + if obj is None: + return + self.individual = getattr(obj, 'individual', None) + + +def pop_from_array_or_individual(array, pop=None): + # the population type can be different + if pop is None: + pop = Population() + + # provide a whole population object + if isinstance(array, Population): + pop = array + elif isinstance(array, np.ndarray): + pop = pop.new('X', np.atleast_2d(array)) + elif isinstance(array, Individual): + pop = Population(1) + pop[0] = array + else: + return None + + return pop + + +class Initialization: + """Initiallize step.""" + + def __init__(self, + sampling, + individual=Individual(), + repair=None, + eliminate_duplicates=None) -> None: + + super().__init__() + self.sampling = sampling + self.individual = individual + self.repair = repair + self.eliminate_duplicates = eliminate_duplicates + + def do(self, problem, n_samples, **kwargs): + + # provide a whole population object + if isinstance(self.sampling, Population): + pop = self.sampling + + else: + pop = Population(0, individual=self.individual) + if isinstance(self.sampling, np.ndarray): + pop = pop.new('X', self.sampling) + else: + pop = self.sampling.do(problem, n_samples, pop=pop, **kwargs) + + # repair all solutions that are not already evaluated + if self.repair: + Index = [k for k in range(len(pop)) if pop[k].F is None] + pop = self.repair.do(problem, pop[Index], **kwargs) + + if self.eliminate_duplicates is not None: + pop = self.eliminate_duplicates.do(pop) + + return pop + + +def split_by_feasibility(pop, sort_infeasbible_by_cv=True): + CV = pop.get('CV') + + b = (CV <= 0) + + feasible = np.where(b)[0] + infeasible = np.where(np.logical_not(b))[0] + + if sort_infeasbible_by_cv: + infeasible = infeasible[np.argsort(CV[infeasible, 0])] + + return feasible, infeasible + + +class Survival: + """The survival process is implemented inheriting from this class, which + selects from a population only specific individuals to survive. + + Parameters + ---------- + filter_infeasible : bool + Whether for the survival infeasible solutions should be + filtered first + """ + + def __init__(self, filter_infeasible=True): + self.filter_infeasible = filter_infeasible + + def do(self, problem, pop, n_survive, return_indices=False, **kwargs): + + # if the split should be done beforehand + if self.filter_infeasible and problem.n_constr > 0: + feasible, infeasible = split_by_feasibility( + pop, sort_infeasbible_by_cv=True) + + # if there was no feasible solution was added at all + if len(feasible) == 0: + survivors = pop[infeasible[:n_survive]] + + # if there are feasible solutions in the population + else: + survivors = pop.new() + + # if feasible solution do exist + if len(feasible) > 0: + survivors = self._do(problem, pop[feasible], + min(len(feasible), n_survive), + **kwargs) + + # if infeasible solutions needs to be added + if len(survivors) < n_survive: + least_infeasible = infeasible[:n_survive - len(feasible)] + survivors = survivors.merge(pop[least_infeasible]) + + else: + survivors = self._do(problem, pop, n_survive, **kwargs) + + if return_indices: + H = {} + for k, ind in enumerate(pop): + H[ind] = k + return [H[survivor] for survivor in survivors] + else: + return survivors + + def _do(self, problem, pop, n_survive, D=None, **kwargs): + + # get the objective space values and objects + F = pop.get('F').astype(np.float, copy=False) + + # the final indices of surviving individuals + survivors = [] + + # do the non-dominated sorting until splitting front + fronts = NonDominatedSorting().do(F, n_stop_if_ranked=n_survive) + + for k, front in enumerate(fronts): + + # calculate the crowding distance of the front + crowding_of_front = calc_crowding_distance(F[front, :]) + # save rank and crowding in the individual class + for j, Index in enumerate(front): + pop[Index].set('rank', k) + pop[Index].set('crowding', crowding_of_front[j]) + + # current front sorted by crowding distance if splitting + if len(survivors) + len(front) > n_survive: + Index = randomized_argsort( + crowding_of_front, order='descending', method='numpy') + Index = Index[:(n_survive - len(survivors))] + + # otherwise take the whole front unsorted + else: + Index = np.arange(len(front)) + + # extend the survivors by all or selected individuals + survivors.extend(front[Index]) + + return pop[survivors] + + +class FitnessSurvival(Survival): + """Survival class for Fitness.""" + + def __init__(self) -> None: + super().__init__(True) + + def _do(self, problem, pop, n_survive, out=None, **kwargs): + F = pop.get('F') + + if F.shape[1] != 1: + raise ValueError( + 'FitnessSurvival can only used for single objective single!') + + return pop[np.argsort(F[:, 0])[:n_survive]] + + +def find_duplicates(X, epsilon=1e-16): + # calculate the distance matrix from each point to another + D = cdist(X, X) + + # set the diagonal to infinity + D[np.triu_indices(len(X))] = np.inf + + # set as duplicate if a point is really close to this one + is_duplicate = np.any(D < epsilon, axis=1) + + return is_duplicate + + +def calc_crowding_distance(F, filter_out_duplicates=True): + n_points, n_obj = F.shape + + if n_points <= 2: + return np.full(n_points, np.inf) + + else: + + if filter_out_duplicates: + # filter out solutions which are duplicates + is_unique = np.where( + np.logical_not(find_duplicates(F, epsilon=1e-24)))[0] + else: + # set every point to be unique without checking it + is_unique = np.arange(n_points) + + # index the unique points of the array + _F = F[is_unique] + + # sort each column and get index + Index = np.argsort(_F, axis=0, kind='mergesort') + + # sort the objective space values for the whole matrix + _F = _F[Index, np.arange(n_obj)] + + # calculate the distance from each point to the last and next + dist = np.row_stack([_F, np.full(n_obj, np.inf)]) - np.row_stack( + [np.full(n_obj, -np.inf), _F]) + + # calculate the norm for each objective + norm = np.max(_F, axis=0) - np.min(_F, axis=0) + norm[norm == 0] = np.nan + + # prepare the distance to last and next vectors + dist_to_last, dist_to_next = dist, np.copy(dist) + dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[ + 1:] / norm + + dist_to_last[np.isnan(dist_to_last)] = 0.0 + dist_to_next[np.isnan(dist_to_next)] = 0.0 + + # sum up the distance to next and last and norm by objectives + J = np.argsort(Index, axis=0) + _cd = np.sum( + dist_to_last[J, np.arange(n_obj)] + + dist_to_next[J, np.arange(n_obj)], + axis=1) / n_obj + + # save the final vector which sets the crowding distance for duplicates + crowding = np.zeros(n_points) + crowding[is_unique] = _cd + + # crowding[np.isinf(crowding)] = 1e+14 + return crowding + + +def randomized_argsort(A, method='numpy', order='ascending'): + if method == 'numpy': + P = np.random.permutation(len(A)) + Index = np.argsort(A[P], kind='quicksort') + Index = P[Index] + + elif method == 'quicksort': + Index = quicksort(A) + + else: + raise Exception('Randomized sort method not known.') + + if order == 'ascending': + return Index + elif order == 'descending': + return np.flip(Index, axis=0) + else: + raise Exception('Unknown sorting order: ascending or descending.') + + +def swap(M, a, b): + tmp = M[a] + M[a] = M[b] + M[b] = tmp + + +def quicksort(A): + Index = np.arange(len(A)) + _quicksort(A, Index, 0, len(A) - 1) + return Index + + +def _quicksort(A, Index, left, right): + if left < right: + + index = np.random.randint(left, right + 1) + swap(Index, right, index) + + pivot = A[Index[right]] + + Index = left - 1 + + for j in range(left, right): + + if A[Index[j]] <= pivot: + Index += 1 + swap(Index, Index, j) + + index = Index + 1 + swap(Index, right, index) + + _quicksort(A, Index, left, index - 1) + _quicksort(A, Index, index + 1, right) + + +def random_permuations(n, input): + perms = [] + for _ in range(n): + perms.append(np.random.permutation(input)) + P = np.concatenate(perms) + return P + + +def crossover_mask(X, M): + # convert input to output by flatting along the first axis + _X = np.copy(X) + _X[0][M] = X[1][M] + _X[1][M] = X[0][M] + + return _X + + +def at_least_2d_array(x, extend_as='row'): + if not isinstance(x, np.ndarray): + x = np.array([x]) + + if x.ndim == 1: + if extend_as == 'row': + x = x[None, :] + elif extend_as == 'column': + x = x[:, None] + + return x + + +def repair_out_of_bounds(problem, X): + xl, xu = problem.xl, problem.xu + + only_1d = (X.ndim == 1) + X = at_least_2d_array(X) + + if xl is not None: + xl = np.repeat(xl[None, :], X.shape[0], axis=0) + X[X < xl] = xl[X < xl] + + if xu is not None: + xu = np.repeat(xu[None, :], X.shape[0], axis=0) + X[X > xu] = xu[X > xu] + + if only_1d: + return X[0, :] + else: + return X + + +def denormalize(x, x_min, x_max): + + if x_max is None: + _range = 1 + else: + _range = (x_max - x_min) + + return x * _range + x_min + + +class Evaluator: + """The evaluator class which is used during the algorithm execution to + limit the number of evaluations.""" + + def __init__(self, evaluate_values_of=['F', 'CV', 'G']): + self.n_eval = 0 + self.evaluate_values_of = evaluate_values_of + + def eval(self, problem, pop, **kwargs): + """This function is used to return the result of one valid evaluation. + + Parameters + ---------- + problem : class + The problem which is used to be evaluated + pop : np.array or Population object + kwargs : dict + Additional arguments which might be necessary for the problem to + evaluate. + """ + + is_individual = isinstance(pop, Individual) + is_numpy_array = isinstance( + pop, np.ndarray) and not isinstance(pop, Population) + + # make sure the object is a population + if is_individual or is_numpy_array: + pop = Population().create(pop) + + # find indices to be evaluated + Index = [k for k in range(len(pop)) if pop[k].F is None] + + # update the function evaluation counter + self.n_eval += len(Index) + + # actually evaluate all solutions using the function + if len(Index) > 0: + self._eval(problem, pop[Index], **kwargs) + + # set the feasibility attribute if cv exists + for ind in pop[Index]: + cv = ind.get('CV') + if cv is not None: + ind.set('feasible', cv <= 0) + + if is_individual: + return pop[0] + elif is_numpy_array: + if len(pop) == 1: + pop = pop[0] + return tuple([pop.get(e) for e in self.evaluate_values_of]) + else: + return pop + + def _eval(self, problem, pop, **kwargs): + + out = problem.evaluate( + pop.get('X'), + return_values_of=self.evaluate_values_of, + return_as_dictionary=True, + **kwargs) + + for key, val in out.items(): + if val is None: + continue + else: + pop.set(key, val) diff --git a/mmrazor/models/task_modules/multi_object_optimizer/utils/selection.py b/mmrazor/models/task_modules/multi_object_optimizer/utils/selection.py new file mode 100644 index 000000000..3ec1c644e --- /dev/null +++ b/mmrazor/models/task_modules/multi_object_optimizer/utils/selection.py @@ -0,0 +1,490 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import numpy as np + +from .domin_matrix import get_relation +# from pymoo.core.population import Population +# from pymoo.util.misc import crossover_mask, random_permuations +# from pymoo.operators.repair.bounce_back import bounce_back_by_problem +from .helper import (Population, crossover_mask, random_permuations, + repair_out_of_bounds) + + +def binary_tournament(pop, P, algorithm, **kwargs): + if P.shape[1] != 2: + raise ValueError('Only implemented for binary tournament!') + + S = np.full(P.shape[0], np.nan) + + for Index in range(P.shape[0]): + + a, b = P[Index, 0], P[Index, 1] + + # if at least one solution is infeasible + if pop[a].CV > 0.0 or pop[b].CV > 0.0: + S[Index] = compare( + a, + pop[a].CV, + b, + pop[b].CV, + method='smaller_is_better', + return_random_if_equal=True) + else: + rel = get_relation(pop[a].F, pop[b].F) + if rel == 1: + S[Index] = a + elif rel == -1: + S[Index] = b + # if rank or domination relation didn't make a decision + if np.isnan(S[Index]): + S[Index] = compare( + a, + pop[a].get('crowding'), + b, + pop[b].get('crowding'), + method='larger_is_better', + return_random_if_equal=True) + + return S[:, None].astype(int, copy=False) + + +def comp_by_cv_and_fitness(pop, P, **kwargs): + S = np.full(P.shape[0], np.nan) + + for Index in range(P.shape[0]): + a, b = P[Index, 0], P[Index, 1] + + # if at least one solution is infeasible + if pop[a].CV > 0.0 or pop[b].CV > 0.0: + S[Index] = compare( + a, + pop[a].CV, + b, + pop[b].CV, + method='smaller_is_better', + return_random_if_equal=True) + + # both solutions are feasible just set random + else: + S[Index] = compare( + a, + pop[a].F, + b, + pop[b].F, + method='smaller_is_better', + return_random_if_equal=True) + + return S[:, None].astype(int) + + +def compare(a, a_val, b, b_val, method, return_random_if_equal=False): + if method == 'larger_is_better': + if a_val > b_val: + return a + elif a_val < b_val: + return b + else: + if return_random_if_equal: + return np.random.choice([a, b]) + else: + return None + elif method == 'smaller_is_better': + if a_val < b_val: + return a + elif a_val > b_val: + return b + else: + if return_random_if_equal: + return np.random.choice([a, b]) + else: + return None + + +class TournamentSelection: + """The Tournament selection is used to simulated a tournament between + individuals. + + The pressure balances greedy the genetic algorithm will be. + """ + + def __init__(self, pressure=2, func_comp='binary_tournament'): + """ + + Parameters + ---------- + func_comp: func + The function to compare two individuals. + It has the shape: comp(pop, indices) and returns the winner. + + pressure: int + The selection pressure to bie applied. + """ + + # selection pressure to be applied + self.pressure = pressure + if func_comp == 'comp_by_cv_and_fitness': + self.f_comp = comp_by_cv_and_fitness + else: + self.f_comp = binary_tournament + + def do(self, pop, n_select, n_parents=2, **kwargs): + # number of random individuals needed + n_random = n_select * n_parents * self.pressure + + # number of permutations needed + n_perms = math.ceil(n_random / len(pop)) + + # get random permutations and reshape them + P = random_permuations(n_perms, len(pop))[:n_random] + P = np.reshape(P, (n_select * n_parents, self.pressure)) + + # compare using tournament function + S = self.f_comp(pop, P, **kwargs) + + return np.reshape(S, (n_select, n_parents)) + + +class PointCrossover: + + def __init__(self, n_points=2, n_parents=2, n_offsprings=2, prob=0.9): + self.n_points = n_points + self.prob = prob + self.n_parents = n_parents + self.n_offsprings = n_offsprings + + def do(self, problem, pop, parents, **kwargs): + """ + + Parameters + ---------- + problem: class + The problem to be solved. + + pop : Population + The population as an object + + parents: numpy.array + The select parents of the population for the crossover + + kwargs : dict + Any additional data that might be necessary. + + Returns + ------- + offsprings : Population + The off as a matrix. n_children rows and the number of columns is + equal to the variable length of the problem. + + """ + + if self.n_parents != parents.shape[1]: + raise ValueError( + 'Exception during crossover: ' + 'Number of parents differs from defined at crossover.') + + # get the design space matrix form the population and parents + X = pop.get('X')[parents.T].copy() + + # now apply the crossover probability + do_crossover = np.random.random(len(parents)) < self.prob + + # execute the crossover + _X = self._do(problem, X, **kwargs) + + X[:, do_crossover, :] = _X[:, do_crossover, :] + + # flatten the array to become a 2d-array + X = X.reshape(-1, X.shape[-1]) + + # create a population object + off = pop.new('X', X) + + return off + + def _do(self, problem, X, **kwargs): + + # get the X of parents and count the matings + _, n_matings, n_var = X.shape + + # start point of crossover + r = np.row_stack([ + np.random.permutation(n_var - 1) + 1 for _ in range(n_matings) + ])[:, :self.n_points] + r.sort(axis=1) + r = np.column_stack([r, np.full(n_matings, n_var)]) + + # the mask do to the crossover + M = np.full((n_matings, n_var), False) + + # create for each individual the crossover range + for Index in range(n_matings): + + j = 0 + while j < r.shape[1] - 1: + a, b = r[Index, j], r[Index, j + 1] + M[Index, a:b] = True + j += 2 + + _X = crossover_mask(X, M) + + return _X + + +class PolynomialMutation: + + def __init__(self, eta=20, prob=None): + super().__init__() + self.eta = float(eta) + + if prob is not None: + self.prob = float(prob) + else: + self.prob = None + + def _do(self, problem, X, **kwargs): + + Y = np.full(X.shape, np.inf) + + if self.prob is None: + self.prob = 1.0 / problem.n_var + + do_mutation = np.random.random(X.shape) < self.prob + + Y[:, :] = X + + xl = np.repeat(problem.xl[None, :], X.shape[0], axis=0)[do_mutation] + xu = np.repeat(problem.xu[None, :], X.shape[0], axis=0)[do_mutation] + + X = X[do_mutation] + + delta1 = (X - xl) / (xu - xl) + delta2 = (xu - X) / (xu - xl) + + mut_pow = 1.0 / (self.eta + 1.0) + + rand = np.random.random(X.shape) + mask = rand <= 0.5 + mask_not = np.logical_not(mask) + + deltaq = np.zeros(X.shape) + + xy = 1.0 - delta1 + val = 2.0 * rand + (1.0 - 2.0 * rand) * ( + np.power(xy, (self.eta + 1.0))) + d = np.power(val, mut_pow) - 1.0 + deltaq[mask] = d[mask] + + xy = 1.0 - delta2 + val = 2.0 * (1.0 - rand) + 2.0 * (rand - 0.5) * ( + np.power(xy, (self.eta + 1.0))) + d = 1.0 - (np.power(val, mut_pow)) + deltaq[mask_not] = d[mask_not] + + # mutated values + _Y = X + deltaq * (xu - xl) + + # back in bounds if necessary (floating point issues) + _Y[_Y < xl] = xl[_Y < xl] + _Y[_Y > xu] = xu[_Y > xu] + + # set the values for output + Y[do_mutation] = _Y + + # in case out of bounds repair (very unlikely) + # Y = bounce_back_by_problem(problem, Y) + Y = repair_out_of_bounds(problem, Y) + + return Y + + def do(self, problem, pop, **kwargs): + Y = self._do(problem, pop.get('X'), **kwargs) + return pop.new('X', Y) + + +class IntegerFromFloatMutation: + + def __init__(self, **kwargs): + + self.mutation = PolynomialMutation(**kwargs) + + def _do(self, problem, X, **kwargs): + + def fun(): + return self.mutation._do(problem, X, **kwargs) + + # save the original bounds of the problem + _xl, _xu = problem.xl, problem.xu + + # copy the arrays of the problem and cast them to float + xl, xu = problem.xl, problem.xu + + # modify the bounds to match the new crossover specifications + problem.xl = xl - (0.5 - 1e-16) + problem.xu = xu + (0.5 - 1e-16) + + # perform the crossover + off = fun() + + # now round to nearest integer for all offsprings + off = np.rint(off) + + # reset the original bounds of the problem and design space values + problem.xl = _xl + problem.xu = _xu + + return off + + def do(self, problem, pop, **kwargs): + """Mutate variables in a genetic way. + + Parameters + ---------- + problem : class + The problem instance + pop : Population + A population object + + Returns + ------- + Y : Population + The mutated population. + """ + + Y = self._do(problem, pop.get('X'), **kwargs) + return pop.new('X', Y) + + +class Mating: + + def __init__(self, + selection, + crossover, + mutation, + repair=None, + eliminate_duplicates=None, + n_max_iterations=100): + + self.selection = selection + self.crossover = crossover + self.mutation = mutation + self.n_max_iterations = n_max_iterations + self.eliminate_duplicates = eliminate_duplicates + self.repair = repair + + def _do(self, problem, pop, n_offsprings, parents=None, **kwargs): + + # if the parents for the mating are not provided directly + if parents is None: + + # how many parents need to be select for the mating + n_select = math.ceil(n_offsprings / self.crossover.n_offsprings) + + # select the parents for the mating - just an index array + parents = self.selection.do(pop, n_select, + self.crossover.n_parents, **kwargs) + + # do the crossover using the parents index and the population + _off = self.crossover.do(problem, pop, parents, **kwargs) + + # do the mutation on the offsprings created through crossover + _off = self.mutation.do(problem, _off, **kwargs) + + return _off + + def do(self, problem, pop, n_offsprings, **kwargs): + + # the population object to be used + off = pop.new() + + # infill counter + # counts how often the mating needs to be done to fill up n_offsprings + n_infills = 0 + # iterate until enough offsprings are created + while len(off) < n_offsprings: + + # how many offsprings are remaining to be created + n_remaining = n_offsprings - len(off) + + # do the mating + _off = self._do(problem, pop, n_remaining, **kwargs) + + # repair the individuals if necessary + if self.repair: + _off = self.repair.do(problem, _off, **kwargs) + + if self.eliminate_duplicates is not None: + _off = self.eliminate_duplicates.do(_off, pop, off) + + # if more offsprings than necessary - truncate them randomly + if len(off) + len(_off) > n_offsprings: + n_remaining = n_offsprings - len(off) + _off = _off[:n_remaining] + + # add to the offsprings and increase the mating counter + off = off.merge(_off) + n_infills += 1 + + if n_infills > self.n_max_iterations: + break + + return off + + +class MySampling: + + def __init__(self): + pass + + def do(self, problem, n_samples, pop=Population(), **kwargs): + X = np.full((n_samples, problem.n_var), False, dtype=bool) + + for k in range(n_samples): + Index = np.random.permutation(problem.n_var)[:problem.n_max] + X[k, Index] = True + + if pop is None: + return X + return pop.new('X', X) + + +class BinaryCrossover(PointCrossover): + + def __init__(self): + super().__init__(n_parents=2, n_offsprings=1) + + def _do(self, problem, X, **kwargs): + n_parents, n_matings, n_var = X.shape + + _X = np.full((self.n_offsprings, n_matings, problem.n_var), False) + + for k in range(n_matings): + p1, p2 = X[0, k], X[1, k] + + both_are_true = np.logical_and(p1, p2) + _X[0, k, both_are_true] = True + + n_remaining = problem.n_max - np.sum(both_are_true) + + Index = np.where(np.logical_xor(p1, p2))[0] + + S = Index[np.random.permutation(len(Index))][:n_remaining] + _X[0, k, S] = True + + return _X + + +class MyMutation(PolynomialMutation): + + def _do(self, problem, X, **kwargs): + for Index in range(X.shape[0]): + X[Index, :] = X[Index, :] + is_false = np.where(np.logical_not(X[Index, :]))[0] + is_true = np.where(X[Index, :])[0] + try: + X[Index, np.random.choice(is_false)] = True + X[Index, np.random.choice(is_true)] = False + except ValueError: + pass + + return X diff --git a/mmrazor/models/task_modules/predictor/metric_predictor.py b/mmrazor/models/task_modules/predictor/metric_predictor.py index ae99ff091..79a915e4d 100644 --- a/mmrazor/models/task_modules/predictor/metric_predictor.py +++ b/mmrazor/models/task_modules/predictor/metric_predictor.py @@ -132,7 +132,7 @@ def vector2model(self, vector: np.array) -> Dict[str, str]: if self.encoding_type == 'onehot': index = np.where(vector[start:start + len(value[0].choices)] == 1)[0][0] - start += len(value) + start += len(value[0].choices) else: index = vector[start] start += 1