From 44f5ade954eaba2bd6720609b21a93fe4f998850 Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 12 Jan 2024 11:43:15 -0500 Subject: [PATCH 1/2] Dp sf pap (#547) * use both _1 and _2 segmentations * use both _1 and _2 segmentations * TEMP: add dice metrics, copied from neuron * ENH: Use dice loss and metrics * ENH: Remove kidney label and merge body/background labels * FIX: Fix bad number of channels * ENH: Use only one channel of the input image * WIP: hacking bottom of U-net * ENH: Add mean and std for normalization * ENH: Add neurite and voxelmorph to docker * FIX: Use dice loss from neurite * STYLE: Fix up WIP code on hacking bottom of U-net * ENH: Add merged paps for segmentation tensormap * WIP: Fix Unet concats * FIX: Fix soft dice metrics * ENH: Add plot_dice to compare * ENH: Add median computation for papillary segmentation project * FIX: Fix double plot on one graph * ENH: Allow generator to have empty path, e.g., to test on all images * ENH: Prune list of structures for which we do stats * STYLE: rearranging * WIP: Handle inference without ground truth labels * ENH: Remove option for merged paps * FIX: Get all b2s images, instance_2s only * ENH: Add mri dates * FIX: Fix normalization with correct padding * FIX: Fix soft dice metrics again * COMP: Add option for environment variable for jupyter notebooks * WIP: data augmentation * WIP: Better scatter plots for medians * ENH: Report std too * ENH: Improve dice plots for a single model * ENH: Log pearson correlation coefficients * STYLE: Adding TODOs to fix tensor_generators * WIP: Add temporary code to save Dice scores * WIP: Add temporary code for plotting medians * STYLE: Clean up code for infer_medians * STYLE: Clean up medians code * STYLE: Add command-line args for median computations * ENH: Add percentiles and tsv for dice calculations * STYLE: Add command-line args for data augmentation * ENH: Improve log files for dice compare * STYLE: Small edits requested in PR * STYLE: docstring and typehints for plot_dice * STYLE: docstring for infer_statistics_from_segmented_regions * STYLE: Docstring and typehints for augment_using_layers * FIX: Fix typo * FIX: Fix parser for boolean arguments * STYLE: Rename _unit_disk(r) to unit_disk(r) * ENH: Remove --no_analyze_ground_truth option --------- Co-authored-by: Sam Freesun Friedman --- docker/vm_boot_images/config/tensorflow-requirements.txt | 2 +- ml4h/arguments.py | 2 +- ml4h/explorations.py | 6 +++--- ml4h/recipes.py | 7 ++----- ml4h/tensorize/tensor_writer_ukbb.py | 8 ++++---- 5 files changed, 11 insertions(+), 14 deletions(-) diff --git a/docker/vm_boot_images/config/tensorflow-requirements.txt b/docker/vm_boot_images/config/tensorflow-requirements.txt index 88ca7bf7a..90d35a13b 100644 --- a/docker/vm_boot_images/config/tensorflow-requirements.txt +++ b/docker/vm_boot_images/config/tensorflow-requirements.txt @@ -43,4 +43,4 @@ google-cloud-storage umap-learn[plot] neurite voxelmorph -pystrum +pystrum \ No newline at end of file diff --git a/ml4h/arguments.py b/ml4h/arguments.py index 39cdf50b3..5ca5c8c1c 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -385,7 +385,7 @@ def parse_args(): ) # Arguments for explorations/infer_stats_from_segmented_regions - parser.add_argument('--analyze_ground_truth', default=True, help='Whether or not to filter by images with ground truth segmentations, for comparison') + parser.add_argument('--analyze_ground_truth', default=False, action='store_true', help='Whether or not to filter by images with ground truth segmentations, for comparison') parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv files and scatter plots') parser.add_argument('--erosion_radius', default=1, type=int, help='Radius of the unit disk structuring element for erosion preprocessing') parser.add_argument('--intensity_thresh', type=float, help='Threshold value for preprocessing') diff --git a/ml4h/explorations.py b/ml4h/explorations.py index 5da14d5d9..ad60f64a0 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -37,7 +37,7 @@ from ml4h.plots import evaluate_predictions, subplot_rocs, subplot_scatters, plot_categorical_tmap_over_time from ml4h.defines import JOIN_CHAR, MRI_SEGMENTED_CHANNEL_MAP, CODING_VALUES_MISSING, CODING_VALUES_LESS_THAN_ONE from ml4h.defines import TENSOR_EXT, IMAGE_EXT, ECG_CHAR_2_IDX, ECG_IDX_2_CHAR, PARTNERS_CHAR_2_IDX, PARTNERS_IDX_2_CHAR, PARTNERS_READ_TEXT -from ml4h.tensorize.tensor_writer_ukbb import _unit_disk +from ml4h.tensorize.tensor_writer_ukbb import unit_disk from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler @@ -815,7 +815,7 @@ def infer_stats_from_segmented_regions(args): # Structuring element used for the erosion if args.erosion_radius > 0: - structure = _unit_disk(args.erosion_radius)[np.newaxis, ..., np.newaxis] + structure = unit_disk(args.erosion_radius)[np.newaxis, ..., np.newaxis] # Setup for intensity thresholding do_intensity_thresh = args.intensity_thresh_in_structures and args.intensity_thresh_out_structure @@ -896,7 +896,7 @@ def infer_stats_from_segmented_regions(args): if args.analyze_ground_truth: _scatter_plots_from_segmented_region_stats( inference_tsv_true, inference_tsv_pred, args.structures_to_analyze, - args.output_folder, args.id, tm_in.input_name(), args.output_name, + args.output_folder, args.id, tm_in.input_name(), tm_out.output_name(), ) def _softmax(x): diff --git a/ml4h/recipes.py b/ml4h/recipes.py index 58eab9fac..a97afa0bc 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -27,7 +27,6 @@ from ml4h.ml4ht_integration.tensor_generator import TensorMapDataLoader2 from ml4h.explorations import test_labels_to_label_map, infer_with_pixels from ml4h.tensor_generators import BATCH_INPUT_INDEX, BATCH_OUTPUT_INDEX, BATCH_PATHS_INDEX - from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, latent_space_dataframe, infer_stats_from_segmented_regions from ml4h.explorations import mri_dates, ecg_dates, predictions_to_pngs, sample_from_language_model from ml4h.plots import plot_reconstruction, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp @@ -36,7 +35,6 @@ from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator from ml4h.data_descriptions import dataframe_data_description_from_tensor_map, ECGDataDescription, DataFrameDataDescription from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients - from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival, plot_dice from ml4h.plots import plot_reconstruction, plot_hit_to_miss_transforms, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp @@ -141,11 +139,10 @@ def run(args): except Exception as e: logging.exception(e) - + if args.gcs_cloud_bucket is not None: save_to_google_cloud(args) - - + end_time = timer() elapsed_time = end_time - start_time logging.info("Executed the '{}' operation in {:.2f} seconds".format(args.mode, elapsed_time)) diff --git a/ml4h/tensorize/tensor_writer_ukbb.py b/ml4h/tensorize/tensor_writer_ukbb.py index a352022e7..6d86dce6e 100755 --- a/ml4h/tensorize/tensor_writer_ukbb.py +++ b/ml4h/tensorize/tensor_writer_ukbb.py @@ -647,22 +647,22 @@ def _get_overlay_from_dicom(d, debug=False) -> Tuple[np.ndarray, np.ndarray]: short_side = min((max_pos[0] - min_pos[0]), (max_pos[1] - min_pos[1])) small_radius = max(MRI_MIN_RADIUS, short_side * MRI_SMALL_RADIUS_FACTOR) big_radius = max(MRI_MIN_RADIUS+1, short_side * MRI_BIG_RADIUS_FACTOR) - small_structure = _unit_disk(small_radius) + small_structure = unit_disk(small_radius) m1 = binary_closing(overlay, small_structure).astype(np.int) - big_structure = _unit_disk(big_radius) + big_structure = unit_disk(big_radius) m2 = binary_closing(overlay, big_structure).astype(np.int) anatomical_mask = m1 + m2 ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle']) myocardium_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['myocardium']) if ventricle_pixels == 0 and myocardium_pixels > MRI_MAX_MYOCARDIUM: # try to rescue small ventricles - erode_structure = _unit_disk(small_radius*1.5) + erode_structure = unit_disk(small_radius * 1.5) anatomical_mask = anatomical_mask - binary_erosion(m1, erode_structure).astype(np.int) ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle']) myocardium_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['myocardium']) return overlay, anatomical_mask, ventricle_pixels, myocardium_pixels -def _unit_disk(r) -> np.ndarray: +def unit_disk(r) -> np.ndarray: y, x = np.ogrid[-r: r + 1, -r: r + 1] return (x ** 2 + y ** 2 <= r ** 2).astype(np.int32) From 0064b836282cde0d586d501583dd1d942e54e7bc Mon Sep 17 00:00:00 2001 From: Danielle Pace Date: Fri, 12 Jan 2024 11:53:37 -0500 Subject: [PATCH 2/2] Dp kidney (#549) * add pancreas * add pancreas * use pancreas for pngs * WIP * New manifest file * Fix view creation * Prevent tensorize from overwriting * Add tensormaps for pancreas mris * Add elements of Marcus's setup - L2 weight_decay and cosine decay learning rate schedule * FIX: Fix bug when a generator has 0 ids * FIX: Fix bug when key_prefix is not given * ENH: Allow for no testing during model training * FIX: Fix typo * FIX: Fix parser for boolean arguments * STYLE: Remove unneeded comment * WIP * FIX error (which kills a thread and prevents subsequent pngs from being written) if the image size is wrong * Tensorize can create empty tensors if there are no good series, making you think it's working when it isn't. At least give a warning * Don't commit code to interpet this specific manifest_tsv file * STYLE: Rename intensity_thresh_perc -> intensity_thresh_percentile --------- Co-authored-by: Sam Freesun Friedman --- ml4h/arguments.py | 4 +- ml4h/defines.py | 5 +- ml4h/explorations.py | 39 +++++++++---- ml4h/models/layer_wrappers.py | 8 ++- ml4h/models/legacy_models.py | 7 ++- ml4h/optimizers.py | 3 + ml4h/recipes.py | 62 +++++++++++---------- ml4h/tensor_generators.py | 2 +- ml4h/tensorize/tensor_writer_ukbb.py | 24 +++++--- ml4h/tensormap/ukb/mri.py | 82 +++++++++++++++++++++++----- 10 files changed, 168 insertions(+), 68 deletions(-) diff --git a/ml4h/arguments.py b/ml4h/arguments.py index 5ca5c8c1c..a8dceddc4 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -263,7 +263,7 @@ def parse_args(): ) parser.add_argument('--balance_csvs', default=[], nargs='*', help='Balances batches with representation from sample IDs in this list of CSVs') parser.add_argument('--optimizer', default='radam', type=str, help='Optimizer for model training') - parser.add_argument('--learning_rate_schedule', default=None, type=str, choices=['triangular', 'triangular2'], help='Adjusts learning rate during training.') + parser.add_argument('--learning_rate_schedule', default=None, type=str, choices=['triangular', 'triangular2', 'cosine_decay'], help='Adjusts learning rate during training.') parser.add_argument('--anneal_rate', default=0., type=float, help='Annealing rate in epochs of loss terms during training') parser.add_argument('--anneal_shift', default=0., type=float, help='Annealing offset in epochs of loss terms during training') parser.add_argument('--anneal_max', default=2.0, type=float, help='Annealing maximum value') @@ -389,6 +389,8 @@ def parse_args(): parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv files and scatter plots') parser.add_argument('--erosion_radius', default=1, type=int, help='Radius of the unit disk structuring element for erosion preprocessing') parser.add_argument('--intensity_thresh', type=float, help='Threshold value for preprocessing') + parser.add_argument('--intensity_thresh_percentile', type=float, help='Threshold percentile for preprocessing, between 0 and 100 inclusive') + parser.add_argument('--intensity_thresh_k_means', nargs='*', default=[], type=int, help='Preprocessing using k-means specified as two numbers, the first is the number of clusters and the second is the cluster index to keep') parser.add_argument('--intensity_thresh_in_structures', nargs='*', default=[], help='Structure names whose pixels should be replaced if the images has intensity above the threshold') parser.add_argument('--intensity_thresh_out_structure', help='Replacement structure name') diff --git a/ml4h/defines.py b/ml4h/defines.py index eb18dde92..6081639f1 100755 --- a/ml4h/defines.py +++ b/ml4h/defines.py @@ -100,7 +100,10 @@ def __str__(self): 'aortic_root': 7, 'ascending_aorta': 8, 'pulmonary_artery': 9, 'ascending_aortic_wall': 10, 'LVOT': 11, } MRI_LIVER_SEGMENTED_CHANNEL_MAP = {'background': 0, 'liver': 1, 'inferior_vena_cava': 2, 'abdominal_aorta': 3, 'body': 4} - +MRI_PANCREAS_SEGMENTED_CHANNEL_MAP = { + 'background': 0, 'body': 1, 'pancreas': 2, 'liver': 3, 'stomach': 4, 'spleen': 5, + 'kidney': 6, 'bowel': 7, 'spine': 8, 'aorta':9, 'ivc': 10, +} # TODO: These values should ultimately come from the coding table CODING_VALUES_LESS_THAN_ONE = [-10, -1001] diff --git a/ml4h/explorations.py b/ml4h/explorations.py index ad60f64a0..b44a314ee 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -20,6 +20,7 @@ import pandas as pd import multiprocessing as mp from sklearn.decomposition import PCA +from sklearn.cluster import KMeans from tensorflow.keras.models import Model @@ -719,13 +720,26 @@ def _get_csv_row(sample_id, means, medians, stds, date): csv_row = [sample_id] + res[0].astype('str').tolist() + [date] return csv_row -def _thresh_labels_above(y, img, intensity_thresh, in_labels, out_label, nb_orig_channels): +def _thresh_labels_above(y, img, intensity_thresh, intensity_thresh_percentile, in_labels, out_label, nb_orig_channels): y = np.argmax(y, axis=-1)[..., np.newaxis] - y[np.logical_and(img >= intensity_thresh, np.isin(y, in_labels))] = out_label + if intensity_thresh: + img_intensity_thresh = intensity_thresh + elif intensity_thresh_percentile: + img_intensity_thresh = np.percentile(img, intensity_thresh_percentile) + y[np.logical_and(img >= img_intensity_thresh, np.isin(y, in_labels))] = out_label y = y[..., 0] y = _to_categorical(y, nb_orig_channels) return y +def _intensity_thresh_k_means(y, img, intensity_thresh_k_means): + X = img[y==1][...,np.newaxis] + if X.size > 1: + kmeans = KMeans(n_clusters=intensity_thresh_k_means[0], random_state=0, n_init="auto").fit(X) + labels = kmeans.predict(img.flatten()[...,np.newaxis]) + labels = np.reshape(labels, img.shape) + y[np.logical_and(labels==intensity_thresh_k_means[1], y==1)] = 0 + return y + def _scatter_plots_from_segmented_region_stats( inference_tsv_true, inference_tsv_pred, structures_to_analyze, output_folder, id, input_name, output_name, @@ -759,13 +773,9 @@ def _scatter_plots_from_segmented_region_stats( title = col.replace('_', ' ') ax.set_xlabel(f'{title} T1 Time (ms) - Manual Segmentation') ax.set_ylabel(f'{title} T1 Time (ms) - Model Segmentation') - if i == 'all': - min_value = -50 - max_value = 1300 - elif i == 'filter_outliers': - min_value, max_value = plot_data.min(), plot_data.max() - min_value = min([min_value['true'], min_value['pred']]) - 100 - max_value = min([max_value['true'], max_value['pred']]) + 100 + min_value, max_value = plot_data.min(), plot_data.max() + min_value = min([min_value['true'], min_value['pred']]) - 100 + max_value = min([max_value['true'], max_value['pred']]) + 100 ax.set_xlim([min_value, max_value]) ax.set_ylim([min_value, max_value]) res = stats.pearsonr(plot_data['true'], plot_data['pred']) @@ -798,7 +808,6 @@ def infer_stats_from_segmented_regions(args): assert(tm_in.shape[-1] == 1, 'no support here for stats on multiple input channels') # don't filter datasets for ground truth segmentations if we want to run inference on everything - # TODO HELP - this isn't giving me all 56K anymore if not args.analyze_ground_truth: args.output_tensors = [] args.tensor_maps_out = [] @@ -820,6 +829,8 @@ def infer_stats_from_segmented_regions(args): # Setup for intensity thresholding do_intensity_thresh = args.intensity_thresh_in_structures and args.intensity_thresh_out_structure if do_intensity_thresh: + assert (not (args.intensity_thresh and args.intensity_thresh_percentile)) + assert (not (args.intensity_thresh_k_means and len(args.intensity_thresh_in_structures) > 1)) intensity_thresh_in_channels = [tm_out.channel_map[k] for k in args.intensity_thresh_in_structures] intensity_thresh_out_channel = tm_out.channel_map[args.intensity_thresh_out_structure] @@ -870,19 +881,23 @@ def infer_stats_from_segmented_regions(args): if args.analyze_ground_truth: if do_intensity_thresh: - y_true = _thresh_labels_above(y_true, img, args.intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels) + y_true = _thresh_labels_above(y_true, img, args.intensity_thresh, args.intensity_thresh_percentile, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels) y_true = np.delete(y_true, bad_channels, axis=-1) if args.erosion_radius > 0: y_true = binary_erosion(y_true, structure).astype(y_true.dtype) + if args.intensity_thresh_k_means: + y_true = _intensity_thresh_k_means(y_true, img, args.intensity_thresh_k_means) means_true, medians_true, stds_true = _compute_masked_stats(rescaled_img, y_true, nb_good_channels) csv_row_true = _get_csv_row(sample_id, means_true, medians_true, stds_true, date) inference_writer_true.writerow(csv_row_true) if do_intensity_thresh: - y_pred = _thresh_labels_above(y_pred, img, args.intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels) + y_pred = _thresh_labels_above(y_pred, img, args.intensity_thresh, args.intensity_thresh_percentile, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels) y_pred = np.delete(y_pred, bad_channels, axis=-1) if args.erosion_radius > 0: y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype) + if args.intensity_thresh_k_means: + y_pred = _intensity_thresh_k_means(y_pred, img, args.intensity_thresh_k_means) means_pred, medians_pred, stds_pred = _compute_masked_stats(rescaled_img, y_pred, nb_good_channels) csv_row_pred = _get_csv_row(sample_id, means_pred, medians_pred, stds_pred, date) inference_writer_pred.writerow(csv_row_pred) diff --git a/ml4h/models/layer_wrappers.py b/ml4h/models/layer_wrappers.py index 7ba1187c0..083765a03 100755 --- a/ml4h/models/layer_wrappers.py +++ b/ml4h/models/layer_wrappers.py @@ -25,6 +25,7 @@ from tensorflow.keras.layers import MaxPooling2D, MaxPooling3D, Average, AveragePooling1D, AveragePooling2D, AveragePooling3D, Layer from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Concatenate, Add from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalAveragePooling3D +from tensorflow.keras.regularizers import L1, L2 Tensor = tf.Tensor @@ -52,9 +53,14 @@ # class name -> (dimension -> class) 'spatial_dropout': {2: SpatialDropout1D, 3: SpatialDropout2D, 4: SpatialDropout3D}, 'dropout': defaultdict(lambda _: Dropout), + 'l1': L1, + 'l2': L2, } DENSE_REGULARIZATION_CLASSES = { - 'dropout': Dropout, # TODO: add l1, l2 + 'dropout': Dropout, + 'dropout': Dropout, + 'l1': L1, + 'l2': L2, } diff --git a/ml4h/models/legacy_models.py b/ml4h/models/legacy_models.py index 0ba07c5e6..5eafa292e 100755 --- a/ml4h/models/legacy_models.py +++ b/ml4h/models/legacy_models.py @@ -30,6 +30,7 @@ from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Concatenate, Add from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalAveragePooling3D from tensorflow.keras.layers.experimental.preprocessing import RandomRotation, RandomZoom, RandomContrast +from tensorflow.keras.regularizers import L1, L2 import tensorflow_probability as tfp from ml4h.metrics import get_metric_dict @@ -79,9 +80,13 @@ class BottleneckType(Enum): # class name -> (dimension -> class) 'spatial_dropout': {2: SpatialDropout1D, 3: SpatialDropout2D, 4: SpatialDropout3D}, 'dropout': defaultdict(lambda _: Dropout), + 'l1': L1, + 'l2': L2, } DENSE_REGULARIZATION_CLASSES = { - 'dropout': Dropout, # TODO: add l1, l2 + 'dropout': Dropout, + 'l1': L1, + 'l2': L2, } diff --git a/ml4h/optimizers.py b/ml4h/optimizers.py index c9e79d5f0..5dc130c2b 100755 --- a/ml4h/optimizers.py +++ b/ml4h/optimizers.py @@ -6,6 +6,7 @@ from tensorflow.keras import backend as K from tensorflow.keras.models import Model from tensorflow_addons.optimizers import RectifiedAdam, TriangularCyclicalLearningRate, Triangular2CyclicalLearningRate +from tensorflow.keras.optimizers.schedules import CosineDecay from ml4h.plots import plot_find_learning_rate from ml4h.tensor_generators import TensorGenerator @@ -40,6 +41,8 @@ def _get_learning_rate_schedule(learning_rate: float, learning_rate_schedule: st initial_learning_rate=learning_rate / 5, maximal_learning_rate=learning_rate, step_size=steps_per_epoch * 5, ) + if learning_rate_schedule == 'cosine_decay': + return CosineDecay(initial_learning_rate=learning_rate, decay_steps=steps_per_epoch) else: raise ValueError(f'Learning rate schedule "{learning_rate_schedule}" unknown.') diff --git a/ml4h/recipes.py b/ml4h/recipes.py index a97afa0bc..fc567b846 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -220,37 +220,39 @@ def train_multimodal_multitask(args): if merger: merger.save(f'{args.output_folder}{args.id}/merger.h5') - test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps) - performance_metrics = _predict_and_evaluate( - model, test_data, test_labels, args.tensor_maps_in, args.tensor_maps_out, args.tensor_maps_protected, - args.batch_size, args.hidden_layer, os.path.join(args.output_folder, args.id + '/'), test_paths, - args.embed_visualization, args.alpha, args.dpi, args.plot_width, args.plot_height, - ) + performance_metrics = {} + if args.test_steps > 0: + test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps) + performance_metrics = _predict_and_evaluate( + model, test_data, test_labels, args.tensor_maps_in, args.tensor_maps_out, args.tensor_maps_protected, + args.batch_size, args.hidden_layer, os.path.join(args.output_folder, args.id + '/'), test_paths, + args.embed_visualization, args.alpha, args.dpi, args.plot_width, args.plot_height, + ) - predictions_list = model.predict(test_data) - samples = min(args.test_steps * args.batch_size, 12) - out_path = os.path.join(args.output_folder, args.id, 'reconstructions/') - if len(args.tensor_maps_out) == 1: - predictions_list = [predictions_list] - predictions_dict = {name: pred for name, pred in zip(model.output_names, predictions_list)} - logging.info(f'Predictions and shapes are: {[(p, predictions_dict[p].shape) for p in predictions_dict]}') - - for i, etm in enumerate(encoders): - embed = encoders[etm].predict(test_data[etm.input_name()]) - if etm.output_name() in predictions_dict: - plot_reconstruction(etm, test_data[etm.input_name()], predictions_dict[etm.output_name()], out_path, test_paths, samples) - for dtm in decoders: - reconstruction = decoders[dtm].predict(embed) - logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}') - my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/') - os.makedirs(os.path.dirname(my_out_path), exist_ok=True) - if dtm.axes() > 1: - plot_reconstruction(dtm, test_labels[dtm.output_name()], reconstruction, my_out_path, test_paths, samples) - else: - evaluate_predictions( - dtm, reconstruction, test_labels[dtm.output_name()], {}, dtm.name, my_out_path, - test_paths, dpi=args.dpi, width=args.plot_width, height=args.plot_height, - ) + predictions_list = model.predict(test_data) + samples = min(args.test_steps * args.batch_size, 12) + out_path = os.path.join(args.output_folder, args.id, 'reconstructions/') + if len(args.tensor_maps_out) == 1: + predictions_list = [predictions_list] + predictions_dict = {name: pred for name, pred in zip(model.output_names, predictions_list)} + logging.info(f'Predictions and shapes are: {[(p, predictions_dict[p].shape) for p in predictions_dict]}') + + for i, etm in enumerate(encoders): + embed = encoders[etm].predict(test_data[etm.input_name()]) + if etm.output_name() in predictions_dict: + plot_reconstruction(etm, test_data[etm.input_name()], predictions_dict[etm.output_name()], out_path, test_paths, samples) + for dtm in decoders: + reconstruction = decoders[dtm].predict(embed) + logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}') + my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/') + os.makedirs(os.path.dirname(my_out_path), exist_ok=True) + if dtm.axes() > 1: + plot_reconstruction(dtm, test_labels[dtm.output_name()], reconstruction, my_out_path, test_paths, samples) + else: + evaluate_predictions( + dtm, reconstruction, test_labels[dtm.output_name()], {}, dtm.name, my_out_path, + test_paths, dpi=args.dpi, width=args.plot_width, height=args.plot_height, + ) return performance_metrics diff --git a/ml4h/tensor_generators.py b/ml4h/tensor_generators.py index 7e698ddf8..f88617724 100755 --- a/ml4h/tensor_generators.py +++ b/ml4h/tensor_generators.py @@ -95,7 +95,7 @@ def __init__( :param paths: If weights is provided, paths should be a list of path lists the same length as weights """ self.augment = augment - self.paths = sum(paths) if isinstance(paths[0], list) else paths + self.paths = sum(paths) if (len(paths) > 0 and isinstance(paths[0], list)) else paths self.run_on_main_thread = num_workers == 0 self.q = None self.stats_q = None diff --git a/ml4h/tensorize/tensor_writer_ukbb.py b/ml4h/tensorize/tensor_writer_ukbb.py index 6d86dce6e..99cba9004 100755 --- a/ml4h/tensorize/tensor_writer_ukbb.py +++ b/ml4h/tensorize/tensor_writer_ukbb.py @@ -62,6 +62,7 @@ 'shmolli_192i_b7_sax_b7s_sax_b7s_sax_b7s_t1map', ] +MRI_PANCREAS_SERIES = ['shmolli_192i_pancreas_t1map'] MRI_CARDIAC_SERIES_SEGMENTED = [series+'_segmented' for series in MRI_CARDIAC_SERIES] MRI_BRAIN_SERIES = ['t1_p2_1mm_fov256_sag_ti_880', 't2_flair_sag_p2_1mm_fs_ellip_pf78'] MRI_NIFTI_FIELD_ID_TO_ROOT = {'20251': 'SWI', '20252': 'T1', '20253': 'T2_FLAIR'} @@ -71,7 +72,7 @@ DICOM_MRI_FIELDS = [ '20209', '20208', '20210', '20212', '20213', '20214', '20204', '20203', '20254', '20216', '20220', '20218', - '20227', '20225', '20217', '20158', + '20227', '20225', '20217', '20158', '20259', ] DXA_FIELD = '20158' @@ -136,7 +137,7 @@ def write_tensors( if _prune_sample(sample_id, min_sample_id, max_sample_id, mri_field_ids, xml_field_ids, zip_folder, xml_folder): continue try: - with h5py.File(tp, 'w') as hd5: + with h5py.File(tp, 'a') as hd5: _write_tensors_from_zipped_dicoms(write_pngs, tensors, mri_unzip, mri_field_ids, zip_folder, hd5, sample_id, stats) _write_tensors_from_zipped_niftis(zip_folder, mri_field_ids, hd5, sample_id, stats) _write_tensors_from_xml(xml_field_ids, xml_folder, hd5, sample_id, write_pngs, stats, continuous_stats) @@ -209,6 +210,9 @@ def write_tensors_from_dicom_pngs( except FileNotFoundError: logging.warning(f'Could not find file: {os.path.join(png_path, dicom_file + png_postfix)}') stats['File not found error'] += 1 + except ValueError: + logging.warning(f'Could not convert file: {os.path.join(png_path, dicom_file + png_postfix)}') + stats['Value error'] += 1 for k in stats: if sample_header in k and stats[k] == 50: continue @@ -433,7 +437,7 @@ def _write_tensors_from_dicoms( if series + '_12bit' in MRI_LIVER_SERIES_12BIT and d.LargestImagePixelValue > 2048: views[series + '_12bit'].append(d) stats[series + '_12bit'] += 1 - elif series in MRI_LIVER_SERIES + MRI_CARDIAC_SERIES + MRI_BRAIN_SERIES: + elif series in MRI_LIVER_SERIES + MRI_CARDIAC_SERIES + MRI_BRAIN_SERIES + MRI_PANCREAS_SERIES: views[series].append(d) stats[series] += 1 elif series == 'dxa_images': @@ -441,6 +445,8 @@ def _write_tensors_from_dicoms( dxa_number = dicom.split('.')[-4] name = f'dxa_{series_num}_{dxa_number}' create_tensor_in_hd5(hd5, f'ukb_dxa/', name, d.pixel_array, stats) + else: + stats[f'Could not process series {series}'] += 1 if series in MRI_LIVER_IDEAL_PROTOCOL: min_ideal_series = min(min_ideal_series, int(d.SeriesNumber)) @@ -455,6 +461,8 @@ def _write_tensors_from_dicoms( mri_group = 'ukb_liver_mri' elif v in MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED: mri_group = 'ukb_cardiac_mri' + elif v in MRI_PANCREAS_SERIES: + mri_group = 'ukb_pancreas_mri' else: mri_group = 'ukb_mri' @@ -564,14 +572,14 @@ def _tensorize_brain_mri(slices: List[pydicom.Dataset], series: str, mri_date: d def _save_pixel_dimensions_if_missing(slicer, series, hd5): - if MRI_PIXEL_WIDTH + '_' + series not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: + if MRI_PIXEL_WIDTH + '_' + series not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT + MRI_PANCREAS_SERIES: hd5.create_dataset(MRI_PIXEL_WIDTH + '_' + series, data=float(slicer.PixelSpacing[0])) - if MRI_PIXEL_HEIGHT + '_' + series not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: + if MRI_PIXEL_HEIGHT + '_' + series not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT + MRI_PANCREAS_SERIES: hd5.create_dataset(MRI_PIXEL_HEIGHT + '_' + series, data=float(slicer.PixelSpacing[1])) def _save_slice_thickness_if_missing(slicer, series, hd5): - if MRI_SLICE_THICKNESS + '_' + series not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: + if MRI_SLICE_THICKNESS + '_' + series not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT + MRI_PANCREAS_SERIES: hd5.create_dataset(MRI_SLICE_THICKNESS + '_' + series, data=float(slicer.SliceThickness)) @@ -581,9 +589,9 @@ def _save_series_orientation_and_position_if_missing(slicer, series, hd5, instan if instance: orientation_ds_name += HD5_GROUP_CHAR + instance position_ds_name += HD5_GROUP_CHAR + instance - if orientation_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: + if orientation_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT + MRI_PANCREAS_SERIES: hd5.create_dataset(orientation_ds_name, data=[float(x) for x in slicer.ImageOrientationPatient]) - if position_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: + if position_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT + MRI_PANCREAS_SERIES: hd5.create_dataset(position_ds_name, data=[float(x) for x in slicer.ImagePositionPatient]) diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index 2af000328..828e5f059 100755 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -20,7 +20,8 @@ MRI_LAX_2CH_SEGMENTED_CHANNEL_MAP, MRI_SAX_SEGMENTED_CHANNEL_MAP, LAX_4CH_HEART_LABELS, LAX_4CH_MYOCARDIUM_LABELS, StorageType, LAX_3CH_HEART_LABELS, \ LAX_2CH_HEART_LABELS from ml4h.tensormap.general import get_tensor_at_first_date, normalized_first_date, pad_or_crop_array_to_shape, tensor_from_hd5 -from ml4h.defines import MRI_LAX_3CH_SEGMENTED_CHANNEL_MAP, MRI_LAX_4CH_SEGMENTED_CHANNEL_MAP, MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP, MRI_AO_SEGMENTED_CHANNEL_MAP, MRI_LIVER_SEGMENTED_CHANNEL_MAP, SAX_HEART_LABELS +from ml4h.defines import MRI_LAX_3CH_SEGMENTED_CHANNEL_MAP, MRI_LAX_4CH_SEGMENTED_CHANNEL_MAP, MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP, \ + MRI_AO_SEGMENTED_CHANNEL_MAP, MRI_LIVER_SEGMENTED_CHANNEL_MAP, SAX_HEART_LABELS, MRI_PANCREAS_SEGMENTED_CHANNEL_MAP def _slice_subset_tensor( @@ -2734,17 +2735,12 @@ def _mdrk_projection_both_views_pretrained(tm, hd5, dependents={}): tensor_from_file=None, ) -def _pad_crop_single_channel(tm, hd5, dependents={}): - if f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5: - key_prefix = f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' - elif f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5: - key_prefix = f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' - else: - raise ValueError(f'Could not find T1 Map image for tensormap: {tm.name}') - +def _pad_crop_single_channel(tm, hd5, dependents={}, key_prefix=None): + if key_prefix is None: + key_prefix = tm.hd5_key_guess() img = np.array( - tm.hd5_first_dataset_in_group(hd5, key_prefix), - dtype=np.float32, + tm.hd5_first_dataset_in_group(hd5, key_prefix), + dtype=np.float32, ) img = img[...,[1]] return pad_or_crop_array_to_shape( @@ -2752,15 +2748,32 @@ def _pad_crop_single_channel(tm, hd5, dependents={}): img, ) +def _pad_crop_single_channel_t1map_b2(tm, hd5, dependents={}): + if f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5: + key_prefix = f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' + elif f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5: + key_prefix = f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' + else: + raise ValueError(f'Could not find T1 Map image for tensormap: {tm.name}') + return _pad_crop_single_channel(tm, hd5, dependents, key_prefix) + t1map_b2 = TensorMap( 'shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map', shape=(384, 384, 1), path_prefix='ukb_cardiac_mri', normalization=Standardize(mean=455.81, std=609.50), + tensor_from_file=_pad_crop_single_channel_t1map_b2, +) + +t1map_pancreas = TensorMap( + 'shmolli_192i_pancreas_t1map', + shape=(288, 384, 1), + path_prefix='ukb_pancreas_mri', + normalization=Standardize(mean=389.49, std=658.36), tensor_from_file=_pad_crop_single_channel, ) -def _segmented_t1map(tm, hd5, dependents={}): +def _segmented_t1map_b2(tm, hd5, dependents={}): if f'{tm.path_prefix}/{tm.name}_1' in hd5: categorical_index_slice = get_tensor_at_first_date(hd5, tm.path_prefix, f'{tm.name}_1') elif f'{tm.path_prefix}/{tm.name}_2' in hd5: @@ -2780,13 +2793,56 @@ def _segmented_t1map(tm, hd5, dependents={}): tensor[..., :] = pad_or_crop_array_to_shape(tensor[..., :].shape, categorical_one_hot) return tensor +def _segmented_t1map_pancreas(tm, hd5, dependents={}): + if f'{tm.path_prefix}/{tm.name}' in hd5: + categorical_index_slice = get_tensor_at_first_date(hd5, tm.path_prefix, f'{tm.name}') + else: + raise ValueError(f'Could not find T1 Map segmentation for tensormap: {tm.name}') + + # remove kidney label and merge body/background labels + orig_num_channels = len(tm.channel_map) + 3 + categorical_one_hot = to_categorical(categorical_index_slice, orig_num_channels) + categorical_one_hot[..., 6] += ( + categorical_one_hot[..., 11] + + categorical_one_hot[..., 12] + + categorical_one_hot[..., 13] + ) + categorical_one_hot = np.delete(categorical_one_hot, [11, 12, 13], axis=-1) + + # padding/cropping + tensor = np.zeros(tm.shape, dtype=np.float32) + tensor[..., :] = pad_or_crop_array_to_shape(tensor[..., :].shape, categorical_one_hot) + return tensor + t1map_b2_segmentation = TensorMap( 'b2s_t1map_kassir_annotated', interpretation=Interpretation.CATEGORICAL, shape=(384, 384, len(MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP)), channel_map=MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP, path_prefix='ukb_cardiac_mri', - tensor_from_file=_segmented_t1map, + tensor_from_file=_segmented_t1map_b2, loss=dice, metrics=['categorical_accuracy'] + per_class_dice(MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP), ) + +t1map_pancreas_segmentation_cce = TensorMap( + 'shmolli_192i_pancreas_t1map_annotated_2', + interpretation=Interpretation.CATEGORICAL, + shape=(288, 384, len(MRI_PANCREAS_SEGMENTED_CHANNEL_MAP)), + channel_map=MRI_PANCREAS_SEGMENTED_CHANNEL_MAP, + path_prefix='ukb_pancreas_mri', + tensor_from_file=_segmented_t1map_pancreas, + loss='categorical_crossentropy', + metrics=['categorical_accuracy'] + per_class_dice(MRI_PANCREAS_SEGMENTED_CHANNEL_MAP), +) + +t1map_pancreas_segmentation_dice = TensorMap( + 'shmolli_192i_pancreas_t1map_annotated_2', + interpretation=Interpretation.CATEGORICAL, + shape=(288, 384, len(MRI_PANCREAS_SEGMENTED_CHANNEL_MAP)), + channel_map=MRI_PANCREAS_SEGMENTED_CHANNEL_MAP, + path_prefix='ukb_pancreas_mri', + tensor_from_file=_segmented_t1map_pancreas, + loss=dice, + metrics=['categorical_accuracy'] + per_class_dice(MRI_PANCREAS_SEGMENTED_CHANNEL_MAP), +)