Skip to content

Commit

Permalink
Merge branch 'release/1.0.2'
Browse files Browse the repository at this point in the history
  • Loading branch information
shkarupa-alex committed Mar 18, 2022
2 parents a7a0ff9 + af75a8a commit 5f2bf80
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
20 changes: 16 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ Based on [Official Pytorch implementation](https://github.com/DingXiaoH/RepLKNet

Supports variable-shape inference.

## Installation

```bash
pip install tfreplknet
```

## Examples

Default usage (without preprocessing):
Expand Down Expand Up @@ -42,11 +48,17 @@ with [ImageNet-v2 test set](https://www.tensorflow.org/datasets/catalog/imagenet
```python
import tensorflow as tf
import tensorflow_datasets as tfds
from tfreplknet import RepLKNet31B224K1, preprocess_input
from tfreplknet import RepLKNet31B224K1, RepLKNet31B384K1, preprocess_input

def _prepare(example):
image = tf.image.resize(example['image'], (256, 256), method=tf.image.ResizeMethod.BICUBIC, antialias=False)
# For RepLKNet31B224K1
image = tf.image.resize(example['image'], (256, 256), method=tf.image.ResizeMethod.BICUBIC)
image = tf.image.central_crop(image, 0.875)

# For RepLKNet31B384K1
# image = tf.image.resize(example['image'], (438, 438), method=tf.image.ResizeMethod.BICUBIC)
# image = tf.image.central_crop(image, 0.877)

image = preprocess_input(image)

return image, example['label']
Expand All @@ -64,8 +76,8 @@ print(history)

| name | original acc@1 | ported acc@1 | original acc@5 | ported acc@5 |
| :---: | :---: | :---: | :---: | :---: |
| RepLKNet31B 224 1K | ? | ? | ? | ? |
| RepLKNet31B 384 1K | ? | ? | ? | ? |
| RepLKNet31B 224 1K | 75.29 | 75.13 | 92.60 | 92.88 |
| RepLKNet31B 384 1K | ? | 76.46 | ? | 93.37 |

## Citation

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name='tfreplknet',
version='1.0.1',
version='1.0.2',
description='Keras (TensorFlow v2) reimplementation of RepLKNet model.',
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
12 changes: 6 additions & 6 deletions tfreplknet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,20 @@ def RepLKNet(filters, kernel_sizes=(31, 29, 27, 13), small_kernel=5, depths=(2,
x = SamePad(3)(x)
x = layers.Conv2D(filters[0], 3, strides=2, use_bias=False, name='stem/0/conv')(x)
x = layers.BatchNormalization(momentum=0.1, epsilon=1.001e-5, name='stem/0/bn')(x)
x = layers.ReLU()(x)
x = layers.ReLU(name='stem/0/relu')(x)

x = layers.DepthwiseConv2D(3, padding='same', use_bias=False, name='stem/1/conv')(x)
x = layers.BatchNormalization(momentum=0.1, epsilon=1.001e-5, name='stem/1/bn')(x)
x = layers.ReLU()(x)
x = layers.ReLU(name='stem/1/relu')(x)

x = layers.Conv2D(filters[0], 1, use_bias=False, name='stem/2/conv')(x)
x = layers.BatchNormalization(momentum=0.1, epsilon=1.001e-5, name='stem/2/bn')(x)
x = layers.ReLU()(x)
x = layers.ReLU(name='stem/2/relu')(x)

x = SamePad(3)(x)
x = layers.DepthwiseConv2D(3, strides=2, use_bias=False, name='stem/3/conv')(x)
x = layers.BatchNormalization(momentum=0.1, epsilon=1.001e-5, name='stem/3/bn')(x)
x = layers.ReLU()(x)
x = layers.ReLU(name='stem/3/relu')(x)

path_drops = np.linspace(0., path_drop, sum(depths))

Expand All @@ -118,12 +118,12 @@ def RepLKNet(filters, kernel_sizes=(31, 29, 27, 13), small_kernel=5, depths=(2,
if not_last:
x = layers.Conv2D(filters[i + 1], 1, use_bias=False, name=f'transitions/{i}/0/conv')(x)
x = layers.BatchNormalization(momentum=0.1, epsilon=1.001e-5, name=f'transitions/{i}/0/bn')(x)
x = layers.ReLU()(x)
x = layers.ReLU(name=f'transitions/{i}/0/relu')(x)

x = SamePad(3)(x)
x = layers.DepthwiseConv2D(3, strides=2, use_bias=False, name=f'transitions/{i}/1/conv')(x)
x = layers.BatchNormalization(momentum=0.1, epsilon=1.001e-5, name=f'transitions/{i}/1/bn')(x)
x = layers.ReLU()(x)
x = layers.ReLU(name=f'transitions/{i}/1/relu')(x)

x = layers.BatchNormalization(momentum=0.1, epsilon=1.001e-5, name='norm')(x)

Expand Down

0 comments on commit 5f2bf80

Please sign in to comment.