Skip to content

Commit

Permalink
Dp kidney (#549)
Browse files Browse the repository at this point in the history
* add pancreas

* add pancreas

* use pancreas for pngs

* WIP

* New manifest file

* Fix view creation

* Prevent tensorize from overwriting

* Add tensormaps for pancreas mris

* Add elements of Marcus's setup - L2 weight_decay and cosine decay learning rate schedule

* FIX: Fix bug when a generator has 0 ids

* FIX: Fix bug when key_prefix is not given

* ENH: Allow for no testing during model training

* FIX: Fix typo

* FIX: Fix parser for boolean arguments

* STYLE: Remove unneeded comment

* WIP

* FIX error (which kills a thread and prevents subsequent pngs from being written) if the image size is wrong

* Tensorize can create empty tensors if there are no good series, making you think it's working when it isn't. At least give a warning

* Don't commit code to interpet this specific manifest_tsv file

* STYLE: Rename intensity_thresh_perc -> intensity_thresh_percentile

---------

Co-authored-by: Sam Freesun Friedman <[email protected]>
  • Loading branch information
daniellepace and lucidtronix authored Jan 12, 2024
1 parent 44f5ade commit 0064b83
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 68 deletions.
4 changes: 3 additions & 1 deletion ml4h/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def parse_args():
)
parser.add_argument('--balance_csvs', default=[], nargs='*', help='Balances batches with representation from sample IDs in this list of CSVs')
parser.add_argument('--optimizer', default='radam', type=str, help='Optimizer for model training')
parser.add_argument('--learning_rate_schedule', default=None, type=str, choices=['triangular', 'triangular2'], help='Adjusts learning rate during training.')
parser.add_argument('--learning_rate_schedule', default=None, type=str, choices=['triangular', 'triangular2', 'cosine_decay'], help='Adjusts learning rate during training.')
parser.add_argument('--anneal_rate', default=0., type=float, help='Annealing rate in epochs of loss terms during training')
parser.add_argument('--anneal_shift', default=0., type=float, help='Annealing offset in epochs of loss terms during training')
parser.add_argument('--anneal_max', default=2.0, type=float, help='Annealing maximum value')
Expand Down Expand Up @@ -389,6 +389,8 @@ def parse_args():
parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv files and scatter plots')
parser.add_argument('--erosion_radius', default=1, type=int, help='Radius of the unit disk structuring element for erosion preprocessing')
parser.add_argument('--intensity_thresh', type=float, help='Threshold value for preprocessing')
parser.add_argument('--intensity_thresh_percentile', type=float, help='Threshold percentile for preprocessing, between 0 and 100 inclusive')
parser.add_argument('--intensity_thresh_k_means', nargs='*', default=[], type=int, help='Preprocessing using k-means specified as two numbers, the first is the number of clusters and the second is the cluster index to keep')
parser.add_argument('--intensity_thresh_in_structures', nargs='*', default=[], help='Structure names whose pixels should be replaced if the images has intensity above the threshold')
parser.add_argument('--intensity_thresh_out_structure', help='Replacement structure name')

Expand Down
5 changes: 4 additions & 1 deletion ml4h/defines.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def __str__(self):
'aortic_root': 7, 'ascending_aorta': 8, 'pulmonary_artery': 9, 'ascending_aortic_wall': 10, 'LVOT': 11,
}
MRI_LIVER_SEGMENTED_CHANNEL_MAP = {'background': 0, 'liver': 1, 'inferior_vena_cava': 2, 'abdominal_aorta': 3, 'body': 4}

MRI_PANCREAS_SEGMENTED_CHANNEL_MAP = {
'background': 0, 'body': 1, 'pancreas': 2, 'liver': 3, 'stomach': 4, 'spleen': 5,
'kidney': 6, 'bowel': 7, 'spine': 8, 'aorta':9, 'ivc': 10,
}

# TODO: These values should ultimately come from the coding table
CODING_VALUES_LESS_THAN_ONE = [-10, -1001]
Expand Down
39 changes: 27 additions & 12 deletions ml4h/explorations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pandas as pd
import multiprocessing as mp
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from tensorflow.keras.models import Model


