diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index 3b7b010d..f2d6e8bb 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -146,6 +146,8 @@ def from_tfrecords( ds_obj.map_labels( label_mapping=label_mapping, num_parallel_calls=num_parallel_calls ) + + ds_obj.filter_zero_volumes() # TODO automatically determine batch size ds_obj.batch(1) @@ -385,3 +387,9 @@ def repeat(self, n_repeats): # through once. self.dataset = self.dataset.repeat(n_repeats) return self + + def filter_zero_volumes(self): + self.dataset = self.dataset.filter( + lambda x, y: tf.cast(tf.math.reduce_sum(y), dtype="bool") + ) + return self