Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add meta ops in mmrazors/models #450

Open
wants to merge 1 commit into
base: 0.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
223 changes: 223 additions & 0 deletions mmrazor/models/architectures/meta_ops/meta_base/meta_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
from abc import abstractmethod
from itertools import repeat
from typing import Callable, Iterable, Optional, Tuple, Set

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.modules.conv import _ConvNd

from abc import ABC, abstractmethod

from mmrazor.models.mutables.base_mutable import BaseMutable



def _ntuple(n: int) -> Callable: # pragma: no cover
"""Repeat a number n times."""

def parse(x):
if isinstance(x, Iterable):
return tuple(x)
return tuple(repeat(x, n))

return parse


def _get_current_kernel_pos(source_kernel_size: int,
target_kernel_size: int) -> Tuple[int, int]:
"""Get position of current kernel size.
Returns:
Tuple[int, int]: (upper left position, bottom right position)
"""
assert source_kernel_size >= target_kernel_size, \
'`source_kernel_size` must greater or equal than `target_kernel_size`'

center = source_kernel_size >> 1
current_offset = target_kernel_size >> 1

start_offset = center - current_offset
end_offset = center + current_offset + 1

return start_offset, end_offset


def _get_same_padding(kernel_size: int, n_dims: int) -> Tuple[int]:
"""Get same padding according to kernel size."""
assert kernel_size & 1
_pair = _ntuple(n_dims)

return _pair(kernel_size >> 1)




class MetaMixin(ABC):
"""Base class for dynamic OP. A dynamic OP usually consists of a normal
static OP and mutables, where mutables are used to control the searchable
(mutable) part of the dynamic OP.
Note:
When the dynamic OP has just been initialized, its forward propagation
logic should be the same as the corresponding static OP. Only after
the searchable part accepts the specific mutable through the
corresponding interface does the part really become dynamic.
Note:
All subclass should implement ``to_static_op`` and
``static_op_factory`` APIs.
Args:
accepted_mutables (set): The string set of all accepted mutables.
"""
accepted_mutable_attrs: Set[str] = set()
attr_mappings: Dict[str, str] = dict()

@abstractmethod
def register_mutable_attr(self, attr: str, mutable: BaseMutable):
pass

def get_mutable_attr(self, attr: str) -> BaseMutable:

self.check_mutable_attr_valid(attr)
if attr in self.attr_mappings:
attr_map = self.attr_mappings[attr]
return getattr(self.mutable_attrs, attr_map, None) # type:ignore
else:
return getattr(self.mutable_attrs, attr, None) # type:ignore

@classmethod
@abstractmethod
def convert_from(cls, module):
"""Convert an instance of Pytorch module to a new instance of Dynamic
module."""

@property
@abstractmethod
def static_op_factory(self):
"""Corresponding Pytorch OP."""

@abstractmethod
def to_static_op(self) -> nn.Module:
"""Convert dynamic OP to static OP.
Note:
The forward result for the same input between dynamic OP and its
corresponding static OP must be same.
Returns:
nn.Module: Corresponding static OP.
"""

def check_if_mutables_fixed(self):
"""Check if all mutables are fixed.
Raises:
RuntimeError: Error if a existing mutable is not fixed.
"""

def check_fixed(mutable: Optional[BaseMutable]) -> None:
if mutable is not None and not mutable.is_fixed:
raise RuntimeError(f'Mutable {type(mutable)} is not fixed.')

for mutable in self.mutable_attrs.values(): # type: ignore
check_fixed(mutable)

def check_mutable_attr_valid(self, attr):
assert attr in self.attr_mappings or \
attr in self.accepted_mutable_attrs

@staticmethod
def get_current_choice(mutable: BaseMutable):
"""
Get current choice of given mutable.
Args:
mutable (BaseMutable): Given mutable.
Raises:
RuntimeError: Error if `current_choice` is None.
Returns:
Any: Current choice of given mutable.
"""
current_choice = mutable.current_choice
if current_choice is None:
raise RuntimeError(f'current choice of mutable {type(mutable)} '
'can not be None at runtime')

return current_choice


class MetaConvMixin(DynamicChannelMixin):
"""A mixin class for Pytorch conv, which can mutate ``in_channels`` and
``out_channels``.
Note:
All subclass should implement ``conv_func``API.
"""

@property
@abstractmethod
def conv_func(self: _ConvNd):
"""The function that will be used in ``forward_mixin``."""
pass

def register_mutable_attr(self, attr, mutable):

if attr == 'in_channels':
self._register_mutable_in_channels(mutable)
elif attr == 'out_channels':
self._register_mutable_out_channels(mutable)
else:
raise NotImplementedError