Expand Down Expand Up @@ -719,13 +720,26 @@ def _get_csv_row(sample_id, means, medians, stds, date):
csv_row = [sample_id] + res[0].astype('str').tolist() + [date]
return csv_row

def _thresh_labels_above(y, img, intensity_thresh, in_labels, out_label, nb_orig_channels):
def _thresh_labels_above(y, img, intensity_thresh, intensity_thresh_percentile, in_labels, out_label, nb_orig_channels):
y = np.argmax(y, axis=-1)[..., np.newaxis]
y[np.logical_and(img >= intensity_thresh, np.isin(y, in_labels))] = out_label
if intensity_thresh:
img_intensity_thresh = intensity_thresh
elif intensity_thresh_percentile:
img_intensity_thresh = np.percentile(img, intensity_thresh_percentile)
y[np.logical_and(img >= img_intensity_thresh, np.isin(y, in_labels))] = out_label
y = y[..., 0]
y = _to_categorical(y, nb_orig_channels)
return y

def _intensity_thresh_k_means(y, img, intensity_thresh_k_means):
X = img[y==1][...,np.newaxis]
if X.size > 1:
kmeans = KMeans(n_clusters=intensity_thresh_k_means[0], random_state=0, n_init="auto").fit(X)
labels = kmeans.predict(img.flatten()[...,np.newaxis])
labels = np.reshape(labels, img.shape)
y[np.logical_and(labels==intensity_thresh_k_means[1], y==1)] = 0
return y

def _scatter_plots_from_segmented_region_stats(
inference_tsv_true, inference_tsv_pred, structures_to_analyze,
output_folder, id, input_name, output_name,
Expand Down Expand Up @@ -759,13 +773,9 @@ def _scatter_plots_from_segmented_region_stats(
title = col.replace('_', ' ')
ax.set_xlabel(f'{title} T1 Time (ms) - Manual Segmentation')
ax.set_ylabel(f'{title} T1 Time (ms) - Model Segmentation')
if i == 'all':
min_value = -50
max_value = 1300
elif i == 'filter_outliers':
min_value, max_value = plot_data.min(), plot_data.max()
min_value = min([min_value['true'], min_value['pred']]) - 100
max_value = min([max_value['true'], max_value['pred']]) + 100
min_value, max_value = plot_data.min(), plot_data.max()
min_value = min([min_value['true'], min_value['pred']]) - 100
max_value = min([max_value['true'], max_value['pred']]) + 100
ax.set_xlim([min_value, max_value])
ax.set_ylim([min_value, max_value])
res = stats.pearsonr(plot_data['true'], plot_data['pred'])
Expand Down Expand Up @@ -798,7 +808,6 @@ def infer_stats_from_segmented_regions(args):
assert(tm_in.shape[-1] == 1, 'no support here for stats on multiple input channels')

# don't filter datasets for ground truth segmentations if we want to run inference on everything
# TODO HELP - this isn't giving me all 56K anymore
if not args.analyze_ground_truth:
args.output_tensors = []
args.tensor_maps_out = []
Expand All @@ -820,6 +829,8 @@ def infer_stats_from_segmented_regions(args):
# Setup for intensity thresholding
do_intensity_thresh = args.intensity_thresh_in_structures and args.intensity_thresh_out_structure
if do_intensity_thresh:
assert (not (args.intensity_thresh and args.intensity_thresh_percentile))
assert (not (args.intensity_thresh_k_means and len(args.intensity_thresh_in_structures) > 1))
intensity_thresh_in_channels = [tm_out.channel_map[k] for k in args.intensity_thresh_in_structures]
intensity_thresh_out_channel = tm_out.channel_map[args.intensity_thresh_out_structure]

Expand Down Expand Up @@ -870,19 +881,23 @@ def infer_stats_from_segmented_regions(args):

if args.analyze_ground_truth:
if do_intensity_thresh:
y_true = _thresh_labels_above(y_true, img, args.intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels)
y_true = _thresh_labels_above(y_true, img, args.intensity_thresh, args.intensity_thresh_percentile, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels)
y_true = np.delete(y_true, bad_channels, axis=-1)
if args.erosion_radius > 0:
y_true = binary_erosion(y_true, structure).astype(y_true.dtype)
if args.intensity_thresh_k_means:
y_true = _intensity_thresh_k_means(y_true, img, args.intensity_thresh_k_means)
means_true, medians_true, stds_true = _compute_masked_stats(rescaled_img, y_true, nb_good_channels)
csv_row_true = _get_csv_row(sample_id, means_true, medians_true, stds_true, date)
inference_writer_true.writerow(csv_row_true)

if do_intensity_thresh:
y_pred = _thresh_labels_above(y_pred, img, args.intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels)
y_pred = _thresh_labels_above(y_pred, img, args.intensity_thresh, args.intensity_thresh_percentile, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels)
y_pred = np.delete(y_pred, bad_channels, axis=-1)
if args.erosion_radius > 0:
y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype)
if args.intensity_thresh_k_means:
y_pred = _intensity_thresh_k_means(y_pred, img, args.intensity_thresh_k_means)
means_pred, medians_pred, stds_pred = _compute_masked_stats(rescaled_img, y_pred, nb_good_channels)
csv_row_pred = _get_csv_row(sample_id, means_pred, medians_pred, stds_pred, date)
inference_writer_pred.writerow(csv_row_pred)
Expand Down
8 changes: 7 additions & 1 deletion ml4h/models/layer_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow.keras.layers import MaxPooling2D, MaxPooling3D, Average, AveragePooling1D, AveragePooling2D, AveragePooling3D, Layer
from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Concatenate, Add
from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalAveragePooling3D
from tensorflow.keras.regularizers import L1, L2


Tensor = tf.Tensor
Expand Down Expand Up @@ -52,9 +53,14 @@
# class name -> (dimension -> class)
'spatial_dropout': {2: SpatialDropout1D, 3: SpatialDropout2D, 4: SpatialDropout3D},
'dropout': defaultdict(lambda _: Dropout),
'l1': L1,
'l2': L2,
}
DENSE_REGULARIZATION_CLASSES = {
'dropout': Dropout, # TODO: add l1, l2
'dropout': Dropout,
'dropout': Dropout,
'l1': L1,
'l2': L2,
}


