This is the Pytorch implementation of our following paper:
Collect-and-Distribute Transformer for 3D Point Cloud Analysis
Haibo Qiu, Baosheng Yu, and Dacheng Tao
Abstract
Although remarkable advancements have been made recently in point cloud analysis through the exploration of transformer architecture, it remains challenging to effectively learn local and global structures within point clouds. In this paper, we propose a new transformer architecture equipped with a collect-and-distribute mechanism to communicate short- and long-range contexts of point clouds, which we refer to as CDFormer. Specifically, we first utilize self-attention to capture short-range interactions within each local patch, and the updated local features are then collected into a set of proxy reference points from which we can extract long-range contexts. Afterward, we distribute the learned long-range contexts back to local points via cross-attention. To address the position clues for short- and long-range contexts, we also introduce context-aware position encoding to facilitate position-aware communications between points. We perform experiments on four popular point cloud datasets, namely ModelNet40, ScanObjectNN, S3DIS, and ShapeNetPart, for classification and segmentation. Results show the effectiveness of the proposed CDFormer, delivering several new state-of-the-art performances on point cloud classification and segmentation tasks. The code is available in the supplementary material and will be made publicly available.
Table of Contents
- Create a conda env with
conda env create -f environment.yml
- Compile
pointops
: Please make sure thegcc
andnvcc
can work normally. Then, compile and install pointops2 by:(Note that we made the successful compilation undercd lib/pointops2 python setup.py install
gcc=7.5.0, cuda=11.3
andgcc=7.4.0, cuda=10.2
) - Compile
emd
(optional for classification):cd lib/emd python setup.py install
Please refer to Pointnet_Pointnet2_pytorch for preprocessing, and put the processed data to dataset/s3dis/stanford_indoor3d
.
We follow PointNext to uniformly sample 2048 points. You can also use the preprocessed data provided below:
cd dataset && mkdir dataset/shapenetpart && cd shapenetpart
gdown https://drive.google.com/uc?id=1W3SEE-dY1sxvlECcOwWSDYemwHEUbJIS
tar -xvf shapenetcore_partanno_segmentation_benchmark_v0_normal.tar
Following PointNext, ModelNet40 dataset will be downloaded automatically.
Download from the official website, or use the processed dataset from PointNext.
cd dataset && mkdir scanobjectnn
gdown https://drive.google.com/uc?id=1iM3mhMJ_N0x5pytcP831l3ZFwbLmbwzi
tar -xvf ScanObjectNN.tar --directory=scanobjectnn
Finally, the entire dataset folder structure will be like:
dataset
|--- s3dis
|--- s3dis_names.txt
|--- stanford_indoor3d
|--- Area_1_conferenceRoom_1.npy
|--- Area_1_conferenceRoom_2.npy
|--- ...
|--- shapenetpart
|--- shapenetcore_partanno_segmentation_benchmark_v0_normal
|--- train_test_split
|--- shuffled_train_file_list.json
|--- ...
|--- 02691156
|--- 1a04e3eab45ca15dd86060f189eb133.txt
|--- ...
|--- 02773838
|--- synsetoffset2category.txt
|--- processed
|--- test_2048_fps.pkl
|--- modelnet40ply2048
|--- modelnet40_ply_hdf5_2048
|--- ply_data_test0.h5
|--- ...
|--- ply_data_train4.h5
|--- scanobjectnn
|--- h5_files
|--- main_split
|--- training_objectdataset_augmentedrot_scale75.h5
|--- test_objectdataset_augmentedrot_scale75.h5
The training commands for S3DIS, ShapeNetPart, ModelNet40, and ScanObjectNN are summarized as follows:
# S3DIS
./scripts/train_s3dis.sh
# ShapeNetPart
./scripts/train_shapepart.sh
# ModelNet40
./scripts/train_modelnet40.sh 0
# ScanObjectNN
./scripts/train_scanobjnn.sh 0
Pretrained models for S3DIS, ShapeNetPart, ModelNet40, and ScanObjectNN are available in google drive. Downloading models into checkpoints/
and running testing scripts under scripts
. Taking ScanObjectNN as an example, we can:
- Download scanobjectnn_cdformer.pth into
checkpoints/
. - Evaluate on ScanObjectNN by simply running:
# using GPU 0 ./scripts/test_scanobjnn.sh 0
Other trained models can be similarly evaluated.
This repo is built based on Stratified Transformer, Point Transformer and PointNeXt. Thanks the contributors of these repos!
If you find our paper or code helpful for your research, please consider citing us with:
@article{qiu2023collect,
title={Collect-and-Distribute Transformer for 3D Point Cloud Analysis},
author={Qiu, Haibo and Yu, Baosheng and Tao, Dacheng},
journal={arXiv preprint arXiv:2306.01257},
year={2023}
}