Skip to content

Commit

Permalink
Merge pull request #5 from rajpurkarlab/ars
Browse files Browse the repository at this point in the history
Ars
  • Loading branch information
asaporta authored Jul 10, 2022
2 parents 1fb3f7f + d9c2fad commit 14f83ae
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 71 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ If you downloaded the CheXlocalize dataset, then these pickle files are in `/che

If using your own saliency maps, please be sure to save them as pickle files using the above formatting.

`<threshold_path>` is an optional csv file path that you can pass in to use your own thresholds to binarize the heatmaps. As an example, we provided `./tuning_results.csv` which saves the threshold for each pathology that maximize mIoU on the validation set. When passing in your own csv file, make sure to follow the same formatting as our example csv. When no threshold path is passed (default), we will apply Otsu's method (an automatic global thresholding algorithm provided by the cv2 package).
`<threshold_path>` is an optional csv file path that you can pass in to use your own thresholds to binarize the heatmaps. As an example, we provide [`./sample/tuning_results.csv`](https://github.com/rajpurkarlab/cheXlocalize/blob/master/sample/tuning_results.csv), which saves the threshold for each pathology that maximizes mIoU on the validation set. When passing in your own csv file, make sure to follow the same formatting as this example csv. By defaul, no threshold path is passed in, in which case we will apply Otsu's method (an automatic global thresholding algorithm provided by the cv2 package).

`<output_path>` is the json file path used for saving the encoded segmentation masks. The json file is formatted such that it can be used as input to `eval.py` (see [_Evaluate localization performance_](#eval) for formatting details).

Expand All @@ -126,17 +126,19 @@ Running this script on the validation set heatmaps from the CheXlocalize dataset

<a name="threshold"></a>
### Fine tune segmentation thresholds
To find the thresholds that maximize mIoU for each pathology on the validation set, run
To find the thresholds that maximize mIoU for each pathology on the validation set, run:

```
(chexlocalize) > python tune_heatmap_threshold.py --map_dir <map_dir> --save_dir <save_dir>
(chexlocalize) > python tune_heatmap_threshold.py --map_dir <map_dir> --gt_path <gt_path> --save_dir <save_dir>
```

`<map_dir>` is the directory with pickle files containing the heatmaps.
`<map_dir>` is the directory with pickle files containing the heatmaps.

`<gt_path>` is the json file where ground-truth segmentations are saved (encoded).

`<save_dir>` is the directory to save the csv file that stores the tuned thresholds. Default is current directory.

This script will replicate 'tuning_results.csv' when you use `/cheXlocalize_dataset/gradcam_maps_val/` as `<map_dir>`. Running this script should take about 30 minutes.
This script will replicate './sample/tuning_results.csv' when you use the CheXlocalize validation set DenseNet121 + Grad-CAM heatmaps in `/cheXlocalize_dataset/gradcam_maps_val/` as `<map_dir>` and the validation set ground-truth pixel-level segmentations in `/cheXlocalize_dataset/gt_segmentations_val.json`. Running this script should take about one hour.

<a name="ann_to_segm"></a>
## Generate segmentations from human annotations
Expand Down
6 changes: 3 additions & 3 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def evaluate(gt_path, pred_path, save_dir, metric, true_pos_only):
-- `{miou/hitrate}_summary_results.csv`: mIoU or hit rate 95% bootstrap confidence intervals for each pathology.
"""
# create save_dir if it does not already exist
Path(save_dir).mkdir(exist_ok=True,parents=True)
Path(save_dir).mkdir(exist_ok=True, parents=True)

if metric == 'miou':
ious, cxr_ids = get_ious(gt_path, pred_path, true_pos_only)
Expand Down Expand Up @@ -244,11 +244,11 @@ def evaluate(gt_path, pred_path, save_dir, metric, true_pos_only):
help='json path where predicted segmentations are saved \
(if metric = miou) or directory with pickle files \
containing heat maps (if metric = hitrate)')
parser.add_argument('--true_pos_only', default="True",
parser.add_argument('--true_pos_only', default='True',
help='if true, run evaluation only on the true positive \
slice of the dataset (CXRs that contain predicted and \
ground-truth segmentations)')
parser.add_argument('--save_dir', default=".",
parser.add_argument('--save_dir', default='.',
help='where to save evaluation results')
parser.add_argument('--seed', type=int, default=0,
help='random seed to fix')
Expand Down
52 changes: 23 additions & 29 deletions heatmap_to_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""
Converts saliency heat maps to binary segmentations and encodes segmentations
Converts saliency heatmaps to binary segmentations and encodes segmentations
using RLE formats using the pycocotools Mask API. The final output is stored in
a json file.
The default thresholding used in this code is Otsu's method (an automatic global thresholding algorithm provided by cv2).
Users can also pass in their self-defined thresholds to binarize the heatmaps through --threshold_path.
Make sure the input is a csv file with the same format as the tuning_results.csv file we provided.
The default thresholding used in this code is Otsu's method (an automatic global
thresholding algorithm provided by cv2). Users can also pass in their own
self-defined thresholds to binarize the heatmaps through --threshold_path. If
doing this, make sure the input is a csv file with the same format as the
provided file sample/tuning_results.csv.
"""
from argparse import ArgumentParser
import cv2
Expand All @@ -25,14 +27,14 @@

def cam_to_segmentation(cam_mask, threshold=np.nan):
"""
Threshold a saliency heat map to binary segmentation mask.
Threshold a saliency heatmap to binary segmentation mask.
Args:
cam_mask (torch.Tensor): heat map in the original image size (H x W).
Will squeeze the tensor if there are more than two dimensions.
threshold (np.float64): threshold to use
Returns:
segmentation_output (np.array): binary segmentation output
segmentation (np.ndarray): binary segmentation output
"""
if (len(cam_mask.size()) > 2):
cam_mask = cam_mask.squeeze()
Expand All @@ -44,14 +46,11 @@ def cam_to_segmentation(cam_mask, threshold=np.nan):
mask = mask.div(mask.max()).data
mask = mask.cpu().detach().numpy()

# use otsu's method if no threshold is passed in
# use Otsu's method to find threshold if no threshold is passed in
if np.isnan(threshold):
mask = np.uint8(255 * mask)

# Use Otsu's method to find threshold
maxval = np.max(mask)
segmentation = cv2.threshold(mask, 0, maxval, cv2.THRESH_OTSU)[1]

else:
segmentation = np.array(mask > threshold, dtype="int")

Expand All @@ -65,7 +64,10 @@ def pkl_to_mask(pkl_path, threshold=np.nan):
Args:
pkl_path (str): path to the model output pickle file
task (str): localization task
threshold (np.float64): threshold to use
Returns:
segmentation (np.ndarray): binary segmentation output
"""
# load pickle file, get saliency map and resize
info = pickle.load(open(pkl_path, 'rb'))
Expand All @@ -82,7 +84,7 @@ def pkl_to_mask(pkl_path, threshold=np.nan):
return segmentation


def heatmap_to_mask(map_dir, output_path, threshold_path=''):
def heatmap_to_mask(map_dir, output_path, threshold_path):
"""
Converts all saliency maps to segmentations and stores segmentations in a
json file.
Expand All @@ -103,8 +105,7 @@ def heatmap_to_mask(map_dir, output_path, threshold_path=''):
# get encoded segmentation mask
if threshold_path:
tuning_results = pd.read_csv(threshold_path)
best_threshold = tuning_results[tuning_results['task'] ==
'Edema']['threshold'].values[0]
best_threshold = tuning_results[tuning_results['task'] == task]['threshold'].values[0]
else:
best_threshold = np.nan

Expand All @@ -123,29 +124,22 @@ def heatmap_to_mask(map_dir, output_path, threshold_path=''):
results[img_id][task] = encoded_mask

# save to json
Path(os.path.dirname(output_path)).mkdir(exist_ok=True, parents=True)
with open(output_path, 'w') as f:
json.dump(results, f)
print(f'Segmentation masks (in RLE format) saved to {output_path}')


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument(
'--map_dir',
type=str,
help='directory with pickle files containing heat maps')
parser.add_argument(
'--threshold_path',
type=str,
default='',
help=
'csv file that stores the threshold tuned on the validation set. Use Otsu'
's method if no path is given.')
parser.add_argument('--output_path',
type=str,
parser.add_argument('--map_dir', type=str,
help='directory with pickle files containing heatmaps')
parser.add_argument('--threshold_path', type=str,
help="csv file that stores pre-defined threshold values. \
If no path is given, script uses Otsu's.")
parser.add_argument('--output_path', type=str,
default='./saliency_segmentations.json',
help='json file path for saving encoded segmentations')

args = parser.parse_args()

heatmap_to_mask(args.map_dir, args.output_path, args.threshold_path)
File renamed without changes.
56 changes: 22 additions & 34 deletions tune_heatmap_threshold.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,35 @@
"""
Find thresholds (used to binarize the heatmaps) that maximizes mIoU on the validation set.
Pass in a list of thresholds [0.2.0.3,...,0.8]. Save the best threshold for each pathology in a csv file.
Find thresholds (used to binarize the heatmaps) that maximize mIoU on the
validation set. Pass in a list of potential thresholds [0.2, 0.3, ... , 0.8].
Save the best threshold for each pathology in a csv file.
"""

import pickle
from argparse import ArgumentParser
import glob
import json
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import pandas as pd
from pathlib import Path
import pickle
from pycocotools import mask
import torch.nn.functional as F
from tqdm import tqdm

from pycocotools import mask
from eval import calculate_iou
from heatmap_to_segmentation import pkl_to_mask
from eval_constants import LOCALIZATION_TASKS

from argparse import ArgumentParser
from heatmap_to_segmentation import pkl_to_mask


def compute_miou(threshold, cam_pkls, gt):
"""
Given a threshold and a list of heatmap pickle files, return the miou
Given a threshold and a list of heatmap pickle files, return the mIoU.
Args:
threshould (double): the threshold used to convert heatmaps to segmentations
cam_pkls (list): a list of heatmap pickle files (for a pathology)
threshold (double): the threshold used to convert heatmaps to segmentations
cam_pkls (list): a list of heatmap pickle files (for a given pathology)
gt (dict): dictionary of ground truth segmentation masks
"""

ious = []

for pkl_path in tqdm(cam_pkls):

# break down path to image name and task
path = str(pkl_path).split('/')
task = path[-1].split('_')[-2]
Expand All @@ -56,11 +52,11 @@ def compute_miou(threshold, cam_pkls, gt):

def tune_threshold(task, gt, cam_dir):
"""
For a given pathology, find the threshold that maximizes mIoU
For a given pathology, find the threshold that maximizes mIoU.
Args:
task (str): localizatoin task
gt (dict): dictionary of the ground truth segmentaiton masks
task (str): localization task
gt (dict): dictionary of the ground truth segmentation masks
cam_dir (str): directory with pickle files containing heat maps
"""
cam_pkls = sorted(list(Path(cam_dir).rglob(f"*{task}_map.pkl")))
Expand All @@ -71,23 +67,15 @@ def tune_threshold(task, gt, cam_dir):


if __name__ == '__main__':

parser = ArgumentParser()
parser.add_argument(
'--map_dir',
type=str,
default='gradcam',
help='directory with pickle files containing heat maps')
parser.add_argument('--gt_path',
type=str,
help='directory where ground-truth segmentations are \
parser.add_argument('--map_dir', type=str,
help='directory with pickle files containing heat maps')
parser.add_argument('--gt_path', type=str,
help='json file where ground-truth segmentations are \
saved (encoded)')
parser.add_argument(
'--save_dir',
type=str,
default='.',
help='where to save the best thresholds tuned on the validation set')

parser.add_argument('--save_dir', type=str, default='.',
help='where to save the best thresholds tuned on the \
validation set')
args = parser.parse_args()

with open(args.gt_path) as f:
Expand Down

0 comments on commit 14f83ae

Please sign in to comment.