Expand Down
7 changes: 6 additions & 1 deletion ml4h/models/legacy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Concatenate, Add
from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalAveragePooling3D
from tensorflow.keras.layers.experimental.preprocessing import RandomRotation, RandomZoom, RandomContrast
from tensorflow.keras.regularizers import L1, L2
import tensorflow_probability as tfp

from ml4h.metrics import get_metric_dict
Expand Down Expand Up @@ -79,9 +80,13 @@ class BottleneckType(Enum):
# class name -> (dimension -> class)
'spatial_dropout': {2: SpatialDropout1D, 3: SpatialDropout2D, 4: SpatialDropout3D},
'dropout': defaultdict(lambda _: Dropout),
'l1': L1,
'l2': L2,
}
DENSE_REGULARIZATION_CLASSES = {
'dropout': Dropout, # TODO: add l1, l2
'dropout': Dropout,
'l1': L1,
'l2': L2,
}


Expand Down
3 changes: 3 additions & 0 deletions ml4h/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow_addons.optimizers import RectifiedAdam, TriangularCyclicalLearningRate, Triangular2CyclicalLearningRate
from tensorflow.keras.optimizers.schedules import CosineDecay

from ml4h.plots import plot_find_learning_rate
from ml4h.tensor_generators import TensorGenerator
Expand Down Expand Up @@ -40,6 +41,8 @@ def _get_learning_rate_schedule(learning_rate: float, learning_rate_schedule: st
initial_learning_rate=learning_rate / 5, maximal_learning_rate=learning_rate,
step_size=steps_per_epoch * 5,
)
if learning_rate_schedule == 'cosine_decay':
return CosineDecay(initial_learning_rate=learning_rate, decay_steps=steps_per_epoch)
else:
raise ValueError(f'Learning rate schedule "{learning_rate_schedule}" unknown.')

