From b409e980dd7af8f0bff8fa28c07ce31ac6833447 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E6=8C=AF?= Date: Wed, 9 Jun 2021 11:28:44 +0800 Subject: [PATCH] add RGPR data augmentation for person reid(An Effective Data Augmentation for person re-identification(https://arxiv.org/abs/2101.08533)) --- configs/DukeMTMC/sbs_R50_rgpr.yml | 14 ++++++++ fastreid/config/defaults.py | 4 +++ fastreid/data/transforms/build.py | 9 +++++ fastreid/data/transforms/trans_gray.py | 50 ++++++++++++++++++++++++++ 4 files changed, 77 insertions(+) create mode 100644 configs/DukeMTMC/sbs_R50_rgpr.yml create mode 100644 fastreid/data/transforms/trans_gray.py diff --git a/configs/DukeMTMC/sbs_R50_rgpr.yml b/configs/DukeMTMC/sbs_R50_rgpr.yml new file mode 100644 index 000000000..e86535d5a --- /dev/null +++ b/configs/DukeMTMC/sbs_R50_rgpr.yml @@ -0,0 +1,14 @@ +_BASE_: ../Base-SBS.yml + +DATASETS: + NAMES: ("DukeMTMC",) + TESTS: ("DukeMTMC",) + +INPUT: + AUTOAUG: + ENABLED: False + RGPR: + ENABLED: True + PROB: 0.4 + +OUTPUT_DIR: logs/dukemtmc/sbs_R50_rgpr diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py index e6bc403dd..d80977cd1 100644 --- a/fastreid/config/defaults.py +++ b/fastreid/config/defaults.py @@ -193,6 +193,10 @@ _C.INPUT.RPT = CN({"ENABLED": False}) _C.INPUT.RPT.PROB = 0.5 +# Random Grayscale Patch Replace +_C.INPUT.RGPR = CN({"ENABLED": False}) +_C.INPUT.RGPR.PROB = 0.4 + # ----------------------------------------------------------------------------- # Dataset # ----------------------------------------------------------------------------- diff --git a/fastreid/data/transforms/build.py b/fastreid/data/transforms/build.py index 5ca06a751..80a283ef6 100644 --- a/fastreid/data/transforms/build.py +++ b/fastreid/data/transforms/build.py @@ -8,6 +8,7 @@ from .transforms import * from .autoaugment import AutoAugment +from .trans_gray import RandomGrayscalePatchReplace def build_transforms(cfg, is_train=True): @@ -59,12 +60,20 @@ def build_transforms(cfg, is_train=True): do_rpt = cfg.INPUT.RPT.ENABLED rpt_prob = cfg.INPUT.RPT.PROB + # Random Grayscale Patch Replace + do_rgpr = cfg.INPUT.RGPR.ENABLED + rgpr_prob = cfg.INPUT.RGPR.PROB + if do_autoaug: res.append(T.RandomApply([AutoAugment()], p=autoaug_prob)) if size_train[0] > 0: res.append(T.Resize(size_train[0] if len(size_train) == 1 else size_train, interpolation=3)) + # after resize and before crop or flip + if do_rgpr: + res.append(RandomGrayscalePatchReplace(rgpr_prob)) + if do_crop: res.append(T.RandomResizedCrop(size=crop_size[0] if len(crop_size) == 1 else crop_size, interpolation=3, diff --git a/fastreid/data/transforms/trans_gray.py b/fastreid/data/transforms/trans_gray.py new file mode 100644 index 000000000..617747b8a --- /dev/null +++ b/fastreid/data/transforms/trans_gray.py @@ -0,0 +1,50 @@ +# encoding: utf-8 + +import math +from PIL import Image +import random +import numpy as np +import random +# This is the code of Random Grayscale Patch Replace + + +class RandomGrayscalePatchReplace(object): + + def __init__(self, probability=0.2, sl=0.02, sh=0.4, r1=0.3): + self.probability = probability + self.sl = sl + self.sh = sh + self.r1 = r1 + + def __call__(self, img): + + new = img.convert("L") # Convert from here to the corresponding grayscale image + np_img = np.array(new, dtype=np.uint8) + img_gray = np.dstack([np_img, np_img, np_img]) + + if random.uniform(0, 1) >= self.probability: + return img + + for attempt in range(100): + area = img.size[0] * img.size[1] + target_area = random.uniform(self.sl, self.sh) * area + aspect_ratio = random.uniform(self.r1, 1 / self.r1) + + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + + if w < img.size[1] and h < img.size[0]: + x1 = random.randint(0, img.size[0] - h) + y1 = random.randint(0, img.size[1] - w) + img = np.asarray(img).astype('float') + + img[y1:y1 + h, x1:x1 + w, 0] = img_gray[y1:y1 + h, x1:x1 + w, 0] + img[y1:y1 + h, x1:x1 + w, 1] = img_gray[y1:y1 + h, x1:x1 + w, 1] + img[y1:y1 + h, x1:x1 + w, 2] = img_gray[y1:y1 + h, x1:x1 + w, 2] + + img = Image.fromarray(img.astype('uint8')) + + return img + + return img +