Skip to content

Commit

Permalink
Merge pull request #1387 from moloney/dw-fixes
Browse files Browse the repository at this point in the history
BF+TST: Fix 'frame_order' for single frame files
  • Loading branch information
effigies authored Jan 12, 2025
2 parents 59733ff + b1eb9b0 commit d5995db
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 24 deletions.
25 changes: 19 additions & 6 deletions nibabel/nicom/dicomwrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ def b_vector(self):
class FrameFilter:
"""Base class for defining how to filter out (ignore) frames from a multiframe file
It is guaranteed that the `applies` method will on a dataset before the `keep` method
is called on any of the frames inside.
It is guaranteed that the `applies` method will called on a dataset before the `keep`
method is called on any of the frames inside.
"""

def applies(self, dcm_wrp) -> bool:
Expand All @@ -549,7 +549,7 @@ class FilterMultiStack(FrameFilter):
"""Filter out all but one `StackID`"""

def __init__(self, keep_id=None):
self._keep_id = keep_id
self._keep_id = str(keep_id) if keep_id is not None else None

def applies(self, dcm_wrp) -> bool:
first_fcs = dcm_wrp.frames[0].get('FrameContentSequence', (None,))[0]
Expand All @@ -562,10 +562,16 @@ def applies(self, dcm_wrp) -> bool:
self._selected = self._keep_id
if len(stack_ids) > 1:
if self._keep_id is None:
try:
sids = [int(x) for x in stack_ids]
except:
self._selected = dcm_wrp.frames[0].FrameContentSequence[0].StackID
else:
self._selected = str(min(sids))
warnings.warn(
'A multi-stack file was passed without an explicit filter, just using lowest StackID'
'A multi-stack file was passed without an explicit filter, '
f'using StackID = {self._selected}'
)
self._selected = min(stack_ids)
return True
return False

Expand Down Expand Up @@ -707,6 +713,7 @@ def vendor(self):

