Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow json mapping usage #16

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
79 changes: 57 additions & 22 deletions compute_metrics_reloaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import pandas as pd
from multiprocessing import Pool, cpu_count
from functools import partial
import json

from MetricsReloaded.metrics.pairwise_measures import BinaryPairwiseMeasures as BPM

Expand Down Expand Up @@ -83,6 +84,10 @@ def get_parser():
'see: https://metricsreloaded.readthedocs.io/en/latest/reference/metrics/metrics.html.')
parser.add_argument('-output', type=str, default='metrics.csv', required=False,
help='Path to the output CSV file to save the metrics. Default: metrics.csv')
parser.add_argument('-pred-map', type=str, metavar='<json>', default=None, required=False,
help='JSON file containing the prediction mapping between the imaged structure and the corresponding integer value in the image ~/<your_path>/<myjson>.json')
parser.add_argument('-ref-map', type=str, metavar='<json>', default=None, required=False,
help='JSON file containing the reference mapping between the imaged structure and the corresponding integer value in the image ~/<your_path>/<myjson>.json')
parser.add_argument('-jobs', type=int, default=cpu_count()//8, required=False,
help='Number of CPU cores to use in parallel. Default: cpu_count()//8.')

Expand Down Expand Up @@ -126,7 +131,7 @@ def get_images_in_folder(prediction, reference):
return prediction_files, reference_files


def compute_metrics_single_subject(prediction, reference, metrics):
def compute_metrics_single_subject(prediction, reference, metrics, ref_map=None, pred_map=None):
"""
Compute MetricsReloaded metrics for a single subject
:param prediction: path to the nifti image with the prediction
Expand All @@ -138,21 +143,25 @@ def compute_metrics_single_subject(prediction, reference, metrics):
prediction_data = load_nifti_image(prediction)
reference_data = load_nifti_image(reference)

# check whether the images have the same shape and orientation
if prediction_data.shape != reference_data.shape:
raise ValueError(f'The prediction and reference (ground truth) images must have the same shape. '
f'The prediction image has shape {prediction_data.shape} and the ground truth image has '
f'shape {reference_data.shape}.')

# get all unique labels (classes)
# for example, for nnunet region-based segmentation, spinal cord has label 1, and lesions have label 2
unique_labels_reference = np.unique(reference_data)
unique_labels_reference = unique_labels_reference[unique_labels_reference != 0] # remove background
unique_labels_prediction = np.unique(prediction_data)
unique_labels_prediction = unique_labels_prediction[unique_labels_prediction != 0] # remove background

# Get the unique labels that are present in the reference OR prediction images
unique_labels = np.unique(np.concatenate((unique_labels_reference, unique_labels_prediction)))
if ref_map is None and pred_map is None:
# check whether the images have the same shape and orientation
if prediction_data.shape != reference_data.shape:
raise ValueError(f'The prediction and reference (ground truth) images must have the same shape. '
f'The prediction image has shape {prediction_data.shape} and the ground truth image has '
f'shape {reference_data.shape}.')

# get all unique labels (classes)
# for example, for nnunet region-based segmentation, spinal cord has label 1, and lesions have label 2
unique_labels_reference = np.unique(reference_data)
unique_labels_reference = unique_labels_reference[unique_labels_reference != 0] # remove background
unique_labels_prediction = np.unique(prediction_data)
unique_labels_prediction = unique_labels_prediction[unique_labels_prediction != 0] # remove background

# Get the unique labels that are present in the reference OR prediction images
unique_labels = np.unique(np.concatenate((unique_labels_reference, unique_labels_prediction)))
else:
# Get the unique labels that are present in the reference OR prediction images
unique_labels = np.unique(np.concatenate((list(ref_map.keys()), list(pred_map.keys()))))

# append entry into the output_list to store the metrics for the current subject
metrics_dict = {'reference': reference, 'prediction': prediction}
Expand All @@ -161,8 +170,12 @@ def compute_metrics_single_subject(prediction, reference, metrics):
# by doing this, we can compute metrics for each label separately, e.g., separately for spinal cord and lesions
for label in unique_labels:
# create binary masks for the current label
prediction_data_label = np.array(prediction_data == label, dtype=float)
reference_data_label = np.array(reference_data == label, dtype=float)
if not isinstance(label, str):
prediction_data_label = np.array(prediction_data == label, dtype=float)
reference_data_label = np.array(reference_data == label, dtype=float)
else:
prediction_data_label = np.array(prediction_data == pred_map[label], dtype=float)
reference_data_label = np.array(reference_data == ref_map[label], dtype=float)

bpm = BPM(prediction_data_label, reference_data_label, measures=metrics)
dict_seg = bpm.to_dict_meas()
Expand Down Expand Up @@ -214,11 +227,11 @@ def build_output_dataframe(output_list):
return df


def process_subject(prediction_file, reference_file, metrics):
def process_subject(prediction_file, reference_file, metrics, ref_map=None, pred_map=None):
"""
Wrapper function to process a single subject.
"""
return compute_metrics_single_subject(prediction_file, reference_file, metrics)
return compute_metrics_single_subject(prediction_file, reference_file, metrics, ref_map, pred_map)


def main():
Expand All @@ -229,6 +242,23 @@ def main():
# Initialize a list to store the output dictionaries (representing a single reference-prediction pair per subject)
output_list = list()

# Check if both -pred-map and -ref-map are referenced if at least one is specified
if any((args.ref_map is None, args.pred_map is None)) and any((args.ref_map is not None, args.pred_map is not None)):
raise ValueError(f'If used, both -ref-map and -pred-map must be provided.')

# Load JSON mapping if provided
if not any((args.ref_map is None, args.pred_map is None)):
# Load JSON files and create a dictionary
with open(args.ref_map, "r") as file:
ref_map = json.load(file)
# Load JSON files and create a dictionary
with open(args.pred_map, "r") as file:
pred_map = json.load(file)
else:
# Assign None value if not used
ref_map = None
pred_map = None

# Print the metrics to be computed
print(f'Computing metrics: {args.metrics}')
print(f'Using {args.jobs} CPU cores in parallel ...')
Expand All @@ -241,14 +271,19 @@ def main():
# Use multiprocessing to parallelize the computation
with Pool(args.jobs) as pool:
# Create a partial function to pass the metrics argument to the process_subject function
func = partial(process_subject, metrics=args.metrics)
func = partial(
process_subject,
metrics=args.metrics,
ref_map=ref_map,
pred_map=pred_map
)
# Compute metrics for each subject in parallel
results = pool.starmap(func, zip(prediction_files, reference_files))

# Collect the results
output_list.extend(results)
else:
metrics_dict = compute_metrics_single_subject(args.prediction, args.reference, args.metrics)
metrics_dict = compute_metrics_single_subject(args.prediction, args.reference, args.metrics, ref_map, pred_map)
# Append the output dictionary (representing a single reference-prediction pair per subject) to the output_list
output_list.append(metrics_dict)

Expand Down
Loading