Skip to content

Latest commit

 

History

History
104 lines (72 loc) · 3.27 KB

README.md

File metadata and controls

104 lines (72 loc) · 3.27 KB

MCAUNet

This repo is the official implementation of "Survival analysis of patients with liver cirrhosis based on deep learning to quantify body composition"

We propose a Channel Transformer module (CTrans) and use it to replace the skip connections in original U-Net, thus we name it "MCAU-Net".

Requirements

Install from the requirements.txt using:

pip install -r requirements.txt

Usage

1. Data Preparation

Prepare the datasets in the following format for easy use of the code:

├── datasets
    |
    │── Test_Folder
    │   ├── img
    │   └──labelcol
    │── Train_Folder
    │   ├── img
    │   └── labelcol
    │── Val_Folder
    │   ├── img
    │   └── labelcol
    └

2. Training

As mentioned in the paper, we introduce two strategies to optimize MCAUNet.

The first step is to change the settings in Config.py, all the configurations including learning rate, batch size and etc. are in it.

2.1 Jointly Training

We optimize the convolution parameters in U-Net and the CTrans parameters together with a single loss. Run:

python train_model.py

2.2 Pre-training

Our method just replaces the skip connections in U-Net, so the parameters in U-Net can be used as part of pretrained weights.

By first training a classical U-Net using /nets/UNet.py then using the pretrained weights to train the UCTransNet, CTrans module can get better initial features.

This strategy can improve the convergence speed and may improve the final segmentation performance in some cases.

3. Testing

3.1. Get Pre-trained Models

Here, we provide pre-trained weights on GlaS and MoNuSeg, if you do not want to train the models by yourself, you can download them in the following links:

3.2. Test the Model and Visualize the Segmentation Results

First, change the session name in Config.py as the training phase. Then run:

python test_model.py

You can get the Dice and IoU scores and the visualization results.

4. Reproducibility

In our code, we carefully set the random seed and set cudnn as 'deterministic' mode to eliminate the randomness. However, there still exsist some factors which may cause different training results, e.g., the cuda version, GPU types, the number of GPUs and etc. The GPU used in our experiments is NVIDIA A40 (48G) and the cuda version is 11.2.

Especially for multi-GPU cases, the upsampling operation has big problems with randomness. See https://pytorch.org/docs/stable/notes/randomness.html for more details.

When training, we suggest to train the model twice to verify wheather the randomness is eliminated. Because we use the early stopping strategy, the final performance may change significantly due to the randomness.

Reference

Contact

YinHan Zhang ([email protected])