Skip to content

Commit

Permalink
Merge branch 'release/1.1.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
shkarupa-alex committed Mar 30, 2022
2 parents 5f2bf80 + 9f2e51f commit 1772ae2
Show file tree
Hide file tree
Showing 15 changed files with 234 additions and 509 deletions.
43 changes: 30 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,31 @@ Supports variable-shape inference.
pip install tfreplknet
```

## Available models and weights

| Model name | Pretrain size | Preprocessing function | Description |
| :---: | :---: | :---: | :---: |
| RepLKNet | - | - | General RepLKNet architecture |
| RepLKNetB | - | - | Base model size preset |
| RepLKNetL | - | - | Large model size preset |
| RepLKNetXL | - | - | Extra large model size preset |
| RepLKNetB224In1k | 224 | preprocess_input_bl | Base model with weighs pretrained on ImageNet 21k and finetuned to 1k |
| RepLKNetB224In21k | 224 | preprocess_input_bl | Base model with weighs pretrained on ImageNet 21k |
| RepLKNetB384In1k | 384 | preprocess_input_bl | Base model with weighs pretrained on ImageNet 21k and finetuned to 1k |
| RepLKNetL384In1k | 384 | preprocess_input_bl | Large model with weighs pretrained on ImageNet 21k and finetuned to 1k |
| RepLKNetL384In21k | 384 | preprocess_input_bl | Large model with weighs pretrained on ImageNet 21k |
| RepLKNetXL320In1k | 320 | preprocess_input_xl | Extra large model with weighs pretrained on MegData-73M and finetuned to 1k |
| RepLKNetXL320In21k | 320 | preprocess_input_xl | Extra large model with weighs pretrained on MegData-73M (21k head) |


## Examples

Default usage (without preprocessing):

```python
from tfreplknet import RepLKNet31B224K1 # + 4 other variants and input preprocessing
from tfreplknet import RepLKNetB224In1k # + 4 other variants and input preprocessing

model = RepLKNet31B224K1() # by default will download imagenet{1k, 21k}-pretrained weights
model = RepLKNetB224In1k() # by default will download imagenet{1k, 21k}-pretrained weights
model.compile(...)
model.fit(...)
```
Expand All @@ -28,11 +45,11 @@ Custom classification (with preprocessing):

```python
from keras import layers, models
from tfreplknet import RepLKNet31B224K1, preprocess_input
from tfreplknet import RepLKNetB224In1k, preprocess_input_bl

inputs = layers.Input(shape=(224, 224, 3), dtype='uint8')
outputs = layers.Lambda(preprocess_input)(inputs)
outputs = RepLKNet31B224K1(include_top=False)(outputs)
outputs = layers.Lambda(preprocess_input_bl)(inputs)
outputs = RepLKNetB224In1k(include_top=False)(outputs)
outputs = layers.Dense(100, activation='softmax')(outputs)

model = models.Model(inputs=inputs, outputs=outputs)
Expand All @@ -42,32 +59,32 @@ model.fit(...)

## Evaluation

