diff --git a/README.md b/README.md index 0dbfad3..c92d2f7 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ If you downloaded the CheXlocalize dataset, then these pickle files are in `/che If using your own saliency maps, please be sure to save them as pickle files using the above formatting. -`` is an optional csv file path that you can pass in to use your own thresholds to binarize the heatmaps. As an example, we provided `./tuning_results.csv` which saves the threshold for each pathology that maximize mIoU on the validation set. When passing in your own csv file, make sure to follow the same formatting as our example csv. When no threshold path is passed (default), we will apply Otsu's method (an automatic global thresholding algorithm provided by the cv2 package). +`` is an optional csv file path that you can pass in to use your own thresholds to binarize the heatmaps. As an example, we provide [`./sample/tuning_results.csv`](https://github.com/rajpurkarlab/cheXlocalize/blob/master/sample/tuning_results.csv), which saves the threshold for each pathology that maximizes mIoU on the validation set. When passing in your own csv file, make sure to follow the same formatting as this example csv. By defaul, no threshold path is passed in, in which case we will apply Otsu's method (an automatic global thresholding algorithm provided by the cv2 package). `` is the json file path used for saving the encoded segmentation masks. The json file is formatted such that it can be used as input to `eval.py` (see [_Evaluate localization performance_](#eval) for formatting details). @@ -126,17 +126,19 @@ Running this script on the validation set heatmaps from the CheXlocalize dataset ### Fine tune segmentation thresholds -To find the thresholds that maximize mIoU for each pathology on the validation set, run +To find the thresholds that maximize mIoU for each pathology on the validation set, run: ``` -(chexlocalize) > python tune_heatmap_threshold.py --map_dir --save_dir +(chexlocalize) > python tune_heatmap_threshold.py --map_dir --gt_path --save_dir ``` -`` is the directory with pickle files containing the heatmaps. +`` is the directory with pickle files containing the heatmaps. + +`` is the json file where ground-truth segmentations are saved (encoded). `` is the directory to save the csv file that stores the tuned thresholds. Default is current directory. -This script will replicate 'tuning_results.csv' when you use `/cheXlocalize_dataset/gradcam_maps_val/` as ``. Running this script should take about 30 minutes. +This script will replicate './sample/tuning_results.csv' when you use the CheXlocalize validation set DenseNet121 + Grad-CAM heatmaps in `/cheXlocalize_dataset/gradcam_maps_val/` as `` and the validation set ground-truth pixel-level segmentations in `/cheXlocalize_dataset/gt_segmentations_val.json`. Running this script should take about one hour. ## Generate segmentations from human annotations diff --git a/eval.py b/eval.py index 101b692..195b155 100644 --- a/eval.py +++ b/eval.py @@ -206,7 +206,7 @@ def evaluate(gt_path, pred_path, save_dir, metric, true_pos_only): -- `{miou/hitrate}_summary_results.csv`: mIoU or hit rate 95% bootstrap confidence intervals for each pathology. """ # create save_dir if it does not already exist - Path(save_dir).mkdir(exist_ok=True,parents=True) + Path(save_dir).mkdir(exist_ok=True, parents=True) if metric == 'miou': ious, cxr_ids = get_ious(gt_path, pred_path, true_pos_only) @@ -244,11 +244,11 @@ def evaluate(gt_path, pred_path, save_dir, metric, true_pos_only): help='json path where predicted segmentations are saved \ (if metric = miou) or directory with pickle files \ containing heat maps (if metric = hitrate)') - parser.add_argument('--true_pos_only', default="True", + parser.add_argument('--true_pos_only', default='True', help='if true, run evaluation only on the true positive \ slice of the dataset (CXRs that contain predicted and \ ground-truth segmentations)') - parser.add_argument('--save_dir', default=".", + parser.add_argument('--save_dir', default='.', help='where to save evaluation results') parser.add_argument('--seed', type=int, default=0, help='random seed to fix') diff --git a/heatmap_to_segmentation.py b/heatmap_to_segmentation.py index 134bc18..9f4e97c 100644 --- a/heatmap_to_segmentation.py +++ b/heatmap_to_segmentation.py @@ -1,11 +1,13 @@ """ -Converts saliency heat maps to binary segmentations and encodes segmentations +Converts saliency heatmaps to binary segmentations and encodes segmentations using RLE formats using the pycocotools Mask API. The final output is stored in a json file. -The default thresholding used in this code is Otsu's method (an automatic global thresholding algorithm provided by cv2). -Users can also pass in their self-defined thresholds to binarize the heatmaps through --threshold_path. -Make sure the input is a csv file with the same format as the tuning_results.csv file we provided. +The default thresholding used in this code is Otsu's method (an automatic global +thresholding algorithm provided by cv2). Users can also pass in their own +self-defined thresholds to binarize the heatmaps through --threshold_path. If +doing this, make sure the input is a csv file with the same format as the +provided file sample/tuning_results.csv. """ from argparse import ArgumentParser import cv2 @@ -25,14 +27,14 @@ def cam_to_segmentation(cam_mask, threshold=np.nan): """ - Threshold a saliency heat map to binary segmentation mask. - + Threshold a saliency heatmap to binary segmentation mask. Args: cam_mask (torch.Tensor): heat map in the original image size (H x W). Will squeeze the tensor if there are more than two dimensions. + threshold (np.float64): threshold to use Returns: - segmentation_output (np.array): binary segmentation output + segmentation (np.ndarray): binary segmentation output """ if (len(cam_mask.size()) > 2): cam_mask = cam_mask.squeeze() @@ -44,14 +46,11 @@ def cam_to_segmentation(cam_mask, threshold=np.nan): mask = mask.div(mask.max()).data mask = mask.cpu().detach().numpy() - # use otsu's method if no threshold is passed in + # use Otsu's method to find threshold if no threshold is passed in if np.isnan(threshold): mask = np.uint8(255 * mask) - - # Use Otsu's method to find threshold maxval = np.max(mask) segmentation = cv2.threshold(mask, 0, maxval, cv2.THRESH_OTSU)[1] - else: segmentation = np.array(mask > threshold, dtype="int") @@ -65,7 +64,10 @@ def pkl_to_mask(pkl_path, threshold=np.nan): Args: pkl_path (str): path to the model output pickle file - task (str): localization task + threshold (np.float64): threshold to use + + Returns: + segmentation (np.ndarray): binary segmentation output """ # load pickle file, get saliency map and resize info = pickle.load(open(pkl_path, 'rb')) @@ -82,7 +84,7 @@ def pkl_to_mask(pkl_path, threshold=np.nan): return segmentation -def heatmap_to_mask(map_dir, output_path, threshold_path=''): +def heatmap_to_mask(map_dir, output_path, threshold_path): """ Converts all saliency maps to segmentations and stores segmentations in a json file. @@ -103,8 +105,7 @@ def heatmap_to_mask(map_dir, output_path, threshold_path=''): # get encoded segmentation mask if threshold_path: tuning_results = pd.read_csv(threshold_path) - best_threshold = tuning_results[tuning_results['task'] == - 'Edema']['threshold'].values[0] + best_threshold = tuning_results[tuning_results['task'] == task]['threshold'].values[0] else: best_threshold = np.nan @@ -123,6 +124,7 @@ def heatmap_to_mask(map_dir, output_path, threshold_path=''): results[img_id][task] = encoded_mask # save to json + Path(os.path.dirname(output_path)).mkdir(exist_ok=True, parents=True) with open(output_path, 'w') as f: json.dump(results, f) print(f'Segmentation masks (in RLE format) saved to {output_path}') @@ -130,22 +132,14 @@ def heatmap_to_mask(map_dir, output_path, threshold_path=''): if __name__ == '__main__': parser = ArgumentParser() - parser.add_argument( - '--map_dir', - type=str, - help='directory with pickle files containing heat maps') - parser.add_argument( - '--threshold_path', - type=str, - default='', - help= - 'csv file that stores the threshold tuned on the validation set. Use Otsu' - 's method if no path is given.') - parser.add_argument('--output_path', - type=str, + parser.add_argument('--map_dir', type=str, + help='directory with pickle files containing heatmaps') + parser.add_argument('--threshold_path', type=str, + help="csv file that stores pre-defined threshold values. \ + If no path is given, script uses Otsu's.") + parser.add_argument('--output_path', type=str, default='./saliency_segmentations.json', help='json file path for saving encoded segmentations') - args = parser.parse_args() heatmap_to_mask(args.map_dir, args.output_path, args.threshold_path) diff --git a/tuning_results.csv b/sample/tuning_results.csv similarity index 100% rename from tuning_results.csv rename to sample/tuning_results.csv diff --git a/tune_heatmap_threshold.py b/tune_heatmap_threshold.py index 1915d8a..a7629ef 100644 --- a/tune_heatmap_threshold.py +++ b/tune_heatmap_threshold.py @@ -1,39 +1,35 @@ """ -Find thresholds (used to binarize the heatmaps) that maximizes mIoU on the validation set. -Pass in a list of thresholds [0.2.0.3,...,0.8]. Save the best threshold for each pathology in a csv file. +Find thresholds (used to binarize the heatmaps) that maximize mIoU on the +validation set. Pass in a list of potential thresholds [0.2, 0.3, ... , 0.8]. +Save the best threshold for each pathology in a csv file. """ - -import pickle +from argparse import ArgumentParser import glob import json -import torch.nn.functional as F -from pathlib import Path import numpy as np import pandas as pd +from pathlib import Path +import pickle +from pycocotools import mask +import torch.nn.functional as F from tqdm import tqdm -from pycocotools import mask from eval import calculate_iou -from heatmap_to_segmentation import pkl_to_mask from eval_constants import LOCALIZATION_TASKS - -from argparse import ArgumentParser +from heatmap_to_segmentation import pkl_to_mask def compute_miou(threshold, cam_pkls, gt): """ - Given a threshold and a list of heatmap pickle files, return the miou + Given a threshold and a list of heatmap pickle files, return the mIoU. Args: - threshould (double): the threshold used to convert heatmaps to segmentations - cam_pkls (list): a list of heatmap pickle files (for a pathology) + threshold (double): the threshold used to convert heatmaps to segmentations + cam_pkls (list): a list of heatmap pickle files (for a given pathology) gt (dict): dictionary of ground truth segmentation masks """ - ious = [] - for pkl_path in tqdm(cam_pkls): - # break down path to image name and task path = str(pkl_path).split('/') task = path[-1].split('_')[-2] @@ -56,11 +52,11 @@ def compute_miou(threshold, cam_pkls, gt): def tune_threshold(task, gt, cam_dir): """ - For a given pathology, find the threshold that maximizes mIoU + For a given pathology, find the threshold that maximizes mIoU. Args: - task (str): localizatoin task - gt (dict): dictionary of the ground truth segmentaiton masks + task (str): localization task + gt (dict): dictionary of the ground truth segmentation masks cam_dir (str): directory with pickle files containing heat maps """ cam_pkls = sorted(list(Path(cam_dir).rglob(f"*{task}_map.pkl"))) @@ -71,23 +67,15 @@ def tune_threshold(task, gt, cam_dir): if __name__ == '__main__': - parser = ArgumentParser() - parser.add_argument( - '--map_dir', - type=str, - default='gradcam', - help='directory with pickle files containing heat maps') - parser.add_argument('--gt_path', - type=str, - help='directory where ground-truth segmentations are \ + parser.add_argument('--map_dir', type=str, + help='directory with pickle files containing heat maps') + parser.add_argument('--gt_path', type=str, + help='json file where ground-truth segmentations are \ saved (encoded)') - parser.add_argument( - '--save_dir', - type=str, - default='.', - help='where to save the best thresholds tuned on the validation set') - + parser.add_argument('--save_dir', type=str, default='.', + help='where to save the best thresholds tuned on the \ + validation set') args = parser.parse_args() with open(args.gt_path) as f: