From a724d2d8fbe4c8aae54036c6e78d0e108180bdf3 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 5 Sep 2023 15:58:09 -0400 Subject: [PATCH 01/50] use both _1 and _2 segmentations --- ml4h/tensormap/ukb/mri.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index 8fc6a3c6e..30ad0e311 100644 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -2681,7 +2681,12 @@ def _mdrk_projection_both_views_pretrained(tm, hd5, dependents={}): def _segmented_t1map(tm, hd5, dependents={}): tensor = np.zeros(tm.shape, dtype=np.float32) - categorical_index_slice = get_tensor_at_first_date(hd5, tm.path_prefix, tm.name) + if f'{tm.path_prefix}/{tm.name}_1' in hd5: + categorical_index_slice = get_tensor_at_first_date(hd5, tm.path_prefix, f'{tm.name}_1') + elif f'{tm.path_prefix}/{tm.name}_2' in hd5: + categorical_index_slice = get_tensor_at_first_date(hd5, tm.path_prefix, f'{tm.name}_2') + else: + raise ValueError(f'Could not find T1 Map segmentation for tensormap: {tm.name}') categorical_one_hot = to_categorical(categorical_index_slice, len(tm.channel_map)) tensor[..., :] = pad_or_crop_array_to_shape(tensor[..., :].shape, categorical_one_hot) return tensor From 7d0fa9cec015548e6b3346aa955bfd38f4e8110b Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 5 Sep 2023 16:03:15 -0400 Subject: [PATCH 02/50] use both _1 and _2 segmentations --- ml4h/tensormap/ukb/mri.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index 30ad0e311..fe0d22a4e 100644 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -2693,7 +2693,7 @@ def _segmented_t1map(tm, hd5, dependents={}): t1map_b2_segmentation = TensorMap( - 'b2s_t1map_kassir_annotated_2', + 'b2s_t1map_kassir_annotated', interpretation=Interpretation.CATEGORICAL, shape=(384, 384, len(MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP)), channel_map=MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP, From 21d95bd3e8388e44634db73e5764fab3dc97810e Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Thu, 14 Sep 2023 13:58:44 +0000 Subject: [PATCH 03/50] TEMP: add dice metrics, copied from neuron --- ml4h/metrics.py | 255 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 255 insertions(+) diff --git a/ml4h/metrics.py b/ml4h/metrics.py index 27e8fb1c4..24b241de5 100755 --- a/ml4h/metrics.py +++ b/ml4h/metrics.py @@ -261,6 +261,261 @@ def loss(y_true, y_pred): return loss + + + + +def batch_channel_flatten(x): + """ + flatten volume elements outside of batch and channel + + using naming based on keras backend + + Args: + x: a Tensor of N dims, size [batch_size, ..., channel_size] + + Returns: + a Tensor of size [batch_size, V, channel_size] where V is the number of elements in the + middle N-2 dimensions + """ + return flatten_axes(x, range(1, K.ndim(x) - 1)) + +def flatten_axes(x, axes): + """ + Flatten the axes[0], ..., axes[-1] of a Tensor x + + Example: + v = tf.random.uniform((3, 4, 5, 6)) + ne.utils.flatten_axes(v, [1, 2]).shape # returns TensorShape([3, 20, 6]) + + Args: + x (Tensor): Tensor to flatten axes + axes: list of axes to flatten + + Returns: + Tensor: with flattened axes + + See Also: + batch_channel_flatten + tensorflow.keras.backend.batch_flatten + """ + assert isinstance(axes, (list, tuple, range)), \ + 'axes must be list or tuple of axes to be flattened' + assert np.all(np.diff(axes) == 1), 'axes need to be contiguous' + if axes[0] < 0: + assert axes[-1] < 0, 'if one axis is negative, all have to be negative' + assert axes[-1] < K.ndim(x), 'axis %d outside max axis %d' % (axes[-1], K.ndim(x) - 1) + + shp = K.shape(x) + lst = [shp[:axes[0]], - tf.ones((1,), dtype=tf.int32)] + if axes[-1] < len(x.shape) - 1 and not (axes[-1] == -1): + lst.append(shp[axes[-1] + 1:]) + reshape = tf.concat(lst, 0) + return K.reshape(x, reshape) + +class Dice: + """ + Dice of two Tensors. + Enables both 'soft' and 'hard' Dice, and weighting per label (or per batch entry) + + More information/citations: + - Dice. Measures of the amount of ecologic association between species. Ecology. 1945 + [orial paper describing metric] + - Dalca AV, Guttag J, Sabuncu MR Anatomical Priors in Convolutional Networks for + Unsupervised Biomedical Segmentation. CVPR 2018. https://arxiv.org/abs/1903.03148 + [paper for which we developed this method] + """ + + def __init__( + self, + dice_type='soft', + input_type='prob', + nb_labels=None, + weights=None, + check_input_limits=True, + laplace_smoothing=0., + normalize=False, + ): # regularization for bottom of Dice coeff + """ + Dice of two Tensors. + + If Tensors are probablistic/one-hot, should be size + [batch_size, *vol_size, nb_labels], where vol_size is the size of the volume (n-dims) + e.g. for a 2D vol, y has 4 dimensions, where each entry is a prob for that voxel + If Tensors contain the label id at each location, size should be + i.e. [batch_size, *vol_size], where vol_size is the size of the volume (n-dims). + e.g. for a 2D vol, y has 3 dimensions, where each entry is the max label of that voxel + If you provide [batch_size, *vol_size, 1], everything will still work since that just + assumes a volume with an extra dimension, but the Dice score would be the same. + + Args: + dice_type (str, optional): 'soft' or 'hard'. Defaults to 'soft'. + hard dice will not provide gradients (and hence should not be used with backprop) + input_type (str, optional): 'prob', 'one_hot', or 'max_label' + 'prob' (or 'one_hot' which will be treated the same) means we assume prob label maps + 'max_label' means we assume each volume location entry has the id of the seg label + Defaults to 'prob'. + nb_labels (int, optional): number of labels (maximum label + 1) + *Required* if using hard dice with max_label data. Defaults to None. + weights (np.array or tf.Tensor, optional): weights matrix, broadcastable to + [batch_size, nb_labels]. most often, would want to weight the labels, so would be + an array of size [1, nb_labels]. + Defaults to None. + check_input_limits (bool, optional): whether to check that input Tensors are in [0, 1]. + using tf debugging asserts. Defaults to True. + laplace_smoothing (float, optional): amount of laplace smoothing + (adding to the numerator and denominator), + use 0 for no smoothing (in which case we employ div_no_nan) + Default to 0. + normalize (bool, optional): whether to renormalize probabilistic Tensors. + Defaults to False. + """ + # input_type is 'prob', or 'max_label' + # dice_type is hard or soft + + self.dice_type = dice_type + self.input_type = input_type + self.nb_labels = nb_labels + self.weights = weights + self.normalize = normalize + self.check_input_limits = check_input_limits + self.laplace_smoothing = laplace_smoothing + + # checks + assert self.input_type in ['prob', 'max_label'] + + if self.dice_type == 'hard' and self.input_type == 'max_label': + assert self.nb_labels is not None, 'If doing hard Dice need nb_labels' + + if self.dice_type == 'soft': + assert self.input_type in ['prob', 'one_hot'], \ + 'if doing soft Dice, must use probabilistic (one_hot)encoding' + + def dice(self, y_true, y_pred): + """ + compute dice between two Tensors + + Args: + y_pred, y_true: Tensors + - if prob/onehot, then shape [batch_size, ..., nb_labels] + - if max_label (label at each location), then shape [batch_size, ...] + + Returns: + Tensor of size [batch_size, nb_labels] + """ + + # input checks + if self.input_type in ['prob', 'one_hot']: + + # Optionally re-normalize. + # Note that in some cases you explicitly don't wnat to, e.g. if you only return a + # subset of the labels + if self.normalize: + y_true = tf.math.divide_no_nan(y_true, K.sum(y_true, axis=-1, keepdims=True)) + y_pred = tf.math.divide_no_nan(y_pred, K.sum(y_pred, axis=-1, keepdims=True)) + + # some value checking + if self.check_input_limits: + msg = 'value outside range' + tf.debugging.assert_greater_equal(y_true, 0., msg) + tf.debugging.assert_greater_equal(y_pred, 0., msg) + tf.debugging.assert_less_equal(y_true, 1., msg) + tf.debugging.assert_less_equal(y_pred, 1., msg) + + # Prepare the volumes to operate on + # If we're doing 'hard' Dice, then we will prepare one_hot-based matrices of size + # [batch_size, nb_voxels, nb_labels], where for each voxel in each batch entry, + # the entries are either 0 or 1 + if self.dice_type == 'hard': + + # if given predicted probability, transform to "hard max"" + if self.input_type == 'prob': + # this breaks differentiability, since argmax is not differentiable. + # warnings.warn('You are using ne.metrics.Dice with probabilistic inputs' + # 'and computing *hard* dice. \n For this, we use argmax to' + # 'get the optimal label at each location, which is not' + # 'differentiable. Do not use expecting gradients.') + + if self.nb_labels is None: + self.nb_labels = y_pred.shape.as_list()[-1] + + y_pred = K.argmax(y_pred, axis=-1) + y_true = K.argmax(y_true, axis=-1) + + # transform to one hot notation + y_pred = K.one_hot(y_pred, self.nb_labels) + y_true = K.one_hot(y_true, self.nb_labels) + + # reshape to [batch_size, nb_voxels, nb_labels] + y_true = batch_channel_flatten(y_true) + y_pred = batch_channel_flatten(y_pred) + + # compute dice for each entry in batch. + # dice will now be [batch_size, nb_labels] + top = 2 * K.sum(y_true * y_pred, 1) + bottom = K.sum(K.square(y_true), 1) + K.sum(K.square(y_pred), 1) + if self.laplace_smoothing > 0: + eps = self.laplace_smoothing + return (top + eps) / (bottom + eps) + else: + return tf.math.divide_no_nan(top, bottom) + + def mean_dice(self, y_true, y_pred): + """ + mean dice across all patches and labels + optionally weighted + + Args: + y_pred, y_true: Tensors + - if prob/onehot, then shape [batch_size, ..., nb_labels] + - if max_label (label at each location), then shape [batch_size, ...] + + Returns: + dice (Tensor of size 1, tf.float32) + """ + + # compute dice, which will now be [batch_size, nb_labels] + dice_metric = self.dice(y_true, y_pred) + + # weigh the entries in the dice matrix: + if self.weights is not None: + assert len(self.weights.shape) == 2, \ + 'weights should be a matrix broadcastable to [batch_size, nb_labels]' + dice_metric *= self.weights + + # return one minus mean dice as loss + mean_dice_metric = K.mean(dice_metric) + tf.debugging.assert_all_finite(mean_dice_metric, 'metric not finite') + return mean_dice_metric + + def loss(self, y_true, y_pred): + """ + Deprecate anytime after 12/01/2021 + """ + # warnings.warn('ne.metrics.*.loss functions are deprecated.' + # 'Please use the ne.losses.*.loss functions.') + + return - self.mean_dice(y_true, y_pred) + +def dice(y_true, y_pred): + return Dice(laplace_smoothing=0.05).loss(y_true, y_pred) + +def per_class_dice(labels): + dice_fxns = [] + for label_key in labels: + label_idx = labels[label_key] + fxn_name = label_key.replace('-', '_').replace(' ', '_') + string_fxn = 'def ' + fxn_name + '_dice(y_true, y_pred):\n' + string_fxn += '\tdice = Dice(laplace_smoothing=0.05).dice(y_true, y_pred)\n' + string_fxn += '\tdice = K.mean(dice, axis=0)['+str(label_idx)+']\n' + string_fxn += '\treturn dice' + + exec(string_fxn) + dice_fxn = eval(fxn_name + '_dice') + dice_fxns.append(dice_fxn) + + return dice_fxns + def euclid_dist(v): return (v[0] - v[1])**2 From f5d90ffc12cd7278565ab6fdd41fa3fd32e3ce1c Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Thu, 14 Sep 2023 14:01:00 +0000 Subject: [PATCH 04/50] ENH: Use dice loss and metrics --- ml4h/tensormap/ukb/mri.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index fe0d22a4e..8212d8005 100644 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -12,7 +12,7 @@ import tensorflow as tf from tensorflow.keras.utils import to_categorical -from ml4h.metrics import weighted_crossentropy +from ml4h.metrics import weighted_crossentropy, dice, per_class_dice from ml4h.normalizer import ZeroMeanStd1, Standardize from ml4h.TensorMap import TensorMap, Interpretation, make_range_validator from ml4h.tensormap.ukb.demographics import is_genetic_man, is_genetic_woman @@ -2688,6 +2688,8 @@ def _segmented_t1map(tm, hd5, dependents={}): else: raise ValueError(f'Could not find T1 Map segmentation for tensormap: {tm.name}') categorical_one_hot = to_categorical(categorical_index_slice, len(tm.channel_map)) + + # padding/cropping tensor[..., :] = pad_or_crop_array_to_shape(tensor[..., :].shape, categorical_one_hot) return tensor @@ -2699,4 +2701,6 @@ def _segmented_t1map(tm, hd5, dependents={}): channel_map=MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP, path_prefix='ukb_cardiac_mri', tensor_from_file=_segmented_t1map, + loss=dice, + metrics=['categorical_accuracy'] + per_class_dice(MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP), ) From 814f68f47135ab2949c1d90c2b93903005ba811a Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Thu, 14 Sep 2023 14:01:52 +0000 Subject: [PATCH 05/50] ENH: Remove kidney label and merge body/background labels --- ml4h/defines.py | 6 +++--- ml4h/tensormap/ukb/mri.py | 5 +++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/ml4h/defines.py b/ml4h/defines.py index 408c368f8..eb18dde92 100755 --- a/ml4h/defines.py +++ b/ml4h/defines.py @@ -78,9 +78,9 @@ def __str__(self): 'interventricular_septum': 5, 'interatrial_septum': 6, 'crista_terminalis': 7, } MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP = { - 'background': 0, 'body': 1, 'thoracic_cavity': 2, 'liver': 3, 'stomach': 4, 'spleen': 5, 'kidney': 6, - 'interventricular_septum': 7, 'LV_free_wall': 8, 'anterolateral_pap': 9, 'posteromedial_pap': 10, 'LV_cavity': 11, - 'RV_free_wall': 12, 'RV_cavity': 13, + 'background': 0, 'thoracic_cavity': 1, 'liver': 2, 'stomach': 3, 'spleen': 4, + 'interventricular_septum': 5, 'LV_free_wall': 6, 'anterolateral_pap': 7, 'posteromedial_pap': 8, 'LV_cavity': 9, + 'RV_free_wall': 10, 'RV_cavity': 11, } MRI_SAX_SEGMENTED_CHANNEL_MAP = { 'background': 0, 'RV_free_wall': 1, 'interventricular_septum': 2, 'LV_free_wall': 3, 'LV_cavity': 4, diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index 8212d8005..3068e18c5 100644 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -2689,6 +2689,11 @@ def _segmented_t1map(tm, hd5, dependents={}): raise ValueError(f'Could not find T1 Map segmentation for tensormap: {tm.name}') categorical_one_hot = to_categorical(categorical_index_slice, len(tm.channel_map)) + # remove kidney label and merge body/background labels + categorical_one_hot = np.delete(categorical_one_hot, 6, axis=-1) + categorical_one_hot[..., 0] += categorical_one_hot[..., 1] + categorical_one_hot = np.delete(categorical_one_hot, 1, axis=-1) + # padding/cropping tensor[..., :] = pad_or_crop_array_to_shape(tensor[..., :].shape, categorical_one_hot) return tensor From aa36994ba88e7aae8400875949a849428ff56102 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Thu, 14 Sep 2023 14:24:36 +0000 Subject: [PATCH 06/50] FIX: Fix bad number of channels --- ml4h/tensormap/ukb/mri.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index 3068e18c5..b5dde9bc3 100644 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -2680,25 +2680,25 @@ def _mdrk_projection_both_views_pretrained(tm, hd5, dependents={}): def _segmented_t1map(tm, hd5, dependents={}): - tensor = np.zeros(tm.shape, dtype=np.float32) if f'{tm.path_prefix}/{tm.name}_1' in hd5: categorical_index_slice = get_tensor_at_first_date(hd5, tm.path_prefix, f'{tm.name}_1') elif f'{tm.path_prefix}/{tm.name}_2' in hd5: categorical_index_slice = get_tensor_at_first_date(hd5, tm.path_prefix, f'{tm.name}_2') else: raise ValueError(f'Could not find T1 Map segmentation for tensormap: {tm.name}') - categorical_one_hot = to_categorical(categorical_index_slice, len(tm.channel_map)) # remove kidney label and merge body/background labels + orig_num_channels = len(tm.channel_map) + 2 + categorical_one_hot = to_categorical(categorical_index_slice, orig_num_channels) categorical_one_hot = np.delete(categorical_one_hot, 6, axis=-1) categorical_one_hot[..., 0] += categorical_one_hot[..., 1] categorical_one_hot = np.delete(categorical_one_hot, 1, axis=-1) # padding/cropping + tensor = np.zeros(tm.shape, dtype=np.float32) tensor[..., :] = pad_or_crop_array_to_shape(tensor[..., :].shape, categorical_one_hot) return tensor - t1map_b2_segmentation = TensorMap( 'b2s_t1map_kassir_annotated', interpretation=Interpretation.CATEGORICAL, From 6d40407d7152b7e250c7f12c5a57561cd9055c36 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Thu, 14 Sep 2023 14:38:59 +0000 Subject: [PATCH 07/50] ENH: Use only one channel of the input image --- ml4h/tensormap/ukb/mri.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index b5dde9bc3..f398a82fa 100644 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -2669,16 +2669,25 @@ def _mdrk_projection_both_views_pretrained(tm, hd5, dependents={}): tensor_from_file=None, ) +def _pad_crop_single_channel(tm, hd5, dependents={}): + img = np.array( + tm.hd5_first_dataset_in_group(hd5, tm.hd5_key_guess()), + dtype=np.float32, + ) + img = img[...,[1]] + return pad_or_crop_array_to_shape( + tm.shape, + img, + ) t1map_b2 = TensorMap( 'shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map', - shape=(384, 384, 2), + shape=(384, 384, 1), path_prefix='ukb_cardiac_mri', normalization=ZeroMeanStd1(), - tensor_from_file=_pad_crop_tensor, + tensor_from_file=_pad_crop_single_channel, ) - def _segmented_t1map(tm, hd5, dependents={}): if f'{tm.path_prefix}/{tm.name}_1' in hd5: categorical_index_slice = get_tensor_at_first_date(hd5, tm.path_prefix, f'{tm.name}_1') From 5005d62a469deee2adb06e2636a1f9ed3249dae6 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Thu, 21 Sep 2023 18:47:45 +0000 Subject: [PATCH 08/50] WIP: hacking bottom of U-net --- ml4h/arguments.py | 1 + ml4h/models/conv_blocks.py | 114 ++++++++++++++++++++++++++++++++--- ml4h/models/model_factory.py | 22 +++++-- ml4h/recipes.py | 1 + 4 files changed, 123 insertions(+), 15 deletions(-) diff --git a/ml4h/arguments.py b/ml4h/arguments.py index 3277b9eb4..73706f3e0 100644 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -188,6 +188,7 @@ def parse_args(): parser.add_argument('--pool_z', default=1, type=int, help='Pooling size in the z-axis, if 1 no pooling will be performed.') parser.add_argument('--padding', default='same', help='Valid or same border padding on the convolutional layers.') parser.add_argument('--dense_blocks', nargs='*', default=[32, 32, 32], type=int, help='List of number of kernels in convolutional layers.') + parser.add_argument('--bottom_dense_blocks', nargs='*', default=[32], type=int, help='List of number of kernels in convolutional layers at the bottom.') parser.add_argument('--encoder_blocks', nargs='*', default=['conv_encode'], help='List of encoding blocks.') parser.add_argument('--merge_blocks', nargs='*', default=['concat'], help='List of merge blocks.') parser.add_argument('--decoder_blocks', nargs='*', default=['conv_decode', 'dense_decode'], help='List of decoding blocks.') diff --git a/ml4h/models/conv_blocks.py b/ml4h/models/conv_blocks.py index 2003bfbb1..ee16ef428 100644 --- a/ml4h/models/conv_blocks.py +++ b/ml4h/models/conv_blocks.py @@ -67,14 +67,15 @@ def __init__( ) for filters, x, y, z in zip(dense_blocks, x_filters[len(conv_layers):], y_filters[len(conv_layers):], z_filters[len(conv_layers):]) ] self.pools = _pool_layers_from_kind_and_dimension(dimension, pool_type, len(dense_blocks) + 1, pool_x, pool_y, pool_z) - self.fully_connected = DenseBlock( - widths=dense_layers, - activation=activation, - normalization=dense_normalize, - regularization=dense_regularize, - regularization_rate=dense_regularize_rate, - name=self.tensor_map.embed_name(), - ) if dense_layers else None + # self.fully_connected = DenseBlock( + # widths=dense_layers, + # activation=activation, + # normalization=dense_normalize, + # regularization=dense_regularize, + # regularization_rate=dense_regularize_rate, + # name=self.tensor_map.embed_name(), + # ) if dense_layers else None # TODO + self.fully_connected = None def can_apply(self): return self.tensor_map.axes() > 1 @@ -96,6 +97,93 @@ def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = Non return x +class ConvEncoderBottomBlock(Block): + def __init__( + self, + *, + tensor_map: TensorMap, + bottom_dense_blocks: List[int] = [32], + dense_layers: List[int] = [32], + dense_normalize: str = None, + dense_regularize: str = None, + dense_regularize_rate: float = 0.0, + conv_layers: List[int] = [32], + conv_type: str = 'conv', + conv_width: List[int] = [71], + conv_x: List[int] = [3], + conv_y: List[int] = [3], + conv_z: List[int] = [3], + block_size: int = 3, + activation: str = 'swish', + conv_normalize: str = None, + conv_regularize: str = None, + conv_regularize_rate: float = 0.0, + conv_dilate: bool = False, + pool_type: str = 'max', + pool_x: int = 2, + pool_y: int = 2, + pool_z: int = 1, + **kwargs, + ): + self.tensor_map = tensor_map + if not self.can_apply(): + return + dimension = self.tensor_map.axes() + + # list of filter dimensions should match the total number of convolutional layers + x_filters = _repeat_dimension(conv_width if dimension == 2 else conv_x, len(bottom_dense_blocks)) + y_filters = _repeat_dimension(conv_y, len(bottom_dense_blocks)) + z_filters = _repeat_dimension(conv_z, len(bottom_dense_blocks)) + + #self.preprocess_block = PreprocessBlock(['rotate'], [0.3]) + # self.res_block = Residual( + # dimension=dimension, filters_per_conv=conv_layers, conv_layer_type=conv_type, conv_x=x_filters[:len(conv_layers)], + # conv_y=y_filters[:len(conv_layers)], conv_z=z_filters[:len(conv_layers)], activation=activation, normalization=conv_normalize, + # regularization=conv_regularize, regularization_rate=conv_regularize_rate, dilate=conv_dilate, + # ) + + self.dense_blocks = [ + DenseConvolutional( + dimension=dimension, conv_layer_type=conv_type, filters=filters, conv_x=[x] * block_size, conv_y=[y] * block_size, + conv_z=[z]*block_size, block_size=block_size, activation=activation, normalization=conv_normalize, + regularization=conv_regularize, regularization_rate=conv_regularize_rate, + ) for filters, x, y, z in zip(bottom_dense_blocks, x_filters, y_filters, z_filters) + ] + self.pools = _pool_layers_from_kind_and_dimension(dimension, pool_type, len(bottom_dense_blocks), pool_x, pool_y, pool_z) + # self.fully_connected = DenseBlock( + # widths=dense_layers, + # activation=activation, + # normalization=dense_normalize, + # regularization=dense_regularize, + # regularization_rate=dense_regularize_rate, + # name=self.tensor_map.embed_name(), + # ) if dense_layers else None # TODO + # self.fully_connected = None + self.upsamples = [_upsampler(dimension, pool_x, pool_y, pool_z) for _ in range(len(bottom_dense_blocks))] + + + def can_apply(self): + return self.tensor_map.axes() > 1 + + def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = None) -> Tensor: + # if not self.can_apply(): + # return x + #x = self.preprocess_block(x) # TODO: upgrade to tensorflow 2.3+ + # x = self.res_block(x) + # intermediates[self.tensor_map].append(x) + y = intermediates[self.tensor_map][-1] + for i, (dense_block, pool) in enumerate(zip(self.dense_blocks, self.pools)): + y = pool(y) + y = dense_block(y) + # intermediates[self.tensor_map].append(x) + # if self.fully_connected: + # x = Flatten()(x) + # x = self.fully_connected(x, intermediates) + # intermediates[self.tensor_map].append(x) + for i, upsample in enumerate(self.upsamples): + y = upsample(y) + return y + class ConvDecoderBlock(Block): def __init__( self, @@ -150,8 +238,14 @@ def can_apply(self): def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = None) -> Tensor: if not self.can_apply(): return x - if x.shape != self.start_shape: - x = self.reshape(x) + + logging.info(f'x.shape') + logging.info(x.shape) + logging.info(f'start_shape') + logging.info(self.start_shape) + + # if x.shape != self.start_shape: + # x = self.reshape(x) for i, (dense_block, upsample) in enumerate(zip(self.dense_conv_blocks, self.upsamples)): intermediate = [intermediates[tm][len(self.upsamples)-(i+1)] for tm in self.u_connect_parents] x = concatenate(intermediate + [x]) if intermediate else x diff --git a/ml4h/models/model_factory.py b/ml4h/models/model_factory.py index b2cd68f5b..6c99a2643 100644 --- a/ml4h/models/model_factory.py +++ b/ml4h/models/model_factory.py @@ -18,7 +18,7 @@ from ml4h.optimizers import NON_KERAS_OPTIMIZERS, get_optimizer from ml4h.models.layer_wrappers import ACTIVATION_FUNCTIONS, NORMALIZATION_CLASSES from ml4h.models.pretrained_blocks import ResNetEncoder, MoviNetEncoder, BertEncoder -from ml4h.models.conv_blocks import ConvEncoderBlock, ConvDecoderBlock, ResidualBlock, PoolBlock, ConvUp, ConvDown +from ml4h.models.conv_blocks import ConvEncoderBlock, ConvEncoderBottomBlock, ConvDecoderBlock, ResidualBlock, PoolBlock, ConvUp, ConvDown from ml4h.models.transformer_blocks import TransformerDecoder, TransformerEncoder, PositionalEncoding, MultiHeadAttention from ml4h.models.merge_blocks import GlobalAveragePoolBlock, EncodeIdentityBlock, L2LossLayer, CosineLossLayer, VariationalDiagNormal from ml4h.models.merge_blocks import FlatConcatDenseBlock, FlatConcatBlock, AverageBlock, PairLossBlock, ReduceMean, ContrastiveLossLayer @@ -28,6 +28,7 @@ BLOCK_CLASSES = { 'conv_encode': ConvEncoderBlock, + 'conv_encode_bottom': ConvEncoderBottomBlock, 'conv_decode': ConvDecoderBlock, 'conv_up': ConvUp, 'conv_down': ConvDown, @@ -172,9 +173,9 @@ def multimodal_multitask_model( merge = identity for merge_block in merge_blocks: if isinstance(merge_block, Block): - merge = compose(merge, merge_block(**kwargs)) + merge = compose(merge, merge_block(tensor_map=tensor_maps_in[0], **kwargs)) else: - merge = compose(merge, BLOCK_CLASSES[merge_block](**kwargs)) + merge = compose(merge, BLOCK_CLASSES[merge_block](tensor_map=tm, **kwargs)) decoder_block_functions = {tm: identity for tm in tensor_maps_out} for tm in tensor_maps_out: @@ -246,13 +247,24 @@ def make_multimodal_multitask_model_block( multimodal_activation = merge(encodings, intermediates) merge_model = Model(list(inputs.values()), multimodal_activation) + + # TODO take me out + logging.info(f'TEMP') + merge_model.summary(print_fn=logging.info, expand_nested=True) + if isinstance(multimodal_activation, list): - latent_inputs = Input(shape=(multimodal_activation[0].shape[-1],), name='input_multimodal_space') + # latent_inputs = Input(shape=(multimodal_activation[0].shape[-1],), name='input_multimodal_space') + latent_inputs = Input(shape=multimodal_activation[0].shape[1:], name='input_multimodal_space') else: - latent_inputs = Input(shape=(multimodal_activation.shape[-1],), name='input_multimodal_space') + # latent_inputs = Input(shape=(multimodal_activation.shape[-1],), name='input_multimodal_space') + latent_inputs = Input(shape=multimodal_activation.shape[1:], name='input_multimodal_space') logging.info(f'multimodal_activation.shapes: {multimodal_activation.shape}') logging.info(f'Graph from input TensorMaps has intermediates: {[(tm, [ti.shape for ti in t]) for tm, t in intermediates.items()]}') + # TODO take me out + logging.info('latent inputs') + logging.info(latent_inputs) + decoders: Dict[TensorMap, Model] = {} decoder_outputs = [] for tm, decoder_block in decoder_block_functions.items(): # TODO this needs to be a topological sorted according to parents hierarchy diff --git a/ml4h/recipes.py b/ml4h/recipes.py index a913c6911..f88ede762 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -262,6 +262,7 @@ def inference_file_name(output_folder: str, id_: str) -> str: def infer_multimodal_multitask(args): + stats = Counter() tensor_paths_inferred = set() inference_tsv = inference_file_name(args.output_folder, args.id) From d91c7f62466afc5605ac655a4c9f6dc25d65393d Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 22 Sep 2023 21:24:44 +0000 Subject: [PATCH 09/50] ENH: Add mean and std for normalization --- ml4h/tensormap/ukb/mri.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index f398a82fa..0b2f521be 100644 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -2684,7 +2684,7 @@ def _pad_crop_single_channel(tm, hd5, dependents={}): 'shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map', shape=(384, 384, 1), path_prefix='ukb_cardiac_mri', - normalization=ZeroMeanStd1(), + normalization=Standardize(mean=548.15, std=627.32), tensor_from_file=_pad_crop_single_channel, ) From 3f1f27b3d901f5dc28543b73dc13413f90bc1c25 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Tue, 26 Sep 2023 16:31:01 +0000 Subject: [PATCH 10/50] ENH: Add neurite and voxelmorph to docker --- docker/vm_boot_images/config/tensorflow-requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker/vm_boot_images/config/tensorflow-requirements.txt b/docker/vm_boot_images/config/tensorflow-requirements.txt index 4f01a8daf..d8e2c49c7 100644 --- a/docker/vm_boot_images/config/tensorflow-requirements.txt +++ b/docker/vm_boot_images/config/tensorflow-requirements.txt @@ -41,3 +41,5 @@ boto3 ml4ht==0.0.10 google-cloud-storage umap-learn[plot] +neurite +voxelmorph From e173f101162bd2404926067cccf0cca8c1b8e785 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Tue, 26 Sep 2023 17:55:11 +0000 Subject: [PATCH 11/50] FIX: Use dice loss from neurite --- ml4h/metrics.py | 239 +----------------------------------------------- 1 file changed, 2 insertions(+), 237 deletions(-) diff --git a/ml4h/metrics.py b/ml4h/metrics.py index 24b241de5..fa414d483 100755 --- a/ml4h/metrics.py +++ b/ml4h/metrics.py @@ -10,6 +10,8 @@ from tensorflow.keras.losses import binary_crossentropy, categorical_crossentropy, sparse_categorical_crossentropy from tensorflow.keras.losses import logcosh, cosine_similarity, mean_squared_error, mean_absolute_error, mean_absolute_percentage_error +from neurite.tf.losses import Dice + STRING_METRICS = [ 'categorical_crossentropy','binary_crossentropy','mean_absolute_error','mae', 'mean_squared_error', 'mse', 'cosine_similarity', 'logcosh', 'sparse_categorical_crossentropy', @@ -260,243 +262,6 @@ def loss(y_true, y_pred): return loss - - - - - -def batch_channel_flatten(x): - """ - flatten volume elements outside of batch and channel - - using naming based on keras backend - - Args: - x: a Tensor of N dims, size [batch_size, ..., channel_size] - - Returns: - a Tensor of size [batch_size, V, channel_size] where V is the number of elements in the - middle N-2 dimensions - """ - return flatten_axes(x, range(1, K.ndim(x) - 1)) - -def flatten_axes(x, axes): - """ - Flatten the axes[0], ..., axes[-1] of a Tensor x - - Example: - v = tf.random.uniform((3, 4, 5, 6)) - ne.utils.flatten_axes(v, [1, 2]).shape # returns TensorShape([3, 20, 6]) - - Args: - x (Tensor): Tensor to flatten axes - axes: list of axes to flatten - - Returns: - Tensor: with flattened axes - - See Also: - batch_channel_flatten - tensorflow.keras.backend.batch_flatten - """ - assert isinstance(axes, (list, tuple, range)), \ - 'axes must be list or tuple of axes to be flattened' - assert np.all(np.diff(axes) == 1), 'axes need to be contiguous' - if axes[0] < 0: - assert axes[-1] < 0, 'if one axis is negative, all have to be negative' - assert axes[-1] < K.ndim(x), 'axis %d outside max axis %d' % (axes[-1], K.ndim(x) - 1) - - shp = K.shape(x) - lst = [shp[:axes[0]], - tf.ones((1,), dtype=tf.int32)] - if axes[-1] < len(x.shape) - 1 and not (axes[-1] == -1): - lst.append(shp[axes[-1] + 1:]) - reshape = tf.concat(lst, 0) - return K.reshape(x, reshape) - -class Dice: - """ - Dice of two Tensors. - Enables both 'soft' and 'hard' Dice, and weighting per label (or per batch entry) - - More information/citations: - - Dice. Measures of the amount of ecologic association between species. Ecology. 1945 - [orial paper describing metric] - - Dalca AV, Guttag J, Sabuncu MR Anatomical Priors in Convolutional Networks for - Unsupervised Biomedical Segmentation. CVPR 2018. https://arxiv.org/abs/1903.03148 - [paper for which we developed this method] - """ - - def __init__( - self, - dice_type='soft', - input_type='prob', - nb_labels=None, - weights=None, - check_input_limits=True, - laplace_smoothing=0., - normalize=False, - ): # regularization for bottom of Dice coeff - """ - Dice of two Tensors. - - If Tensors are probablistic/one-hot, should be size - [batch_size, *vol_size, nb_labels], where vol_size is the size of the volume (n-dims) - e.g. for a 2D vol, y has 4 dimensions, where each entry is a prob for that voxel - If Tensors contain the label id at each location, size should be - i.e. [batch_size, *vol_size], where vol_size is the size of the volume (n-dims). - e.g. for a 2D vol, y has 3 dimensions, where each entry is the max label of that voxel - If you provide [batch_size, *vol_size, 1], everything will still work since that just - assumes a volume with an extra dimension, but the Dice score would be the same. - - Args: - dice_type (str, optional): 'soft' or 'hard'. Defaults to 'soft'. - hard dice will not provide gradients (and hence should not be used with backprop) - input_type (str, optional): 'prob', 'one_hot', or 'max_label' - 'prob' (or 'one_hot' which will be treated the same) means we assume prob label maps - 'max_label' means we assume each volume location entry has the id of the seg label - Defaults to 'prob'. - nb_labels (int, optional): number of labels (maximum label + 1) - *Required* if using hard dice with max_label data. Defaults to None. - weights (np.array or tf.Tensor, optional): weights matrix, broadcastable to - [batch_size, nb_labels]. most often, would want to weight the labels, so would be - an array of size [1, nb_labels]. - Defaults to None. - check_input_limits (bool, optional): whether to check that input Tensors are in [0, 1]. - using tf debugging asserts. Defaults to True. - laplace_smoothing (float, optional): amount of laplace smoothing - (adding to the numerator and denominator), - use 0 for no smoothing (in which case we employ div_no_nan) - Default to 0. - normalize (bool, optional): whether to renormalize probabilistic Tensors. - Defaults to False. - """ - # input_type is 'prob', or 'max_label' - # dice_type is hard or soft - - self.dice_type = dice_type - self.input_type = input_type - self.nb_labels = nb_labels - self.weights = weights - self.normalize = normalize - self.check_input_limits = check_input_limits - self.laplace_smoothing = laplace_smoothing - - # checks - assert self.input_type in ['prob', 'max_label'] - - if self.dice_type == 'hard' and self.input_type == 'max_label': - assert self.nb_labels is not None, 'If doing hard Dice need nb_labels' - - if self.dice_type == 'soft': - assert self.input_type in ['prob', 'one_hot'], \ - 'if doing soft Dice, must use probabilistic (one_hot)encoding' - - def dice(self, y_true, y_pred): - """ - compute dice between two Tensors - - Args: - y_pred, y_true: Tensors - - if prob/onehot, then shape [batch_size, ..., nb_labels] - - if max_label (label at each location), then shape [batch_size, ...] - - Returns: - Tensor of size [batch_size, nb_labels] - """ - - # input checks - if self.input_type in ['prob', 'one_hot']: - - # Optionally re-normalize. - # Note that in some cases you explicitly don't wnat to, e.g. if you only return a - # subset of the labels - if self.normalize: - y_true = tf.math.divide_no_nan(y_true, K.sum(y_true, axis=-1, keepdims=True)) - y_pred = tf.math.divide_no_nan(y_pred, K.sum(y_pred, axis=-1, keepdims=True)) - - # some value checking - if self.check_input_limits: - msg = 'value outside range' - tf.debugging.assert_greater_equal(y_true, 0., msg) - tf.debugging.assert_greater_equal(y_pred, 0., msg) - tf.debugging.assert_less_equal(y_true, 1., msg) - tf.debugging.assert_less_equal(y_pred, 1., msg) - - # Prepare the volumes to operate on - # If we're doing 'hard' Dice, then we will prepare one_hot-based matrices of size - # [batch_size, nb_voxels, nb_labels], where for each voxel in each batch entry, - # the entries are either 0 or 1 - if self.dice_type == 'hard': - - # if given predicted probability, transform to "hard max"" - if self.input_type == 'prob': - # this breaks differentiability, since argmax is not differentiable. - # warnings.warn('You are using ne.metrics.Dice with probabilistic inputs' - # 'and computing *hard* dice. \n For this, we use argmax to' - # 'get the optimal label at each location, which is not' - # 'differentiable. Do not use expecting gradients.') - - if self.nb_labels is None: - self.nb_labels = y_pred.shape.as_list()[-1] - - y_pred = K.argmax(y_pred, axis=-1) - y_true = K.argmax(y_true, axis=-1) - - # transform to one hot notation - y_pred = K.one_hot(y_pred, self.nb_labels) - y_true = K.one_hot(y_true, self.nb_labels) - - # reshape to [batch_size, nb_voxels, nb_labels] - y_true = batch_channel_flatten(y_true) - y_pred = batch_channel_flatten(y_pred) - - # compute dice for each entry in batch. - # dice will now be [batch_size, nb_labels] - top = 2 * K.sum(y_true * y_pred, 1) - bottom = K.sum(K.square(y_true), 1) + K.sum(K.square(y_pred), 1) - if self.laplace_smoothing > 0: - eps = self.laplace_smoothing - return (top + eps) / (bottom + eps) - else: - return tf.math.divide_no_nan(top, bottom) - - def mean_dice(self, y_true, y_pred): - """ - mean dice across all patches and labels - optionally weighted - - Args: - y_pred, y_true: Tensors - - if prob/onehot, then shape [batch_size, ..., nb_labels] - - if max_label (label at each location), then shape [batch_size, ...] - - Returns: - dice (Tensor of size 1, tf.float32) - """ - - # compute dice, which will now be [batch_size, nb_labels] - dice_metric = self.dice(y_true, y_pred) - - # weigh the entries in the dice matrix: - if self.weights is not None: - assert len(self.weights.shape) == 2, \ - 'weights should be a matrix broadcastable to [batch_size, nb_labels]' - dice_metric *= self.weights - - # return one minus mean dice as loss - mean_dice_metric = K.mean(dice_metric) - tf.debugging.assert_all_finite(mean_dice_metric, 'metric not finite') - return mean_dice_metric - - def loss(self, y_true, y_pred): - """ - Deprecate anytime after 12/01/2021 - """ - # warnings.warn('ne.metrics.*.loss functions are deprecated.' - # 'Please use the ne.losses.*.loss functions.') - - return - self.mean_dice(y_true, y_pred) - def dice(y_true, y_pred): return Dice(laplace_smoothing=0.05).loss(y_true, y_pred) From ab321bb620830c097f14037f14ad0124b97ecd9b Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Tue, 26 Sep 2023 19:41:18 +0000 Subject: [PATCH 12/50] STYLE: Fix up WIP code on hacking bottom of U-net --- ml4h/arguments.py | 3 +- ml4h/models/conv_blocks.py | 90 +++++++++++------------------------- ml4h/models/model_factory.py | 22 +++------ ml4h/recipes.py | 1 - 4 files changed, 34 insertions(+), 82 deletions(-) diff --git a/ml4h/arguments.py b/ml4h/arguments.py index 0c0291fc0..1a9b107da 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -189,7 +189,8 @@ def parse_args(): parser.add_argument('--pool_z', default=1, type=int, help='Pooling size in the z-axis, if 1 no pooling will be performed.') parser.add_argument('--padding', default='same', help='Valid or same border padding on the convolutional layers.') parser.add_argument('--dense_blocks', nargs='*', default=[32, 32, 32], type=int, help='List of number of kernels in convolutional layers.') - parser.add_argument('--bottom_dense_blocks', nargs='*', default=[32], type=int, help='List of number of kernels in convolutional layers at the bottom.') + parser.add_argument('--merge_dimension', default=3, type=int, help='Dimension of the merge layer.') + parser.add_argument('--merge_dense_blocks', nargs='*', default=[32], type=int, help='List of number of kernels in convolutional merge layer.') parser.add_argument('--encoder_blocks', nargs='*', default=['conv_encode'], help='List of encoding blocks.') parser.add_argument('--merge_blocks', nargs='*', default=['concat'], help='List of merge blocks.') parser.add_argument('--decoder_blocks', nargs='*', default=['conv_decode', 'dense_decode'], help='List of decoding blocks.') diff --git a/ml4h/models/conv_blocks.py b/ml4h/models/conv_blocks.py index ee16ef428..6888ca766 100755 --- a/ml4h/models/conv_blocks.py +++ b/ml4h/models/conv_blocks.py @@ -67,15 +67,16 @@ def __init__( ) for filters, x, y, z in zip(dense_blocks, x_filters[len(conv_layers):], y_filters[len(conv_layers):], z_filters[len(conv_layers):]) ] self.pools = _pool_layers_from_kind_and_dimension(dimension, pool_type, len(dense_blocks) + 1, pool_x, pool_y, pool_z) - # self.fully_connected = DenseBlock( - # widths=dense_layers, - # activation=activation, - # normalization=dense_normalize, - # regularization=dense_regularize, - # regularization_rate=dense_regularize_rate, - # name=self.tensor_map.embed_name(), - # ) if dense_layers else None # TODO - self.fully_connected = None + + use_fully_connected = dense_layers and (not np.all(np.array(dense_layers)==0)) + self.fully_connected = DenseBlock( + widths=dense_layers, + activation=activation, + normalization=dense_normalize, + regularization=dense_regularize, + regularization_rate=dense_regularize_rate, + name=self.tensor_map.embed_name(), + ) if use_fully_connected else None def can_apply(self): return self.tensor_map.axes() > 1 @@ -97,17 +98,12 @@ def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = Non return x -class ConvEncoderBottomBlock(Block): +class ConvEncoderMergeBlock(Block): def __init__( self, *, - tensor_map: TensorMap, - bottom_dense_blocks: List[int] = [32], - dense_layers: List[int] = [32], - dense_normalize: str = None, - dense_regularize: str = None, - dense_regularize_rate: float = 0.0, - conv_layers: List[int] = [32], + merge_dimension, + merge_dense_blocks: List[int] = [32], conv_type: str = 'conv', conv_width: List[int] = [71], conv_x: List[int] = [3], @@ -118,68 +114,40 @@ def __init__( conv_normalize: str = None, conv_regularize: str = None, conv_regularize_rate: float = 0.0, - conv_dilate: bool = False, pool_type: str = 'max', pool_x: int = 2, pool_y: int = 2, pool_z: int = 1, **kwargs, ): - self.tensor_map = tensor_map + self.merge_dimension = merge_dimension if not self.can_apply(): return - dimension = self.tensor_map.axes() # list of filter dimensions should match the total number of convolutional layers - x_filters = _repeat_dimension(conv_width if dimension == 2 else conv_x, len(bottom_dense_blocks)) - y_filters = _repeat_dimension(conv_y, len(bottom_dense_blocks)) - z_filters = _repeat_dimension(conv_z, len(bottom_dense_blocks)) - - #self.preprocess_block = PreprocessBlock(['rotate'], [0.3]) - # self.res_block = Residual( - # dimension=dimension, filters_per_conv=conv_layers, conv_layer_type=conv_type, conv_x=x_filters[:len(conv_layers)], - # conv_y=y_filters[:len(conv_layers)], conv_z=z_filters[:len(conv_layers)], activation=activation, normalization=conv_normalize, - # regularization=conv_regularize, regularization_rate=conv_regularize_rate, dilate=conv_dilate, - # ) + x_filters = _repeat_dimension(conv_width if self.merge_dimension == 2 else conv_x, len(merge_dense_blocks)) + y_filters = _repeat_dimension(conv_y, len(merge_dense_blocks)) + z_filters = _repeat_dimension(conv_z, len(merge_dense_blocks)) self.dense_blocks = [ DenseConvolutional( - dimension=dimension, conv_layer_type=conv_type, filters=filters, conv_x=[x] * block_size, conv_y=[y] * block_size, + dimension=self.merge_dimension, conv_layer_type=conv_type, filters=filters, conv_x=[x] * block_size, conv_y=[y] * block_size, conv_z=[z]*block_size, block_size=block_size, activation=activation, normalization=conv_normalize, regularization=conv_regularize, regularization_rate=conv_regularize_rate, - ) for filters, x, y, z in zip(bottom_dense_blocks, x_filters, y_filters, z_filters) + ) for filters, x, y, z in zip(merge_dense_blocks, x_filters, y_filters, z_filters) ] - self.pools = _pool_layers_from_kind_and_dimension(dimension, pool_type, len(bottom_dense_blocks), pool_x, pool_y, pool_z) - # self.fully_connected = DenseBlock( - # widths=dense_layers, - # activation=activation, - # normalization=dense_normalize, - # regularization=dense_regularize, - # regularization_rate=dense_regularize_rate, - # name=self.tensor_map.embed_name(), - # ) if dense_layers else None # TODO - # self.fully_connected = None - self.upsamples = [_upsampler(dimension, pool_x, pool_y, pool_z) for _ in range(len(bottom_dense_blocks))] - + self.pools = _pool_layers_from_kind_and_dimension(self.merge_dimension, pool_type, len(merge_dense_blocks), pool_x, pool_y, pool_z) + self.upsamples = [_upsampler(self.merge_dimension, pool_x, pool_y, pool_z) for _ in range(len(merge_dense_blocks))] def can_apply(self): - return self.tensor_map.axes() > 1 + return self.merge_dimension > 1 def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = None) -> Tensor: - # if not self.can_apply(): - # return x - #x = self.preprocess_block(x) # TODO: upgrade to tensorflow 2.3+ - # x = self.res_block(x) - # intermediates[self.tensor_map].append(x) - y = intermediates[self.tensor_map][-1] + y = [x[-1] for tm, x in intermediates.items()] + y = concatenate(y) if len(y) > 1 else y[0] for i, (dense_block, pool) in enumerate(zip(self.dense_blocks, self.pools)): y = pool(y) y = dense_block(y) - # intermediates[self.tensor_map].append(x) - # if self.fully_connected: - # x = Flatten()(x) - # x = self.fully_connected(x, intermediates) - # intermediates[self.tensor_map].append(x) for i, upsample in enumerate(self.upsamples): y = upsample(y) return y @@ -238,14 +206,8 @@ def can_apply(self): def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = None) -> Tensor: if not self.can_apply(): return x - - logging.info(f'x.shape') - logging.info(x.shape) - logging.info(f'start_shape') - logging.info(self.start_shape) - - # if x.shape != self.start_shape: - # x = self.reshape(x) + if x.shape[1:] != self.start_shape: + x = self.reshape(x) for i, (dense_block, upsample) in enumerate(zip(self.dense_conv_blocks, self.upsamples)): intermediate = [intermediates[tm][len(self.upsamples)-(i+1)] for tm in self.u_connect_parents] x = concatenate(intermediate + [x]) if intermediate else x diff --git a/ml4h/models/model_factory.py b/ml4h/models/model_factory.py index 68a5a3e65..ac751ea4d 100755 --- a/ml4h/models/model_factory.py +++ b/ml4h/models/model_factory.py @@ -18,7 +18,7 @@ from ml4h.optimizers import NON_KERAS_OPTIMIZERS, get_optimizer from ml4h.models.layer_wrappers import ACTIVATION_FUNCTIONS, NORMALIZATION_CLASSES from ml4h.models.pretrained_blocks import ResNetEncoder, MoviNetEncoder, BertEncoder -from ml4h.models.conv_blocks import ConvEncoderBlock, ConvEncoderBottomBlock, ConvDecoderBlock, ResidualBlock, PoolBlock, ConvUp, ConvDown +from ml4h.models.conv_blocks import ConvEncoderBlock, ConvEncoderMergeBlock, ConvDecoderBlock, ResidualBlock, PoolBlock, ConvUp, ConvDown from ml4h.models.transformer_blocks import TransformerDecoder, TransformerEncoder, PositionalEncoding from ml4h.models.transformer_blocks_embedding import TransformerEncoderEmbedding,MultiHeadAttention from ml4h.models.perceiver_blocks import PerceiverEncoder,PerceiverLatentLayer @@ -30,7 +30,7 @@ BLOCK_CLASSES = { 'conv_encode': ConvEncoderBlock, - 'conv_encode_bottom': ConvEncoderBottomBlock, + 'merge_conv_encode': ConvEncoderMergeBlock, 'conv_decode': ConvDecoderBlock, 'conv_up': ConvUp, 'conv_down': ConvDown, @@ -178,9 +178,9 @@ def multimodal_multitask_model( merge = identity for merge_block in merge_blocks: if isinstance(merge_block, Block): - merge = compose(merge, merge_block(tensor_map=tensor_maps_in[0], **kwargs)) + merge = compose(merge, merge_block(**kwargs)) else: - merge = compose(merge, BLOCK_CLASSES[merge_block](tensor_map=tm, **kwargs)) + merge = compose(merge, BLOCK_CLASSES[merge_block](**kwargs)) decoder_block_functions = {tm: identity for tm in tensor_maps_out} for tm in tensor_maps_out: @@ -252,23 +252,13 @@ def make_multimodal_multitask_model_block( multimodal_activation = merge(encodings, intermediates) merge_model = Model(list(inputs.values()), multimodal_activation) - # TODO take me out - logging.info(f'TEMP') - merge_model.summary(print_fn=logging.info, expand_nested=True) - if isinstance(multimodal_activation, list): - # latent_inputs = Input(shape=(multimodal_activation[0].shape[-1],), name='input_multimodal_space') - latent_inputs = Input(shape=multimodal_activation[0].shape[1:], name='input_multimodal_space') + latent_inputs = Input(shape=(multimodal_activation[0].shape[-1],), name='input_multimodal_space') else: - # latent_inputs = Input(shape=(multimodal_activation.shape[-1],), name='input_multimodal_space') - latent_inputs = Input(shape=multimodal_activation.shape[1:], name='input_multimodal_space') + latent_inputs = Input(shape=(multimodal_activation.shape[-1],), name='input_multimodal_space') logging.info(f'multimodal_activation.shapes: {multimodal_activation.shape}') logging.info(f'Graph from input TensorMaps has intermediates: {[(tm, [ti.shape for ti in t]) for tm, t in intermediates.items()]}') - # TODO take me out - logging.info('latent inputs') - logging.info(latent_inputs) - decoders: Dict[TensorMap, Model] = {} decoder_outputs = [] for tm, decoder_block in decoder_block_functions.items(): # TODO this needs to be a topological sorted according to parents hierarchy diff --git a/ml4h/recipes.py b/ml4h/recipes.py index e2673bf92..fb9ae763e 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -298,7 +298,6 @@ def inference_file_name(output_folder: str, id_: str) -> str: def infer_multimodal_multitask(args): - stats = Counter() tensor_paths_inferred = set() inference_tsv = inference_file_name(args.output_folder, args.id) From 9e2d864e04abda99e5781ae4e5ba521bbe2e82a3 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 29 Sep 2023 15:46:40 +0000 Subject: [PATCH 13/50] ENH: Add merged paps for segmentation tensormap --- ml4h/defines.py | 5 +++++ ml4h/tensormap/ukb/mri.py | 13 ++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ml4h/defines.py b/ml4h/defines.py index eb18dde92..b0f703512 100755 --- a/ml4h/defines.py +++ b/ml4h/defines.py @@ -82,6 +82,11 @@ def __str__(self): 'interventricular_septum': 5, 'LV_free_wall': 6, 'anterolateral_pap': 7, 'posteromedial_pap': 8, 'LV_cavity': 9, 'RV_free_wall': 10, 'RV_cavity': 11, } +MRI_SAX_MERGED_PAP_SEGMENTED_CHANNEL_MAP = { + 'background': 0, 'thoracic_cavity': 1, 'liver': 2, 'stomach': 3, 'spleen': 4, + 'interventricular_septum': 5, 'LV_free_wall': 6, 'anterolateral_pap': 7, 'posteromedial_pap': 7, 'LV_cavity': 8, + 'RV_free_wall': 9, 'RV_cavity': 10, +} MRI_SAX_SEGMENTED_CHANNEL_MAP = { 'background': 0, 'RV_free_wall': 1, 'interventricular_septum': 2, 'LV_free_wall': 3, 'LV_cavity': 4, 'RV_cavity': 5, 'thoracic_cavity': 6, 'liver': 7, 'stomach': 8, 'spleen': 9, 'kidney': 11, 'body': 10, diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index 0b2f521be..d7f790187 100755 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -20,7 +20,7 @@ MRI_LAX_2CH_SEGMENTED_CHANNEL_MAP, MRI_SAX_SEGMENTED_CHANNEL_MAP, LAX_4CH_HEART_LABELS, LAX_4CH_MYOCARDIUM_LABELS, StorageType, LAX_3CH_HEART_LABELS, \ LAX_2CH_HEART_LABELS from ml4h.tensormap.general import get_tensor_at_first_date, normalized_first_date, pad_or_crop_array_to_shape, tensor_from_hd5 -from ml4h.defines import MRI_LAX_3CH_SEGMENTED_CHANNEL_MAP, MRI_LAX_4CH_SEGMENTED_CHANNEL_MAP, MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP, MRI_AO_SEGMENTED_CHANNEL_MAP, MRI_LIVER_SEGMENTED_CHANNEL_MAP, SAX_HEART_LABELS +from ml4h.defines import MRI_LAX_3CH_SEGMENTED_CHANNEL_MAP, MRI_LAX_4CH_SEGMENTED_CHANNEL_MAP, MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP, MRI_SAX_MERGED_PAP_SEGMENTED_CHANNEL_MAP, MRI_AO_SEGMENTED_CHANNEL_MAP, MRI_LIVER_SEGMENTED_CHANNEL_MAP, SAX_HEART_LABELS def _slice_subset_tensor( @@ -2718,3 +2718,14 @@ def _segmented_t1map(tm, hd5, dependents={}): loss=dice, metrics=['categorical_accuracy'] + per_class_dice(MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP), ) + +t1map_b2_segmentation_merged_paps = TensorMap( + 'b2s_t1map_kassir_annotated', + interpretation=Interpretation.CATEGORICAL, + shape=(384, 384, len(MRI_SAX_MERGED_PAP_SEGMENTED_CHANNEL_MAP)), + channel_map=MRI_SAX_MERGED_PAP_SEGMENTED_CHANNEL_MAP, + path_prefix='ukb_cardiac_mri', + tensor_from_file=_segmented_t1map, + loss=dice, + metrics=['categorical_accuracy'] + per_class_dice(MRI_SAX_MERGED_PAP_SEGMENTED_CHANNEL_MAP), +) From 78dd55d7464867446df2767b41131b29050cba7a Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 29 Sep 2023 20:30:36 +0000 Subject: [PATCH 14/50] WIP: Fix Unet concats --- ml4h/arguments.py | 1 + ml4h/models/conv_blocks.py | 74 ++++++++++++++++++++++++++++++++++-- ml4h/models/model_factory.py | 3 +- 3 files changed, 74 insertions(+), 4 deletions(-) diff --git a/ml4h/arguments.py b/ml4h/arguments.py index 1a9b107da..e8c54fe1b 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -191,6 +191,7 @@ def parse_args(): parser.add_argument('--dense_blocks', nargs='*', default=[32, 32, 32], type=int, help='List of number of kernels in convolutional layers.') parser.add_argument('--merge_dimension', default=3, type=int, help='Dimension of the merge layer.') parser.add_argument('--merge_dense_blocks', nargs='*', default=[32], type=int, help='List of number of kernels in convolutional merge layer.') + parser.add_argument('--decoder_dense_blocks', nargs='*', default=[32, 32, 32], type=int, help='List of number of kernels in convolutional decoder layers.') parser.add_argument('--encoder_blocks', nargs='*', default=['conv_encode'], help='List of encoding blocks.') parser.add_argument('--merge_blocks', nargs='*', default=['concat'], help='List of merge blocks.') parser.add_argument('--decoder_blocks', nargs='*', default=['conv_decode', 'dense_decode'], help='List of decoding blocks.') diff --git a/ml4h/models/conv_blocks.py b/ml4h/models/conv_blocks.py index 6888ca766..c06e20443 100755 --- a/ml4h/models/conv_blocks.py +++ b/ml4h/models/conv_blocks.py @@ -148,8 +148,6 @@ def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = Non for i, (dense_block, pool) in enumerate(zip(self.dense_blocks, self.pools)): y = pool(y) y = dense_block(y) - for i, upsample in enumerate(self.upsamples): - y = upsample(y) return y class ConvDecoderBlock(Block): @@ -206,7 +204,7 @@ def can_apply(self): def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = None) -> Tensor: if not self.can_apply(): return x - if x.shape[1:] != self.start_shape: + if x.shape != self.start_shape: x = self.reshape(x) for i, (dense_block, upsample) in enumerate(zip(self.dense_conv_blocks, self.upsamples)): intermediate = [intermediates[tm][len(self.upsamples)-(i+1)] for tm in self.u_connect_parents] @@ -218,6 +216,76 @@ def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = Non return self.conv_label(x) +class ConvUnetDecoderBlock(Block): + def __init__( + self, + *, + tensor_map: TensorMap, + decoder_dense_blocks: List[int] = [32, 32, 32], + conv_type: str = 'conv', + conv_width: List[int] = [71], + conv_x: List[int] = [3], + conv_y: List[int] = [3], + conv_z: List[int] = [3], + block_size: int = 3, + activation: str = 'swish', + conv_normalize: str = None, + conv_regularize: str = None, + conv_regularize_rate: float = 0.0, + pool_x: int = 2, + pool_y: int = 2, + pool_z: int = 1, + u_connect_parents: List[TensorMap] = None, + **kwargs, + ): + self.tensor_map = tensor_map + if not self.can_apply(): + return + dimension = tensor_map.axes() + + x_filters = _repeat_dimension(conv_width if dimension == 2 else conv_x, len(decoder_dense_blocks)) + y_filters = _repeat_dimension(conv_y, len(decoder_dense_blocks)) + z_filters = _repeat_dimension(conv_z, len(decoder_dense_blocks)) + self.dense_conv_blocks = [ + DenseConvolutional( + dimension=tensor_map.axes(), conv_layer_type=conv_type, filters=filters, conv_x=[x] * block_size, + conv_y=[y]*block_size, conv_z=[z]*block_size, block_size=block_size, activation=activation, normalization=conv_normalize, + regularization=conv_regularize, regularization_rate=conv_regularize_rate, + ) + for filters, x, y, z in zip(decoder_dense_blocks, x_filters, y_filters, z_filters) + ] + conv_layer, _ = _conv_layer_from_kind_and_dimension(dimension, 'conv', conv_x, conv_y, conv_z) + self.conv_label = conv_layer(tensor_map.shape[-1], _one_by_n_kernel(dimension), activation=tensor_map.activation, name=tensor_map.output_name()) + self.upsamples = [_upsampler(dimension, pool_x, pool_y, pool_z) for _ in range(len(decoder_dense_blocks))] + self.u_connect_parents = u_connect_parents or [] + self.start_shape = _start_shape_before_pooling( + num_upsamples=len(decoder_dense_blocks), output_shape=tensor_map.shape, + upsample_rates=[pool_x, pool_y, pool_z], channels=decoder_dense_blocks[0], + ) + self.reshape = FlatToStructure(output_shape=self.start_shape, activation=activation, normalization=conv_normalize) + logging.info(f'Built a decoder with: {len(self.dense_conv_blocks)} and reshape {self.start_shape}') + + def can_apply(self): + return self.tensor_map.axes() > 1 + + def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = None) -> Tensor: + if not self.can_apply(): + return x + + logging.info('x') + logging.info(x.shape) + logging.info('self.start_shape') + logging.info(self.start_shape) + + if x.shape[1:-1] != self.start_shape[:-1]: + x = self.reshape(x) + for i, (dense_block, upsample) in enumerate(zip(self.dense_conv_blocks, self.upsamples)): + intermediate = [intermediates[tm][len(self.upsamples)-(i+1)] for tm in self.u_connect_parents] + x = upsample(x) + x = concatenate(intermediate + [x]) if intermediate else x + x = dense_block(x) + return self.conv_label(x) + class ResidualBlock(Block): def __init__( self, diff --git a/ml4h/models/model_factory.py b/ml4h/models/model_factory.py index ac751ea4d..804edd10b 100755 --- a/ml4h/models/model_factory.py +++ b/ml4h/models/model_factory.py @@ -18,7 +18,7 @@ from ml4h.optimizers import NON_KERAS_OPTIMIZERS, get_optimizer from ml4h.models.layer_wrappers import ACTIVATION_FUNCTIONS, NORMALIZATION_CLASSES from ml4h.models.pretrained_blocks import ResNetEncoder, MoviNetEncoder, BertEncoder -from ml4h.models.conv_blocks import ConvEncoderBlock, ConvEncoderMergeBlock, ConvDecoderBlock, ResidualBlock, PoolBlock, ConvUp, ConvDown +from ml4h.models.conv_blocks import ConvEncoderBlock, ConvEncoderMergeBlock, ConvDecoderBlock, ConvUnetDecoderBlock, ResidualBlock, PoolBlock, ConvUp, ConvDown from ml4h.models.transformer_blocks import TransformerDecoder, TransformerEncoder, PositionalEncoding from ml4h.models.transformer_blocks_embedding import TransformerEncoderEmbedding,MultiHeadAttention from ml4h.models.perceiver_blocks import PerceiverEncoder,PerceiverLatentLayer @@ -32,6 +32,7 @@ 'conv_encode': ConvEncoderBlock, 'merge_conv_encode': ConvEncoderMergeBlock, 'conv_decode': ConvDecoderBlock, + 'unet_conv_decode': ConvUnetDecoderBlock, 'conv_up': ConvUp, 'conv_down': ConvDown, 'residual': ResidualBlock, From 33248c30f6d5a48808cf40a7eef8f84934661454 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Thu, 5 Oct 2023 14:01:25 +0000 Subject: [PATCH 15/50] FIX: Fix soft dice metrics --- ml4h/metrics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ml4h/metrics.py b/ml4h/metrics.py index fa414d483..2ec7ab93f 100755 --- a/ml4h/metrics.py +++ b/ml4h/metrics.py @@ -10,7 +10,7 @@ from tensorflow.keras.losses import binary_crossentropy, categorical_crossentropy, sparse_categorical_crossentropy from tensorflow.keras.losses import logcosh, cosine_similarity, mean_squared_error, mean_absolute_error, mean_absolute_percentage_error -from neurite.tf.losses import Dice +from neurite.tf.losses import SoftDice STRING_METRICS = [ 'categorical_crossentropy','binary_crossentropy','mean_absolute_error','mae', @@ -263,7 +263,7 @@ def loss(y_true, y_pred): return loss def dice(y_true, y_pred): - return Dice(laplace_smoothing=0.05).loss(y_true, y_pred) + return SoftDice(laplace_smoothing=1e-05, check_input_limits=False).mean_loss(y_true, y_pred) def per_class_dice(labels): dice_fxns = [] @@ -271,7 +271,7 @@ def per_class_dice(labels): label_idx = labels[label_key] fxn_name = label_key.replace('-', '_').replace(' ', '_') string_fxn = 'def ' + fxn_name + '_dice(y_true, y_pred):\n' - string_fxn += '\tdice = Dice(laplace_smoothing=0.05).dice(y_true, y_pred)\n' + string_fxn += '\tdice = SoftDice(laplace_smoothing=1e-05, check_input_limits=False).dice(y_true, y_pred)\n' string_fxn += '\tdice = K.mean(dice, axis=0)['+str(label_idx)+']\n' string_fxn += '\treturn dice' From 196e050dde4d73711c91c82c9da40e7ba9dcf8c2 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 6 Oct 2023 16:39:25 +0000 Subject: [PATCH 16/50] ENH: Add plot_dice to compare --- .../config/tensorflow-requirements.txt | 2 + ml4h/plots.py | 60 +++++++++++++++++++ ml4h/recipes.py | 9 ++- 3 files changed, 68 insertions(+), 3 deletions(-) diff --git a/docker/vm_boot_images/config/tensorflow-requirements.txt b/docker/vm_boot_images/config/tensorflow-requirements.txt index d8e2c49c7..77233b2e0 100644 --- a/docker/vm_boot_images/config/tensorflow-requirements.txt +++ b/docker/vm_boot_images/config/tensorflow-requirements.txt @@ -43,3 +43,5 @@ google-cloud-storage umap-learn[plot] neurite voxelmorph +pystrum + diff --git a/ml4h/plots.py b/ml4h/plots.py index 2ae6b71b5..155e05655 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -54,6 +54,8 @@ from scipy.ndimage.filters import gaussian_filter from scipy import stats +from pystrum.medipy.metrics import dice + import ml4h.tensormap.ukb.ecg import ml4h.tensormap.mgb.ecg from ml4h.tensormap.mgb.dynamic import make_waveform_maps @@ -73,6 +75,7 @@ RECALL_LABEL = "Recall | Sensitivity | True Positive Rate | TP/(TP+FN)" FALLOUT_LABEL = "Fallout | 1 - Specificity | False Positive Rate | FP/(FP+TN)" PRECISION_LABEL = "Precision | Positive Predictive Value | TP/(TP+FP)" +DICE_LABEL = "Dice Score" SUBPLOT_SIZE = 7 @@ -2733,6 +2736,63 @@ def plot_precision_recalls(predictions, truth, labels, title, prefix="./figures/ plt.savefig(figure_path) logging.info("Saved Precision Recall curve at: {}".format(figure_path)) +def plot_dice(predictions, truth, labels, title, prefix="./figures/", dpi=300, width=3, height=3): + label_names = labels.keys() + logging.info(f"label_names: {label_names}") + label_vals = [labels[k] for k in label_names] + batch_size = truth.shape[0] + y_true = truth.argmax(-1) + y_true_unique = [np.unique(y_true[i]) for i in range(batch_size)] + missing_truth_label_vals = [[k for k in label_vals if k not in y_true_unique[i]] for i in range(batch_size)] + + dice_scores = {} + mean_dice_scores = {} + for p in predictions: + y_pred = predictions[p].argmax(-1) + dice_scores[p] = np.stack([dice(y_true[i], y_pred[i], labels=label_vals) for i in range(batch_size)], axis=0) + + # If a label is not in y_true nor y_pred, this is actually a perfect score + y_pred_unique = [np.unique(y_pred[i]) for i in range(batch_size)] + replace = {i: [k for k in missing_truth_label_vals[i] if k not in y_pred_unique[i]] for i in range(batch_size)} + for i in range(batch_size): + for k in replace[i]: + dice_scores[p][i,k] = 1.0 + + mean_dice_scores[p] = np.average(dice_scores[p], axis=0) + logging.info(f"{p} mean Dice scores {mean_dice_scores[p]}") + + row = 0 + col = 0 + total_plots = len(label_names) + cols = int(math.ceil(math.sqrt(total_plots))) + rows = int(math.ceil(total_plots / cols)) + f, axes = plt.subplots( + rows, cols, figsize=(int(cols * width), int(rows * height)), dpi=dpi, + ) + + for i,k in enumerate(label_names): + for j,p in enumerate(predictions): + axes[row, col].boxplot(dice_scores[p][:,i], positions = [j], labels=['']) + label_text = [f"{p} mean dice:{mean_dice_scores[p][i]:.3f}" for p in predictions] + axes[row, col].set_title(f"{k}") + axes[row, col].set_ylabel(DICE_LABEL) + axes[row, col].legend(label_text, loc="lower right") + + row += 1 + if row == rows: + row = 0 + col += 1 + if col >= cols: + break + + plt.tight_layout() + plt.suptitle(f"{title} n={batch_size:.0f}") + now_string = datetime.now().strftime('%Y-%m-%d_%H-%M') + figure_path = os.path.join(prefix, f'dice_{now_string}_{title}{IMAGE_EXT}') + if not os.path.exists(os.path.dirname(figure_path)): + os.makedirs(os.path.dirname(figure_path)) + plt.savefig(figure_path) + logging.info(f"Saved Dice plots at: {figure_path}") def get_fpr_tpr_roc_pred(y_pred, test_truth, labels): # Compute ROC curve and ROC area for each class diff --git a/ml4h/recipes.py b/ml4h/recipes.py index fb9ae763e..4e96177ee 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -28,7 +28,7 @@ from ml4h.explorations import plot_while_learning, plot_histograms_of_tensors_in_pdf, explore, pca_on_tsv from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients -from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival +from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival, plot_dice from ml4h.plots import plot_reconstruction, plot_hit_to_miss_transforms, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp from ml4h.plots import subplot_rocs, subplot_comparison_rocs, subplot_scatters, subplot_comparison_scatters, plot_prediction_calibrations from ml4h.models.legacy_models import make_character_model_plus, embed_model_predict, make_siamese_model, legacy_multimodal_multitask_model @@ -753,7 +753,6 @@ def _get_predictions(args, models_inputs_outputs, input_data, outputs, input_pre # We can feed 'model.predict()' the entire input data because it knows what subset to use y_pred = model.predict(input_data, batch_size=args.batch_size) - for i, tm in enumerate(args.tensor_maps_out): if tm in outputs: if len(args.tensor_maps_out) == 1: @@ -863,11 +862,15 @@ def _calculate_and_plot_prediction_stats(args, predictions, outputs, paths): aucs = {"ROC": roc_aucs, "Precision-Recall": precision_recall_aucs} log_aucs(**aucs) elif tm.is_categorical() and tm.axes() == 3: + # have to plot dice before the reshape + plot_dice( + predictions[tm], outputs[tm.output_name()], tm.channel_map, plot_title, plot_folder, + dpi=args.dpi, width=args.plot_width, height=args.plot_height, + ) for p in predictions[tm]: y = predictions[tm][p] melt_shape = (y.shape[0]*y.shape[1]*y.shape[2], y.shape[3]) predictions[tm][p] = y.reshape(melt_shape) - y_truth = outputs[tm.output_name()].reshape(melt_shape) plot_rocs( predictions[tm], y_truth, tm.channel_map, plot_title, plot_folder, From 6051b8d8305882a41d7f9ecbc89083cf117d0573 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 13 Oct 2023 16:54:30 +0000 Subject: [PATCH 17/50] ENH: Add median computation for papillary segmentation project --- ml4h/explorations.py | 145 ++++++++++++++++++++++++--- ml4h/recipes.py | 4 +- ml4h/tensorize/tensor_writer_ukbb.py | 2 +- 3 files changed, 135 insertions(+), 16 deletions(-) diff --git a/ml4h/explorations.py b/ml4h/explorations.py index 2d0800f1c..eaa6d5393 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -28,6 +28,7 @@ import matplotlib.pyplot as plt # First import matplotlib, then use Agg, then import plt from ml4h.models.legacy_models import legacy_multimodal_multitask_model +from ml4h.models.model_factory import make_multimodal_multitask_model from ml4h.TensorMap import TensorMap, Interpretation, decompress_data from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators from ml4h.tensor_generators import BATCH_INPUT_INDEX, BATCH_OUTPUT_INDEX, BATCH_PATHS_INDEX @@ -35,19 +36,23 @@ from ml4h.plots import evaluate_predictions, subplot_rocs, subplot_scatters, plot_categorical_tmap_over_time from ml4h.defines import JOIN_CHAR, MRI_SEGMENTED_CHANNEL_MAP, CODING_VALUES_MISSING, CODING_VALUES_LESS_THAN_ONE from ml4h.defines import TENSOR_EXT, IMAGE_EXT, ECG_CHAR_2_IDX, ECG_IDX_2_CHAR, PARTNERS_CHAR_2_IDX, PARTNERS_IDX_2_CHAR, PARTNERS_READ_TEXT +from ml4h.tensorize.tensor_writer_ukbb import _unit_disk from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler from sklearn.linear_model import LogisticRegression, LinearRegression, ElasticNet, Ridge, Lasso +from scipy.ndimage import binary_erosion CSV_EXT = '.tsv' -def stratify_and_project_latent_space(stratify_column: str,stratify_thresh: float,stratify_std: float,latent_cols: List[str], - latent_df: pd.DataFrame, - normalize: bool = False, - train_ratio: int = 1.0): +def stratify_and_project_latent_space( + stratify_column: str,stratify_thresh: float,stratify_std: float,latent_cols: List[str], + latent_df: pd.DataFrame, + normalize: bool = False, + train_ratio: int = 1.0, +): """ Stratify data and project it to new latent space. Args: @@ -61,13 +66,13 @@ def stratify_and_project_latent_space(stratify_column: str,stratify_thresh: floa Returns: Dict[str, Tuple[float,float,float]] - """ + """ if train_ratio == 1.0: train = latent_df test = latent_df else: train = latent_df.sample(frac=train_ratio) - test = latent_df.drop(train.index) + test = latent_df.drop(train.index) hit = train.loc[train[stratify_column] >= stratify_thresh+(1*stratify_std)] miss = train.loc[train[stratify_column] < stratify_thresh-(1*stratify_std)] hit_np = hit[latent_cols].to_numpy() @@ -75,10 +80,10 @@ def stratify_and_project_latent_space(stratify_column: str,stratify_thresh: floa miss_mean_vector = np.mean(miss_np, axis=0) hit_mean_vector = np.mean(hit_np, axis=0) angle = angle_between(miss_mean_vector, hit_mean_vector) - + hit_test = test.loc[test[stratify_column] >= stratify_thresh+(1*stratify_std)] miss_test = test.loc[test[stratify_column] < stratify_thresh-(1*stratify_std)] - + if normalize: phenotype_vector = unit_vector(hit_mean_vector-miss_mean_vector) hit_dots = [np.dot(phenotype_vector, unit_vector(v)) for v in hit_test[latent_cols].to_numpy()] @@ -88,8 +93,8 @@ def stratify_and_project_latent_space(stratify_column: str,stratify_thresh: floa hit_dots = [np.dot(phenotype_vector, v) for v in hit_test[latent_cols].to_numpy()] miss_dots = [np.dot(phenotype_vector, v) for v in miss_test[latent_cols].to_numpy()] t2, p2 = stats.ttest_ind(hit_dots, miss_dots, equal_var = False) - - return {f'{stratify_column}': (t2, p2, len(hit)) } + + return {f'{stratify_column}': (t2, p2, len(hit))} @@ -100,7 +105,7 @@ def plot_nested_dictionary(all_scores: DefaultDict[str, DefaultDict[str, Tuple[f all_scores (DefaultDict[str, DefaultDict[str, Tuple[float, float, float]]]): Nested dictionary containing the scores. Returns: None - """ + """ n = 4 eps = 1e-300 for model in all_scores: @@ -221,10 +226,12 @@ def confounder_matrix(adjust_cols: List[str], df: pd.DataFrame, space: np.ndarra vectors.append(cv) return np.array(vectors), scores -def iterative_subspace_removal(adjust_cols: List[str], latent_df: pd.DataFrame, latent_cols: List[str], - r2_thresh: float = 0.01, fit_pca: bool = False): +def iterative_subspace_removal( + adjust_cols: List[str], latent_df: pd.DataFrame, latent_cols: List[str], + r2_thresh: float = 0.01, fit_pca: bool = False, +): """ - Perform iterative subspace removal based on specified columns, a latent dataframe, + Perform iterative subspace removal based on specified columns, a latent dataframe, and other parameters to remove confounder variables. Args: @@ -691,6 +698,116 @@ def infer_with_pixels(args): logging.info(f"Wrote:{stats['count']} rows of inference. Last tensor:{tensor_paths[0]}") +def _compute_masked_stats(img, y, nb_classes): + img = np.tile(img, nb_classes) + melt_shape = (img.shape[0], img.shape[1] * img.shape[2], img.shape[3]) + img = img.reshape(melt_shape) + y = y.reshape(melt_shape) + + masked_img = np.ma.array(img, mask=np.logical_not(y)) + means = masked_img.mean(axis=1).data + medians = np.ma.median(masked_img, axis=1).data + stds = masked_img.std(axis=1).data + + return means, medians, stds + +def _to_categorical(y, nb_classes): + return np.eye(nb_classes)[y] + +def _get_csv_row(means, medians, stds, tensor_paths): + res = np.concatenate([means, medians, stds], axis=-1) + sample_id = os.path.basename(tensor_paths[0]).replace(TENSOR_EXT, '') + csv_row = [sample_id] + csv_row += res[0].astype('str').tolist() + return csv_row + +def infer_medians(args): + # Structuring element used for the erosion + structure = _unit_disk(2)[np.newaxis, ..., np.newaxis] + print(structure[0, :, :, 0]) + + tm_in = args.tensor_maps_in[0] + tm_out = args.tensor_maps_out[0] + assert(len(args.tensor_maps_out) == 1) + assert (tm_in.shape[-1] == 1) + assert (args.batch_size == 1) + + _, _, generate_test = test_train_valid_tensor_generators(**args.__dict__) + model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) + + stats = Counter() + tensor_paths_inferred = set() + inference_tsv_true = os.path.join(args.output_folder, args.id, f'pixel_inference_true_{args.id}_{tm_in.input_name()}_{tm_out.output_name()}.tsv') + inference_tsv_pred = os.path.join(args.output_folder, args.id, f'pixel_inference_pred_{args.id}_{tm_in.input_name()}_{tm_out.output_name()}.tsv') + + with open(inference_tsv_true, mode='w') as inference_file_true, open(inference_tsv_pred, mode='w') as inference_file_pred: + inference_writer_true = csv.writer(inference_file_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) + inference_writer_pred = csv.writer(inference_file_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) + + header = ['sample_id'] + header += [f'{k}_mean' for k in tm_out.channel_map.keys()] + header += [f'{k}_median' for k in tm_out.channel_map.keys()] + header += [f'{k}_std' for k in tm_out.channel_map.keys()] + inference_writer_true.writerow(header) + inference_writer_pred.writerow(header) + + while True: + batch = next(generate_test) + data, labels, tensor_paths = batch[BATCH_INPUT_INDEX], batch[BATCH_OUTPUT_INDEX], batch[BATCH_PATHS_INDEX] + if tensor_paths[0] in tensor_paths_inferred: + next(generate_test) # this print end of epoch info + logging.info( + f"Inference on {stats['count']} tensors finished. Inference TSV files at: {inference_tsv_true}, {inference_tsv_pred}", + ) + break + + img = data[tm_in.input_name()] + img = tm_in.rescale(img) + y_true = labels[tm_out.output_name()] + nb_classes = y_true.shape[-1] + y_pred = model.predict(data, batch_size=args.batch_size, verbose=0) + y_pred = np.argmax(y_pred, axis=-1) + y_pred = _to_categorical(y_pred, nb_classes) + + y_true = binary_erosion(y_true, structure).astype(y_true.dtype) + y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype) + + means_true, medians_true, stds_true = _compute_masked_stats(img, y_true, nb_classes) + means_pred, medians_pred, stds_pred = _compute_masked_stats(img, y_pred, nb_classes) + + csv_row_true = _get_csv_row(means_true, medians_true, stds_true, tensor_paths) + csv_row_pred = _get_csv_row(means_pred, medians_pred, stds_pred, tensor_paths) + + inference_writer_true.writerow(csv_row_true) + inference_writer_pred.writerow(csv_row_pred) + + tensor_paths_inferred.add(tensor_paths[0]) + stats['count'] += 1 + if stats['count'] % 250 == 0: + logging.info(f"Wrote:{stats['count']} rows of inference. Last tensor:{tensor_paths[0]}") + + inference_tsv_true = os.path.join(args.output_folder, args.id, f'pixel_inference_true_{args.id}_{tm_in.input_name()}_{tm_out.output_name()}.tsv') + inference_tsv_pred = os.path.join(args.output_folder, args.id, f'pixel_inference_pred_{args.id}_{tm_in.input_name()}_{tm_out.output_name()}.tsv') + + df_true = pd.read_csv(inference_tsv_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) + df_pred = pd.read_csv(inference_tsv_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) + + plt.scatter(df_true.anterolateral_pap_median, df_pred.anterolateral_pap_median) + plt.plot([0, 1300], [0, 1300], color='k', linestyle='--', linewidth=2) + plt.title('anterolateral_pap_median') + plt.xlabel('true') + plt.ylabel('pred') + figure_path = os.path.join(args.output_folder, args.id, f'pixel_inference_anterolateral_pap_median_{args.id}_{tm_in.input_name()}_{tm_out.output_name()}.png') + plt.savefig(figure_path) + + plt.scatter(df_true.posteromedial_pap_median, df_pred.posteromedial_pap_median) + plt.plot([0, 1300], [0, 1300], color='k', linestyle='--', linewidth=2) + plt.title('posteromedial_pap_median') + plt.xlabel('true') + plt.ylabel('pred') + figure_path = os.path.join(args.output_folder, args.id, f'pixel_inference_posteromedial_pap_median_{args.id}_{tm_in.input_name()}_{tm_out.output_name()}.png') + plt.savefig(figure_path) + def _softmax(x): """Compute softmax values for each sets of scores in x.""" e_x = np.exp(x - np.max(x)) diff --git a/ml4h/recipes.py b/ml4h/recipes.py index 4e96177ee..53252101f 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -23,7 +23,7 @@ from ml4h.tensorize.tensor_writer_mgb import write_tensors_mgb from ml4h.models.model_factory import make_multimodal_multitask_model from ml4h.tensor_generators import BATCH_INPUT_INDEX, BATCH_OUTPUT_INDEX, BATCH_PATHS_INDEX -from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, latent_space_dataframe +from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, latent_space_dataframe, infer_medians from ml4h.explorations import mri_dates, ecg_dates, predictions_to_pngs, sample_from_language_model from ml4h.explorations import plot_while_learning, plot_histograms_of_tensors_in_pdf, explore, pca_on_tsv from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator @@ -66,6 +66,8 @@ def run(args): infer_hidden_layer_multimodal_multitask(args) elif 'infer_pixels' == args.mode: infer_with_pixels(args) + elif 'infer_medians' == args.mode: + infer_medians(args) elif 'infer_encoders' == args.mode: infer_encoders_block_multimodal_multitask(args) elif 'test_scalar' == args.mode: diff --git a/ml4h/tensorize/tensor_writer_ukbb.py b/ml4h/tensorize/tensor_writer_ukbb.py index f76f9dca2..a352022e7 100755 --- a/ml4h/tensorize/tensor_writer_ukbb.py +++ b/ml4h/tensorize/tensor_writer_ukbb.py @@ -664,7 +664,7 @@ def _get_overlay_from_dicom(d, debug=False) -> Tuple[np.ndarray, np.ndarray]: def _unit_disk(r) -> np.ndarray: y, x = np.ogrid[-r: r + 1, -r: r + 1] - return (x ** 2 + y ** 2 <= r ** 2).astype(np.int) + return (x ** 2 + y ** 2 <= r ** 2).astype(np.int32) def _outline_to_mask(labeled_outline, idx) -> np.ndarray: From 598f4c6b6d7cb81914fec322c68893d910a5ec4d Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 13 Oct 2023 16:59:58 +0000 Subject: [PATCH 18/50] FIX: Fix double plot on one graph --- ml4h/explorations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ml4h/explorations.py b/ml4h/explorations.py index eaa6d5393..d038936aa 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -724,7 +724,6 @@ def _get_csv_row(means, medians, stds, tensor_paths): def infer_medians(args): # Structuring element used for the erosion structure = _unit_disk(2)[np.newaxis, ..., np.newaxis] - print(structure[0, :, :, 0]) tm_in = args.tensor_maps_in[0] tm_out = args.tensor_maps_out[0] @@ -792,6 +791,7 @@ def infer_medians(args): df_true = pd.read_csv(inference_tsv_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) df_pred = pd.read_csv(inference_tsv_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) + plt.figure() plt.scatter(df_true.anterolateral_pap_median, df_pred.anterolateral_pap_median) plt.plot([0, 1300], [0, 1300], color='k', linestyle='--', linewidth=2) plt.title('anterolateral_pap_median') @@ -800,6 +800,7 @@ def infer_medians(args): figure_path = os.path.join(args.output_folder, args.id, f'pixel_inference_anterolateral_pap_median_{args.id}_{tm_in.input_name()}_{tm_out.output_name()}.png') plt.savefig(figure_path) + plt.figure() plt.scatter(df_true.posteromedial_pap_median, df_pred.posteromedial_pap_median) plt.plot([0, 1300], [0, 1300], color='k', linestyle='--', linewidth=2) plt.title('posteromedial_pap_median') From 3e8e2d73fd16841b8c029ccaef6750756134de8e Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 13 Oct 2023 17:55:47 +0000 Subject: [PATCH 19/50] ENH: Allow generator to have empty path, e.g., to test on all images --- ml4h/tensor_generators.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/ml4h/tensor_generators.py b/ml4h/tensor_generators.py index 4a0634585..60f421141 100755 --- a/ml4h/tensor_generators.py +++ b/ml4h/tensor_generators.py @@ -666,13 +666,6 @@ def get_train_valid_test_paths( logging.info(f'Found {len(train_paths)} train, {len(valid_paths)} validation, and {len(test_paths)} testing tensors at: {tensors}') logging.debug(f'Discarded {len(discard_paths)} tensors due to given ratios') - if len(train_paths) == 0 or len(valid_paths) == 0 or len(test_paths) == 0: - raise ValueError( - f'Not enough tensors at {tensors}\n' - f'Found {len(train_paths)} training, {len(valid_paths)} validation, and {len(test_paths)} testing tensors\n' - f'Discarded {len(discard_paths)} tensors', - ) - return train_paths, valid_paths, test_paths @@ -811,7 +804,10 @@ def test_train_valid_tensor_generators( num_train_workers = int(training_steps / (training_steps + validation_steps) * num_workers) or (1 if num_workers else 0) num_valid_workers = int(validation_steps / (training_steps + validation_steps) * num_workers) or (1 if num_workers else 0) - generator_class = pick_generator(train_paths, weights, mixup_alpha, siamese) + # use the longest list of [train_paths, valid_paths, test_paths], avoiding hard-coding one + # in case it is empty + paths = max([train_paths, valid_paths, test_paths], key=len) + generator_class = pick_generator(paths, weights, mixup_alpha, siamese) # TODO generate_train = generator_class( batch_size=batch_size, input_maps=tensor_maps_in, output_maps=tensor_maps_out, From f65a8f4db4729f3cfcbe4dff61985cb9e5969792 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 13 Oct 2023 19:42:08 +0000 Subject: [PATCH 20/50] ENH: Prune list of structures for which we do stats --- ml4h/explorations.py | 40 ++++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/ml4h/explorations.py b/ml4h/explorations.py index d038936aa..bd03ade0c 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -708,7 +708,6 @@ def _compute_masked_stats(img, y, nb_classes): means = masked_img.mean(axis=1).data medians = np.ma.median(masked_img, axis=1).data stds = masked_img.std(axis=1).data - return means, medians, stds def _to_categorical(y, nb_classes): @@ -722,18 +721,32 @@ def _get_csv_row(means, medians, stds, tensor_paths): return csv_row def infer_medians(args): - # Structuring element used for the erosion - structure = _unit_disk(2)[np.newaxis, ..., np.newaxis] + assert (args.batch_size == 1) # no support here for iterating over larger batches + assert (len(args.tensor_maps_out) == 1) # no support here for stats on multi-channel inputs tm_in = args.tensor_maps_in[0] tm_out = args.tensor_maps_out[0] - assert(len(args.tensor_maps_out) == 1) assert (tm_in.shape[-1] == 1) - assert (args.batch_size == 1) _, _, generate_test = test_train_valid_tensor_generators(**args.__dict__) model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) + # TODO remove this hard-coding + important_structures = [ + 'interventricular_septum', 'LV_free_wall', 'anterolateral_pap', 'posteromedial_pap', + 'LV_cavity', 'RV_free_wall', 'RV_cavity', + ] + # end TODO remove this hard-coding + + good_channels = sorted([tm_out.channel_map[k] for k in important_structures]) + good_structures = [[k for k in tm_out.channel_map.keys() if tm_out.channel_map[k] == v][0] for v in good_channels] + nb_orig_classes = len(tm_out.channel_map) + nb_good_classes = len(good_channels) + bad_channels = [k for k in range(nb_orig_classes) if k not in good_channels] + + # Structuring element used for the erosion + structure = _unit_disk(2)[np.newaxis, ..., np.newaxis] + stats = Counter() tensor_paths_inferred = set() inference_tsv_true = os.path.join(args.output_folder, args.id, f'pixel_inference_true_{args.id}_{tm_in.input_name()}_{tm_out.output_name()}.tsv') @@ -744,9 +757,9 @@ def infer_medians(args): inference_writer_pred = csv.writer(inference_file_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) header = ['sample_id'] - header += [f'{k}_mean' for k in tm_out.channel_map.keys()] - header += [f'{k}_median' for k in tm_out.channel_map.keys()] - header += [f'{k}_std' for k in tm_out.channel_map.keys()] + header += [f'{k}_mean' for k in good_structures] + header += [f'{k}_median' for k in good_structures] + header += [f'{k}_std' for k in good_structures] inference_writer_true.writerow(header) inference_writer_pred.writerow(header) @@ -763,16 +776,19 @@ def infer_medians(args): img = data[tm_in.input_name()] img = tm_in.rescale(img) y_true = labels[tm_out.output_name()] - nb_classes = y_true.shape[-1] y_pred = model.predict(data, batch_size=args.batch_size, verbose=0) y_pred = np.argmax(y_pred, axis=-1) - y_pred = _to_categorical(y_pred, nb_classes) + y_pred = _to_categorical(y_pred, nb_orig_classes) + + # prune unnecessary labels + y_true = np.delete(y_true, bad_channels, axis=-1) + y_pred = np.delete(y_pred, bad_channels, axis=-1) y_true = binary_erosion(y_true, structure).astype(y_true.dtype) y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype) - means_true, medians_true, stds_true = _compute_masked_stats(img, y_true, nb_classes) - means_pred, medians_pred, stds_pred = _compute_masked_stats(img, y_pred, nb_classes) + means_true, medians_true, stds_true = _compute_masked_stats(img, y_true, nb_good_classes) + means_pred, medians_pred, stds_pred = _compute_masked_stats(img, y_pred, nb_good_classes) csv_row_true = _get_csv_row(means_true, medians_true, stds_true, tensor_paths) csv_row_pred = _get_csv_row(means_pred, medians_pred, stds_pred, tensor_paths) From d012a1397b25cc09283497467bfb44133fce1875 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 13 Oct 2023 20:26:34 +0000 Subject: [PATCH 21/50] STYLE: rearranging --- ml4h/explorations.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/ml4h/explorations.py b/ml4h/explorations.py index bd03ade0c..b3aedd356 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -722,11 +722,11 @@ def _get_csv_row(means, medians, stds, tensor_paths): def infer_medians(args): assert (args.batch_size == 1) # no support here for iterating over larger batches - assert (len(args.tensor_maps_out) == 1) # no support here for stats on multi-channel inputs + assert (len(args.tensor_maps_out) == 1) # no support here for multiple output channels tm_in = args.tensor_maps_in[0] tm_out = args.tensor_maps_out[0] - assert (tm_in.shape[-1] == 1) + assert (tm_in.shape[-1] == 1) # no support here for stats on multiple input channels _, _, generate_test = test_train_valid_tensor_generators(**args.__dict__) model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) @@ -782,18 +782,15 @@ def infer_medians(args): # prune unnecessary labels y_true = np.delete(y_true, bad_channels, axis=-1) - y_pred = np.delete(y_pred, bad_channels, axis=-1) - y_true = binary_erosion(y_true, structure).astype(y_true.dtype) - y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype) - means_true, medians_true, stds_true = _compute_masked_stats(img, y_true, nb_good_classes) - means_pred, medians_pred, stds_pred = _compute_masked_stats(img, y_pred, nb_good_classes) - csv_row_true = _get_csv_row(means_true, medians_true, stds_true, tensor_paths) - csv_row_pred = _get_csv_row(means_pred, medians_pred, stds_pred, tensor_paths) - inference_writer_true.writerow(csv_row_true) + + y_pred = np.delete(y_pred, bad_channels, axis=-1) + y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype) + means_pred, medians_pred, stds_pred = _compute_masked_stats(img, y_pred, nb_good_classes) + csv_row_pred = _get_csv_row(means_pred, medians_pred, stds_pred, tensor_paths) inference_writer_pred.writerow(csv_row_pred) tensor_paths_inferred.add(tensor_paths[0]) From fd7c0cf98524b965c65e4183e37752f44b38b231 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 13 Oct 2023 21:13:31 +0000 Subject: [PATCH 22/50] WIP: Handle inference without ground truth labels --- ml4h/explorations.py | 82 ++++++++++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 33 deletions(-) diff --git a/ml4h/explorations.py b/ml4h/explorations.py index b3aedd356..1c0e3d0d1 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -38,6 +38,10 @@ from ml4h.defines import TENSOR_EXT, IMAGE_EXT, ECG_CHAR_2_IDX, ECG_IDX_2_CHAR, PARTNERS_CHAR_2_IDX, PARTNERS_IDX_2_CHAR, PARTNERS_READ_TEXT from ml4h.tensorize.tensor_writer_ukbb import _unit_disk +# TODO remove this hard-coding +from ml4h.defines import MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP +# end TODO remove this hard-coding + from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler from sklearn.linear_model import LogisticRegression, LinearRegression, ElasticNet, Ridge, Lasso @@ -722,12 +726,23 @@ def _get_csv_row(means, medians, stds, tensor_paths): def infer_medians(args): assert (args.batch_size == 1) # no support here for iterating over larger batches - assert (len(args.tensor_maps_out) == 1) # no support here for multiple output channels + assert (len(args.tensor_maps_out) <= 1) # no support here for multiple output channels tm_in = args.tensor_maps_in[0] - tm_out = args.tensor_maps_out[0] assert (tm_in.shape[-1] == 1) # no support here for stats on multiple input channels + # TODO remove this hard-coding + if len(args.tensor_maps_out) == 0: + has_y_true = False + channel_map = MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP + output_name = 'output_b2s_t1map_kassir_annotated' + elif len(args.tensor_maps_out) == 1: + has_y_true = True + tm_out = args.tensor_maps_out[0] + channel_map = tm_out.channel_map + output_name = tm_out.output_name() + # end TODO remove this hard-coding + _, _, generate_test = test_train_valid_tensor_generators(**args.__dict__) model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) @@ -738,9 +753,9 @@ def infer_medians(args): ] # end TODO remove this hard-coding - good_channels = sorted([tm_out.channel_map[k] for k in important_structures]) - good_structures = [[k for k in tm_out.channel_map.keys() if tm_out.channel_map[k] == v][0] for v in good_channels] - nb_orig_classes = len(tm_out.channel_map) + good_channels = sorted([channel_map[k] for k in important_structures]) + good_structures = [[k for k in channel_map.keys() if channel_map[k] == v][0] for v in good_channels] + nb_orig_classes = len(channel_map) nb_good_classes = len(good_channels) bad_channels = [k for k in range(nb_orig_classes) if k not in good_channels] @@ -749,8 +764,8 @@ def infer_medians(args): stats = Counter() tensor_paths_inferred = set() - inference_tsv_true = os.path.join(args.output_folder, args.id, f'pixel_inference_true_{args.id}_{tm_in.input_name()}_{tm_out.output_name()}.tsv') - inference_tsv_pred = os.path.join(args.output_folder, args.id, f'pixel_inference_pred_{args.id}_{tm_in.input_name()}_{tm_out.output_name()}.tsv') + inference_tsv_true = os.path.join(args.output_folder, args.id, f'medians_inference_true_{args.id}_{tm_in.input_name()}_{output_name}.tsv') + inference_tsv_pred = os.path.join(args.output_folder, args.id, f'medians_inference_pred_{args.id}_{tm_in.input_name()}_{output_name}.tsv') with open(inference_tsv_true, mode='w') as inference_file_true, open(inference_tsv_pred, mode='w') as inference_file_pred: inference_writer_true = csv.writer(inference_file_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) @@ -775,17 +790,17 @@ def infer_medians(args): img = data[tm_in.input_name()] img = tm_in.rescale(img) - y_true = labels[tm_out.output_name()] y_pred = model.predict(data, batch_size=args.batch_size, verbose=0) y_pred = np.argmax(y_pred, axis=-1) y_pred = _to_categorical(y_pred, nb_orig_classes) - # prune unnecessary labels - y_true = np.delete(y_true, bad_channels, axis=-1) - y_true = binary_erosion(y_true, structure).astype(y_true.dtype) - means_true, medians_true, stds_true = _compute_masked_stats(img, y_true, nb_good_classes) - csv_row_true = _get_csv_row(means_true, medians_true, stds_true, tensor_paths) - inference_writer_true.writerow(csv_row_true) + if has_y_true: + y_true = labels[tm_out.output_name()] + y_true = np.delete(y_true, bad_channels, axis=-1) + y_true = binary_erosion(y_true, structure).astype(y_true.dtype) + means_true, medians_true, stds_true = _compute_masked_stats(img, y_true, nb_good_classes) + csv_row_true = _get_csv_row(means_true, medians_true, stds_true, tensor_paths) + inference_writer_true.writerow(csv_row_true) y_pred = np.delete(y_pred, bad_channels, axis=-1) y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype) @@ -798,29 +813,30 @@ def infer_medians(args): if stats['count'] % 250 == 0: logging.info(f"Wrote:{stats['count']} rows of inference. Last tensor:{tensor_paths[0]}") - inference_tsv_true = os.path.join(args.output_folder, args.id, f'pixel_inference_true_{args.id}_{tm_in.input_name()}_{tm_out.output_name()}.tsv') - inference_tsv_pred = os.path.join(args.output_folder, args.id, f'pixel_inference_pred_{args.id}_{tm_in.input_name()}_{tm_out.output_name()}.tsv') + inference_tsv_true = os.path.join(args.output_folder, args.id, f'medians_inference_true_{args.id}_{tm_in.input_name()}_{output_name}.tsv') + inference_tsv_pred = os.path.join(args.output_folder, args.id, f'medians_inference_pred_{args.id}_{tm_in.input_name()}_{output_name}.tsv') df_true = pd.read_csv(inference_tsv_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) df_pred = pd.read_csv(inference_tsv_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) - plt.figure() - plt.scatter(df_true.anterolateral_pap_median, df_pred.anterolateral_pap_median) - plt.plot([0, 1300], [0, 1300], color='k', linestyle='--', linewidth=2) - plt.title('anterolateral_pap_median') - plt.xlabel('true') - plt.ylabel('pred') - figure_path = os.path.join(args.output_folder, args.id, f'pixel_inference_anterolateral_pap_median_{args.id}_{tm_in.input_name()}_{tm_out.output_name()}.png') - plt.savefig(figure_path) - - plt.figure() - plt.scatter(df_true.posteromedial_pap_median, df_pred.posteromedial_pap_median) - plt.plot([0, 1300], [0, 1300], color='k', linestyle='--', linewidth=2) - plt.title('posteromedial_pap_median') - plt.xlabel('true') - plt.ylabel('pred') - figure_path = os.path.join(args.output_folder, args.id, f'pixel_inference_posteromedial_pap_median_{args.id}_{tm_in.input_name()}_{tm_out.output_name()}.png') - plt.savefig(figure_path) + if has_y_true: + plt.figure() + plt.scatter(df_true.anterolateral_pap_median, df_pred.anterolateral_pap_median) + plt.plot([0, 1300], [0, 1300], color='k', linestyle='--', linewidth=2) + plt.title('anterolateral_pap_median') + plt.xlabel('true') + plt.ylabel('pred') + figure_path = os.path.join(args.output_folder, args.id, f'medians_inference_anterolateral_pap_median_{args.id}_{tm_in.input_name()}_{output_name}.png') + plt.savefig(figure_path) + + plt.figure() + plt.scatter(df_true.posteromedial_pap_median, df_pred.posteromedial_pap_median) + plt.plot([0, 1300], [0, 1300], color='k', linestyle='--', linewidth=2) + plt.title('posteromedial_pap_median') + plt.xlabel('true') + plt.ylabel('pred') + figure_path = os.path.join(args.output_folder, args.id, f'medians_inference_posteromedial_pap_median_{args.id}_{tm_in.input_name()}_{output_name}.png') + plt.savefig(figure_path) def _softmax(x): """Compute softmax values for each sets of scores in x.""" From 989c680e23b65d491d90a75df052b6a6682c0e36 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Mon, 16 Oct 2023 19:57:36 +0000 Subject: [PATCH 23/50] ENH: Remove option for merged paps --- ml4h/defines.py | 5 ----- ml4h/tensormap/ukb/mri.py | 13 +------------ 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/ml4h/defines.py b/ml4h/defines.py index b0f703512..eb18dde92 100755 --- a/ml4h/defines.py +++ b/ml4h/defines.py @@ -82,11 +82,6 @@ def __str__(self): 'interventricular_septum': 5, 'LV_free_wall': 6, 'anterolateral_pap': 7, 'posteromedial_pap': 8, 'LV_cavity': 9, 'RV_free_wall': 10, 'RV_cavity': 11, } -MRI_SAX_MERGED_PAP_SEGMENTED_CHANNEL_MAP = { - 'background': 0, 'thoracic_cavity': 1, 'liver': 2, 'stomach': 3, 'spleen': 4, - 'interventricular_septum': 5, 'LV_free_wall': 6, 'anterolateral_pap': 7, 'posteromedial_pap': 7, 'LV_cavity': 8, - 'RV_free_wall': 9, 'RV_cavity': 10, -} MRI_SAX_SEGMENTED_CHANNEL_MAP = { 'background': 0, 'RV_free_wall': 1, 'interventricular_septum': 2, 'LV_free_wall': 3, 'LV_cavity': 4, 'RV_cavity': 5, 'thoracic_cavity': 6, 'liver': 7, 'stomach': 8, 'spleen': 9, 'kidney': 11, 'body': 10, diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index d7f790187..0b2f521be 100755 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -20,7 +20,7 @@ MRI_LAX_2CH_SEGMENTED_CHANNEL_MAP, MRI_SAX_SEGMENTED_CHANNEL_MAP, LAX_4CH_HEART_LABELS, LAX_4CH_MYOCARDIUM_LABELS, StorageType, LAX_3CH_HEART_LABELS, \ LAX_2CH_HEART_LABELS from ml4h.tensormap.general import get_tensor_at_first_date, normalized_first_date, pad_or_crop_array_to_shape, tensor_from_hd5 -from ml4h.defines import MRI_LAX_3CH_SEGMENTED_CHANNEL_MAP, MRI_LAX_4CH_SEGMENTED_CHANNEL_MAP, MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP, MRI_SAX_MERGED_PAP_SEGMENTED_CHANNEL_MAP, MRI_AO_SEGMENTED_CHANNEL_MAP, MRI_LIVER_SEGMENTED_CHANNEL_MAP, SAX_HEART_LABELS +from ml4h.defines import MRI_LAX_3CH_SEGMENTED_CHANNEL_MAP, MRI_LAX_4CH_SEGMENTED_CHANNEL_MAP, MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP, MRI_AO_SEGMENTED_CHANNEL_MAP, MRI_LIVER_SEGMENTED_CHANNEL_MAP, SAX_HEART_LABELS def _slice_subset_tensor( @@ -2718,14 +2718,3 @@ def _segmented_t1map(tm, hd5, dependents={}): loss=dice, metrics=['categorical_accuracy'] + per_class_dice(MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP), ) - -t1map_b2_segmentation_merged_paps = TensorMap( - 'b2s_t1map_kassir_annotated', - interpretation=Interpretation.CATEGORICAL, - shape=(384, 384, len(MRI_SAX_MERGED_PAP_SEGMENTED_CHANNEL_MAP)), - channel_map=MRI_SAX_MERGED_PAP_SEGMENTED_CHANNEL_MAP, - path_prefix='ukb_cardiac_mri', - tensor_from_file=_segmented_t1map, - loss=dice, - metrics=['categorical_accuracy'] + per_class_dice(MRI_SAX_MERGED_PAP_SEGMENTED_CHANNEL_MAP), -) From 83e5dc695b47a6c8cd6ca19dcff70e199a280fbe Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Tue, 17 Oct 2023 18:15:46 +0000 Subject: [PATCH 24/50] FIX: Get all b2s images, instance_2s only --- ml4h/tensormap/ukb/mri.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index 0b2f521be..738eff20f 100755 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -2670,8 +2670,15 @@ def _mdrk_projection_both_views_pretrained(tm, hd5, dependents={}): ) def _pad_crop_single_channel(tm, hd5, dependents={}): + if f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5: + key_prefix = f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' + elif f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5: + key_prefix = f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' + else: + raise ValueError(f'Could not find T1 Map image for tensormap: {tm.name}') + img = np.array( - tm.hd5_first_dataset_in_group(hd5, tm.hd5_key_guess()), + tm.hd5_first_dataset_in_group(hd5, key_prefix), dtype=np.float32, ) img = img[...,[1]] From 75b0446583c0d0e10e7413ae631197da0df4d5db Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Tue, 17 Oct 2023 18:50:58 +0000 Subject: [PATCH 25/50] ENH: Add mri dates --- ml4h/explorations.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/ml4h/explorations.py b/ml4h/explorations.py index 1c0e3d0d1..25caa39aa 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -717,11 +717,9 @@ def _compute_masked_stats(img, y, nb_classes): def _to_categorical(y, nb_classes): return np.eye(nb_classes)[y] -def _get_csv_row(means, medians, stds, tensor_paths): +def _get_csv_row(sample_id, means, medians, stds, date): res = np.concatenate([means, medians, stds], axis=-1) - sample_id = os.path.basename(tensor_paths[0]).replace(TENSOR_EXT, '') - csv_row = [sample_id] - csv_row += res[0].astype('str').tolist() + csv_row = [sample_id] + res[0].astype('str').tolist() + [date] return csv_row def infer_medians(args): @@ -762,6 +760,14 @@ def infer_medians(args): # Structuring element used for the erosion structure = _unit_disk(2)[np.newaxis, ..., np.newaxis] + # Get the dates + # TODO remove this hard-coding + dates_filename = '/home/pace/csvs/mri_dates_instance2.csv' + with open(dates_filename, mode='r') as dates_file: + dates_reader = csv.reader(dates_file) + dates_dict = {rows[0]:rows[1] for rows in dates_reader} + # end TODO remove this hard-coding + stats = Counter() tensor_paths_inferred = set() inference_tsv_true = os.path.join(args.output_folder, args.id, f'medians_inference_true_{args.id}_{tm_in.input_name()}_{output_name}.tsv') @@ -775,6 +781,7 @@ def infer_medians(args): header += [f'{k}_mean' for k in good_structures] header += [f'{k}_median' for k in good_structures] header += [f'{k}_std' for k in good_structures] + header += ['mri_date'] inference_writer_true.writerow(header) inference_writer_pred.writerow(header) @@ -794,18 +801,21 @@ def infer_medians(args): y_pred = np.argmax(y_pred, axis=-1) y_pred = _to_categorical(y_pred, nb_orig_classes) + sample_id = os.path.basename(tensor_paths[0]).replace(TENSOR_EXT, '') + date = dates_dict[sample_id] + if has_y_true: y_true = labels[tm_out.output_name()] y_true = np.delete(y_true, bad_channels, axis=-1) y_true = binary_erosion(y_true, structure).astype(y_true.dtype) means_true, medians_true, stds_true = _compute_masked_stats(img, y_true, nb_good_classes) - csv_row_true = _get_csv_row(means_true, medians_true, stds_true, tensor_paths) + csv_row_true = _get_csv_row(sample_id, means_true, medians_true, stds_true, date) inference_writer_true.writerow(csv_row_true) y_pred = np.delete(y_pred, bad_channels, axis=-1) y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype) means_pred, medians_pred, stds_pred = _compute_masked_stats(img, y_pred, nb_good_classes) - csv_row_pred = _get_csv_row(means_pred, medians_pred, stds_pred, tensor_paths) + csv_row_pred = _get_csv_row(sample_id, means_pred, medians_pred, stds_pred, date) inference_writer_pred.writerow(csv_row_pred) tensor_paths_inferred.add(tensor_paths[0]) From e54d4af8c0feda676383af9981c00d747d27f22b Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 20 Oct 2023 11:14:12 -0400 Subject: [PATCH 26/50] FIX: Fix normalization with correct padding --- ml4h/tensormap/ukb/mri.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index 738eff20f..788b768ba 100755 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -2691,7 +2691,7 @@ def _pad_crop_single_channel(tm, hd5, dependents={}): 'shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map', shape=(384, 384, 1), path_prefix='ukb_cardiac_mri', - normalization=Standardize(mean=548.15, std=627.32), + normalization=Standardize(mean=455.81, std=609.50), tensor_from_file=_pad_crop_single_channel, ) From ddb9d46dc20ef03aa68ffb1a3b723449d342bc32 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 20 Oct 2023 11:15:08 -0400 Subject: [PATCH 27/50] FIX: Fix soft dice metrics again --- ml4h/metrics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ml4h/metrics.py b/ml4h/metrics.py index 2ec7ab93f..d4444f174 100755 --- a/ml4h/metrics.py +++ b/ml4h/metrics.py @@ -10,7 +10,7 @@ from tensorflow.keras.losses import binary_crossentropy, categorical_crossentropy, sparse_categorical_crossentropy from tensorflow.keras.losses import logcosh, cosine_similarity, mean_squared_error, mean_absolute_error, mean_absolute_percentage_error -from neurite.tf.losses import SoftDice +from neurite.tf.losses import Dice STRING_METRICS = [ 'categorical_crossentropy','binary_crossentropy','mean_absolute_error','mae', @@ -263,7 +263,7 @@ def loss(y_true, y_pred): return loss def dice(y_true, y_pred): - return SoftDice(laplace_smoothing=1e-05, check_input_limits=False).mean_loss(y_true, y_pred) + return Dice(laplace_smoothing=1e-05).mean_loss(y_true, y_pred) def per_class_dice(labels): dice_fxns = [] @@ -271,7 +271,7 @@ def per_class_dice(labels): label_idx = labels[label_key] fxn_name = label_key.replace('-', '_').replace(' ', '_') string_fxn = 'def ' + fxn_name + '_dice(y_true, y_pred):\n' - string_fxn += '\tdice = SoftDice(laplace_smoothing=1e-05, check_input_limits=False).dice(y_true, y_pred)\n' + string_fxn += '\tdice = Dice(laplace_smoothing=1e-05).dice(y_true, y_pred)\n' string_fxn += '\tdice = K.mean(dice, axis=0)['+str(label_idx)+']\n' string_fxn += '\treturn dice' From 84c5365dc120b879a82422c61fed5a23ad8b9cb6 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 20 Oct 2023 11:18:31 -0400 Subject: [PATCH 28/50] COMP: Add option for environment variable for jupyter notebooks --- scripts/jupyter.sh | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/scripts/jupyter.sh b/scripts/jupyter.sh index 26ae27364..e6e85d51b 100755 --- a/scripts/jupyter.sh +++ b/scripts/jupyter.sh @@ -15,6 +15,7 @@ DOCKER_COMMAND="docker" PORT="8888" SCRIPT_NAME=$( echo $0 | sed 's#.*/##g' ) GPU_DEVICE="--gpus all" +ENV="" ################### HELP TEXT ############################################ @@ -36,12 +37,14 @@ usage() -p Port to use, by default '${PORT}' -i Run Docker with the specified custom . The default image is '${DOCKER_IMAGE}'. + + -e Run Docker with the specified environment variables set USAGE_MESSAGE } ################### OPTION PARSING ####################################### -while getopts ":i:p:ch" opt ; do +while getopts ":i:p:e:ch" opt ; do case ${opt} in h) usage @@ -57,6 +60,9 @@ while getopts ":i:p:ch" opt ; do DOCKER_IMAGE=${DOCKER_IMAGE_NO_GPU} GPU_DEVICE="" ;; + e) + ENV="--env $OPTARG" + ;; :) echo "ERROR: Option -${OPTARG} requires an argument." 1>&2 usage @@ -102,6 +108,7 @@ ${GPU_DEVICE} \ -v /home/${USER}/:/home/${USER}/ \ -v /mnt/:/mnt/ \ -p 0.0.0.0:${PORT}:${PORT} \ +${ENV} \ ${DOCKER_IMAGE} /bin/bash -c "pip3 install --upgrade pip pip3 install /home/${USER}/ml4h; jupyter notebook --no-browser --ip=0.0.0.0 --port=${PORT} --NotebookApp.token= --allow-root --notebook-dir=/home/${USER}" From 01bb4f56ffa100c7a7c1f55280ceb87c1234aba6 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 20 Oct 2023 11:19:31 -0400 Subject: [PATCH 29/50] WIP: data augmentation --- ml4h/tensor_generators.py | 55 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/ml4h/tensor_generators.py b/ml4h/tensor_generators.py index 60f421141..d53127b90 100755 --- a/ml4h/tensor_generators.py +++ b/ml4h/tensor_generators.py @@ -726,6 +726,48 @@ def get_train_valid_test_paths_split_by_csvs( return train_paths, valid_paths, test_paths +# TODO prototyping +# https://stackoverflow.com/questions/65475057/keras-data-augmentation-pipeline-for-image-segmentation-dataset-image-and-mask +def augment_using_layers(images, mask): + + def aug(): + + rota = tf.keras.layers.RandomRotation(factor=(0.014), fill_mode='constant') + + zoom = tf.keras.layers.RandomZoom( + height_factor=(-0.05, 0.05), + width_factor=None, + fill_mode='constant', + ) + + trans = tf.keras.layers.RandomTranslation( + height_factor=(-0.042, 0.042), + width_factor=(-0.042, 0.042), + fill_mode='constant', + ) + + layers = [rota, zoom, trans] + aug_model = tf.keras.Sequential(layers) + + return aug_model + + aug = aug() + + images = images['input_shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map_continuous'] + mask = mask['output_b2s_t1map_kassir_annotated_categorical'] + + images_mask = tf.concat([images, mask], -1) + images_mask = aug(images_mask) + + image = images_mask[..., 0] + image = image[..., tf.newaxis] + mask = images_mask[..., 1:] + + return image, mask +# end TODO prototyping + + + def test_train_valid_tensor_generators( tensor_maps_in: List[TensorMap], tensor_maps_out: List[TensorMap], @@ -824,7 +866,8 @@ def test_train_valid_tensor_generators( paths=test_paths, num_workers=num_train_workers, cache_size=0, weights=weights, keep_paths=keep_paths or keep_paths_test, mixup_alpha=0, name='test_worker', siamese=siamese, augment=False, ) - if wrap_with_tf_dataset: + # TODO prototyping + if True: in_shapes = {tm.input_name(): (batch_size,) + tm.static_shape() for tm in tensor_maps_in} out_shapes = {tm.output_name(): (batch_size,) + tm.static_shape() for tm in tensor_maps_out} train_dataset = tf.data.Dataset.from_generator( @@ -832,6 +875,11 @@ def test_train_valid_tensor_generators( output_types=({k: tf.float32 for k in in_shapes}, {k: tf.float32 for k in out_shapes}), output_shapes=(in_shapes, out_shapes), ) + train_dataset = train_dataset.map(lambda x, y: augment_using_layers(x, y)) + # end TODO prototyping + if wrap_with_tf_dataset: + in_shapes = {tm.input_name(): (batch_size,) + tm.static_shape() for tm in tensor_maps_in} + out_shapes = {tm.output_name(): (batch_size,) + tm.static_shape() for tm in tensor_maps_out} valid_dataset = tf.data.Dataset.from_generator( generate_valid, output_types=({k: tf.float32 for k in in_shapes}, {k: tf.float32 for k in out_shapes}), @@ -844,7 +892,10 @@ def test_train_valid_tensor_generators( ) return train_dataset, valid_dataset, test_dataset else: - return generate_train, generate_valid, generate_test + # TODO prototyping + # return generate_train, generate_valid, generate_test + return train_dataset, generate_valid, generate_test + # end TODO prototyping def _log_first_error(stats: Counter, tensor_path: str): From ef92261db08b675d279ab46969f8a332489e9113 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Tue, 24 Oct 2023 15:48:40 +0000 Subject: [PATCH 30/50] WIP: Better scatter plots for medians --- ml4h/explorations.py | 68 ++++++++++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 22 deletions(-) diff --git a/ml4h/explorations.py b/ml4h/explorations.py index 25caa39aa..66b9931ca 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -9,6 +9,7 @@ import operator import datetime from scipy import stats +from seaborn import lmplot from functools import reduce from itertools import combinations from collections import defaultdict, Counter, OrderedDict @@ -768,7 +769,7 @@ def infer_medians(args): dates_dict = {rows[0]:rows[1] for rows in dates_reader} # end TODO remove this hard-coding - stats = Counter() + stats_counter = Counter() tensor_paths_inferred = set() inference_tsv_true = os.path.join(args.output_folder, args.id, f'medians_inference_true_{args.id}_{tm_in.input_name()}_{output_name}.tsv') inference_tsv_pred = os.path.join(args.output_folder, args.id, f'medians_inference_pred_{args.id}_{tm_in.input_name()}_{output_name}.tsv') @@ -791,7 +792,7 @@ def infer_medians(args): if tensor_paths[0] in tensor_paths_inferred: next(generate_test) # this print end of epoch info logging.info( - f"Inference on {stats['count']} tensors finished. Inference TSV files at: {inference_tsv_true}, {inference_tsv_pred}", + f"Inference on {stats_counter['count']} tensors finished. Inference TSV files at: {inference_tsv_true}, {inference_tsv_pred}", ) break @@ -819,9 +820,9 @@ def infer_medians(args): inference_writer_pred.writerow(csv_row_pred) tensor_paths_inferred.add(tensor_paths[0]) - stats['count'] += 1 - if stats['count'] % 250 == 0: - logging.info(f"Wrote:{stats['count']} rows of inference. Last tensor:{tensor_paths[0]}") + stats_counter['count'] += 1 + if stats_counter['count'] % 250 == 0: + logging.info(f"Wrote:{stats_counter['count']} rows of inference. Last tensor:{tensor_paths[0]}") inference_tsv_true = os.path.join(args.output_folder, args.id, f'medians_inference_true_{args.id}_{tm_in.input_name()}_{output_name}.tsv') inference_tsv_pred = os.path.join(args.output_folder, args.id, f'medians_inference_pred_{args.id}_{tm_in.input_name()}_{output_name}.tsv') @@ -830,23 +831,46 @@ def infer_medians(args): df_pred = pd.read_csv(inference_tsv_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) if has_y_true: - plt.figure() - plt.scatter(df_true.anterolateral_pap_median, df_pred.anterolateral_pap_median) - plt.plot([0, 1300], [0, 1300], color='k', linestyle='--', linewidth=2) - plt.title('anterolateral_pap_median') - plt.xlabel('true') - plt.ylabel('pred') - figure_path = os.path.join(args.output_folder, args.id, f'medians_inference_anterolateral_pap_median_{args.id}_{tm_in.input_name()}_{output_name}.png') - plt.savefig(figure_path) - - plt.figure() - plt.scatter(df_true.posteromedial_pap_median, df_pred.posteromedial_pap_median) - plt.plot([0, 1300], [0, 1300], color='k', linestyle='--', linewidth=2) - plt.title('posteromedial_pap_median') - plt.xlabel('true') - plt.ylabel('pred') - figure_path = os.path.join(args.output_folder, args.id, f'medians_inference_posteromedial_pap_median_{args.id}_{tm_in.input_name()}_{output_name}.png') - plt.savefig(figure_path) + cols = ['anterolateral_pap_median', 'posteromedial_pap_median'] + for col in cols: + plot_data = pd.concat( + [df_true[col], df_pred[col]], + axis=1, keys=['true', 'pred'], + ) + + for i in range(2): + if i == 1: + plot_data = plot_data[plot_data.true != 0] + plot_data = plot_data[plot_data.pred != 0] + + plt.figure() + g = lmplot(x='true', y='pred', data=plot_data) + ax = plt.gca() + title = col.replace('_', ' ') + ax.set_xlabel(f'{title} T1 Time (ms) - Manual Segmentation') + ax.set_ylabel(f'{title} T1 Time (ms) - Model Segmentation') + if i == 0: + min_value = -50 + max_value = 1300 + else: + min_value, max_value = plot_data.min(), plot_data.max() + min_value = min([min_value['true'], min_value['pred']]) - 100 + max_value = min([max_value['true'], max_value['pred']]) + 100 + ax.set_xlim([min_value, max_value]) + ax.set_ylim([min_value, max_value]) + res = stats.pearsonr(plot_data['true'], plot_data['pred']) + conf = res.confidence_interval(confidence_level=0.95) + text = f'Pearson Correlation Coefficient r={res.statistic:.2f},\n95% CI {conf.low:.2f} - {conf.high:.2f}' + ax.text(0.25, 0.1, text, transform=ax.transAxes) + if i == 0: + postfix = '' + else: + postfix = '_no_zeros' + figure_path = os.path.join( + args.output_folder, args.id, + f'medians_inference_{col}_{args.id}_{tm_in.input_name()}_{output_name}{postfix}.png', + ) + plt.savefig(figure_path) def _softmax(x): """Compute softmax values for each sets of scores in x.""" From 79e87a634653878747271907b315a900585a46ea Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Wed, 25 Oct 2023 00:54:25 +0000 Subject: [PATCH 31/50] ENH: Report std too --- ml4h/plots.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ml4h/plots.py b/ml4h/plots.py index 155e05655..fada6c5e2 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -2747,6 +2747,7 @@ def plot_dice(predictions, truth, labels, title, prefix="./figures/", dpi=300, w dice_scores = {} mean_dice_scores = {} + std_dice_scores = {} for p in predictions: y_pred = predictions[p].argmax(-1) dice_scores[p] = np.stack([dice(y_true[i], y_pred[i], labels=label_vals) for i in range(batch_size)], axis=0) @@ -2760,6 +2761,8 @@ def plot_dice(predictions, truth, labels, title, prefix="./figures/", dpi=300, w mean_dice_scores[p] = np.average(dice_scores[p], axis=0) logging.info(f"{p} mean Dice scores {mean_dice_scores[p]}") + std_dice_scores[p] = np.std(dice_scores[p], axis=0) + logging.info(f"{p} std Dice scores {std_dice_scores[p]}") row = 0 col = 0 From b87407c3282d1028aa535f39a6f67caa00695e6f Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Wed, 25 Oct 2023 15:39:22 +0000 Subject: [PATCH 32/50] ENH: Improve dice plots for a single model --- ml4h/plots.py | 56 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/ml4h/plots.py b/ml4h/plots.py index fada6c5e2..1719d9f10 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -2764,32 +2764,42 @@ def plot_dice(predictions, truth, labels, title, prefix="./figures/", dpi=300, w std_dice_scores[p] = np.std(dice_scores[p], axis=0) logging.info(f"{p} std Dice scores {std_dice_scores[p]}") - row = 0 - col = 0 - total_plots = len(label_names) - cols = int(math.ceil(math.sqrt(total_plots))) - rows = int(math.ceil(total_plots / cols)) - f, axes = plt.subplots( - rows, cols, figsize=(int(cols * width), int(rows * height)), dpi=dpi, - ) + if len(predictions) > 1: + row = 0 + col = 0 + total_plots = len(label_names) + cols = int(math.ceil(math.sqrt(total_plots))) + rows = int(math.ceil(total_plots / cols)) + f, axes = plt.subplots( + rows, cols, figsize=(int(cols * width), int(rows * height)), dpi=dpi, + ) + + for i,k in enumerate(label_names): + for j,p in enumerate(predictions): + axes[row, col].boxplot(dice_scores[p][:,i], positions = [j], labels=['']) + label_text = [f"{p} mean dice:{mean_dice_scores[p][i]:.3f}" for p in predictions] + axes[row, col].set_title(f"{k}") + axes[row, col].set_ylabel(DICE_LABEL) + axes[row, col].legend(label_text, loc="lower right") - for i,k in enumerate(label_names): - for j,p in enumerate(predictions): - axes[row, col].boxplot(dice_scores[p][:,i], positions = [j], labels=['']) - label_text = [f"{p} mean dice:{mean_dice_scores[p][i]:.3f}" for p in predictions] - axes[row, col].set_title(f"{k}") - axes[row, col].set_ylabel(DICE_LABEL) - axes[row, col].legend(label_text, loc="lower right") - - row += 1 - if row == rows: - row = 0 - col += 1 - if col >= cols: - break + row += 1 + if row == rows: + row = 0 + col += 1 + if col >= cols: + break + + else: + logging.info([p for p in predictions]) + p = list(predictions.keys())[0] + for i,k in enumerate(label_names): + plt.boxplot(dice_scores[p][:,i], positions = [i], labels=[k]) + + ax = plt.gca() + ax.set_ylabel(DICE_LABEL) + plt.xticks(rotation=90) plt.tight_layout() - plt.suptitle(f"{title} n={batch_size:.0f}") now_string = datetime.now().strftime('%Y-%m-%d_%H-%M') figure_path = os.path.join(prefix, f'dice_{now_string}_{title}{IMAGE_EXT}') if not os.path.exists(os.path.dirname(figure_path)): From d2449a935ef4f2e93504e409d9b0ddf3784034c8 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Wed, 25 Oct 2023 16:08:46 +0000 Subject: [PATCH 33/50] ENH: Log pearson correlation coefficients --- ml4h/explorations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ml4h/explorations.py b/ml4h/explorations.py index 66b9931ca..a07ae0894 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -866,6 +866,7 @@ def infer_medians(args): postfix = '' else: postfix = '_no_zeros' + logging.info(f'{col} pearson{postfix} {res.statistic}') figure_path = os.path.join( args.output_folder, args.id, f'medians_inference_{col}_{args.id}_{tm_in.input_name()}_{output_name}{postfix}.png', From 081b5da965e9e2c7124a94dfb624ae56a704b667 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Wed, 1 Nov 2023 15:10:15 +0000 Subject: [PATCH 34/50] STYLE: Adding TODOs to fix tensor_generators --- ml4h/tensor_generators.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ml4h/tensor_generators.py b/ml4h/tensor_generators.py index d53127b90..b49b40d6f 100755 --- a/ml4h/tensor_generators.py +++ b/ml4h/tensor_generators.py @@ -731,7 +731,7 @@ def get_train_valid_test_paths_split_by_csvs( def augment_using_layers(images, mask): def aug(): - + # TODO pull these parameters out and default to None - if any are not none then go through wrapper rota = tf.keras.layers.RandomRotation(factor=(0.014), fill_mode='constant') zoom = tf.keras.layers.RandomZoom( @@ -753,12 +753,14 @@ def aug(): aug = aug() + # TODO just one for now, assert to make sure it's not multimodal images = images['input_shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map_continuous'] mask = mask['output_b2s_t1map_kassir_annotated_categorical'] images_mask = tf.concat([images, mask], -1) images_mask = aug(images_mask) + # TODO can get from tm shape image = images_mask[..., 0] image = image[..., tf.newaxis] mask = images_mask[..., 1:] @@ -867,6 +869,7 @@ def test_train_valid_tensor_generators( keep_paths=keep_paths or keep_paths_test, mixup_alpha=0, name='test_worker', siamese=siamese, augment=False, ) # TODO prototyping + # TODO if wrap_with_tf_dataset or if True: in_shapes = {tm.input_name(): (batch_size,) + tm.static_shape() for tm in tensor_maps_in} out_shapes = {tm.output_name(): (batch_size,) + tm.static_shape() for tm in tensor_maps_out} From 7b0643d1a781c155f742ab5555102c494f014711 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Wed, 1 Nov 2023 15:11:23 +0000 Subject: [PATCH 35/50] WIP: Add temporary code to save Dice scores --- ml4h/plots.py | 22 +++++++++++++++++++++- ml4h/recipes.py | 2 +- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/ml4h/plots.py b/ml4h/plots.py index 1719d9f10..989bf265c 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -2736,7 +2736,7 @@ def plot_precision_recalls(predictions, truth, labels, title, prefix="./figures/ plt.savefig(figure_path) logging.info("Saved Precision Recall curve at: {}".format(figure_path)) -def plot_dice(predictions, truth, labels, title, prefix="./figures/", dpi=300, width=3, height=3): +def plot_dice(predictions, truth, labels, paths, title, prefix="./figures/", dpi=300, width=3, height=3): label_names = labels.keys() logging.info(f"label_names: {label_names}") label_vals = [labels[k] for k in label_names] @@ -2759,6 +2759,26 @@ def plot_dice(predictions, truth, labels, title, prefix="./figures/", dpi=300, w for k in replace[i]: dice_scores[p][i,k] = 1.0 + # TODO take me out + al_pap_dice_scores = dice_scores[p][:,7] + pm_pap_dice_scores = dice_scores[p][:,8] + al_percentiles = [np.percentile(al_pap_dice_scores, perc) for perc in [5, 25, 50, 75, 95]] + pm_percentiles = [np.percentile(pm_pap_dice_scores, perc) for perc in [5, 25, 50, 75, 95]] + al_idx = [min(range(len(al_pap_dice_scores)), key=lambda i: abs(al_pap_dice_scores[i] - perc)) for perc in al_percentiles] + pm_idx = [min(range(len(pm_pap_dice_scores)), key=lambda i: abs(pm_pap_dice_scores[i] - perc)) for perc in pm_percentiles] + logging.info([paths[i] for i in al_idx]) + logging.info([paths[i] for i in pm_idx]) + logging.info('sorted al paps (worst to best):') + sorted_al_paths = [paths[k] for k in sorted(range(len(al_pap_dice_scores)), key=lambda k:al_pap_dice_scores[k])] + for p in sorted_al_paths: + logging.info(p) + logging.info('sorted pm paps (worst to best):') + sorted_pm_paths = [paths[k] for k in sorted(range(len(pm_pap_dice_scores)), key=lambda k:pm_pap_dice_scores[k])] + for p in sorted_pm_paths: + logging.info(p) + assert(False) + # end TODO take me out + mean_dice_scores[p] = np.average(dice_scores[p], axis=0) logging.info(f"{p} mean Dice scores {mean_dice_scores[p]}") std_dice_scores[p] = np.std(dice_scores[p], axis=0) diff --git a/ml4h/recipes.py b/ml4h/recipes.py index 53252101f..171084a8e 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -866,7 +866,7 @@ def _calculate_and_plot_prediction_stats(args, predictions, outputs, paths): elif tm.is_categorical() and tm.axes() == 3: # have to plot dice before the reshape plot_dice( - predictions[tm], outputs[tm.output_name()], tm.channel_map, plot_title, plot_folder, + predictions[tm], outputs[tm.output_name()], tm.channel_map, paths, plot_title, plot_folder, dpi=args.dpi, width=args.plot_width, height=args.plot_height, ) for p in predictions[tm]: From e92f1339bc034416fbabfede3b2d5d65296d8678 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Wed, 1 Nov 2023 15:23:17 +0000 Subject: [PATCH 36/50] WIP: Add temporary code for plotting medians --- ml4h/explorations.py | 47 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/ml4h/explorations.py b/ml4h/explorations.py index a07ae0894..2e73f1ec7 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -759,7 +759,8 @@ def infer_medians(args): bad_channels = [k for k in range(nb_orig_classes) if k not in good_channels] # Structuring element used for the erosion - structure = _unit_disk(2)[np.newaxis, ..., np.newaxis] + # TODO this can be a parameter + structure = _unit_disk(1)[np.newaxis, ..., np.newaxis] # Get the dates # TODO remove this hard-coding @@ -797,25 +798,42 @@ def infer_medians(args): break img = data[tm_in.input_name()] - img = tm_in.rescale(img) + rescaled_img = tm_in.rescale(img) y_pred = model.predict(data, batch_size=args.batch_size, verbose=0) y_pred = np.argmax(y_pred, axis=-1) - y_pred = _to_categorical(y_pred, nb_orig_classes) + # y_pred = _to_categorical(y_pred, nb_orig_classes) # TODO may need this with no threshold sample_id = os.path.basename(tensor_paths[0]).replace(TENSOR_EXT, '') date = dates_dict[sample_id] if has_y_true: y_true = labels[tm_out.output_name()] + + # TODO new + y_true = np.argmax(y_true, axis=-1)[..., np.newaxis] + y_true[np.logical_and(img >= 1.37, y_true == 7)] = 9 + y_true[np.logical_and(img >= 1.37, y_true == 8)] = 9 + y_true = y_true[...,0] + y_true = _to_categorical(y_true, nb_orig_classes) + # end TODO new + y_true = np.delete(y_true, bad_channels, axis=-1) y_true = binary_erosion(y_true, structure).astype(y_true.dtype) - means_true, medians_true, stds_true = _compute_masked_stats(img, y_true, nb_good_classes) + means_true, medians_true, stds_true = _compute_masked_stats(rescaled_img, y_true, nb_good_classes) csv_row_true = _get_csv_row(sample_id, means_true, medians_true, stds_true, date) inference_writer_true.writerow(csv_row_true) + # TODO new + y_pred = y_pred[..., np.newaxis] + y_pred[np.logical_and(img >= 1.37, y_pred == 7)] = 9 + y_pred[np.logical_and(img >= 1.37, y_pred == 8)] = 9 + y_pred = y_pred[...,0] + y_pred = _to_categorical(y_pred, nb_orig_classes) + # end TODO new + y_pred = np.delete(y_pred, bad_channels, axis=-1) y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype) - means_pred, medians_pred, stds_pred = _compute_masked_stats(img, y_pred, nb_good_classes) + means_pred, medians_pred, stds_pred = _compute_masked_stats(rescaled_img, y_pred, nb_good_classes) csv_row_pred = _get_csv_row(sample_id, means_pred, medians_pred, stds_pred, date) inference_writer_pred.writerow(csv_row_pred) @@ -830,18 +848,26 @@ def infer_medians(args): df_true = pd.read_csv(inference_tsv_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) df_pred = pd.read_csv(inference_tsv_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) + # TODO fix me if has_y_true: cols = ['anterolateral_pap_median', 'posteromedial_pap_median'] for col in cols: - plot_data = pd.concat( - [df_true[col], df_pred[col]], - axis=1, keys=['true', 'pred'], - ) - for i in range(2): + plot_data = pd.concat( + [df_true['sample_id'], df_true[col], df_pred[col]], + axis=1, keys=['sample_id', 'true', 'pred'], + ) + + if i == 0: + logging.info(plot_data) + true_outliers = plot_data[plot_data.true == 0] + pred_outliers = plot_data[plot_data.pred == 0] + logging.info(true_outliers) + logging.info(pred_outliers) if i == 1: plot_data = plot_data[plot_data.true != 0] plot_data = plot_data[plot_data.pred != 0] + plot_data = plot_data.drop('sample_id', axis=1) plt.figure() g = lmplot(x='true', y='pred', data=plot_data) @@ -872,6 +898,7 @@ def infer_medians(args): f'medians_inference_{col}_{args.id}_{tm_in.input_name()}_{output_name}{postfix}.png', ) plt.savefig(figure_path) + # end TODO fix me def _softmax(x): """Compute softmax values for each sets of scores in x.""" From 18eed01e6502c03c050a0d58ed471e320024a306 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Wed, 1 Nov 2023 16:27:48 +0000 Subject: [PATCH 37/50] STYLE: Clean up code for infer_medians --- ml4h/explorations.py | 153 +++++++++++++++++++++---------------------- 1 file changed, 75 insertions(+), 78 deletions(-) diff --git a/ml4h/explorations.py b/ml4h/explorations.py index 2e73f1ec7..5644675f8 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -39,10 +39,6 @@ from ml4h.defines import TENSOR_EXT, IMAGE_EXT, ECG_CHAR_2_IDX, ECG_IDX_2_CHAR, PARTNERS_CHAR_2_IDX, PARTNERS_IDX_2_CHAR, PARTNERS_READ_TEXT from ml4h.tensorize.tensor_writer_ukbb import _unit_disk -# TODO remove this hard-coding -from ml4h.defines import MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP -# end TODO remove this hard-coding - from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler from sklearn.linear_model import LogisticRegression, LinearRegression, ElasticNet, Ridge, Lasso @@ -723,58 +719,73 @@ def _get_csv_row(sample_id, means, medians, stds, date): csv_row = [sample_id] + res[0].astype('str').tolist() + [date] return csv_row +def _thresh_labels_above(y, img, intensity_thresh, in_labels, out_label): + y = np.argmax(y, axis=-1)[..., np.newaxis] + y[np.logical_and(img >= intensity_thresh, np.isin(y, in_labels))] = out_label + y = y[..., 0] + y = _to_categorical(y, nb_orig_classes) + return y + def infer_medians(args): - assert (args.batch_size == 1) # no support here for iterating over larger batches - assert (len(args.tensor_maps_out) <= 1) # no support here for multiple output channels - tm_in = args.tensor_maps_in[0] - assert (tm_in.shape[-1] == 1) # no support here for stats on multiple input channels - - # TODO remove this hard-coding - if len(args.tensor_maps_out) == 0: - has_y_true = False - channel_map = MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP - output_name = 'output_b2s_t1map_kassir_annotated' - elif len(args.tensor_maps_out) == 1: - has_y_true = True - tm_out = args.tensor_maps_out[0] - channel_map = tm_out.channel_map - output_name = tm_out.output_name() - # end TODO remove this hard-coding - - _, _, generate_test = test_train_valid_tensor_generators(**args.__dict__) - model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) - - # TODO remove this hard-coding - important_structures = [ + # TODO make these command-line arguments + has_y_true = True + dates_filename = '/home/pace/csvs/mri_dates_instance2.csv' + structures_to_analyze = [ 'interventricular_septum', 'LV_free_wall', 'anterolateral_pap', 'posteromedial_pap', 'LV_cavity', 'RV_free_wall', 'RV_cavity', ] - # end TODO remove this hard-coding + erosion_radius = 1 + intensity_thresh = 1.37 + intensity_thresh_in_structures = ['anterolateral_pap', 'posteromedial_pap'] + intensity_thresh_out_structure = 'LV_cavity' + results_to_plot = ['anterolateral_pap_median', 'posteromedial_pap_median'] + # end TODO make these command-line argument + + assert(args.batch_size == 1) # no support here for iterating over larger batches + assert(len(args.tensor_maps_in) == 1) # no support here for multiple input maps + assert(len(args.tensor_maps_out) == 1) # no support here for multiple output channels + + tm_in = args.tensor_maps_in[0] + tm_out = args.tensor_maps_out[0] + assert(tm_in.shape[-1] == 1) # no support here for stats on multiple input channels + assert(tm_out.shape[-1] == 1) # no support here for stats on multiple output channels - good_channels = sorted([channel_map[k] for k in important_structures]) - good_structures = [[k for k in channel_map.keys() if channel_map[k] == v][0] for v in good_channels] - nb_orig_classes = len(channel_map) - nb_good_classes = len(good_channels) - bad_channels = [k for k in range(nb_orig_classes) if k not in good_channels] + # don't filter datasets for ground truth segmentations if we want to run inference on everything + gen_args = copy.deep_copy(args) + if not has_y_true: + gen_args.tensor_maps_out = [] + + _, _, generate_test = test_train_valid_tensor_generators(**gen_args.__dict__) + model, _, _, _ = make_multimodal_multitask_model(**gen_args.__dict__) + + # good_structures has to be sorted by channel idx + good_channels = sorted([tm_out.channel_map[k] for k in structures_to_analyze]) + good_structures = [[k for k in tm_out.channel_map.keys() if tm_out.channel_map[k] == v][0] for v in good_channels] + nb_orig_channels = len(tm_out.channel_map) + nb_good_channels = len(good_channels) + bad_channels = [k for k in range(nb_orig_channels) if k not in good_channels] # Structuring element used for the erosion - # TODO this can be a parameter - structure = _unit_disk(1)[np.newaxis, ..., np.newaxis] + structure = _unit_disk(erosion_radius)[np.newaxis, ..., np.newaxis] + + # Setup for intensity thresholding + if intensity_thresh: + intensity_thresh_in_channels = [tm_out.channel_map[k] for k in intensity_thresh_in_structures] + intensity_thresh_out_channel = tm_out.channel_map[intensity_thresh_out_structure] # Get the dates - # TODO remove this hard-coding - dates_filename = '/home/pace/csvs/mri_dates_instance2.csv' with open(dates_filename, mode='r') as dates_file: dates_reader = csv.reader(dates_file) dates_dict = {rows[0]:rows[1] for rows in dates_reader} - # end TODO remove this hard-coding - stats_counter = Counter() - tensor_paths_inferred = set() + output_name = tm_out.output_name() inference_tsv_true = os.path.join(args.output_folder, args.id, f'medians_inference_true_{args.id}_{tm_in.input_name()}_{output_name}.tsv') inference_tsv_pred = os.path.join(args.output_folder, args.id, f'medians_inference_pred_{args.id}_{tm_in.input_name()}_{output_name}.tsv') + stats_counter = Counter() + tensor_paths_inferred = set() + with open(inference_tsv_true, mode='w') as inference_file_true, open(inference_tsv_pred, mode='w') as inference_file_pred: inference_writer_true = csv.writer(inference_file_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) inference_writer_pred = csv.writer(inference_file_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) @@ -801,39 +812,27 @@ def infer_medians(args): rescaled_img = tm_in.rescale(img) y_pred = model.predict(data, batch_size=args.batch_size, verbose=0) y_pred = np.argmax(y_pred, axis=-1) - # y_pred = _to_categorical(y_pred, nb_orig_classes) # TODO may need this with no threshold + y_pred = _to_categorical(y_pred, nb_orig_channels) + if has_y_true: + y_true = labels[tm_out.output_name()] sample_id = os.path.basename(tensor_paths[0]).replace(TENSOR_EXT, '') date = dates_dict[sample_id] if has_y_true: - y_true = labels[tm_out.output_name()] - - # TODO new - y_true = np.argmax(y_true, axis=-1)[..., np.newaxis] - y_true[np.logical_and(img >= 1.37, y_true == 7)] = 9 - y_true[np.logical_and(img >= 1.37, y_true == 8)] = 9 - y_true = y_true[...,0] - y_true = _to_categorical(y_true, nb_orig_classes) - # end TODO new - + if intensity_thresh: + y_true = _thresh_labels_above(y_true, img, intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel) y_true = np.delete(y_true, bad_channels, axis=-1) y_true = binary_erosion(y_true, structure).astype(y_true.dtype) - means_true, medians_true, stds_true = _compute_masked_stats(rescaled_img, y_true, nb_good_classes) + means_true, medians_true, stds_true = _compute_masked_stats(rescaled_img, y_true, nb_good_channels) csv_row_true = _get_csv_row(sample_id, means_true, medians_true, stds_true, date) inference_writer_true.writerow(csv_row_true) - # TODO new - y_pred = y_pred[..., np.newaxis] - y_pred[np.logical_and(img >= 1.37, y_pred == 7)] = 9 - y_pred[np.logical_and(img >= 1.37, y_pred == 8)] = 9 - y_pred = y_pred[...,0] - y_pred = _to_categorical(y_pred, nb_orig_classes) - # end TODO new - + if intensity_thresh: + y_pred = _thresh_labels_above(y_pred, img, intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel) y_pred = np.delete(y_pred, bad_channels, axis=-1) y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype) - means_pred, medians_pred, stds_pred = _compute_masked_stats(rescaled_img, y_pred, nb_good_classes) + means_pred, medians_pred, stds_pred = _compute_masked_stats(rescaled_img, y_pred, nb_good_channels) csv_row_pred = _get_csv_row(sample_id, means_pred, medians_pred, stds_pred, date) inference_writer_pred.writerow(csv_row_pred) @@ -845,26 +844,25 @@ def infer_medians(args): inference_tsv_true = os.path.join(args.output_folder, args.id, f'medians_inference_true_{args.id}_{tm_in.input_name()}_{output_name}.tsv') inference_tsv_pred = os.path.join(args.output_folder, args.id, f'medians_inference_pred_{args.id}_{tm_in.input_name()}_{output_name}.tsv') - df_true = pd.read_csv(inference_tsv_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) - df_pred = pd.read_csv(inference_tsv_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) - - # TODO fix me + # Scatter plots if has_y_true: - cols = ['anterolateral_pap_median', 'posteromedial_pap_median'] - for col in cols: - for i in range(2): + df_true = pd.read_csv(inference_tsv_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) + df_pred = pd.read_csv(inference_tsv_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) + + for col in results_to_plot: + for i in ['all', 'filter_outliers']: # Two types of plots plot_data = pd.concat( [df_true['sample_id'], df_true[col], df_pred[col]], axis=1, keys=['sample_id', 'true', 'pred'], ) - if i == 0: - logging.info(plot_data) + if i == 'all': + logging.info(plot_data) # TODO fix me up true_outliers = plot_data[plot_data.true == 0] pred_outliers = plot_data[plot_data.pred == 0] - logging.info(true_outliers) - logging.info(pred_outliers) - if i == 1: + logging.info(true_outliers) # TODO fix me up + logging.info(pred_outliers) # TODO fix me up + elif i == 'filter_outliers': plot_data = plot_data[plot_data.true != 0] plot_data = plot_data[plot_data.pred != 0] plot_data = plot_data.drop('sample_id', axis=1) @@ -875,10 +873,10 @@ def infer_medians(args): title = col.replace('_', ' ') ax.set_xlabel(f'{title} T1 Time (ms) - Manual Segmentation') ax.set_ylabel(f'{title} T1 Time (ms) - Model Segmentation') - if i == 0: + if i == 'all': min_value = -50 max_value = 1300 - else: + elif i == 'filter_outliers': min_value, max_value = plot_data.min(), plot_data.max() min_value = min([min_value['true'], min_value['pred']]) - 100 max_value = min([max_value['true'], max_value['pred']]) + 100 @@ -888,17 +886,16 @@ def infer_medians(args): conf = res.confidence_interval(confidence_level=0.95) text = f'Pearson Correlation Coefficient r={res.statistic:.2f},\n95% CI {conf.low:.2f} - {conf.high:.2f}' ax.text(0.25, 0.1, text, transform=ax.transAxes) - if i == 0: + if i == 'all': postfix = '' - else: - postfix = '_no_zeros' + elif i == 'filter_outliers': + postfix = '_no_outliers' logging.info(f'{col} pearson{postfix} {res.statistic}') figure_path = os.path.join( args.output_folder, args.id, f'medians_inference_{col}_{args.id}_{tm_in.input_name()}_{output_name}{postfix}.png', ) plt.savefig(figure_path) - # end TODO fix me def _softmax(x): """Compute softmax values for each sets of scores in x.""" From 9fa389709f7a6071821ee5c78bc9672b661bafe1 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 3 Nov 2023 00:46:08 +0000 Subject: [PATCH 38/50] STYLE: Clean up medians code --- ml4h/explorations.py | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/ml4h/explorations.py b/ml4h/explorations.py index 5644675f8..5730799af 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -719,11 +719,11 @@ def _get_csv_row(sample_id, means, medians, stds, date): csv_row = [sample_id] + res[0].astype('str').tolist() + [date] return csv_row -def _thresh_labels_above(y, img, intensity_thresh, in_labels, out_label): +def _thresh_labels_above(y, img, intensity_thresh, in_labels, out_label, nb_orig_channels): y = np.argmax(y, axis=-1)[..., np.newaxis] y[np.logical_and(img >= intensity_thresh, np.isin(y, in_labels))] = out_label y = y[..., 0] - y = _to_categorical(y, nb_orig_classes) + y = _to_categorical(y, nb_orig_channels) return y def infer_medians(args): @@ -749,15 +749,13 @@ def infer_medians(args): tm_in = args.tensor_maps_in[0] tm_out = args.tensor_maps_out[0] assert(tm_in.shape[-1] == 1) # no support here for stats on multiple input channels - assert(tm_out.shape[-1] == 1) # no support here for stats on multiple output channels # don't filter datasets for ground truth segmentations if we want to run inference on everything - gen_args = copy.deep_copy(args) if not has_y_true: - gen_args.tensor_maps_out = [] + args.tensor_maps_out = [] - _, _, generate_test = test_train_valid_tensor_generators(**gen_args.__dict__) - model, _, _, _ = make_multimodal_multitask_model(**gen_args.__dict__) + _, _, generate_test = test_train_valid_tensor_generators(**args.__dict__) + model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) # good_structures has to be sorted by channel idx good_channels = sorted([tm_out.channel_map[k] for k in structures_to_analyze]) @@ -780,8 +778,8 @@ def infer_medians(args): dates_dict = {rows[0]:rows[1] for rows in dates_reader} output_name = tm_out.output_name() - inference_tsv_true = os.path.join(args.output_folder, args.id, f'medians_inference_true_{args.id}_{tm_in.input_name()}_{output_name}.tsv') - inference_tsv_pred = os.path.join(args.output_folder, args.id, f'medians_inference_pred_{args.id}_{tm_in.input_name()}_{output_name}.tsv') + inference_tsv_true = os.path.join(args.output_folder, args.id, f'true_{args.id}_{tm_in.input_name()}_{output_name}.tsv') + inference_tsv_pred = os.path.join(args.output_folder, args.id, f'pred_{args.id}_{tm_in.input_name()}_{output_name}.tsv') stats_counter = Counter() tensor_paths_inferred = set() @@ -821,7 +819,11 @@ def infer_medians(args): if has_y_true: if intensity_thresh: - y_true = _thresh_labels_above(y_true, img, intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel) + y_true = _thresh_labels_above( + y_true, img, intensity_thresh, + intensity_thresh_in_channels, intensity_thresh_out_channel, + nb_orig_channels, + ) y_true = np.delete(y_true, bad_channels, axis=-1) y_true = binary_erosion(y_true, structure).astype(y_true.dtype) means_true, medians_true, stds_true = _compute_masked_stats(rescaled_img, y_true, nb_good_channels) @@ -829,7 +831,11 @@ def infer_medians(args): inference_writer_true.writerow(csv_row_true) if intensity_thresh: - y_pred = _thresh_labels_above(y_pred, img, intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel) + y_pred = _thresh_labels_above( + y_pred, img, intensity_thresh, + intensity_thresh_in_channels, intensity_thresh_out_channel, + nb_orig_channels, + ) y_pred = np.delete(y_pred, bad_channels, axis=-1) y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype) means_pred, medians_pred, stds_pred = _compute_masked_stats(rescaled_img, y_pred, nb_good_channels) @@ -841,9 +847,6 @@ def infer_medians(args): if stats_counter['count'] % 250 == 0: logging.info(f"Wrote:{stats_counter['count']} rows of inference. Last tensor:{tensor_paths[0]}") - inference_tsv_true = os.path.join(args.output_folder, args.id, f'medians_inference_true_{args.id}_{tm_in.input_name()}_{output_name}.tsv') - inference_tsv_pred = os.path.join(args.output_folder, args.id, f'medians_inference_pred_{args.id}_{tm_in.input_name()}_{output_name}.tsv') - # Scatter plots if has_y_true: df_true = pd.read_csv(inference_tsv_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) @@ -857,11 +860,12 @@ def infer_medians(args): ) if i == 'all': - logging.info(plot_data) # TODO fix me up true_outliers = plot_data[plot_data.true == 0] pred_outliers = plot_data[plot_data.pred == 0] - logging.info(true_outliers) # TODO fix me up - logging.info(pred_outliers) # TODO fix me up + logging.info(f'sample_ids where {col} is zero in the manual segmentation:') + logging.info(true_outliers['sample_id'].to_list()) + logging.info(f'sample_ids where {col} is zero in the model segmentation:') + logging.info(pred_outliers['sample_id'].to_list()) elif i == 'filter_outliers': plot_data = plot_data[plot_data.true != 0] plot_data = plot_data[plot_data.pred != 0] @@ -893,7 +897,7 @@ def infer_medians(args): logging.info(f'{col} pearson{postfix} {res.statistic}') figure_path = os.path.join( args.output_folder, args.id, - f'medians_inference_{col}_{args.id}_{tm_in.input_name()}_{output_name}{postfix}.png', + f'{col}_{args.id}_{tm_in.input_name()}_{output_name}{postfix}.png', ) plt.savefig(figure_path) From 0bf07bcc1c9650479cc1b2b86e9d4d28391fc155 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 3 Nov 2023 13:15:34 +0000 Subject: [PATCH 39/50] STYLE: Add command-line args for median computations --- ml4h/arguments.py | 10 +++++++ ml4h/explorations.py | 63 ++++++++++++++++---------------------------- 2 files changed, 33 insertions(+), 40 deletions(-) diff --git a/ml4h/arguments.py b/ml4h/arguments.py index e8c54fe1b..1807ad164 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -379,6 +379,16 @@ def parse_args(): default='3M', ) + # Arguments for explorations/infer_medians + parser.add_argument('--analyze_ground_truth', default=True, help='Whether or not to filter by images with ground truth segmentations, for comparison') + parser.add_argument('--dates_file', help='File containing dates for each sample_id') + parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv outputs') + parser.add_argument('--erosion_radius', default=1, type=int, help='Radius of the unit disk structuring element for erosion preprocessing') + parser.add_argument('--intensity_thresh', type=float, help='Threshold value for preprocessing') + parser.add_argument('--intensity_thresh_in_structures', nargs='*', default=[], help='Structure names whose pixels should be replaced if the images has intensity above the threshold') + parser.add_argument('--intensity_thresh_out_structure', help='Replacement structure name') + parser.add_argument('--results_to_plot', nargs='*', default=[], help='Structure names to make scatter plots') + # TensorMap prefix for convenience parser.add_argument('--tensormap_prefix', default="ml4h.tensormap", type=str, help="Module prefix path for TensorMaps. Defaults to \"ml4h.tensormap\"") diff --git a/ml4h/explorations.py b/ml4h/explorations.py index 5730799af..666589618 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -727,21 +727,6 @@ def _thresh_labels_above(y, img, intensity_thresh, in_labels, out_label, nb_orig return y def infer_medians(args): - - # TODO make these command-line arguments - has_y_true = True - dates_filename = '/home/pace/csvs/mri_dates_instance2.csv' - structures_to_analyze = [ - 'interventricular_septum', 'LV_free_wall', 'anterolateral_pap', 'posteromedial_pap', - 'LV_cavity', 'RV_free_wall', 'RV_cavity', - ] - erosion_radius = 1 - intensity_thresh = 1.37 - intensity_thresh_in_structures = ['anterolateral_pap', 'posteromedial_pap'] - intensity_thresh_out_structure = 'LV_cavity' - results_to_plot = ['anterolateral_pap_median', 'posteromedial_pap_median'] - # end TODO make these command-line argument - assert(args.batch_size == 1) # no support here for iterating over larger batches assert(len(args.tensor_maps_in) == 1) # no support here for multiple input maps assert(len(args.tensor_maps_out) == 1) # no support here for multiple output channels @@ -751,29 +736,33 @@ def infer_medians(args): assert(tm_in.shape[-1] == 1) # no support here for stats on multiple input channels # don't filter datasets for ground truth segmentations if we want to run inference on everything - if not has_y_true: + # TODO HELP - this isn't giving me all 56K anymore + if not args.analyze_ground_truth: + args.output_tensors = [] args.tensor_maps_out = [] _, _, generate_test = test_train_valid_tensor_generators(**args.__dict__) model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) # good_structures has to be sorted by channel idx - good_channels = sorted([tm_out.channel_map[k] for k in structures_to_analyze]) + good_channels = sorted([tm_out.channel_map[k] for k in args.structures_to_analyze]) good_structures = [[k for k in tm_out.channel_map.keys() if tm_out.channel_map[k] == v][0] for v in good_channels] nb_orig_channels = len(tm_out.channel_map) nb_good_channels = len(good_channels) bad_channels = [k for k in range(nb_orig_channels) if k not in good_channels] # Structuring element used for the erosion - structure = _unit_disk(erosion_radius)[np.newaxis, ..., np.newaxis] + if args.erosion_radius > 0: + structure = _unit_disk(args.erosion_radius)[np.newaxis, ..., np.newaxis] # Setup for intensity thresholding - if intensity_thresh: - intensity_thresh_in_channels = [tm_out.channel_map[k] for k in intensity_thresh_in_structures] - intensity_thresh_out_channel = tm_out.channel_map[intensity_thresh_out_structure] + do_intensity_thresh = args.intensity_thresh_in_structures and args.intensity_thresh_out_structure + if do_intensity_thresh: + intensity_thresh_in_channels = [tm_out.channel_map[k] for k in args.intensity_thresh_in_structures] + intensity_thresh_out_channel = tm_out.channel_map[args.intensity_thresh_out_structure] # Get the dates - with open(dates_filename, mode='r') as dates_file: + with open(args.dates_file, mode='r') as dates_file: dates_reader = csv.reader(dates_file) dates_dict = {rows[0]:rows[1] for rows in dates_reader} @@ -811,33 +800,27 @@ def infer_medians(args): y_pred = model.predict(data, batch_size=args.batch_size, verbose=0) y_pred = np.argmax(y_pred, axis=-1) y_pred = _to_categorical(y_pred, nb_orig_channels) - if has_y_true: + if args.analyze_ground_truth: y_true = labels[tm_out.output_name()] sample_id = os.path.basename(tensor_paths[0]).replace(TENSOR_EXT, '') date = dates_dict[sample_id] - if has_y_true: - if intensity_thresh: - y_true = _thresh_labels_above( - y_true, img, intensity_thresh, - intensity_thresh_in_channels, intensity_thresh_out_channel, - nb_orig_channels, - ) + if args.analyze_ground_truth: + if do_intensity_thresh: + y_true = _thresh_labels_above(y_true, img, args.intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels) y_true = np.delete(y_true, bad_channels, axis=-1) - y_true = binary_erosion(y_true, structure).astype(y_true.dtype) + if args.erosion_radius > 0: + y_true = binary_erosion(y_true, structure).astype(y_true.dtype) means_true, medians_true, stds_true = _compute_masked_stats(rescaled_img, y_true, nb_good_channels) csv_row_true = _get_csv_row(sample_id, means_true, medians_true, stds_true, date) inference_writer_true.writerow(csv_row_true) - if intensity_thresh: - y_pred = _thresh_labels_above( - y_pred, img, intensity_thresh, - intensity_thresh_in_channels, intensity_thresh_out_channel, - nb_orig_channels, - ) + if do_intensity_thresh: + y_pred = _thresh_labels_above(y_pred, img, args.intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels) y_pred = np.delete(y_pred, bad_channels, axis=-1) - y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype) + if args.erosion_radius > 0: + y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype) means_pred, medians_pred, stds_pred = _compute_masked_stats(rescaled_img, y_pred, nb_good_channels) csv_row_pred = _get_csv_row(sample_id, means_pred, medians_pred, stds_pred, date) inference_writer_pred.writerow(csv_row_pred) @@ -848,11 +831,11 @@ def infer_medians(args): logging.info(f"Wrote:{stats_counter['count']} rows of inference. Last tensor:{tensor_paths[0]}") # Scatter plots - if has_y_true: + if args.analyze_ground_truth: df_true = pd.read_csv(inference_tsv_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) df_pred = pd.read_csv(inference_tsv_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) - for col in results_to_plot: + for col in args.results_to_plot: for i in ['all', 'filter_outliers']: # Two types of plots plot_data = pd.concat( [df_true['sample_id'], df_true[col], df_pred[col]], From ebcaa9cfb78c62d83d8197b8fb8743652ebdb671 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 3 Nov 2023 15:13:43 +0000 Subject: [PATCH 40/50] ENH: Add percentiles and tsv for dice calculations --- ml4h/plots.py | 42 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/ml4h/plots.py b/ml4h/plots.py index 989bf265c..2b8f04058 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -2,6 +2,7 @@ # Imports import os +import csv import re import math import h5py @@ -2738,7 +2739,6 @@ def plot_precision_recalls(predictions, truth, labels, title, prefix="./figures/ def plot_dice(predictions, truth, labels, paths, title, prefix="./figures/", dpi=300, width=3, height=3): label_names = labels.keys() - logging.info(f"label_names: {label_names}") label_vals = [labels[k] for k in label_names] batch_size = truth.shape[0] y_true = truth.argmax(-1) @@ -2759,31 +2759,20 @@ def plot_dice(predictions, truth, labels, paths, title, prefix="./figures/", dpi for k in replace[i]: dice_scores[p][i,k] = 1.0 - # TODO take me out - al_pap_dice_scores = dice_scores[p][:,7] - pm_pap_dice_scores = dice_scores[p][:,8] - al_percentiles = [np.percentile(al_pap_dice_scores, perc) for perc in [5, 25, 50, 75, 95]] - pm_percentiles = [np.percentile(pm_pap_dice_scores, perc) for perc in [5, 25, 50, 75, 95]] - al_idx = [min(range(len(al_pap_dice_scores)), key=lambda i: abs(al_pap_dice_scores[i] - perc)) for perc in al_percentiles] - pm_idx = [min(range(len(pm_pap_dice_scores)), key=lambda i: abs(pm_pap_dice_scores[i] - perc)) for perc in pm_percentiles] - logging.info([paths[i] for i in al_idx]) - logging.info([paths[i] for i in pm_idx]) - logging.info('sorted al paps (worst to best):') - sorted_al_paths = [paths[k] for k in sorted(range(len(al_pap_dice_scores)), key=lambda k:al_pap_dice_scores[k])] - for p in sorted_al_paths: - logging.info(p) - logging.info('sorted pm paps (worst to best):') - sorted_pm_paths = [paths[k] for k in sorted(range(len(pm_pap_dice_scores)), key=lambda k:pm_pap_dice_scores[k])] - for p in sorted_pm_paths: - logging.info(p) - assert(False) - # end TODO take me out - + # stats mean_dice_scores[p] = np.average(dice_scores[p], axis=0) logging.info(f"{p} mean Dice scores {mean_dice_scores[p]}") std_dice_scores[p] = np.std(dice_scores[p], axis=0) logging.info(f"{p} std Dice scores {std_dice_scores[p]}") + # percentiles + for k in label_names: + structure_dice_scores = dice_scores[p][:,labels[k]] + structure_dice_percentiles = [np.percentile(structure_dice_scores, perc) for perc in [5, 25, 50, 75, 95]] + structure_dice_percentile_idxs = [min(range(len(structure_dice_scores)), key=lambda i: abs(structure_dice_scores[i] - perc)) for perc in structure_dice_percentiles] + logging.info(f'{p}: sample_ids for [5, 25, 50, 75, 95] percentiles for {k}: {[paths[i] for i in structure_dice_percentile_idxs]}') + + # Plot fig if len(predictions) > 1: row = 0 col = 0 @@ -2810,7 +2799,6 @@ def plot_dice(predictions, truth, labels, paths, title, prefix="./figures/", dpi break else: - logging.info([p for p in predictions]) p = list(predictions.keys())[0] for i,k in enumerate(label_names): plt.boxplot(dice_scores[p][:,i], positions = [i], labels=[k]) @@ -2827,6 +2815,16 @@ def plot_dice(predictions, truth, labels, paths, title, prefix="./figures/", dpi plt.savefig(figure_path) logging.info(f"Saved Dice plots at: {figure_path}") + # Save tsv + for p in predictions: + tsv_path = os.path.join(prefix, f'dice_{p}_{now_string}_{title}.tsv') + with open(tsv_path, mode='w') as tsv_file: + tsv_writer = csv.writer(tsv_file, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) + tsv_writer.writerow(label_names) + for r in range(dice_scores[p].shape[0]): + tsv_writer.writerow(dice_scores[p][r,:]) + logging.info(f"Saved dice tsv at: {tsv_path}") + def get_fpr_tpr_roc_pred(y_pred, test_truth, labels): # Compute ROC curve and ROC area for each class fpr = dict() From 94902dae86d89249d35e7bedfda271ccd6ab7d67 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 3 Nov 2023 16:24:48 +0000 Subject: [PATCH 41/50] STYLE: Add command-line args for data augmentation --- ml4h/arguments.py | 5 +++ ml4h/tensor_generators.py | 73 ++++++++++++++++++++++++--------------- 2 files changed, 50 insertions(+), 28 deletions(-) diff --git a/ml4h/arguments.py b/ml4h/arguments.py index 1807ad164..e3c154f02 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -272,6 +272,11 @@ def parse_args(): help='If true saves the model weights from the last training epoch, otherwise the model with best validation loss is saved.', ) + # 2D image data augmentation parameters + parser.add_argument('--rotation_factor', default=0., type=float, help='a float represented as fraction of 2 Pi, e.g., rotation_factor = 0.014 results in an output rotated by a random amount in the range [-5 degrees, 5 degrees]') + parser.add_argument('--zoom_factor', default=0., type=float, help='a float represented as fraction of value, e.g., zoom_factor = 0.05 results in an output zoomed in a random amount in the range [-5%, 5%]') + parser.add_argument('--translation_factor', default=0., type=float, help='a float represented as a fraction of value, e.g., translation_factor = 0.05 results in an output shifted by a random amount in the range [-5%, 5%] in the x- and y- directions') + # Run specific and debugging arguments parser.add_argument('--id', default='no_id', help='Identifier for this run, user-defined string to keep experiments organized.') parser.add_argument('--random_seed', default=12878, type=int, help='Random seed to use throughout run. Always use np.random.') diff --git a/ml4h/tensor_generators.py b/ml4h/tensor_generators.py index b49b40d6f..f9316befa 100755 --- a/ml4h/tensor_generators.py +++ b/ml4h/tensor_generators.py @@ -725,24 +725,24 @@ def get_train_valid_test_paths_split_by_csvs( logging.info(f"CSV:{balance_csvs[i-1]}\nhas: {len(train_paths[i])} train, {len(valid_paths[i])} valid, {len(test_paths[i])} test tensors.") return train_paths, valid_paths, test_paths - -# TODO prototyping # https://stackoverflow.com/questions/65475057/keras-data-augmentation-pipeline-for-image-segmentation-dataset-image-and-mask -def augment_using_layers(images, mask): +def augment_using_layers(images, mask, in_shapes, out_shapes, rotation_factor, zoom_factor, translation_factor): + + assert(len(in_shapes) == 1) # no support for multiple inputs + assert(len(out_shapes) == 1) # no support for mulitple outputs def aug(): - # TODO pull these parameters out and default to None - if any are not none then go through wrapper - rota = tf.keras.layers.RandomRotation(factor=(0.014), fill_mode='constant') + rota = tf.keras.layers.RandomRotation(factor=rotation_factor, fill_mode='constant') zoom = tf.keras.layers.RandomZoom( - height_factor=(-0.05, 0.05), + height_factor=(-zoom_factor, zoom_factor), width_factor=None, fill_mode='constant', ) trans = tf.keras.layers.RandomTranslation( - height_factor=(-0.042, 0.042), - width_factor=(-0.042, 0.042), + height_factor=(-translation_factor, translation_factor), + width_factor=(-translation_factor, translation_factor), fill_mode='constant', ) @@ -753,22 +753,23 @@ def aug(): aug = aug() - # TODO just one for now, assert to make sure it's not multimodal - images = images['input_shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map_continuous'] - mask = mask['output_b2s_t1map_kassir_annotated_categorical'] + # we know there's just one + in_key = list(in_shapes.keys())[0] + out_key = list(out_shapes.keys())[0] + images = images[in_key] + mask = mask[out_key] + # concatenate the inputs and outputs together into a single tensor, and do data augmentation images_mask = tf.concat([images, mask], -1) images_mask = aug(images_mask) - # TODO can get from tm shape + # split the inputs and outputs again + assert(in_shapes[in_key][-1] == 1) # we are only handling one channel in the input image = images_mask[..., 0] image = image[..., tf.newaxis] mask = images_mask[..., 1:] return image, mask -# end TODO prototyping - - def test_train_valid_tensor_generators( tensor_maps_in: List[TensorMap], @@ -791,6 +792,9 @@ def test_train_valid_tensor_generators( valid_csv: str = None, test_csv: str = None, siamese: bool = False, + rotation_factor: float = 0, + zoom_factor: float = 0, + translation_factor: float = 0, wrap_with_tf_dataset: bool = False, **kwargs ) -> Tuple[TensorGeneratorABC, TensorGeneratorABC, TensorGeneratorABC]: @@ -817,6 +821,9 @@ def test_train_valid_tensor_generators( :param valid_csv: CSV file of sample ids to use for validation, mutually exclusive with valid_ratio :param test_csv: CSV file of sample ids to use for testing, mutually exclusive with test_ratio :param siamese: if True generate input for a siamese model i.e. a left and right input tensors for every input TensorMap + :param rotation_factor: rotation for data augmentation: a float represented as fraction of 2 Pi, e.g., rotation_factor = 0.014 results in an output rotated by a random amount in the range [-5 degrees, 5 degrees] + :param zoom_factor: zoom for data augmentation: a float represented as fraction of value, e.g., zoom_factor = 0.05 results in an output zoomed in a random amount in the range [-5%, 5%] + :param translation_factor: translation for data augmentation: a float represented as a fraction of value, e.g., translation_factor = 0.05 results in an output shifted by a random amount in the range [-5%, 5%] in the x- and y- directions :param wrap_with_tf_dataset: if True will return tf.dataset objects for the 3 generators :return: A tuple of three generators. Each yields a Tuple of dictionaries of input and output numpy arrays for training, validation and testing. """ @@ -851,7 +858,7 @@ def test_train_valid_tensor_generators( # use the longest list of [train_paths, valid_paths, test_paths], avoiding hard-coding one # in case it is empty paths = max([train_paths, valid_paths, test_paths], key=len) - generator_class = pick_generator(paths, weights, mixup_alpha, siamese) # TODO + generator_class = pick_generator(paths, weights, mixup_alpha, siamese) generate_train = generator_class( batch_size=batch_size, input_maps=tensor_maps_in, output_maps=tensor_maps_out, @@ -868,21 +875,31 @@ def test_train_valid_tensor_generators( paths=test_paths, num_workers=num_train_workers, cache_size=0, weights=weights, keep_paths=keep_paths or keep_paths_test, mixup_alpha=0, name='test_worker', siamese=siamese, augment=False, ) - # TODO prototyping - # TODO if wrap_with_tf_dataset or - if True: + + do_augmentation = bool(rotation_factor or zoom_factor or translation_factor) + logging.info(f'doing_augmentation {do_augmentation}') + + if do_augmentation: + assert(len(tensor_maps_in) == 1) # no support for multiple input tensors + assert(len(tensor_maps_out) == 1) # no support for multiple output tensors + + if wrap_with_tf_dataset or do_augmentation: in_shapes = {tm.input_name(): (batch_size,) + tm.static_shape() for tm in tensor_maps_in} out_shapes = {tm.output_name(): (batch_size,) + tm.static_shape() for tm in tensor_maps_out} + train_dataset = tf.data.Dataset.from_generator( generate_train, output_types=({k: tf.float32 for k in in_shapes}, {k: tf.float32 for k in out_shapes}), output_shapes=(in_shapes, out_shapes), ) - train_dataset = train_dataset.map(lambda x, y: augment_using_layers(x, y)) - # end TODO prototyping + train_dataset = train_dataset.map( + lambda x, y: augment_using_layers( + x, y, in_shapes, out_shapes, + rotation_factor, zoom_factor, translation_factor, + ), + ) + if wrap_with_tf_dataset: - in_shapes = {tm.input_name(): (batch_size,) + tm.static_shape() for tm in tensor_maps_in} - out_shapes = {tm.output_name(): (batch_size,) + tm.static_shape() for tm in tensor_maps_out} valid_dataset = tf.data.Dataset.from_generator( generate_valid, output_types=({k: tf.float32 for k in in_shapes}, {k: tf.float32 for k in out_shapes}), @@ -893,13 +910,13 @@ def test_train_valid_tensor_generators( output_types=({k: tf.float32 for k in in_shapes}, {k: tf.float32 for k in out_shapes}), output_shapes=(in_shapes, out_shapes), ) + + if wrap_with_tf_dataset: return train_dataset, valid_dataset, test_dataset - else: - # TODO prototyping - # return generate_train, generate_valid, generate_test + elif do_augmentation: return train_dataset, generate_valid, generate_test - # end TODO prototyping - + else: + return generate_train, generate_valid, generate_test def _log_first_error(stats: Counter, tensor_path: str): for k in stats: From 8d8f6b02ccec8730dc3ab461fbbd9ab13a57ea4a Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Thu, 9 Nov 2023 20:27:10 +0000 Subject: [PATCH 42/50] ENH: Improve log files for dice compare --- ml4h/plots.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ml4h/plots.py b/ml4h/plots.py index 2b8f04058..9acd91630 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -2745,6 +2745,7 @@ def plot_dice(predictions, truth, labels, paths, title, prefix="./figures/", dpi y_true_unique = [np.unique(y_true[i]) for i in range(batch_size)] missing_truth_label_vals = [[k for k in label_vals if k not in y_true_unique[i]] for i in range(batch_size)] + logging.info(f"label_names: {label_names}") dice_scores = {} mean_dice_scores = {} std_dice_scores = {} @@ -2770,6 +2771,7 @@ def plot_dice(predictions, truth, labels, paths, title, prefix="./figures/", dpi structure_dice_scores = dice_scores[p][:,labels[k]] structure_dice_percentiles = [np.percentile(structure_dice_scores, perc) for perc in [5, 25, 50, 75, 95]] structure_dice_percentile_idxs = [min(range(len(structure_dice_scores)), key=lambda i: abs(structure_dice_scores[i] - perc)) for perc in structure_dice_percentiles] + logging.info(f'{p}: [5, 25, 50, 75, 95] percentiles for {k}: {structure_dice_percentiles}') logging.info(f'{p}: sample_ids for [5, 25, 50, 75, 95] percentiles for {k}: {[paths[i] for i in structure_dice_percentile_idxs]}') # Plot fig @@ -2820,9 +2822,9 @@ def plot_dice(predictions, truth, labels, paths, title, prefix="./figures/", dpi tsv_path = os.path.join(prefix, f'dice_{p}_{now_string}_{title}.tsv') with open(tsv_path, mode='w') as tsv_file: tsv_writer = csv.writer(tsv_file, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) - tsv_writer.writerow(label_names) - for r in range(dice_scores[p].shape[0]): - tsv_writer.writerow(dice_scores[p][r,:]) + tsv_writer.writerow(['sample_id'] + list(label_names)) + for i in range(dice_scores[p].shape[0]): + tsv_writer.writerow([paths[i]] + list(dice_scores[p][i,:])) logging.info(f"Saved dice tsv at: {tsv_path}") def get_fpr_tpr_roc_pred(y_pred, test_truth, labels): From 3967219683467ec2c805560fcbb602344705788d Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Tue, 14 Nov 2023 14:57:42 +0000 Subject: [PATCH 43/50] STYLE: Small edits requested in PR --- ml4h/arguments.py | 14 ++--- ml4h/explorations.py | 123 ++++++++++++++++++++------------------ ml4h/recipes.py | 6 +- ml4h/tensor_generators.py | 17 ++++-- 4 files changed, 87 insertions(+), 73 deletions(-) diff --git a/ml4h/arguments.py b/ml4h/arguments.py index e3c154f02..3bbb011d3 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -72,7 +72,7 @@ def parse_args(): parser.add_argument('--dicoms', default='./dicoms/', help='Path to folder of dicoms.') parser.add_argument('--sample_csv', default=None, help='Path to CSV with Sample IDs to restrict tensor paths') parser.add_argument('--tsv_style', default='standard', choices=['standard', 'genetics'], help='Format choice for the TSV file produced in output by infer and explore modes.') - parser.add_argument('--app_csv', help='Path to file used to link sample IDs between UKBB applications 17488 and 7089') + parser.add_argument('--app_csv', help='Path to file used by the recipe') parser.add_argument('--tensors', help='Path to folder containing tensors, or where tensors will be written.') parser.add_argument('--output_folder', default='./recipes_output/', help='Path to output folder for recipes.py runs.') parser.add_argument('--model_file', help='Path to a saved model architecture and weights (hd5).') @@ -273,9 +273,9 @@ def parse_args(): ) # 2D image data augmentation parameters - parser.add_argument('--rotation_factor', default=0., type=float, help='a float represented as fraction of 2 Pi, e.g., rotation_factor = 0.014 results in an output rotated by a random amount in the range [-5 degrees, 5 degrees]') - parser.add_argument('--zoom_factor', default=0., type=float, help='a float represented as fraction of value, e.g., zoom_factor = 0.05 results in an output zoomed in a random amount in the range [-5%, 5%]') - parser.add_argument('--translation_factor', default=0., type=float, help='a float represented as a fraction of value, e.g., translation_factor = 0.05 results in an output shifted by a random amount in the range [-5%, 5%] in the x- and y- directions') + parser.add_argument('--rotation_factor', default=0., type=float, help='for data augmentation, a float represented as fraction of 2 Pi, e.g., rotation_factor = 0.014 results in an output rotated by a random amount in the range [-5 degrees, 5 degrees]') + parser.add_argument('--zoom_factor', default=0., type=float, help='for dat augmentation, a float represented as fraction of value, e.g., zoom_factor = 0.05 results in an output zoomed in a random amount in the range [-5%, 5%]') + parser.add_argument('--translation_factor', default=0., type=float, help='for data augmentation, a float represented as a fraction of value, e.g., translation_factor = 0.05 results in an output shifted by a random amount in the range [-5%, 5%] in the x- and y- directions') # Run specific and debugging arguments parser.add_argument('--id', default='no_id', help='Identifier for this run, user-defined string to keep experiments organized.') @@ -384,15 +384,13 @@ def parse_args(): default='3M', ) - # Arguments for explorations/infer_medians + # Arguments for explorations/infer_stats_from_segmented_regions parser.add_argument('--analyze_ground_truth', default=True, help='Whether or not to filter by images with ground truth segmentations, for comparison') - parser.add_argument('--dates_file', help='File containing dates for each sample_id') - parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv outputs') + parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv files and scatter plots') parser.add_argument('--erosion_radius', default=1, type=int, help='Radius of the unit disk structuring element for erosion preprocessing') parser.add_argument('--intensity_thresh', type=float, help='Threshold value for preprocessing') parser.add_argument('--intensity_thresh_in_structures', nargs='*', default=[], help='Structure names whose pixels should be replaced if the images has intensity above the threshold') parser.add_argument('--intensity_thresh_out_structure', help='Replacement structure name') - parser.add_argument('--results_to_plot', nargs='*', default=[], help='Structure names to make scatter plots') # TensorMap prefix for convenience parser.add_argument('--tensormap_prefix', default="ml4h.tensormap", type=str, help="Module prefix path for TensorMaps. Defaults to \"ml4h.tensormap\"") diff --git a/ml4h/explorations.py b/ml4h/explorations.py index 666589618..91ddefd3e 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -726,14 +726,70 @@ def _thresh_labels_above(y, img, intensity_thresh, in_labels, out_label, nb_orig y = _to_categorical(y, nb_orig_channels) return y -def infer_medians(args): - assert(args.batch_size == 1) # no support here for iterating over larger batches - assert(len(args.tensor_maps_in) == 1) # no support here for multiple input maps - assert(len(args.tensor_maps_out) == 1) # no support here for multiple output channels +def _scatter_plots_from_segmented_region_stats( + inference_tsv_true, inference_tsv_pred, structures_to_analyze, + output_folder, id, input_name, output_name, +): + df_true = pd.read_csv(inference_tsv_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) + df_pred = pd.read_csv(inference_tsv_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) + + results_to_plot = [f'{s}_median' for s in structures_to_analyze] + for col in results_to_plot: + for i in ['all', 'filter_outliers']: # Two types of plots + plot_data = pd.concat( + [df_true['sample_id'], df_true[col], df_pred[col]], + axis=1, keys=['sample_id', 'true', 'pred'], + ) + + if i == 'all': + true_outliers = plot_data[plot_data.true == 0] + pred_outliers = plot_data[plot_data.pred == 0] + logging.info(f'sample_ids where {col} is zero in the manual segmentation:') + logging.info(true_outliers['sample_id'].to_list()) + logging.info(f'sample_ids where {col} is zero in the model segmentation:') + logging.info(pred_outliers['sample_id'].to_list()) + elif i == 'filter_outliers': + plot_data = plot_data[plot_data.true != 0] + plot_data = plot_data[plot_data.pred != 0] + plot_data = plot_data.drop('sample_id', axis=1) + + plt.figure() + g = lmplot(x='true', y='pred', data=plot_data) + ax = plt.gca() + title = col.replace('_', ' ') + ax.set_xlabel(f'{title} T1 Time (ms) - Manual Segmentation') + ax.set_ylabel(f'{title} T1 Time (ms) - Model Segmentation') + if i == 'all': + min_value = -50 + max_value = 1300 + elif i == 'filter_outliers': + min_value, max_value = plot_data.min(), plot_data.max() + min_value = min([min_value['true'], min_value['pred']]) - 100 + max_value = min([max_value['true'], max_value['pred']]) + 100 + ax.set_xlim([min_value, max_value]) + ax.set_ylim([min_value, max_value]) + res = stats.pearsonr(plot_data['true'], plot_data['pred']) + conf = res.confidence_interval(confidence_level=0.95) + text = f'Pearson Correlation Coefficient r={res.statistic:.2f},\n95% CI {conf.low:.2f} - {conf.high:.2f}' + ax.text(0.25, 0.1, text, transform=ax.transAxes) + if i == 'all': + postfix = '' + elif i == 'filter_outliers': + postfix = '_no_outliers' + logging.info(f'{col} pearson{postfix} {res.statistic}') + figure_path = os.path.join( + output_folder, id, f'{col}_{id}_{input_name}_{output_name}{postfix}.png', + ) + plt.savefig(figure_path) + +def infer_stats_from_segmented_regions(args): + assert(args.batch_size == 1, 'no support here for iterating over larger batches') + assert(len(args.tensor_maps_in) == 1, 'no support here for multiple input maps') + assert(len(args.tensor_maps_out) == 1, 'no support here for multiple output channels') tm_in = args.tensor_maps_in[0] tm_out = args.tensor_maps_out[0] - assert(tm_in.shape[-1] == 1) # no support here for stats on multiple input channels + assert(tm_in.shape[-1] == 1, 'no support here for stats on multiple input channels') # don't filter datasets for ground truth segmentations if we want to run inference on everything # TODO HELP - this isn't giving me all 56K anymore @@ -762,7 +818,7 @@ def infer_medians(args): intensity_thresh_out_channel = tm_out.channel_map[args.intensity_thresh_out_structure] # Get the dates - with open(args.dates_file, mode='r') as dates_file: + with open(args.app_csv, mode='r') as dates_file: dates_reader = csv.reader(dates_file) dates_dict = {rows[0]:rows[1] for rows in dates_reader} @@ -832,57 +888,10 @@ def infer_medians(args): # Scatter plots if args.analyze_ground_truth: - df_true = pd.read_csv(inference_tsv_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) - df_pred = pd.read_csv(inference_tsv_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL) - - for col in args.results_to_plot: - for i in ['all', 'filter_outliers']: # Two types of plots - plot_data = pd.concat( - [df_true['sample_id'], df_true[col], df_pred[col]], - axis=1, keys=['sample_id', 'true', 'pred'], - ) - - if i == 'all': - true_outliers = plot_data[plot_data.true == 0] - pred_outliers = plot_data[plot_data.pred == 0] - logging.info(f'sample_ids where {col} is zero in the manual segmentation:') - logging.info(true_outliers['sample_id'].to_list()) - logging.info(f'sample_ids where {col} is zero in the model segmentation:') - logging.info(pred_outliers['sample_id'].to_list()) - elif i == 'filter_outliers': - plot_data = plot_data[plot_data.true != 0] - plot_data = plot_data[plot_data.pred != 0] - plot_data = plot_data.drop('sample_id', axis=1) - - plt.figure() - g = lmplot(x='true', y='pred', data=plot_data) - ax = plt.gca() - title = col.replace('_', ' ') - ax.set_xlabel(f'{title} T1 Time (ms) - Manual Segmentation') - ax.set_ylabel(f'{title} T1 Time (ms) - Model Segmentation') - if i == 'all': - min_value = -50 - max_value = 1300 - elif i == 'filter_outliers': - min_value, max_value = plot_data.min(), plot_data.max() - min_value = min([min_value['true'], min_value['pred']]) - 100 - max_value = min([max_value['true'], max_value['pred']]) + 100 - ax.set_xlim([min_value, max_value]) - ax.set_ylim([min_value, max_value]) - res = stats.pearsonr(plot_data['true'], plot_data['pred']) - conf = res.confidence_interval(confidence_level=0.95) - text = f'Pearson Correlation Coefficient r={res.statistic:.2f},\n95% CI {conf.low:.2f} - {conf.high:.2f}' - ax.text(0.25, 0.1, text, transform=ax.transAxes) - if i == 'all': - postfix = '' - elif i == 'filter_outliers': - postfix = '_no_outliers' - logging.info(f'{col} pearson{postfix} {res.statistic}') - figure_path = os.path.join( - args.output_folder, args.id, - f'{col}_{args.id}_{tm_in.input_name()}_{output_name}{postfix}.png', - ) - plt.savefig(figure_path) + _scatter_plots_from_segmented_region_stats( + inference_tsv_true, inference_tsv_pred, args.structures_to_analyze, + args.output_folder, args.id, tm_in.input_name(), args.output_name, + ) def _softmax(x): """Compute softmax values for each sets of scores in x.""" diff --git a/ml4h/recipes.py b/ml4h/recipes.py index 171084a8e..33ca54b7f 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -23,7 +23,7 @@ from ml4h.tensorize.tensor_writer_mgb import write_tensors_mgb from ml4h.models.model_factory import make_multimodal_multitask_model from ml4h.tensor_generators import BATCH_INPUT_INDEX, BATCH_OUTPUT_INDEX, BATCH_PATHS_INDEX -from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, latent_space_dataframe, infer_medians +from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, latent_space_dataframe, infer_stats_from_segmented_regions from ml4h.explorations import mri_dates, ecg_dates, predictions_to_pngs, sample_from_language_model from ml4h.explorations import plot_while_learning, plot_histograms_of_tensors_in_pdf, explore, pca_on_tsv from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator @@ -66,8 +66,8 @@ def run(args): infer_hidden_layer_multimodal_multitask(args) elif 'infer_pixels' == args.mode: infer_with_pixels(args) - elif 'infer_medians' == args.mode: - infer_medians(args) + elif 'infer_stats_from_segmented_regions' == args.mode: + infer_stats_from_segmented_regions(args) elif 'infer_encoders' == args.mode: infer_encoders_block_multimodal_multitask(args) elif 'test_scalar' == args.mode: diff --git a/ml4h/tensor_generators.py b/ml4h/tensor_generators.py index f9316befa..fdc530393 100755 --- a/ml4h/tensor_generators.py +++ b/ml4h/tensor_generators.py @@ -666,6 +666,13 @@ def get_train_valid_test_paths( logging.info(f'Found {len(train_paths)} train, {len(valid_paths)} validation, and {len(test_paths)} testing tensors at: {tensors}') logging.debug(f'Discarded {len(discard_paths)} tensors due to given ratios') + if len(train_paths) == 0 and len(valid_paths) == 0 and len(test_paths) == 0: + raise ValueError( + f'Not enough tensors at {tensors}\n' + f'Found {len(train_paths)} training, {len(valid_paths)} validation, and {len(test_paths)} testing tensors\n' + f'Discarded {len(discard_paths)} tensors', + ) + return train_paths, valid_paths, test_paths @@ -728,8 +735,8 @@ def get_train_valid_test_paths_split_by_csvs( # https://stackoverflow.com/questions/65475057/keras-data-augmentation-pipeline-for-image-segmentation-dataset-image-and-mask def augment_using_layers(images, mask, in_shapes, out_shapes, rotation_factor, zoom_factor, translation_factor): - assert(len(in_shapes) == 1) # no support for multiple inputs - assert(len(out_shapes) == 1) # no support for mulitple outputs + assert(len(in_shapes) == 1, 'no support for multiple inputs') + assert(len(out_shapes) == 1, 'no support for mulitple outputs') def aug(): rota = tf.keras.layers.RandomRotation(factor=rotation_factor, fill_mode='constant') @@ -877,11 +884,11 @@ def test_train_valid_tensor_generators( ) do_augmentation = bool(rotation_factor or zoom_factor or translation_factor) - logging.info(f'doing_augmentation {do_augmentation}') + logging.info(f'doing_augmentation {do_augmentation} with rotation {rotation_factor}, zoom {zoom_factor}, translation {translation_factor}') if do_augmentation: - assert(len(tensor_maps_in) == 1) # no support for multiple input tensors - assert(len(tensor_maps_out) == 1) # no support for multiple output tensors + assert(len(tensor_maps_in) == 1, 'no support for multiple input tensors') + assert(len(tensor_maps_out) == 1, 'no support for multiple output tensors') if wrap_with_tf_dataset or do_augmentation: in_shapes = {tm.input_name(): (batch_size,) + tm.static_shape() for tm in tensor_maps_in} From 86a980365286999049c0b1bddd73997fe7f29122 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Tue, 14 Nov 2023 15:53:04 +0000 Subject: [PATCH 44/50] STYLE: docstring and typehints for plot_dice --- ml4h/plots.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/ml4h/plots.py b/ml4h/plots.py index 9acd91630..5f756eb3b 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -2737,7 +2737,30 @@ def plot_precision_recalls(predictions, truth, labels, title, prefix="./figures/ plt.savefig(figure_path) logging.info("Saved Precision Recall curve at: {}".format(figure_path)) -def plot_dice(predictions, truth, labels, paths, title, prefix="./figures/", dpi=300, width=3, height=3): +def plot_dice( + predictions: Dict[str, List[np.ndarray]], + truth: np.ndarray, + labels: Dict[str, int], + paths: List[str], + title: str, + prefix: str = "./figures/", + dpi: int = 300, + width: int = 3, + height: int = 3, +) -> None: + """ + Produces boxplots of dice score distributions and .tsv files of dice scores for individual images and structures. + :param predictions: dictionary of predicted segmentations for each model, in which keys are model names and values are lists of arrays with shape (height, width, num_labels) + :param truth: ground truth segmentations, with shape (num images, height, width, num labels) + :param labels: channel map dictionary mapping label names to integer values + :param paths: paths of input hd5 files + :param title: name for the output files + :param prefix: directory that the outputs will be written to + :param dpi: dots per inch of the plot + :param width: width of the plot + :param height: height of the plot + :return: None + """ label_names = labels.keys() label_vals = [labels[k] for k in label_names] batch_size = truth.shape[0] From 364574abff864353a5c2a5b26d92bfa2a8f3e66a Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Tue, 14 Nov 2023 16:20:35 +0000 Subject: [PATCH 45/50] STYLE: docstring for infer_statistics_from_segmented_regions --- ml4h/explorations.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ml4h/explorations.py b/ml4h/explorations.py index 91ddefd3e..5da14d5d9 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -783,6 +783,12 @@ def _scatter_plots_from_segmented_region_stats( plt.savefig(figure_path) def infer_stats_from_segmented_regions(args): + """ + Computes .tsv files of intensity means, medians and standard deviations within predicted segmentations for + a given list of structures of interest. If ground truth segmentations are available, computes the same + statistics within them, as well as scatter plots that compare median intensities within predicted and + ground truth segmentations. + """ assert(args.batch_size == 1, 'no support here for iterating over larger batches') assert(len(args.tensor_maps_in) == 1, 'no support here for multiple input maps') assert(len(args.tensor_maps_out) == 1, 'no support here for multiple output channels') From 0d2eb5fd9b92a3568565829e5b4e2c0a35d6dd53 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Tue, 14 Nov 2023 16:47:38 +0000 Subject: [PATCH 46/50] STYLE: Docstring and typehints for augment_using_layers --- ml4h/arguments.py | 2 +- ml4h/tensor_generators.py | 31 ++++++++++++++++++++++++++----- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/ml4h/arguments.py b/ml4h/arguments.py index 3bbb011d3..39cdf50b3 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -274,7 +274,7 @@ def parse_args(): # 2D image data augmentation parameters parser.add_argument('--rotation_factor', default=0., type=float, help='for data augmentation, a float represented as fraction of 2 Pi, e.g., rotation_factor = 0.014 results in an output rotated by a random amount in the range [-5 degrees, 5 degrees]') - parser.add_argument('--zoom_factor', default=0., type=float, help='for dat augmentation, a float represented as fraction of value, e.g., zoom_factor = 0.05 results in an output zoomed in a random amount in the range [-5%, 5%]') + parser.add_argument('--zoom_factor', default=0., type=float, help='for data augmentation, a float represented as fraction of value, e.g., zoom_factor = 0.05 results in an output zoomed in a random amount in the range [-5%, 5%]') parser.add_argument('--translation_factor', default=0., type=float, help='for data augmentation, a float represented as a fraction of value, e.g., translation_factor = 0.05 results in an output shifted by a random amount in the range [-5%, 5%] in the x- and y- directions') # Run specific and debugging arguments diff --git a/ml4h/tensor_generators.py b/ml4h/tensor_generators.py index fdc530393..7e698ddf8 100755 --- a/ml4h/tensor_generators.py +++ b/ml4h/tensor_generators.py @@ -732,8 +732,29 @@ def get_train_valid_test_paths_split_by_csvs( logging.info(f"CSV:{balance_csvs[i-1]}\nhas: {len(train_paths[i])} train, {len(valid_paths[i])} valid, {len(test_paths[i])} test tensors.") return train_paths, valid_paths, test_paths -# https://stackoverflow.com/questions/65475057/keras-data-augmentation-pipeline-for-image-segmentation-dataset-image-and-mask -def augment_using_layers(images, mask, in_shapes, out_shapes, rotation_factor, zoom_factor, translation_factor): +def augment_using_layers( + images: Dict[str, tf.Tensor], + mask: Dict[str, tf.Tensor], + in_shapes: Dict[str, Tuple[int, int, int, int]], + out_shapes: Dict[str, Tuple[int, int, int, int]], + rotation_factor: float, + zoom_factor: float, + translation_factor: float, +) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Applies random data augmentation (rotation, zoom and/or translation) to pairs of 2D images and segmentations. + :param images: a dictionary mapping an input tensor map's name to an image tensor + :param mask: a dictionary mapping an output tensor map's name to a segmentation tensor + :param in_shapes: a dictionary mapping an input tensor map's name to its shape (including the batch_size) + :param out_shapes: a dictionary mapping an output tensor map's name to its shape (including the batch_size) + :param rotation_factor: a float represented as fraction of 2 Pi, e.g., rotation_factor = 0.014 results in an output rotated by a random amount in the range [-5 degrees, 5 degrees] + :param zoom_factor: a float represented as fraction of value, e.g., zoom_factor = 0.05 results in an output zoomed in a random amount in the range [-5%, 5%] + :param translation_factor: a float represented as a fraction of value, e.g., translation_factor = 0.05 results in an output shifted by a random amount in the range [-5%, 5%] in the x- and y- directions + :return: an augmented image tensor and its corresponding augmented segmentation tensor + """ + + # Adapted from: + # https://stackoverflow.com/questions/65475057/keras-data-augmentation-pipeline-for-image-segmentation-dataset-image-and-mask assert(len(in_shapes) == 1, 'no support for multiple inputs') assert(len(out_shapes) == 1, 'no support for mulitple outputs') @@ -828,9 +849,9 @@ def test_train_valid_tensor_generators( :param valid_csv: CSV file of sample ids to use for validation, mutually exclusive with valid_ratio :param test_csv: CSV file of sample ids to use for testing, mutually exclusive with test_ratio :param siamese: if True generate input for a siamese model i.e. a left and right input tensors for every input TensorMap - :param rotation_factor: rotation for data augmentation: a float represented as fraction of 2 Pi, e.g., rotation_factor = 0.014 results in an output rotated by a random amount in the range [-5 degrees, 5 degrees] - :param zoom_factor: zoom for data augmentation: a float represented as fraction of value, e.g., zoom_factor = 0.05 results in an output zoomed in a random amount in the range [-5%, 5%] - :param translation_factor: translation for data augmentation: a float represented as a fraction of value, e.g., translation_factor = 0.05 results in an output shifted by a random amount in the range [-5%, 5%] in the x- and y- directions + :param rotation_factor: for data augmentation, a float represented as fraction of 2 Pi, e.g., rotation_factor = 0.014 results in an output rotated by a random amount in the range [-5 degrees, 5 degrees] + :param zoom_factor: for data augmentation, a float represented as fraction of value, e.g., zoom_factor = 0.05 results in an output zoomed in a random amount in the range [-5%, 5%] + :param translation_factor: for data augmentation, a float represented as a fraction of value, e.g., translation_factor = 0.05 results in an output shifted by a random amount in the range [-5%, 5%] in the x- and y- directions :param wrap_with_tf_dataset: if True will return tf.dataset objects for the 3 generators :return: A tuple of three generators. Each yields a Tuple of dictionaries of input and output numpy arrays for training, validation and testing. """ From c64275ecd92b0a1be467db103b99a87c581380d3 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 8 Dec 2023 17:10:12 +0000 Subject: [PATCH 47/50] FIX: Fix typo --- ml4h/explorations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/explorations.py b/ml4h/explorations.py index 5da14d5d9..612aaa0c0 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -896,7 +896,7 @@ def infer_stats_from_segmented_regions(args): if args.analyze_ground_truth: _scatter_plots_from_segmented_region_stats( inference_tsv_true, inference_tsv_pred, args.structures_to_analyze, - args.output_folder, args.id, tm_in.input_name(), args.output_name, + args.output_folder, args.id, tm_in.input_name(), tm_out.output_name(), ) def _softmax(x): From ba4e5fd396888c1ba88fb65804f1b87c9f5dc8ae Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 8 Dec 2023 17:29:55 +0000 Subject: [PATCH 48/50] FIX: Fix parser for boolean arguments --- ml4h/arguments.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ml4h/arguments.py b/ml4h/arguments.py index 39cdf50b3..ed0c08479 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -385,7 +385,9 @@ def parse_args(): ) # Arguments for explorations/infer_stats_from_segmented_regions - parser.add_argument('--analyze_ground_truth', default=True, help='Whether or not to filter by images with ground truth segmentations, for comparison') + parser.add_argument('--analyze_ground_truth', action='store_true', help='Filter by images with ground truth segmentations, for comparison') + parser.add_argument('--no_analyze_ground_truth', dest='analyze_ground_truth', action='store_false', help='Do not filter by images with ground truth segmentations, for comparison') + parser.set_defaults(analyze_ground_truth=True) parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv files and scatter plots') parser.add_argument('--erosion_radius', default=1, type=int, help='Radius of the unit disk structuring element for erosion preprocessing') parser.add_argument('--intensity_thresh', type=float, help='Threshold value for preprocessing') From 55a58b05f67a54901f889c246368abfef7b2a96e Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 12 Jan 2024 11:15:34 -0500 Subject: [PATCH 49/50] STYLE: Rename _unit_disk(r) to unit_disk(r) --- ml4h/explorations.py | 4 ++-- ml4h/tensorize/tensor_writer_ukbb.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ml4h/explorations.py b/ml4h/explorations.py index 612aaa0c0..ad60f64a0 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -37,7 +37,7 @@ from ml4h.plots import evaluate_predictions, subplot_rocs, subplot_scatters, plot_categorical_tmap_over_time from ml4h.defines import JOIN_CHAR, MRI_SEGMENTED_CHANNEL_MAP, CODING_VALUES_MISSING, CODING_VALUES_LESS_THAN_ONE from ml4h.defines import TENSOR_EXT, IMAGE_EXT, ECG_CHAR_2_IDX, ECG_IDX_2_CHAR, PARTNERS_CHAR_2_IDX, PARTNERS_IDX_2_CHAR, PARTNERS_READ_TEXT -from ml4h.tensorize.tensor_writer_ukbb import _unit_disk +from ml4h.tensorize.tensor_writer_ukbb import unit_disk from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler @@ -815,7 +815,7 @@ def infer_stats_from_segmented_regions(args): # Structuring element used for the erosion if args.erosion_radius > 0: - structure = _unit_disk(args.erosion_radius)[np.newaxis, ..., np.newaxis] + structure = unit_disk(args.erosion_radius)[np.newaxis, ..., np.newaxis] # Setup for intensity thresholding do_intensity_thresh = args.intensity_thresh_in_structures and args.intensity_thresh_out_structure diff --git a/ml4h/tensorize/tensor_writer_ukbb.py b/ml4h/tensorize/tensor_writer_ukbb.py index a352022e7..6d86dce6e 100755 --- a/ml4h/tensorize/tensor_writer_ukbb.py +++ b/ml4h/tensorize/tensor_writer_ukbb.py @@ -647,22 +647,22 @@ def _get_overlay_from_dicom(d, debug=False) -> Tuple[np.ndarray, np.ndarray]: short_side = min((max_pos[0] - min_pos[0]), (max_pos[1] - min_pos[1])) small_radius = max(MRI_MIN_RADIUS, short_side * MRI_SMALL_RADIUS_FACTOR) big_radius = max(MRI_MIN_RADIUS+1, short_side * MRI_BIG_RADIUS_FACTOR) - small_structure = _unit_disk(small_radius) + small_structure = unit_disk(small_radius) m1 = binary_closing(overlay, small_structure).astype(np.int) - big_structure = _unit_disk(big_radius) + big_structure = unit_disk(big_radius) m2 = binary_closing(overlay, big_structure).astype(np.int) anatomical_mask = m1 + m2 ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle']) myocardium_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['myocardium']) if ventricle_pixels == 0 and myocardium_pixels > MRI_MAX_MYOCARDIUM: # try to rescue small ventricles - erode_structure = _unit_disk(small_radius*1.5) + erode_structure = unit_disk(small_radius * 1.5) anatomical_mask = anatomical_mask - binary_erosion(m1, erode_structure).astype(np.int) ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle']) myocardium_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['myocardium']) return overlay, anatomical_mask, ventricle_pixels, myocardium_pixels -def _unit_disk(r) -> np.ndarray: +def unit_disk(r) -> np.ndarray: y, x = np.ogrid[-r: r + 1, -r: r + 1] return (x ** 2 + y ** 2 <= r ** 2).astype(np.int32) From 618d8578a16fbfb6cbd78266ceb593a1f08779aa Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 12 Jan 2024 11:23:26 -0500 Subject: [PATCH 50/50] ENH: Remove --no_analyze_ground_truth option --- ml4h/arguments.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ml4h/arguments.py b/ml4h/arguments.py index ed0c08479..0ead3003d 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -385,9 +385,7 @@ def parse_args(): ) # Arguments for explorations/infer_stats_from_segmented_regions - parser.add_argument('--analyze_ground_truth', action='store_true', help='Filter by images with ground truth segmentations, for comparison') - parser.add_argument('--no_analyze_ground_truth', dest='analyze_ground_truth', action='store_false', help='Do not filter by images with ground truth segmentations, for comparison') - parser.set_defaults(analyze_ground_truth=True) + parser.add_argument('--analyze_ground_truth', default=False, action='store_true', help='Filter by images with ground truth segmentations, for comparison') parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv files and scatter plots') parser.add_argument('--erosion_radius', default=1, type=int, help='Radius of the unit disk structuring element for erosion preprocessing') parser.add_argument('--intensity_thresh', type=float, help='Threshold value for preprocessing')