Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zero mAP during validation when using ViT backbone in Faster R-CNN #12017

Open
chan1031 opened this issue Oct 28, 2024 · 4 comments
Open

Zero mAP during validation when using ViT backbone in Faster R-CNN #12017

chan1031 opened this issue Oct 28, 2024 · 4 comments
Assignees

Comments

@chan1031
Copy link

chan1031 commented Oct 28, 2024

Describe the bug
When using ViT as a backbone in Faster R-CNN, the bbox mAP is always 0 during validation, even though the training loss is decreasing normally. The training shows normal loss values (loss_rpn_cls, loss_rpn_bbox, loss_cls, loss_bbox), but validation results show all zeros for mAP metrics.

  1. Here is my config file
custom_imports = dict(
    imports=['mmpretrain.models'],
    allow_failed_imports=False)

_base_ = [
    '/home/skku//mmdetection/configs/_base_/datasets/coco_detection.py',
    '/home/skku//mmdetection/configs/_base_/schedules/schedule_1x.py', '/home/skku//mmdetection/configs/_base_/default_runtime.py'
]

data_root = "/home/skku/mm_test/data/cargox/"
metainfo = {
    "classes": ("knife1-1", "knife1-2", "knife2-1", "knife2-2", "knife3-1", "knife3-2", "knife4-1", "knife4-2"),
    "palette": [
        (225, 0, 0), (0, 0, 255), (0, 255, 0), (255, 255, 0),
        (0, 255, 255), (255, 0, 255), (255, 165, 0), (0, 0, 128),
    ],
    "class_id_offset": 1
}

model = dict(
    type='FasterRCNN',
    data_preprocessor=dict(
        type='DetDataPreprocessor',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        bgr_to_rgb=True,
        pad_size_divisor=32),
    backbone=dict(
        type='mmpretrain.VisionTransformer',
        arch='base',         # 'base' 구조 사용
        img_size=384,        # 입력 이미지 크기
        patch_size=16,       # 패치 크기
        out_indices=(2, 5, 8, 11),  # 출력할 레이어 인덱스
        drop_rate=0.0,       
        drop_path_rate=0.1, 
        norm_cfg=dict(type='LN', eps=1e-6), 
        out_type='featmap',  
        with_cls_token=True, 
        final_norm=True,   
    ),
    neck=dict(
        type='FPN',
        in_channels=[768, 768, 768, 768],
        out_channels=256,
        num_outs=5
    ),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]
        ),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]
        ),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0
        ),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)
    ),
    roi_head=dict(
        type='StandardRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]
        ),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=8,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]
            ),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0
            ),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0)
        )
    ),
    train_cfg=dict(
        rpn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.7,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=256,
                pos_fraction=0.5,
                neg_pos_ub=-1,
                add_gt_as_proposals=False),
            allowed_border=-1,
            pos_weight=-1,
            debug=False),
        rpn_proposal=dict(
            nms_pre=2000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=0.5,
                match_low_quality=False,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=512,
                pos_fraction=0.25,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            pos_weight=-1,
            debug=False)),
    test_cfg=dict(
        rpn=dict(
            nms_pre=1000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.05,
            nms=dict(type='nms', iou_threshold=0.5),
            max_per_img=100))
)

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PackDetInputs')
]

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(1333, 800), keep_ratio=True),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(
        type='PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')
    )
]

train_dataloader = dict(
    batch_size=2,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    batch_sampler=dict(type='AspectRatioBatchSampler'),
    dataset=dict(
        type='CocoDataset',
        data_root=data_root,
        metainfo=metainfo,
        ann_file='annotations/train.json',
        data_prefix=dict(img='train/'),
        filter_cfg=dict(filter_empty_gt=True, min_size=32),
        pipeline=train_pipeline
    )
)

val_dataloader = dict(
    batch_size=1,
    num_workers=2,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type='CocoDataset',
        data_root=data_root,
        metainfo=metainfo,
        ann_file='annotations/val.json',
        data_prefix=dict(img='val/'),
        test_mode=True,
        pipeline=test_pipeline
    )
)

test_dataloader = val_dataloader

val_evaluator = dict(
    type='CocoMetric',
    ann_file=data_root + 'annotations/val.json',
    metric='bbox',
    format_only=False)
test_evaluator = val_evaluator

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [
    dict(
        type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
    dict(
        type='MultiStepLR',
        begin=0,
        end=12,
        by_epoch=True,
        milestones=[8, 11],
        gamma=0.1)
]

optim_wrapper = dict(
    _delete_=True,  
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=0.00001, weight_decay=0.05),
    clip_grad=dict(max_norm=1.0, norm_type=2)
)

default_hooks = dict(
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=50),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=1),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    visualization=dict(type='DetVisualizationHook'))

env_cfg = dict(
    cudnn_benchmark=False,
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    dist_cfg=dict(backend='nccl'),
)

log_level = 'INFO'
load_from = None
resume = False
  1. Did you make any modifications on the code or config? Did you understand what you have modified?

Modified the Faster R-CNN config to use ViT as backbone instead of ResNet. The main modifications are:
Changed backbone to ViT
Adjusted FPN in_channels for ViT output
Modified optimizer settings(lr)

  1. What dataset did you use?
    Custom COCO format dataset with 8 classes (CargoX dataset)