def _register_mutable_in_channels(
self: _ConvNd, mutable_in_channels: BaseMutable):
"""Mutate ``in_channels`` with given mutable.
Args:
mutable_in_channels (BaseMutable): Mutable for controlling
``in_channels``.
Raises:
ValueError: Error if size of mask if not same as ``in_channels``.
"""
assert hasattr(self, 'mutable_attrs')
self.check_mutable_channels(mutable_in_channels)
mask_size = mutable_in_channels.current_mask.size(0)
if mask_size != self.in_channels:
raise ValueError(
f'Expect mask size of mutable to be {self.in_channels} as '
f'`in_channels`, but got: {mask_size}.')

self.mutable_attrs['in_channels'] = mutable_in_channels

def _register_mutable_out_channels(
self: _ConvNd, mutable_out_channels: BaseMutable):
"""Mutate ``out_channels`` with given mutable.
Args:
mutable_out_channels (BaseMutable): Mutable for controlling
``out_channels``.
Raises:
ValueError: Error if size of mask if not same as ``out_channels``.
"""
assert hasattr(self, 'mutable_attrs')
self.check_mutable_channels(mutable_out_channels)
mask_size = mutable_out_channels.current_mask.size(0)
if mask_size != self.out_channels:
raise ValueError(
f'Expect mask size of mutable to be {self.out_channels} as '
f'`out_channels`, but got: {mask_size}.')

self.mutable_attrs['out_channels'] = mutable_out_channels

@property
def mutable_in_channels(self: _ConvNd):
"""Mutable related to input."""
assert hasattr(self, 'mutable_attrs')
return getattr(self.mutable_attrs, 'in_channels', None) # type:ignore

@property
def mutable_out_channels(self: _ConvNd):
"""Mutable related to output."""
assert hasattr(self, 'mutable_attrs')
return getattr(self.mutable_attrs, 'out_channels', None) # type:ignore

def forward_inpoup(self):
if 'in_channels' in self.mutable_attrs:
mutable_in_channels = self.mutable_attrs['in_channels']
inp = mutable_in_channels.activated_channels
if 'out_channels' in self.mutable_attrs:
mutable_out_channels = self.mutable_attrs['out_channels']
oup = mutable_out_channels.activated_channels
return inp, oup

Empty file.
87 changes: 87 additions & 0 deletions mmrazor/models/architectures/meta_ops/meta_bircks/meta_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor, scalar_tensor

import math
from typing import Callable, Dict


def groups_channels(in_channels, groups):
if in_channels % groups == 0:
return int(in_channels/groups), groups
else:
num_mul = in_channels // groups
in_channels = groups * num_mul if num_mul > 0 else groups * (num_mul + 1)
in_channels = in_channels / groups
return int(in_channels), groups

def groups_out_channels(out_channels, groups):
if out_channels % groups == 0:
return out_channels, groups
else:
num_mul = out_channels // groups
out_channels = groups * num_mul if num_mul > 0 else groups * (num_mul + 1)
out_channels = out_channels
return int(out_channels), groups



class MetaConv2d(nn.Conv2d, MetaConvMixin):

def __init__(self, in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, bias, padding_mode):
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode):

self.mutable_attrs: Dict[str, BaseMuable] = nn.ModuleDict
self.stride = stride
self.padding = padding
self.kernel_size = kernel_size if not isinstance(kernel_size, int) \
else [kernel_size, kernel_size]
self.base_oup = out_channels
self.base_inp = in_channels

self.groups_ = groups
self.bias_ = True if bias is not False else False
self.max_oup_channel = self.base_oup
if in_channels/groups == 1:
self.max_inp_channel = 1
else:
self.max_inp_channel = self.base_inp

self.fc11 = nn.Linear(2, 64)
self.fc12 = nn.Linear(64, self.max_oup_channel * self.max_inp_channel \
* self.kernel_size[0] * self.kernel_size[1])
if self.bias_:
self.fc_bias = nn.Sequential(
nn.Linear(2, 16),
nn.ReLU(),
nn.Linear(16, self.max_out_channel)
)

def forward(self, x: Tensor):

inp, out = self.forward_inpoup()
group_sample_num = self.base_inp / self.groups_
group_sample_num = inp if group_sample_num > inp else group_sample_num
groups_new = int(inp / group_sample_num) if int(inp / group_sample_num) > 0 else 1
inp, _ = groups_channels(inp, groups_new)
oup, _ = groups_out_channels(oup, groups_new)

scale_tensor = torch.FloatTensor([inp/self.max_inp_channel, oup/self.max_out_channel]).to(x.device)
fc11_out = F.relu(self.fc11(scale_tensor))

vggconv3x3_weight = self.fc12(fc11_out).view(
self.max_oup_channel,
self.max_inp_channel,
self.kernel_size[0],
self.kernel_size[1])
bias = None
if self.bias_:
bias = self.fc_bias(scale_tensor)
bias = bias[:oup]

out = F.conv2d(x, vggconv3x3_weight[:oup, :inp, :, :],
bias=bias, stride=self.stride, padding=self.padding, groups=groups_new)
return out