For correctness, `RepLKNet31B224K1` and `RepLKNet31B384K1` models (original and ported) tested
For correctness, `RepLKNetB224In1k` and `RepLKNetB384In1k` models (original and ported) tested
with [ImageNet-v2 test set](https://www.tensorflow.org/datasets/catalog/imagenet_v2).

```python
import tensorflow as tf
import tensorflow_datasets as tfds
from tfreplknet import RepLKNet31B224K1, RepLKNet31B384K1, preprocess_input
from tfreplknet import RepLKNetB224In1k, RepLKNetB384In1k, preprocess_input_bl

def _prepare(example):
# For RepLKNet31B224K1
# For RepLKNetB224In1k
image = tf.image.resize(example['image'], (256, 256), method=tf.image.ResizeMethod.BICUBIC)
image = tf.image.central_crop(image, 0.875)

# For RepLKNet31B384K1
# For RepLKNetB384In1k
# image = tf.image.resize(example['image'], (438, 438), method=tf.image.ResizeMethod.BICUBIC)
# image = tf.image.central_crop(image, 0.877)

image = preprocess_input(image)
image = preprocess_input_bl(image)

return image, example['label']

imagenet2 = tfds.load('imagenet_v2', split='test', shuffle_files=True)
imagenet2 = imagenet2.map(_prepare, num_parallel_calls=tf.data.AUTOTUNE)
imagenet2 = imagenet2.batch(8)

model = RepLKNet31B224K1()
model = RepLKNetB224In1k()
model.compile('sgd', 'sparse_categorical_crossentropy', ['accuracy', 'sparse_top_k_categorical_accuracy'])
history = model.evaluate(imagenet2)

Expand All @@ -76,8 +93,8 @@ print(history)

| name | original acc@1 | ported acc@1 | original acc@5 | ported acc@5 |
| :---: | :---: | :---: | :---: | :---: |
| RepLKNet31B 224 1K | 75.29 | 75.13 | 92.60 | 92.88 |
| RepLKNet31B 384 1K | ? | 76.46 | ? | 93.37 |
| RepLKNetB 224 1K | 75.29 | 75.13 | 92.60 | 92.88 |
| RepLKNetB 384 1K | 72.77 | 76.46 | 89.91 | 93.37 |

## Citation

Expand Down
20 changes: 14 additions & 6 deletions convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@
'rep_l_k_net_31_b_384_k1': 'https://drive.google.com/file/d/1Sc46BWdXXm2fVP-K_hKKU_W8vAB-0duX/view?usp=sharing',
# 'rep_l_k_net_31_b_384_k21': '',
'rep_l_k_net_31_l_384_k1': 'https://drive.google.com/file/d/1JYXoNHuRvC33QV1pmpzMTKEni1hpWfBl/view?usp=sharing',
'rep_l_k_net_31_l_384_k21': 'https://drive.google.com/file/d/16jcPsPwo5rko7ojWS9k_W-svHX-iFknY/view?usp=sharing'
'rep_l_k_net_31_l_384_k21': 'https://drive.google.com/file/d/16jcPsPwo5rko7ojWS9k_W-svHX-iFknY/view?usp=sharing',
'rep_l_k_net_27_xl_320_k1': 'https://drive.google.com/file/d/1tPC60El34GntXByIRHb-z-Apm4Y5LX1T/view?usp=sharing',
'rep_l_k_net_27_xl_320_m73': 'https://drive.google.com/file/d/1CBHAEUlCzoHfiAQmMIjZhDMAIyHUmAAj/view?usp=sharing',
}
TF_MODELS = {
'rep_l_k_net_31_b_224_k1': tfreplknet.RepLKNet31B224K1,
'rep_l_k_net_31_b_224_k21': tfreplknet.RepLKNet31B224K21,
'rep_l_k_net_31_b_384_k1': tfreplknet.RepLKNet31B384K1,
'rep_l_k_net_31_b_224_k1': tfreplknet.RepLKNetB224In1k,
'rep_l_k_net_31_b_224_k21': tfreplknet.RepLKNetB224In21k,
'rep_l_k_net_31_b_384_k1': tfreplknet.RepLKNetB384In1k,
# 'rep_l_k_net_31_b_384_k21': '',
'rep_l_k_net_31_l_384_k1': tfreplknet.RepLKNet31L384K1,
'rep_l_k_net_31_l_384_k21': tfreplknet.RepLKNet31L384K21
'rep_l_k_net_31_l_384_k1': tfreplknet.RepLKNetL384In1k,
'rep_l_k_net_31_l_384_k21': tfreplknet.RepLKNetL384In1k,
'rep_l_k_net_27_xl_320_k1': tfreplknet.RepLKNetXL320In1k,
'rep_l_k_net_27_xl_320_k21': tfreplknet.RepLKNetXL320In21k
}


Expand Down Expand Up @@ -70,7 +74,11 @@ def convert_weight(weight, name):

weights_tf = []
for w in model.weights:

name = convert_name(w.name)
if name.startswith('head.') and \
('head.cls_22k.weight' in weights_torch or 'head.cls_22k.bias' in weights_torch):
name = name.replace('head.', 'head.cls_22k.')
assert name in weights_torch, f'Can\'t find weight {name} in checkpoint'

weight = weights_torch.pop(name).numpy()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name='tfreplknet',
version='1.0.2',
version='1.1.0',
description='Keras (TensorFlow v2) reimplementation of RepLKNet model.',
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
11 changes: 6 additions & 5 deletions tfreplknet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from tfreplknet.model import RepLKNet, RepLKNet31B, RepLKNet31L
from tfreplknet.model import RepLKNet31B224K1, RepLKNet31B224K21
from tfreplknet.model import RepLKNet31B384K1 # , RepLKNet31B384K21
from tfreplknet.model import RepLKNet31L384K1, RepLKNet31L384K21
from tfreplknet.prep import preprocess_input
from tfreplknet.model import RepLKNet, RepLKNetB, RepLKNetL, RepLKNetXL
from tfreplknet.model import RepLKNetB224In1k, RepLKNetB224In21k
from tfreplknet.model import RepLKNetB384In1k # , RepLKNetB384In21k
from tfreplknet.model import RepLKNetL384In1k, RepLKNetL384In21k
from tfreplknet.model import RepLKNetXL320In1k, RepLKNetXL320In21k
from tfreplknet.prep import preprocess_input_bl, preprocess_input_xl
6 changes: 4 additions & 2 deletions tfreplknet/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

@register_keras_serializable(package='TFRepLKNet')
class Block(layers.Layer):
def __init__(self, kernel_size, small_kernel, dropout, **kwargs):
def __init__(self, kernel_size, small_kernel, ratio, dropout, **kwargs):
super().__init__(**kwargs)
self.input_spec = layers.InputSpec(ndim=4)

self.kernel_size = kernel_size
self.small_kernel = small_kernel
self.ratio = ratio
self.dropout = dropout

@shape_type_conversion
Expand All @@ -27,7 +28,7 @@ def build(self, input_shape):

# noinspection PyAttributeOutsideInit
self.pw1 = models.Sequential([
layers.Conv2D(channels, 1, use_bias=False, name=f'{self.name}/pw1/conv'),
layers.Conv2D(int(channels * self.ratio), 1, use_bias=False, name=f'{self.name}/pw1/conv'),
layers.BatchNormalization(momentum=0.1, epsilon=1.001e-5, name=f'{self.name}/pw1/bn'),
layers.ReLU()
], name='pw1')
Expand Down Expand Up @@ -65,6 +66,7 @@ def get_config(self):
config.update({
'kernel_size': self.kernel_size,
'small_kernel': self.small_kernel,
'ratio': self.ratio,
'dropout': self.dropout
})

Expand Down
72 changes: 61 additions & 11 deletions tfreplknet/large.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from keras import layers, models
import tensorflow as tf
from keras import backend, layers, models
from keras.mixed_precision.autocast_variable import AutoCastVariable
from keras.utils.control_flow_util import smart_cond
from keras.utils.generic_utils import register_keras_serializable
from keras.utils.tf_utils import shape_type_conversion


@register_keras_serializable(package='TFRepLKNet')
class LargeConv(layers.Layer):
def __init__(self, kernel_size, small_kernel, **kwargs):
def __init__(self, kernel_size, small_kernel, fused=True, **kwargs):
super().__init__(**kwargs)
self.input_spec = layers.InputSpec(ndim=4)

self.kernel_size = kernel_size
self.small_kernel = small_kernel
self.fused = fused

@shape_type_conversion
def build(self, input_shape):
Expand All @@ -25,19 +29,64 @@ def build(self, input_shape):
self.kernel_size, padding='same', use_bias=False, name=f'{self.name}/lkb_origin/conv'),
layers.BatchNormalization(momentum=0.1, epsilon=1.001e-5, name=f'{self.name}/lkb_origin/bn')
], name='lkb_origin')
self.big.build(input_shape)

# noinspection PyAttributeOutsideInit
self.small = models.Sequential([
layers.DepthwiseConv2D(
self.small_kernel, padding='same', use_bias=False, name=f'{self.name}/small_conv/conv'),
layers.BatchNormalization(momentum=0.1, epsilon=1.001e-5, name=f'{self.name}/small_conv/bn')
], name='small_conv')
if self.small_kernel is not None:
# noinspection PyAttributeOutsideInit
self.small = models.Sequential([
layers.DepthwiseConv2D(
self.small_kernel, padding='same', use_bias=False, name=f'{self.name}/small_conv/conv'),
layers.BatchNormalization(momentum=0.1, epsilon=1.001e-5, name=f'{self.name}/small_conv/bn')
], name='small_conv')
self.small.build(input_shape)

super().build(input_shape)

def call(self, inputs, *args, **kwargs):
def call(self, inputs, training=None, *args, **kwargs):
if self.small_kernel is None or not self.fused:
return self._branched(inputs)

if not self.trainable:
training = False
elif training is None:
training = backend.learning_phase()

outputs = smart_cond(training, lambda: self._branched(inputs), lambda: self._merged(inputs))

return outputs

def _branched(self, inputs):
outputs = self.big(inputs)
outputs += self.small(inputs)

if self.small_kernel is not None:
outputs += self.small(inputs)

return outputs

def _merged(self, inputs):
def _raw_var(var):
return var if not isinstance(var, AutoCastVariable) else var._variable

def _fuse_conv_bn(conv, bn):
fused_scale = bn.gamma / tf.sqrt(bn.moving_variance + bn.epsilon)
fused_kernel = _raw_var(conv.depthwise_kernel) * fused_scale[None, None, ..., None]

fused_bias = bn.beta - bn.moving_mean * fused_scale

return fused_kernel, fused_bias

big_kernel, big_bias = _fuse_conv_bn(*self.big.layers)
small_kernel, small_bias = _fuse_conv_bn(*self.small.layers)

small_pad = [[(self.kernel_size - self.small_kernel) // 2] * 2] * 2 + [[0, 0]] * 2
kernel = big_kernel + tf.pad(small_kernel, small_pad)
kernel = tf.cast(kernel, self.compute_dtype)

bias = big_bias + small_bias
bias = tf.cast(bias, self.compute_dtype)

outputs = backend.depthwise_conv2d(inputs, kernel, padding='same', data_format=backend.image_data_format())
outputs = backend.bias_add(outputs, bias, data_format=backend.image_data_format())

return outputs

Expand All @@ -49,7 +98,8 @@ def get_config(self):
config = super().get_config()
config.update({
'kernel_size': self.kernel_size,
'small_kernel': self.small_kernel
'small_kernel': self.small_kernel,
'fused': self.fused
})

return config
Loading

0 comments on commit 1772ae2

Please sign in to comment.