Skip to content

Commit

Permalink
add more python api nifti output tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Feb 9, 2024
1 parent 4a206ad commit d07c9ea
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
8 changes: 8 additions & 0 deletions tests/tests_os.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def run_tests_and_exit_on_failure():
os.remove("tests/unittest_prediction_fast.nii.gz")
if r != 0: sys.exit("Test failed: test_prediction_fast with Nifti input and output")

# Test python api 3 - nifti input, nifti output, ml=True (has no effect here), roi_subset=['liver', 'brain']
input_img = nib.load("tests/reference_files/example_ct_sm.nii.gz")
output_img = totalsegmentator(input_img, None, device="cpu", ml=True, roi_subset=['liver', 'brain'])
nib.save(output_img, "tests/unittest_prediction_roi_subset.nii.gz")
r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_prediction_liver_roi_subset"])
os.remove("tests/unittest_prediction_roi_subset.nii.gz")
if r != 0: sys.exit("Test failed: test_prediction_fast with Nifti input and output")

# Test terminal
# Test organ predictions - fast - multilabel
# makes correct path for windows and linux. Only required for terminal call. Within python
Expand Down
4 changes: 2 additions & 2 deletions totalsegmentator/nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def nnUNetv2_predict(dir_in, dir_out, task_id, model="3d_fullres", folds=None,
else:
device = torch.device('mps')
step_size = 0.5
# step_size = 0.8 # overal speedup roughly 11%; for fast model no speedup; dice 0.001 worse
# step_size = 0.8 # overall speedup roughly 11%; for fast model no speedup; dice 0.001 worse
disable_tta = not tta
verbose = False
save_probabilities = False
Expand Down Expand Up @@ -535,7 +535,7 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
new_header.set_data_dtype(np.uint8)
img_out = nib.Nifti1Image(img_data, img_pred.affine, new_header)
img_out = add_label_map_to_nifti(img_out, class_map[task_name])

if file_out is not None and skip_saving is False:
if not quiet: print("Saving segmentations...")

Expand Down

0 comments on commit d07c9ea

Please sign in to comment.