Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
pritamrungta committed Dec 10, 2024
1 parent 052a733 commit 0613100
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 38 deletions.
67 changes: 35 additions & 32 deletions redbrick/utils/dicom.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ async def process_nifti_upload(

group_map: Dict[int, Set[int]] = {}
map_instances: Set[int] = set()
file_instances: Set[int] = set()
reverse_map: Dict[Tuple[int, ...], int] = {}
for instance_id, instance_groups in instances.items():
map_instances.add(instance_id)
Expand All @@ -548,9 +549,14 @@ async def process_nifti_upload(
if binary_mask and files[0] in reverse_masks:
instance_number = reverse_masks[files[0]][0]
base_data[np.nonzero(base_data)] = instance_number
file_instances.add(instance_number)
else:
file_instances.update(
[x.item() for x in np.unique(base_data[np.nonzero(base_data)])]
)

instance_pool = sorted(
set(range(1, 65536)) - map_instances,
set(range(1, 65536)) - map_instances - file_instances,
reverse=True,
)

Expand Down Expand Up @@ -609,6 +615,7 @@ async def process_nifti_upload(
if base_v == 0:
# No instance, so we can just set the base value to the instance number
base_data[v_indices] = instance_number
file_instances.add(instance_number)
reverse_map[tuple(sorted(mask_instances))] = instance_number
else:
# An existing instance or group, so we create a new group with the
Expand All @@ -631,38 +638,34 @@ async def process_nifti_upload(
reverse_map[group_key] = next_group

base_data[v_indices] = next_group

if label_validate or prune_segmentations:
file_instances: Set[int] = set(
x.item() for x in np.unique(base_data[np.nonzero(base_data)])
file_instances.add(next_group)

if prune_segmentations and file_instances != map_instances:
if file_excess := file_instances - map_instances:
logger.warning(f"Pruning segmentation instances: {file_excess}")
base_non_zero = np.nonzero(base_data)
base_data[base_non_zero] = np.where(
np.isin(base_data[base_non_zero], list(file_excess)),
0,
base_data[base_non_zero],
)
file_instances -= file_excess

if map_excess := map_instances - file_instances:
logger.warning(f"Pruning segmentMap instances: {map_excess}")
map_instances -= map_excess

if label_validate and file_instances != map_instances:
raise ValueError(
"Instance IDs in segmentation file(s) and segmentMap do not match.\n"
+ f"Segmentation file(s) have instances: {file_instances} and "
+ f"segmentMap has instances: {map_instances}\n"
+ f"Segmentation(s): {files}"
)
if file_instances != map_instances:
if prune_segmentations:
if excess := file_instances - map_instances:
logger.warning(f"Pruning segmentation instances: {excess}")
base_non_zero = np.nonzero(base_data)
base_data[base_non_zero] = np.where(
np.isin(base_data[base_non_zero], list(excess)),
0,
base_data[base_non_zero],
)
file_instances -= excess

if excess := map_instances - file_instances:
logger.warning(f"Pruning segmentMap instances: {excess}")
map_instances -= excess

elif label_validate:
raise ValueError(
"Instance IDs in segmentation file(s) and segmentMap do not match.\n"
+ f"Segmentation file(s) have instances: {file_instances} and "
+ f"segmentMap has instances: {map_instances}\n"
+ f"Segmentation(s): {files}"
)

if not any(v >= 256 for v in file_instances):
base_img.set_data_dtype(np.uint8)
base_data = base_data.astype(np.uint8)
if not any(v >= 256 for v in file_instances):
base_img.set_data_dtype(np.uint8)
base_data = base_data.astype(np.uint8)

if isinstance(base_img, Nifti1Image):
new_img = Nifti1Image(base_data, base_img.affine, base_img.header)
Expand All @@ -675,7 +678,7 @@ async def process_nifti_upload(
nib_save(new_img, filename)

segment_map: Dict[int, Optional[List[int]]] = {}
for instance in map_instances:
for instance in map_instances | file_instances:
if instance in group_map:
for instance_id in group_map[instance]:
groups = segment_map.get(instance_id)
Expand Down
40 changes: 34 additions & 6 deletions tests/test_utils/test_dicom.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for `redbrick.utils.dicom`."""

import os
from typing import Dict, List, Optional
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -469,37 +470,47 @@ async def test_process_nifti_download(

@pytest.mark.unit
@pytest.mark.asyncio
async def test_process_nifti_upload(tmpdir, nifti_instance_files_png):
async def test_process_nifti_upload(tmpdir, mock_labels, nifti_instance_files_png):
"""Test dicom.process_nifti_upload"""
files = nifti_instance_files_png
instances = {1, 2, 3, 4, 5, 9}
instance_ids = {1, 2, 3, 4, 5, 9}
instances: Dict[int, Optional[List[int]]] = {}
for label in mock_labels:
if label.get("dicom", {}).get("instanceid") in instance_ids:
instances[label["dicom"]["instanceid"]] = label["dicom"].get("groupids")
instance_ids.difference_update([label["dicom"]["instanceid"]])
instance_ids.difference_update(label["dicom"].get("groupids") or [])

for instance_id in instance_ids:
instances[instance_id] = None

semantic_mask = False # not used
png_mask = False # not supported
binary_mask = True
_mask = nifti_instance_files_png[0]
_mask_inst_id = _mask.split(".")[-3].split("-")[-1]
masks = {_mask_inst_id: _mask}
label_validate = False
prune_segmentations = False

with patch.object(
dicom, "config_path", return_value=str(tmpdir)
) as mock_config_path:
result, group_map = await dicom.process_nifti_upload(
files,
{inst: None for inst in instances},
instances,
binary_mask,
semantic_mask,
png_mask,
masks,
label_validate,
prune_segmentations,
)

mock_config_path.assert_called_once()
assert isinstance(result, str) and result.endswith("label.nii.gz")
assert os.path.isfile(result)
assert isinstance(group_map, dict)
assert set(group_map) == instances
assert isinstance(group_map, dict)
assert group_map.keys() == instances.keys()

# Ensure no group IDs has the same value any instance ID
assert (
Expand All @@ -509,6 +520,23 @@ async def test_process_nifti_upload(tmpdir, nifti_instance_files_png):
== set()
)

# Verify that we can produce the expected masks from the label file and group map
expected_masks = {
1: np.array([[[1], [1], [1]], [[1], [1], [1]], [[1], [1], [1]]]),
2: np.array([[[0], [0], [1]], [[1], [1], [0]], [[0], [0], [0]]]),
3: np.array([[[1], [0], [0]], [[0], [0], [1]], [[0], [1], [0]]]),
4: np.array([[[0], [0], [0]], [[0], [0], [0]], [[0], [0], [1]]]),
5: np.array([[[0], [1], [0]], [[0], [1], [0]], [[0], [0], [1]]]),
9: np.array([[[0], [0], [0]], [[0], [0], [0]], [[1], [0], [0]]]),
}
global_mask = np.asanyarray(nib.load(result).dataobj, np.uint8)
for instance_id in instances:
selector = global_mask == instance_id
for group_id in (group_map or {}).get(instance_id) or []:
selector = np.logical_or(selector, global_mask == group_id)
mask = selector.astype(np.uint8)
assert np.all(mask == expected_masks[instance_id]), instance_id


@pytest.mark.unit
@pytest.mark.asyncio
Expand Down

0 comments on commit 0613100

Please sign in to comment.