From a534bc45ed57269a48cfe52fca97c1abfdbd42d1 Mon Sep 17 00:00:00 2001 From: Aakanksha Rana <40461936+Aakanksha-Rana@users.noreply.github.com> Date: Fri, 22 Mar 2024 16:38:43 -0600 Subject: [PATCH] [WIP] ConvneXT Models for Classification and Segmentation (#210) * Update CHANGELOG.md [skip ci] * Create DepthwiseConv3d.py * for classification * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update convnext.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update DepthwiseConv3d.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update convnext.py * Update convnext.py fix typo --------- Co-authored-by: Satrajit Ghosh Co-authored-by: Nobrainer Bot Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: H Gazula Co-authored-by: Aakanksha Rana --- nobrainer/layers/DepthwiseConv3d.py | 338 ++++++++++++++++++++++++++++ nobrainer/models/convnext.py | 196 ++++++++++++++++ 2 files changed, 534 insertions(+) create mode 100644 nobrainer/layers/DepthwiseConv3d.py create mode 100644 nobrainer/models/convnext.py diff --git a/nobrainer/layers/DepthwiseConv3d.py b/nobrainer/layers/DepthwiseConv3d.py new file mode 100644 index 00000000..2cfd0efa --- /dev/null +++ b/nobrainer/layers/DepthwiseConv3d.py @@ -0,0 +1,338 @@ +""" +Directly taken from +https://github.com/alexandrosstergiou/keras-DepthwiseConv3D/blob/master/DepthwiseConv3D.py +This is a modification of the SeparableConv3D code in Keras, +to perform just the Depthwise Convolution (1st step) of the +Depthwise Separable Convolution layer. +""" +from __future__ import absolute_import + +from keras import backend as K +from keras import constraints, initializers, regularizers +from keras.backend.tensorflow_backend import ( + _preprocess_conv3d_input, + _preprocess_padding, +) +from keras.engine import InputSpec +from keras.layers import Conv3D +from keras.legacy.interfaces import conv3d_args_preprocessor +from keras.utils import conv_utils +import tensorflow as tf + + +def depthwise_conv3d_args_preprocessor(args, kwargs): + converted = [] + + if "init" in kwargs: + init = kwargs.pop("init") + kwargs["depthwise_initializer"] = init + converted.append(("init", "depthwise_initializer")) + + args, kwargs, _converted = conv3d_args_preprocessor(args, kwargs) + return args, kwargs, converted + _converted + + +# legacy_depthwise_conv3d_support = generate_legacy_interface( +# allowed_positional_args=["filters", "kernel_size"], +# conversions=[ +# ("nb_filter", "filters"), +# ("subsample", "strides"), +# ("border_mode", "padding"), +# ("dim_ordering", "data_format"), +# ("b_regularizer", "bias_regularizer"), +# ("b_constraint", "bias_constraint"), +# ("bias", "use_bias"), +# ], +# value_conversions={ +# "dim_ordering": { +# "tf": "channels_last", +# "th": "channels_first", +# "default": None, +# } +# }, +# preprocessor=depthwise_conv3d_args_preprocessor, +# ) + + +class DepthwiseConv3D(Conv3D): + """Depthwise 3D convolution. + Depth-wise part of separable convolutions consist in performing + just the first step/operation + (which acts on each input channel separately). + It does not perform the pointwise convolution (second step). + The `depth_multiplier` argument controls how many + output channels are generated per input channel in the depthwise step. + # Arguments + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, width and height of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, width and height. + Can be a single integer to specify the same value for + all spatial dimensions. + padding: one of `"valid"` or `"same"` (case-insensitive). + depth_multiplier: The number of depthwise convolution output channels + for each input channel. + The total number of depthwise convolution output + channels will be equal to `filterss_in * depth_multiplier`. + groups: The depth size of the convolution (as a variant of the original Depthwise conv) + data_format: A string, + one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, height, width)`. + It defaults to the `image_data_format` value found in your + Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be "channels_last". + activation: Activation function to use + (see [activations](../activations.md)). + If you don't specify anything, no activation is applied + (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, whether the layer uses a bias vector. + depthwise_initializer: Initializer for the depthwise kernel matrix + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). + depthwise_regularizer: Regularizer function applied to + the depthwise kernel matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). + dialation_rate: List of ints. + Defines the dilation factor for each dimension in the + input. Defaults to (1,1,1) + activity_regularizer: Regularizer function applied to + the output of the layer (its "activation"). + (see [regularizer](../regularizers.md)). + depthwise_constraint: Constraint function applied to + the depthwise kernel matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). + # Input shape + 5D tensor with shape: + `(batch, depth, channels, rows, cols)` if data_format='channels_first' + or 5D tensor with shape: + `(batch, depth, rows, cols, channels)` if data_format='channels_last'. + # Output shape + 5D tensor with shape: + `(batch, filters * depth, new_depth, new_rows, new_cols)` if data_format='channels_first' + or 4D tensor with shape: + `(batch, new_depth, new_rows, new_cols, filters * depth)` if data_format='channels_last'. + `rows` and `cols` values might have changed due to padding. + """ + + # @legacy_depthwise_conv3d_support + def __init__( + self, + kernel_size, + strides=(1, 1, 1), + padding="valid", + depth_multiplier=1, + groups=None, + data_format=None, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + bias_initializer="zeros", + dilation_rate=(1, 1, 1), + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + **kwargs + ): + super(DepthwiseConv3D, self).__init__( + filters=None, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + activation=activation, + use_bias=use_bias, + bias_regularizer=bias_regularizer, + dilation_rate=dilation_rate, + activity_regularizer=activity_regularizer, + bias_constraint=bias_constraint, + **kwargs + ) + self.depth_multiplier = depth_multiplier + self.groups = groups + self.depthwise_initializer = initializers.get(depthwise_initializer) + self.depthwise_regularizer = regularizers.get(depthwise_regularizer) + self.depthwise_constraint = constraints.get(depthwise_constraint) + self.bias_initializer = initializers.get(bias_initializer) + self.dilation_rate = dilation_rate + self._padding = _preprocess_padding(self.padding) + self._strides = (1,) + self.strides + (1,) + self._data_format = "NDHWC" + self.input_dim = None + + def build(self, input_shape): + if len(input_shape) < 5: + raise ValueError( + "Inputs to `DepthwiseConv3D` should have rank 5. " + "Received input shape:", + str(input_shape), + ) + if self.data_format == "channels_first": + channel_axis = 1 + else: + channel_axis = -1 + if input_shape[channel_axis] is None: + raise ValueError( + "The channel dimension of the inputs to " + "`DepthwiseConv3D` " + "should be defined. Found `None`." + ) + self.input_dim = int(input_shape[channel_axis]) + + if self.groups is None: + self.groups = self.input_dim + + if self.groups > self.input_dim: + raise ValueError( + "The number of groups cannot exceed the number of channels" + ) + + if self.input_dim % self.groups != 0: + raise ValueError( + "Warning! The channels dimension is not divisible by the group size chosen" + ) + + depthwise_kernel_shape = ( + self.kernel_size[0], + self.kernel_size[1], + self.kernel_size[2], + self.input_dim, + self.depth_multiplier, + ) + + self.depthwise_kernel = self.add_weight( + shape=depthwise_kernel_shape, + initializer=self.depthwise_initializer, + name="depthwise_kernel", + regularizer=self.depthwise_regularizer, + constraint=self.depthwise_constraint, + ) + + if self.use_bias: + self.bias = self.add_weight( + shape=(self.groups * self.depth_multiplier,), + initializer=self.bias_initializer, + name="bias", + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + ) + else: + self.bias = None + # Set input spec. + self.input_spec = InputSpec(ndim=5, axes={channel_axis: self.input_dim}) + self.built = True + + def call(self, inputs, training=None): + inputs = _preprocess_conv3d_input(inputs, self.data_format) + + if self.data_format == "channels_last": + dilation = (1,) + self.dilation_rate + (1,) + else: + dilation = self.dilation_rate + (1,) + (1,) + + if self._data_format == "NCDHW": + outputs = tf.concat( + [ + tf.nn.conv3d( + inputs[0][:, i : i + self.input_dim // self.groups, :, :, :], + self.depthwise_kernel[ + :, :, :, i : i + self.input_dim // self.groups, : + ], + strides=self._strides, + padding=self._padding, + dilations=dilation, + data_format=self._data_format, + ) + for i in range(0, self.input_dim, self.input_dim // self.groups) + ], + axis=1, + ) + + else: + outputs = tf.concat( + [ + tf.nn.conv3d( + inputs[0][:, :, :, :, i : i + self.input_dim // self.groups], + self.depthwise_kernel[ + :, :, :, i : i + self.input_dim // self.groups, : + ], + strides=self._strides, + padding=self._padding, + dilations=dilation, + data_format=self._data_format, + ) + for i in range(0, self.input_dim, self.input_dim // self.groups) + ], + axis=-1, + ) + + if self.bias is not None: + outputs = K.bias_add(outputs, self.bias, data_format=self.data_format) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_first": + depth = input_shape[2] + rows = input_shape[3] + cols = input_shape[4] + out_filters = self.groups * self.depth_multiplier + elif self.data_format == "channels_last": + depth = input_shape[1] + rows = input_shape[2] + cols = input_shape[3] + out_filters = self.groups * self.depth_multiplier + + depth = conv_utils.conv_output_length( + depth, self.kernel_size[0], self.padding, self.strides[0] + ) + + rows = conv_utils.conv_output_length( + rows, self.kernel_size[1], self.padding, self.strides[1] + ) + + cols = conv_utils.conv_output_length( + cols, self.kernel_size[2], self.padding, self.strides[2] + ) + + if self.data_format == "channels_first": + return (input_shape[0], out_filters, depth, rows, cols) + + elif self.data_format == "channels_last": + return (input_shape[0], depth, rows, cols, out_filters) + + def get_config(self): + config = super(DepthwiseConv3D, self).get_config() + config.pop("filters") + config.pop("kernel_initializer") + config.pop("kernel_regularizer") + config.pop("kernel_constraint") + config["depth_multiplier"] = self.depth_multiplier + config["depthwise_initializer"] = initializers.serialize( + self.depthwise_initializer + ) + config["depthwise_regularizer"] = regularizers.serialize( + self.depthwise_regularizer + ) + config["depthwise_constraint"] = constraints.serialize( + self.depthwise_constraint + ) + return config + + +DepthwiseConvolution3D = DepthwiseConv3D diff --git a/nobrainer/models/convnext.py b/nobrainer/models/convnext.py new file mode 100644 index 00000000..b327670f --- /dev/null +++ b/nobrainer/models/convnext.py @@ -0,0 +1,196 @@ +import numpy as np +import tensorflow as tf +from tensorflow.keras import layers + +from ..layers.DepthwiseConv3d import DepthwiseConv3D + + +def drop_path(inputs, drop_prob, is_training): + # https://github.com/rishigami/Swin-Transformer-TF/blob/main/swintransformer/model.py + if (not is_training) or (drop_prob == 0.0): + return inputs + + # Compute keep_prob + keep_prob = 1.0 - drop_prob + + # Compute drop_connect tensor + random_tensor = keep_prob + shape = (tf.shape(inputs)[0],) + (1,) * (len(tf.shape(inputs)) - 1) + random_tensor += tf.random.uniform(shape, dtype=inputs.dtype) + binary_tensor = tf.floor(random_tensor) + output = tf.math.divide(inputs, keep_prob) * binary_tensor + return output + + +class DropPath(tf.keras.layers.Layer): + # https://github.com/rishigami/Swin-Transformer-TF/blob/main/swintransformer/model.py + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = drop_prob + + def call(self, x, training=None): + return drop_path(x, self.drop_prob, training) + + +class Block(layers.Layer): + """ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) + -> 1x1x1 Conv -> GELU -> 1x1x1 Conv; all in (N, C, H, W, D) + (2) DwConv -> Permute to (N, H, W, D, C); + LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6, prefix=""): + super().__init__() + self.dwconv = DepthwiseConv3D(kernel_size=7, padding="same") # depthwise conv + self.norm = layers.LayerNormalization(epsilon=1e-6) + # pointwise/1x1x1 convs, implemented with linear layers + self.pwconv1 = layers.Dense(4 * dim) + self.act = tf.keras.activations.gelu + self.pwconv2 = layers.Dense(dim) + self.drop_path = DropPath(drop_path) + self.dim = dim + self.layer_scale_init_value = layer_scale_init_value + self.prefix = prefix + + def build(self, input_shape): + self.gamma = tf.Variable( + initial_value=self.layer_scale_init_value * tf.ones((self.dim)), + trainable=True, + name=f"{self.prefix}/gamma", + ) + self.built = True + + def call(self, x): + input = x + x = self.dwconv(x) + # x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + # x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class ConvNeXt(tf.keras.Model): + """3D ConvNeXt Classification Model. + + Adapted from 2D Tensorflow keras impl of : `A ConvNet for the 2020s` - + https://arxiv.org/pdf/2201.03545.pdf + Args: + num_classes (int): Number of classes for classification head. Default: 1 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + include_top (bool): whether to add head or + just use it as feature extractor. Default: True + drop_path_rate (float): Stochastic depth rate. Default: 0. + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + head_init_scale (float): Init scaling value for + classifier weights and biases. Default: 1. + """ + + def __init__( + self, + num_classes=1, + depths=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + include_top=True, + drop_path_rate=0.0, + layer_scale_init_value=1e-6, + head_init_scale=1.0, + ): + super().__init__() + self.include_top = include_top + self.downsample_layers = [] # stem and 3 intermediate downsampling conv layers + stem = tf.keras.Sequential( + [ + layers.Conv3D(dims[0], kernel_size=4, strides=4, padding="same"), + layers.LayerNormalization(epsilon=1e-6), + ] + ) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = tf.keras.Sequential( + [ + layers.LayerNormalization(epsilon=1e-6), + layers.Conv3D( + dims[i + 1], kernel_size=2, strides=2, padding="same" + ), + ] + ) + self.downsample_layers.append(downsample_layer) + + self.stages = ( + [] + ) # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates = [x for x in np.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(4): + stage = tf.keras.Sequential( + [ + Block( + dim=dims[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value, + prefix=f"block{i}", + ) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + cur += depths[i] + + if self.include_top: + self.avg = layers.GlobalAveragePooling3D() + self.norm = layers.LayerNormalization(epsilon=1e-6) # final norm layer + self.head = layers.Dense(num_classes) + else: + self.avg = None + self.norm = None + self.head = None + + def forward_features(self, x): + for i in range(4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + return x + + def call(self, x): + x = self.forward_features(x) + if self.include_top: + x = self.avg(x) + x = self.norm(x) + x = self.head(x) + return x + + +model_configs = dict( + convnext_tiny=dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]), + convnext_small=dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768]), + convnext_base=dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024]), + convnext_large=dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536]), + convnext_xlarge=dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048]), +) + + +def create_model( + input_shape=(128, 128, 128, 1), + num_classes=1, + include_top=True, + model_name="convnext_tiny_1k", + **kwargs, +): + cfg = model_configs["_".join(model_name.split("_")[:2])] + net = ConvNeXt(num_classes, cfg["depths"], cfg["dims"], include_top, **kwargs) + net(tf.keras.Input(shape=input_shape)) + return net