From 0b78ebabfaec0edabb4e6ca629fc51130f1505d5 Mon Sep 17 00:00:00 2001 From: Florian Bruggisser Date: Tue, 7 Feb 2023 03:22:37 +0100 Subject: [PATCH] Implemented yolo dataset support (#487) * implemented yolo data loader * added yolo example configuration * fixed super call for yolo data loader * converted normalized values to pixels for yolo dataset * run pre-commit and fixed coordinate bug * fixed yolo categories indexed by zero * added readme hint for yolo format --- README.md | 2 + config/nanodet-plus-m_416-yolo.yml | 134 ++++++++++++++++++++++ nanodet/data/dataset/__init__.py | 5 + nanodet/data/dataset/yolo.py | 173 +++++++++++++++++++++++++++++ 4 files changed, 314 insertions(+) create mode 100644 config/nanodet-plus-m_416-yolo.yml create mode 100644 nanodet/data/dataset/yolo.py diff --git a/README.md b/README.md index 5ae2b4e08..ba22331d8 100644 --- a/README.md +++ b/README.md @@ -220,6 +220,8 @@ NanoDet-RepVGG | RepVGG-A0 | 416*416 | 27.8 | 11.3G | 6.75M | If your dataset annotations are pascal voc xml format, refer to [config/nanodet_custom_xml_dataset.yml](config/nanodet_custom_xml_dataset.yml) + Otherwise, if your dataset annotations are YOLO format ([Darknet TXT](https://github.com/AlexeyAB/Yolo_mark/issues/60#issuecomment-401854885)), refer to [config/nanodet-plus-m_416-yolo.yml](config/nanodet-plus-m_416-yolo.yml) + Or convert your dataset annotations to MS COCO format[(COCO annotation format details)](https://cocodataset.org/#format-data). 2. **Prepare config file** diff --git a/config/nanodet-plus-m_416-yolo.yml b/config/nanodet-plus-m_416-yolo.yml new file mode 100644 index 000000000..3b74246fd --- /dev/null +++ b/config/nanodet-plus-m_416-yolo.yml @@ -0,0 +1,134 @@ +# nanodet-plus-m_416 +# COCO mAP(0.5:0.95) = 0.304 +# AP_50 = 0.459 +# AP_75 = 0.317 +# AP_small = 0.106 +# AP_m = 0.322 +# AP_l = 0.477 +save_dir: workspace/nanodet-plus-m_416 +model: + weight_averager: + name: ExpMovingAverager + decay: 0.9998 + arch: + name: NanoDetPlus + detach_epoch: 10 + backbone: + name: ShuffleNetV2 + model_size: 1.0x + out_stages: [2,3,4] + activation: LeakyReLU + fpn: + name: GhostPAN + in_channels: [116, 232, 464] + out_channels: 96 + kernel_size: 5 + num_extra_level: 1 + use_depthwise: True + activation: LeakyReLU + head: + name: NanoDetPlusHead + num_classes: 80 + input_channel: 96 + feat_channels: 96 + stacked_convs: 2 + kernel_size: 5 + strides: [8, 16, 32, 64] + activation: LeakyReLU + reg_max: 7 + norm_cfg: + type: BN + loss: + loss_qfl: + name: QualityFocalLoss + use_sigmoid: True + beta: 2.0 + loss_weight: 1.0 + loss_dfl: + name: DistributionFocalLoss + loss_weight: 0.25 + loss_bbox: + name: GIoULoss + loss_weight: 2.0 + # Auxiliary head, only use in training time. + aux_head: + name: SimpleConvHead + num_classes: 80 + input_channel: 192 + feat_channels: 192 + stacked_convs: 4 + strides: [8, 16, 32, 64] + activation: LeakyReLU + reg_max: 7 + +class_names: &class_names ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant', + 'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat', + 'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket', + 'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', + 'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'] + +data: + train: + name: YoloDataset + img_path: coco/train2017 + ann_path: coco/train2017 + class_names: *class_names + input_size: [416,416] #[w,h] + keep_ratio: False + pipeline: + perspective: 0.0 + scale: [0.6, 1.4] + stretch: [[0.8, 1.2], [0.8, 1.2]] + rotation: 0 + shear: 0 + translate: 0.2 + flip: 0.5 + brightness: 0.2 + contrast: [0.6, 1.4] + saturation: [0.5, 1.2] + normalize: [[103.53, 116.28, 123.675], [57.375, 57.12, 58.395]] + val: + name: YoloDataset + img_path: coco/val2017 + ann_path: coco/val2017 + class_names: *class_names + input_size: [416,416] #[w,h] + keep_ratio: False + pipeline: + normalize: [[103.53, 116.28, 123.675], [57.375, 57.12, 58.395]] +device: + gpu_ids: [0] + workers_per_gpu: 10 + batchsize_per_gpu: 96 +schedule: +# resume: +# load_model: + optimizer: + name: AdamW + lr: 0.001 + weight_decay: 0.05 + warmup: + name: linear + steps: 500 + ratio: 0.0001 + total_epochs: 300 + lr_schedule: + name: CosineAnnealingLR + T_max: 300 + eta_min: 0.00005 + val_intervals: 10 +grad_clip: 35 +evaluator: + name: CocoDetectionEvaluator + save_key: mAP +log: + interval: 50 diff --git a/nanodet/data/dataset/__init__.py b/nanodet/data/dataset/__init__.py index 92c405b28..f7dbdef5d 100644 --- a/nanodet/data/dataset/__init__.py +++ b/nanodet/data/dataset/__init__.py @@ -17,6 +17,7 @@ from .coco import CocoDataset from .xml_dataset import XMLDataset +from .yolo import YoloDataset def build_dataset(cfg, mode): @@ -27,6 +28,8 @@ def build_dataset(cfg, mode): "Dataset name coco has been deprecated. Please use CocoDataset instead." ) return CocoDataset(mode=mode, **dataset_cfg) + elif name == "yolo": + return YoloDataset(mode=mode, **dataset_cfg) elif name == "xml_dataset": warnings.warn( "Dataset name xml_dataset has been deprecated. " @@ -35,6 +38,8 @@ def build_dataset(cfg, mode): return XMLDataset(mode=mode, **dataset_cfg) elif name == "CocoDataset": return CocoDataset(mode=mode, **dataset_cfg) + elif name == "YoloDataset": + return YoloDataset(mode=mode, **dataset_cfg) elif name == "XMLDataset": return XMLDataset(mode=mode, **dataset_cfg) else: diff --git a/nanodet/data/dataset/yolo.py b/nanodet/data/dataset/yolo.py new file mode 100644 index 000000000..8a7baef1b --- /dev/null +++ b/nanodet/data/dataset/yolo.py @@ -0,0 +1,173 @@ +# Copyright 2023 cansik. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import time +from collections import defaultdict +from typing import Optional, Sequence + +import cv2 +import numpy as np +from pycocotools.coco import COCO + +from .coco import CocoDataset +from .xml_dataset import get_file_list + + +class CocoYolo(COCO): + def __init__(self, annotation): + """ + Constructor of Microsoft COCO helper class for + reading and visualizing annotations. + :param annotation: annotation dict + :return: + """ + # load dataset + super().__init__() + self.dataset, self.anns, self.cats, self.imgs = dict(), dict(), dict(), dict() + self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list) + dataset = annotation + assert type(dataset) == dict, "annotation file format {} not supported".format( + type(dataset) + ) + self.dataset = dataset + self.createIndex() + + +class YoloDataset(CocoDataset): + def __init__(self, class_names, **kwargs): + self.class_names = class_names + super(YoloDataset, self).__init__(**kwargs) + + @staticmethod + def _find_image( + image_prefix: str, + image_types: Sequence[str] = (".png", ".jpg", ".jpeg", ".bmp", ".tiff"), + ) -> Optional[str]: + for image_type in image_types: + path = f"{image_prefix}{image_type}" + if os.path.exists(path): + return path + return None + + def yolo_to_coco(self, ann_path): + """ + convert xml annotations to coco_api + :param ann_path: + :return: + """ + logging.info("loading annotations into memory...") + tic = time.time() + ann_file_names = get_file_list(ann_path, type=".txt") + logging.info("Found {} annotation files.".format(len(ann_file_names))) + image_info = [] + categories = [] + annotations = [] + for idx, supercat in enumerate(self.class_names): + categories.append( + {"supercategory": supercat, "id": idx + 1, "name": supercat} + ) + ann_id = 1 + + for idx, txt_name in enumerate(ann_file_names): + ann_file = os.path.join(ann_path, txt_name) + image_file = self._find_image(os.path.splitext(ann_file)[0]) + + if image_file is None: + logging.warning(f"Could not find image for {ann_file}") + continue + + with open(ann_file, "r") as f: + lines = f.readlines() + + image = cv2.imread(image_file) + height, width = image.shape[:2] + + file_name = os.path.basename(image_file) + info = { + "file_name": file_name, + "height": height, + "width": width, + "id": idx + 1, + } + image_info.append(info) + for line in lines: + data = [float(t) for t in line.split(" ")] + cat_id = int(data[0]) + locations = np.array(data[1:]).reshape((len(data) // 2, 2)) + bbox = locations[0:2] + + bbox[0] -= bbox[1] * 0.5 + + bbox = np.round(bbox * np.array([width, height])).astype(int) + x, y = bbox[0][0], bbox[0][1] + w, h = bbox[1][0], bbox[1][1] + + if cat_id >= len(self.class_names): + logging.warning( + f"Category {cat_id} is not defined in config ({txt_name})" + ) + continue + + if w < 0 or h < 0: + logging.warning( + "WARNING! Find error data in file {}! Box w and " + "h should > 0. Pass this box annotation.".format(txt_name) + ) + continue + + coco_box = [max(x, 0), max(y, 0), min(w, width), min(h, height)] + ann = { + "image_id": idx + 1, + "bbox": coco_box, + "category_id": cat_id + 1, + "iscrowd": 0, + "id": ann_id, + "area": coco_box[2] * coco_box[3], + } + annotations.append(ann) + ann_id += 1 + + coco_dict = { + "images": image_info, + "categories": categories, + "annotations": annotations, + } + logging.info( + "Load {} txt files and {} boxes".format(len(image_info), len(annotations)) + ) + logging.info("Done (t={:0.2f}s)".format(time.time() - tic)) + return coco_dict + + def get_data_info(self, ann_path): + """ + Load basic information of dataset such as image path, label and so on. + :param ann_path: coco json file path + :return: image info: + [{'file_name': '000000000139.jpg', + 'height': 426, + 'width': 640, + 'id': 139}, + ... + ] + """ + coco_dict = self.yolo_to_coco(ann_path) + self.coco_api = CocoYolo(coco_dict) + self.cat_ids = sorted(self.coco_api.getCatIds()) + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.cats = self.coco_api.loadCats(self.cat_ids) + self.img_ids = sorted(self.coco_api.imgs.keys()) + img_info = self.coco_api.loadImgs(self.img_ids) + return img_info