@cached_property
def frame_order(self):
"""The ordering of frames to make nD array"""
if self._frame_indices is None:
_ = self.image_shape
return np.lexsort(self._frame_indices.T)
Expand Down Expand Up @@ -742,14 +749,20 @@ def image_shape(self):
rows, cols = self.get('Rows'), self.get('Columns')
if None in (rows, cols):
raise WrapperError('Rows and/or Columns are empty.')
# Check number of frames, initialize array of frame indices
# Check number of frames and handle single frame files
n_frames = len(self.frames)
if n_frames == 1:
self._frame_indices = np.array([[0]], dtype=np.int64)
return (rows, cols)
# Initialize array of frame indices
try:
frame_indices = np.array(
[frame.FrameContentSequence[0].DimensionIndexValues for frame in self.frames]
)
except AttributeError:
raise WrapperError("Can't find frame 'DimensionIndexValues'")
if len(frame_indices.shape) == 1:
frame_indices = frame_indices.reshape(frame_indices.shape + (1,))
# Determine the shape and which indices to use
shape = [rows, cols]
curr_parts = n_frames
Expand Down
51 changes: 33 additions & 18 deletions nibabel/nicom/tests/test_dicomwrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,13 +427,6 @@ def fake_shape_dependents(
generate ipp values so slice location is negatively correlated with slice index
"""

class PrintBase:
def __repr__(self):
attr_strs = [
f'{attr}={getattr(self, attr)}' for attr in dir(self) if attr[0].isupper()
]
return f"{self.__class__.__name__}({', '.join(attr_strs)})"

class DimIdxSeqElem(pydicom.Dataset):
def __init__(self, dip=(0, 0), fgp=None):
super().__init__()
Expand All @@ -444,8 +437,8 @@ def __init__(self, dip=(0, 0), fgp=None):
class FrmContSeqElem(pydicom.Dataset):
def __init__(self, div, sid):
super().__init__()
self.DimensionIndexValues = div
self.StackID = sid
self.DimensionIndexValues = list(div)
self.StackID = str(sid)

class PlnPosSeqElem(pydicom.Dataset):
def __init__(self, ipp):
Expand Down Expand Up @@ -545,17 +538,28 @@ def test_shape(self):
with pytest.raises(didw.WrapperError):
dw.image_shape
fake_mf.Rows = 32
# No frame data raises WrapperError
# Single frame doesn't need dimension index values
assert dw.image_shape == (32, 64)
assert len(dw.frame_order) == 1
assert dw.frame_order[0] == 0
# Multiple frames do require dimension index values
fake_mf.PerFrameFunctionalGroupsSequence = [pydicom.Dataset(), pydicom.Dataset()]
with pytest.raises(didw.WrapperError):
dw.image_shape
MFW(fake_mf).image_shape
# check 2D shape with StackID index is 0
div_seq = ((1, 1),)
fake_mf.update(fake_shape_dependents(div_seq, sid_dim=0))
assert MFW(fake_mf).image_shape == (32, 64)
dw = MFW(fake_mf)
assert dw.image_shape == (32, 64)
assert len(dw.frame_order) == 1
assert dw.frame_order[0] == 0
# Check 2D shape with extraneous extra indices
div_seq = ((1, 1, 2),)
fake_mf.update(fake_shape_dependents(div_seq, sid_dim=0))
assert MFW(fake_mf).image_shape == (32, 64)
dw = MFW(fake_mf)
assert dw.image_shape == (32, 64)
assert len(dw.frame_order) == 1
assert dw.frame_order[0] == 0
# Check 2D plus time
div_seq = ((1, 1, 1), (1, 1, 2), (1, 1, 3))
fake_mf.update(fake_shape_dependents(div_seq, sid_dim=0))
Expand All @@ -569,7 +573,7 @@ def test_shape(self):
fake_mf.update(fake_shape_dependents(div_seq, sid_dim=0))
with pytest.warns(
UserWarning,
match='A multi-stack file was passed without an explicit filter, just using lowest StackID',
match='A multi-stack file was passed without an explicit filter,',
):
assert MFW(fake_mf).image_shape == (32, 64, 3)
# No warning if we expclitly select that StackID to keep
Expand All @@ -581,7 +585,7 @@ def test_shape(self):
fake_mf.update(fake_shape_dependents(div_seq, sid_seq=sid_seq))
with pytest.warns(
UserWarning,
match='A multi-stack file was passed without an explicit filter, just using lowest StackID',
match='A multi-stack file was passed without an explicit filter,',
):
assert MFW(fake_mf).image_shape == (32, 64, 3)
# No warning if we expclitly select that StackID to keep
Expand All @@ -590,6 +594,17 @@ def test_shape(self):
# Check for error when explicitly requested StackID is missing
with pytest.raises(didw.WrapperError):
MFW(fake_mf, frame_filters=(didw.FilterMultiStack(3),))
# StackID can be a string
div_seq = ((1,), (2,), (3,), (4,))
sid_seq = ('a', 'a', 'a', 'b')
fake_mf.update(fake_shape_dependents(div_seq, sid_seq=sid_seq))
with pytest.warns(
UserWarning,
match='A multi-stack file was passed without an explicit filter,',
):
assert MFW(fake_mf).image_shape == (32, 64, 3)
assert MFW(fake_mf, frame_filters=(didw.FilterMultiStack('a'),)).image_shape == (32, 64, 3)
assert MFW(fake_mf, frame_filters=(didw.FilterMultiStack('b'),)).image_shape == (32, 64)
# Make some fake frame data for 4D when StackID index is 0
div_seq = ((1, 1, 1), (1, 2, 1), (1, 1, 2), (1, 2, 2), (1, 1, 3), (1, 2, 3))
fake_mf.update(fake_shape_dependents(div_seq, sid_dim=0))
Expand All @@ -599,7 +614,7 @@ def test_shape(self):
fake_mf.update(fake_shape_dependents(div_seq, sid_dim=0))
with pytest.warns(
UserWarning,
match='A multi-stack file was passed without an explicit filter, just using lowest StackID',
match='A multi-stack file was passed without an explicit filter,',
):
with pytest.raises(didw.WrapperError):
MFW(fake_mf).image_shape
Expand Down Expand Up @@ -638,7 +653,7 @@ def test_shape(self):
fake_mf.update(fake_shape_dependents(div_seq, sid_seq=sid_seq))
with pytest.warns(
UserWarning,
match='A multi-stack file was passed without an explicit filter, just using lowest StackID',
match='A multi-stack file was passed without an explicit filter,',
):
with pytest.raises(didw.WrapperError):
MFW(fake_mf).image_shape
Expand All @@ -651,7 +666,7 @@ def test_shape(self):
fake_mf.update(fake_shape_dependents(div_seq, sid_dim=1))
with pytest.warns(
UserWarning,
match='A multi-stack file was passed without an explicit filter, just using lowest StackID',
match='A multi-stack file was passed without an explicit filter,',
):
assert MFW(fake_mf).image_shape == (32, 64, 3)
# Make some fake frame data for 4D when StackID index is 1
Expand Down

0 comments on commit d5995db

Please sign in to comment.