Environment
2.Please run python mmdet/utils/collect_env.py to collect necessary environment information and paste it here.
sys.platform: linux
Python: 3.8.19 (default, Mar 20 2024, 19:58:24) [GCC 11.2.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0,1: NVIDIA GeForce RTX 3090
CUDA_HOME: /usr/local/cuda-12.1
NVCC: Cuda compilation tools, release 12.1, V12.1.66
GCC: gcc (Ubuntu 13.2.0-23ubuntu4) 13.2.0
PyTorch: 2.1.1
PyTorch compiling details: PyTorch built with:

  • GCC 9.3
  • C++ Version: 201703
  • Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
  • OpenMP 201511 (a.k.a. OpenMP 4.5)
  • LAPACK is enabled (usually provided by MKL)
  • NNPACK is enabled
  • CPU capability usage: AVX2
  • CUDA Runtime 12.1
  • NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
  • CuDNN 8.9.2
  • Magma 2.6.1
  • Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-aligned-allocation-unavailable -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,

TorchVision: 0.16.1
OpenCV: 4.9.0
MMEngine: 0.10.3
MMDetection: 3.3.0+cfd5d3a

Error traceback
If applicable, paste the error trackback here.

2024/10/28 13:04:15 - mmengine - INFO - Epoch(val)  [1][5800/6000]    eta: 0:00:07  time: 0.0351  data_time: 0.0007  memory: 1789  
2024/10/28 13:04:17 - mmengine - INFO - Epoch(val)  [1][5850/6000]    eta: 0:00:05  time: 0.0351  data_time: 0.0007  memory: 1789  
2024/10/28 13:04:19 - mmengine - INFO - Epoch(val)  [1][5900/6000]    eta: 0:00:03  time: 0.0350  data_time: 0.0007  memory: 1789  
2024/10/28 13:04:21 - mmengine - INFO - Epoch(val)  [1][5950/6000]    eta: 0:00:01  time: 0.0350  data_time: 0.0007  memory: 1789  
2024/10/28 13:04:22 - mmengine - INFO - Epoch(val)  [1][6000/6000]    eta: 0:00:00  time: 0.0342  data_time: 0.0007  memory: 1789  
2024/10/28 13:04:23 - mmengine - INFO - Evaluating bbox...
2024/10/28 13:04:28 - mmengine - INFO - bbox_mAP_copypaste: 0.000 0.000 0.000 0.000 0.001 0.000
2024/10/28 13:04:28 - mmengine - INFO - Epoch(val) [1][6000/6000]    coco/bbox_mAP: 0.0000  coco/bbox_mAP_50: 0.0000  coco/bbox_mAP_75: 0.0000  coco/bbox_mAP_s: 0.0000  coco/bbox_mAP_m: 0.0010  coco/bbox_mAP_l: 0.0000  data_time: 0.0008  time: 0.0354
2024/10/28 13:04:40 - mmengine - INFO - Epoch(train)  [2][  50/8750]  lr: 1.0000e-05  eta: 6:09:23  time: 0.2298  data_time: 0.0019  memory: 4745  grad_norm: 39.0173  loss: 0.2024  loss_rpn_cls: 0.0789  loss_rpn_bbox: 0.0484  loss_cls: 0.0580  acc: 99.7070  loss_bbox: 0.0171
2024/10/28 13:04:51 - mmengine - INFO - Epoch(train)  [2][ 100/8750]  lr: 1.0000e-05  eta: 6:09:11  time: 0.2297  data_time: 0.0017  memory: 4746  grad_norm: 34.3565  loss: 0.2102  loss_rpn_cls: 0.0703  loss_rpn_bbox: 0.0345  loss_cls: 0.0799  acc: 99.5117  loss_bbox: 0.0255
2024/10/28 13:05:03 - mmengine - INFO - Epoch(train)  [2][ 150/8750]  lr: 1.0000e-05  eta: 6:09:00  time: 0.2306  data_time: 0.0018  memory: 4746  grad_norm: 32.8533  loss: 0.2172  loss_rpn_cls: 0.0786  loss_rpn_bbox: 0.0419  loss_cls: 0.0728  acc: 99.7070  loss_bbox: 0.0239
2024/10/28 13:05:14 - mmengine - INFO - Epoch(train)  [2][ 200/8750]  lr: 1.0000e-05  eta: 6:08:48  time: 0.2308  data_time: 0.0018  memory: 4746  grad_norm: 27.1379  loss: 0.1880  loss_rpn_cls: 0.0704  loss_rpn_bbox: 0.0343  loss_cls: 0.0625  acc: 99.7070  loss_bbox: 0.0208

After 1 step of epoch bbox_mAP is 0 in validation.

2024/10/28 13:04:28 - mmengine - INFO - Epoch(val) [1][6000/6000] coco/bbox_mAP: 0.0000 coco/bbox_mAP_50: 0.0000 coco/bbox_mAP_75: 0.0000 coco/bbox_mAP_s: 0.0000 coco/bbox_mAP_m: 0.0010 coco/bbox_mAP_l: 0.0000 data_time: 0.0008 time: 0.0354

@tojimahammatov
Copy link

tojimahammatov commented Oct 29, 2024

Hi @chan1031 , I guess you need to change the lines in your test_pipeline, precisely Resize and LoadAnnotations.
Call LoadAnnotations before calling Resize.

Annotations should also be resized after loading, so you need to first load annotations and then resize them as well.
This is why you might be getting 0 mAP after evaluation only. Let us know if it helps.

@chan1031
Copy link
Author

chan1031 commented Oct 29, 2024

load
@tojimahammatov
Thank you for reply
I modified the code as you suggested and retrained it, but the mAP value is still coming out as 0.
image

@tojimahammatov
Copy link

How about if you remove LoadAnnotations completely while keeping only Resize pipeline?

In fact, you don't need LoadAnnotations in validation and test since annotation file (ann_file) contains everything.

@twisti14
Copy link

twisti14 commented Nov 7, 2024

I have the same question when I use vitdet-mask,my mAP is 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants