diff --git a/resources/contrast_phase_classifiers.pkl b/resources/contrast_phase_classifiers.pkl new file mode 100644 index 000000000..faeb6743a Binary files /dev/null and b/resources/contrast_phase_classifiers.pkl differ diff --git a/totalsegmentator/bin/totalseg_get_phase.py b/totalsegmentator/bin/totalseg_get_phase.py index 74010eead..4691f20c0 100644 --- a/totalsegmentator/bin/totalseg_get_phase.py +++ b/totalsegmentator/bin/totalseg_get_phase.py @@ -16,13 +16,13 @@ def pi_time_to_phase(pi_time: float) -> str: """ Convert the pi time to a phase and get a probability for the value. - + native: 0-10 arterial_early: 10-30 arterial_late: 30-50 portal_venous: 50-100 delayed: 100+ - + returns: phase, probability """ if pi_time < 5: @@ -43,7 +43,7 @@ def pi_time_to_phase(pi_time: float) -> str: return "portal_venous", 0.7 else: return "delayed", 0.7 - + def get_ct_contrast_phase(ct_img: nib.Nifti1Image): @@ -51,19 +51,19 @@ def get_ct_contrast_phase(ct_img: nib.Nifti1Image): "heart", "aorta", "inferior_vena_cava", "portal_vein_and_splenic_vein", "iliac_vena_left", "iliac_vena_right", "iliac_artery_left", "iliac_artery_right", "pulmonary_vein"] - + seg_img, stats = totalsegmentator(ct_img, None, ml=True, fast=True, statistics=True, roi_subset=None, quiet=False) - + features = [] for organ in organs: features.append(stats[organ]["intensity"]) - - # todo: adapt - # classifier_path = Path(__file__).parent / "classifier.pkl" - classifier_path = "/mnt/nvme/data/phase_classification/classifiers.pkl" + + # weights from longitudinalliver dataset + classifier_path = Path(__file__).parents[2] / "resources" / "contrast_phase_classifiers.pkl" + # classifier_path = "/mnt/nvme/data/phase_classification/classifiers.pkl" clfs = pickle.load(open(classifier_path, "rb")) - + # ensemble across folds preds = [] for fold, clf in clfs.items(): @@ -71,13 +71,13 @@ def get_ct_contrast_phase(ct_img: nib.Nifti1Image): preds = np.array(preds) pi_time = round(float(np.mean(preds)), 2) pi_time_std = round(float(np.std(preds)), 4) - + print("Ensemble res:") print(preds) # print(f"mean: {pi_time} +/- {pi_time_std}") print(f"mean: {pi_time} [{preds.min():.1f}-{preds.max():.1f}]") phase, probability = pi_time_to_phase(pi_time) - + return {"pi_time": pi_time, "phase": phase, "probability": probability} @@ -101,10 +101,10 @@ def main(): args = parser.parse_args() res = get_ct_contrast_phase(nib.load(args.input_file)) - + print("Result:") pprint(res) - + with open(args.output_file, "w") as f: f.write(json.dumps(res, indent=4))