diff --git a/compute_metrics_reloaded.py b/compute_metrics_reloaded.py index 05494ae..d43cd52 100644 --- a/compute_metrics_reloaded.py +++ b/compute_metrics_reloaded.py @@ -43,6 +43,7 @@ import pandas as pd from multiprocessing import Pool, cpu_count from functools import partial +import json from MetricsReloaded.metrics.pairwise_measures import BinaryPairwiseMeasures as BPM @@ -83,6 +84,10 @@ def get_parser(): 'see: https://metricsreloaded.readthedocs.io/en/latest/reference/metrics/metrics.html.') parser.add_argument('-output', type=str, default='metrics.csv', required=False, help='Path to the output CSV file to save the metrics. Default: metrics.csv') + parser.add_argument('-pred-map', type=str, metavar='', default=None, required=False, + help='JSON file containing the prediction mapping between the imaged structure and the corresponding integer value in the image ~//.json') + parser.add_argument('-ref-map', type=str, metavar='', default=None, required=False, + help='JSON file containing the reference mapping between the imaged structure and the corresponding integer value in the image ~//.json') parser.add_argument('-jobs', type=int, default=cpu_count()//8, required=False, help='Number of CPU cores to use in parallel. Default: cpu_count()//8.') @@ -126,7 +131,7 @@ def get_images_in_folder(prediction, reference): return prediction_files, reference_files -def compute_metrics_single_subject(prediction, reference, metrics): +def compute_metrics_single_subject(prediction, reference, metrics, ref_map=None, pred_map=None): """ Compute MetricsReloaded metrics for a single subject :param prediction: path to the nifti image with the prediction @@ -138,21 +143,25 @@ def compute_metrics_single_subject(prediction, reference, metrics): prediction_data = load_nifti_image(prediction) reference_data = load_nifti_image(reference) - # check whether the images have the same shape and orientation - if prediction_data.shape != reference_data.shape: - raise ValueError(f'The prediction and reference (ground truth) images must have the same shape. ' - f'The prediction image has shape {prediction_data.shape} and the ground truth image has ' - f'shape {reference_data.shape}.') - - # get all unique labels (classes) - # for example, for nnunet region-based segmentation, spinal cord has label 1, and lesions have label 2 - unique_labels_reference = np.unique(reference_data) - unique_labels_reference = unique_labels_reference[unique_labels_reference != 0] # remove background - unique_labels_prediction = np.unique(prediction_data) - unique_labels_prediction = unique_labels_prediction[unique_labels_prediction != 0] # remove background - - # Get the unique labels that are present in the reference OR prediction images - unique_labels = np.unique(np.concatenate((unique_labels_reference, unique_labels_prediction))) + if ref_map is None and pred_map is None: + # check whether the images have the same shape and orientation + if prediction_data.shape != reference_data.shape: + raise ValueError(f'The prediction and reference (ground truth) images must have the same shape. ' + f'The prediction image has shape {prediction_data.shape} and the ground truth image has ' + f'shape {reference_data.shape}.') + + # get all unique labels (classes) + # for example, for nnunet region-based segmentation, spinal cord has label 1, and lesions have label 2 + unique_labels_reference = np.unique(reference_data) + unique_labels_reference = unique_labels_reference[unique_labels_reference != 0] # remove background + unique_labels_prediction = np.unique(prediction_data) + unique_labels_prediction = unique_labels_prediction[unique_labels_prediction != 0] # remove background + + # Get the unique labels that are present in the reference OR prediction images + unique_labels = np.unique(np.concatenate((unique_labels_reference, unique_labels_prediction))) + else: + # Get the unique labels that are present in the reference OR prediction images + unique_labels = np.unique(np.concatenate((list(ref_map.keys()), list(pred_map.keys())))) # append entry into the output_list to store the metrics for the current subject metrics_dict = {'reference': reference, 'prediction': prediction} @@ -161,8 +170,12 @@ def compute_metrics_single_subject(prediction, reference, metrics): # by doing this, we can compute metrics for each label separately, e.g., separately for spinal cord and lesions for label in unique_labels: # create binary masks for the current label - prediction_data_label = np.array(prediction_data == label, dtype=float) - reference_data_label = np.array(reference_data == label, dtype=float) + if not isinstance(label, str): + prediction_data_label = np.array(prediction_data == label, dtype=float) + reference_data_label = np.array(reference_data == label, dtype=float) + else: + prediction_data_label = np.array(prediction_data == pred_map[label], dtype=float) + reference_data_label = np.array(reference_data == ref_map[label], dtype=float) bpm = BPM(prediction_data_label, reference_data_label, measures=metrics) dict_seg = bpm.to_dict_meas() @@ -214,11 +227,11 @@ def build_output_dataframe(output_list): return df -def process_subject(prediction_file, reference_file, metrics): +def process_subject(prediction_file, reference_file, metrics, ref_map=None, pred_map=None): """ Wrapper function to process a single subject. """ - return compute_metrics_single_subject(prediction_file, reference_file, metrics) + return compute_metrics_single_subject(prediction_file, reference_file, metrics, ref_map, pred_map) def main(): @@ -229,6 +242,23 @@ def main(): # Initialize a list to store the output dictionaries (representing a single reference-prediction pair per subject) output_list = list() + # Check if both -pred-map and -ref-map are referenced if at least one is specified + if any((args.ref_map is None, args.pred_map is None)) and any((args.ref_map is not None, args.pred_map is not None)): + raise ValueError(f'If used, both -ref-map and -pred-map must be provided.') + + # Load JSON mapping if provided + if not any((args.ref_map is None, args.pred_map is None)): + # Load JSON files and create a dictionary + with open(args.ref_map, "r") as file: + ref_map = json.load(file) + # Load JSON files and create a dictionary + with open(args.pred_map, "r") as file: + pred_map = json.load(file) + else: + # Assign None value if not used + ref_map = None + pred_map = None + # Print the metrics to be computed print(f'Computing metrics: {args.metrics}') print(f'Using {args.jobs} CPU cores in parallel ...') @@ -241,14 +271,19 @@ def main(): # Use multiprocessing to parallelize the computation with Pool(args.jobs) as pool: # Create a partial function to pass the metrics argument to the process_subject function - func = partial(process_subject, metrics=args.metrics) + func = partial( + process_subject, + metrics=args.metrics, + ref_map=ref_map, + pred_map=pred_map + ) # Compute metrics for each subject in parallel results = pool.starmap(func, zip(prediction_files, reference_files)) # Collect the results output_list.extend(results) else: - metrics_dict = compute_metrics_single_subject(args.prediction, args.reference, args.metrics) + metrics_dict = compute_metrics_single_subject(args.prediction, args.reference, args.metrics, ref_map, pred_map) # Append the output dictionary (representing a single reference-prediction pair per subject) to the output_list output_list.append(metrics_dict)