From d07c9eac2206e0ab5a28255107d1e01487be9197 Mon Sep 17 00:00:00 2001 From: wasserth Date: Fri, 9 Feb 2024 17:30:39 +0100 Subject: [PATCH] add more python api nifti output tests --- tests/tests_os.py | 8 ++++++++ totalsegmentator/nnunet.py | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/tests_os.py b/tests/tests_os.py index 605759dd5..65507c925 100755 --- a/tests/tests_os.py +++ b/tests/tests_os.py @@ -34,6 +34,14 @@ def run_tests_and_exit_on_failure(): os.remove("tests/unittest_prediction_fast.nii.gz") if r != 0: sys.exit("Test failed: test_prediction_fast with Nifti input and output") + # Test python api 3 - nifti input, nifti output, ml=True (has no effect here), roi_subset=['liver', 'brain'] + input_img = nib.load("tests/reference_files/example_ct_sm.nii.gz") + output_img = totalsegmentator(input_img, None, device="cpu", ml=True, roi_subset=['liver', 'brain']) + nib.save(output_img, "tests/unittest_prediction_roi_subset.nii.gz") + r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_liver_roi_subset"]) + os.remove("tests/unittest_prediction_roi_subset.nii.gz") + if r != 0: sys.exit("Test failed: test_prediction_fast with Nifti input and output") + # Test terminal # Test organ predictions - fast - multilabel # makes correct path for windows and linux. Only required for terminal call. Within python diff --git a/totalsegmentator/nnunet.py b/totalsegmentator/nnunet.py index 74683cbd1..73b798e5c 100644 --- a/totalsegmentator/nnunet.py +++ b/totalsegmentator/nnunet.py @@ -180,7 +180,7 @@ def nnUNetv2_predict(dir_in, dir_out, task_id, model="3d_fullres", folds=None, else: device = torch.device('mps') step_size = 0.5 - # step_size = 0.8 # overal speedup roughly 11%; for fast model no speedup; dice 0.001 worse + # step_size = 0.8 # overall speedup roughly 11%; for fast model no speedup; dice 0.001 worse disable_tta = not tta verbose = False save_probabilities = False @@ -535,7 +535,7 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_ new_header.set_data_dtype(np.uint8) img_out = nib.Nifti1Image(img_data, img_pred.affine, new_header) img_out = add_label_map_to_nifti(img_out, class_map[task_name]) - + if file_out is not None and skip_saving is False: if not quiet: print("Saving segmentations...")