Expand Down
62 changes: 32 additions & 30 deletions ml4h/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,37 +220,39 @@ def train_multimodal_multitask(args):
if merger:
merger.save(f'{args.output_folder}{args.id}/merger.h5')

test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)
performance_metrics = _predict_and_evaluate(
model, test_data, test_labels, args.tensor_maps_in, args.tensor_maps_out, args.tensor_maps_protected,
args.batch_size, args.hidden_layer, os.path.join(args.output_folder, args.id + '/'), test_paths,
args.embed_visualization, args.alpha, args.dpi, args.plot_width, args.plot_height,
)
performance_metrics = {}
if args.test_steps > 0:
test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)
performance_metrics = _predict_and_evaluate(
model, test_data, test_labels, args.tensor_maps_in, args.tensor_maps_out, args.tensor_maps_protected,
args.batch_size, args.hidden_layer, os.path.join(args.output_folder, args.id + '/'), test_paths,
args.embed_visualization, args.alpha, args.dpi, args.plot_width, args.plot_height,
)

predictions_list = model.predict(test_data)
samples = min(args.test_steps * args.batch_size, 12)
out_path = os.path.join(args.output_folder, args.id, 'reconstructions/')
if len(args.tensor_maps_out) == 1:
predictions_list = [predictions_list]
predictions_dict = {name: pred for name, pred in zip(model.output_names, predictions_list)}
logging.info(f'Predictions and shapes are: {[(p, predictions_dict[p].shape) for p in predictions_dict]}')

for i, etm in enumerate(encoders):
embed = encoders[etm].predict(test_data[etm.input_name()])
if etm.output_name() in predictions_dict:
plot_reconstruction(etm, test_data[etm.input_name()], predictions_dict[etm.output_name()], out_path, test_paths, samples)
for dtm in decoders:
reconstruction = decoders[dtm].predict(embed)
logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}')
my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/')
os.makedirs(os.path.dirname(my_out_path), exist_ok=True)
if dtm.axes() > 1:
plot_reconstruction(dtm, test_labels[dtm.output_name()], reconstruction, my_out_path, test_paths, samples)
else:
evaluate_predictions(
dtm, reconstruction, test_labels[dtm.output_name()], {}, dtm.name, my_out_path,
test_paths, dpi=args.dpi, width=args.plot_width, height=args.plot_height,
)
predictions_list = model.predict(test_data)
samples = min(args.test_steps * args.batch_size, 12)
out_path = os.path.join(args.output_folder, args.id, 'reconstructions/')
if len(args.tensor_maps_out) == 1:
predictions_list = [predictions_list]
predictions_dict = {name: pred for name, pred in zip(model.output_names, predictions_list)}
logging.info(f'Predictions and shapes are: {[(p, predictions_dict[p].shape) for p in predictions_dict]}')

for i, etm in enumerate(encoders):
embed = encoders[etm].predict(test_data[etm.input_name()])
if etm.output_name() in predictions_dict:
plot_reconstruction(etm, test_data[etm.input_name()], predictions_dict[etm.output_name()], out_path, test_paths, samples)
for dtm in decoders:
reconstruction = decoders[dtm].predict(embed)
logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}')
my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/')
os.makedirs(os.path.dirname(my_out_path), exist_ok=True)
if dtm.axes() > 1:
plot_reconstruction(dtm, test_labels[dtm.output_name()], reconstruction, my_out_path, test_paths, samples)
else:
evaluate_predictions(
dtm, reconstruction, test_labels[dtm.output_name()], {}, dtm.name, my_out_path,
test_paths, dpi=args.dpi, width=args.plot_width, height=args.plot_height,
)
return performance_metrics


Expand Down
2 changes: 1 addition & 1 deletion ml4h/tensor_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
:param paths: If weights is provided, paths should be a list of path lists the same length as weights
"""
self.augment = augment
self.paths = sum(paths) if isinstance(paths[0], list) else paths
self.paths = sum(paths) if (len(paths) > 0 and isinstance(paths[0], list)) else paths
self.run_on_main_thread = num_workers == 0
self.q = None
self.stats_q = None
Expand Down
Loading

0 comments on commit 0064b83

Please sign in to comment.