diff --git a/ml4h/arguments.py b/ml4h/arguments.py index 5ca5c8c1c..a8dceddc4 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -263,7 +263,7 @@ def parse_args(): ) parser.add_argument('--balance_csvs', default=[], nargs='*', help='Balances batches with representation from sample IDs in this list of CSVs') parser.add_argument('--optimizer', default='radam', type=str, help='Optimizer for model training') - parser.add_argument('--learning_rate_schedule', default=None, type=str, choices=['triangular', 'triangular2'], help='Adjusts learning rate during training.') + parser.add_argument('--learning_rate_schedule', default=None, type=str, choices=['triangular', 'triangular2', 'cosine_decay'], help='Adjusts learning rate during training.') parser.add_argument('--anneal_rate', default=0., type=float, help='Annealing rate in epochs of loss terms during training') parser.add_argument('--anneal_shift', default=0., type=float, help='Annealing offset in epochs of loss terms during training') parser.add_argument('--anneal_max', default=2.0, type=float, help='Annealing maximum value') @@ -389,6 +389,8 @@ def parse_args(): parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv files and scatter plots') parser.add_argument('--erosion_radius', default=1, type=int, help='Radius of the unit disk structuring element for erosion preprocessing') parser.add_argument('--intensity_thresh', type=float, help='Threshold value for preprocessing') + parser.add_argument('--intensity_thresh_percentile', type=float, help='Threshold percentile for preprocessing, between 0 and 100 inclusive') + parser.add_argument('--intensity_thresh_k_means', nargs='*', default=[], type=int, help='Preprocessing using k-means specified as two numbers, the first is the number of clusters and the second is the cluster index to keep') parser.add_argument('--intensity_thresh_in_structures', nargs='*', default=[], help='Structure names whose pixels should be replaced if the images has intensity above the threshold') parser.add_argument('--intensity_thresh_out_structure', help='Replacement structure name') diff --git a/ml4h/defines.py b/ml4h/defines.py index eb18dde92..6081639f1 100755 --- a/ml4h/defines.py +++ b/ml4h/defines.py @@ -100,7 +100,10 @@ def __str__(self): 'aortic_root': 7, 'ascending_aorta': 8, 'pulmonary_artery': 9, 'ascending_aortic_wall': 10, 'LVOT': 11, } MRI_LIVER_SEGMENTED_CHANNEL_MAP = {'background': 0, 'liver': 1, 'inferior_vena_cava': 2, 'abdominal_aorta': 3, 'body': 4} - +MRI_PANCREAS_SEGMENTED_CHANNEL_MAP = { + 'background': 0, 'body': 1, 'pancreas': 2, 'liver': 3, 'stomach': 4, 'spleen': 5, + 'kidney': 6, 'bowel': 7, 'spine': 8, 'aorta':9, 'ivc': 10, +} # TODO: These values should ultimately come from the coding table CODING_VALUES_LESS_THAN_ONE = [-10, -1001] diff --git a/ml4h/explorations.py b/ml4h/explorations.py index ad60f64a0..b44a314ee 100755 --- a/ml4h/explorations.py +++ b/ml4h/explorations.py @@ -20,6 +20,7 @@ import pandas as pd import multiprocessing as mp from sklearn.decomposition import PCA +from sklearn.cluster import KMeans from tensorflow.keras.models import Model @@ -719,13 +720,26 @@ def _get_csv_row(sample_id, means, medians, stds, date): csv_row = [sample_id] + res[0].astype('str').tolist() + [date] return csv_row -def _thresh_labels_above(y, img, intensity_thresh, in_labels, out_label, nb_orig_channels): +def _thresh_labels_above(y, img, intensity_thresh, intensity_thresh_percentile, in_labels, out_label, nb_orig_channels): y = np.argmax(y, axis=-1)[..., np.newaxis] - y[np.logical_and(img >= intensity_thresh, np.isin(y, in_labels))] = out_label + if intensity_thresh: + img_intensity_thresh = intensity_thresh + elif intensity_thresh_percentile: + img_intensity_thresh = np.percentile(img, intensity_thresh_percentile) + y[np.logical_and(img >= img_intensity_thresh, np.isin(y, in_labels))] = out_label y = y[..., 0] y = _to_categorical(y, nb_orig_channels) return y +def _intensity_thresh_k_means(y, img, intensity_thresh_k_means): + X = img[y==1][...,np.newaxis] + if X.size > 1: + kmeans = KMeans(n_clusters=intensity_thresh_k_means[0], random_state=0, n_init="auto").fit(X) + labels = kmeans.predict(img.flatten()[...,np.newaxis]) + labels = np.reshape(labels, img.shape) + y[np.logical_and(labels==intensity_thresh_k_means[1], y==1)] = 0 + return y + def _scatter_plots_from_segmented_region_stats( inference_tsv_true, inference_tsv_pred, structures_to_analyze, output_folder, id, input_name, output_name, @@ -759,13 +773,9 @@ def _scatter_plots_from_segmented_region_stats( title = col.replace('_', ' ') ax.set_xlabel(f'{title} T1 Time (ms) - Manual Segmentation') ax.set_ylabel(f'{title} T1 Time (ms) - Model Segmentation') - if i == 'all': - min_value = -50 - max_value = 1300 - elif i == 'filter_outliers': - min_value, max_value = plot_data.min(), plot_data.max() - min_value = min([min_value['true'], min_value['pred']]) - 100 - max_value = min([max_value['true'], max_value['pred']]) + 100 + min_value, max_value = plot_data.min(), plot_data.max() + min_value = min([min_value['true'], min_value['pred']]) - 100 + max_value = min([max_value['true'], max_value['pred']]) + 100 ax.set_xlim([min_value, max_value]) ax.set_ylim([min_value, max_value]) res = stats.pearsonr(plot_data['true'], plot_data['pred']) @@ -798,7 +808,6 @@ def infer_stats_from_segmented_regions(args): assert(tm_in.shape[-1] == 1, 'no support here for stats on multiple input channels') # don't filter datasets for ground truth segmentations if we want to run inference on everything - # TODO HELP - this isn't giving me all 56K anymore if not args.analyze_ground_truth: args.output_tensors = [] args.tensor_maps_out = [] @@ -820,6 +829,8 @@ def infer_stats_from_segmented_regions(args): # Setup for intensity thresholding do_intensity_thresh = args.intensity_thresh_in_structures and args.intensity_thresh_out_structure if do_intensity_thresh: + assert (not (args.intensity_thresh and args.intensity_thresh_percentile)) + assert (not (args.intensity_thresh_k_means and len(args.intensity_thresh_in_structures) > 1)) intensity_thresh_in_channels = [tm_out.channel_map[k] for k in args.intensity_thresh_in_structures] intensity_thresh_out_channel = tm_out.channel_map[args.intensity_thresh_out_structure] @@ -870,19 +881,23 @@ def infer_stats_from_segmented_regions(args): if args.analyze_ground_truth: if do_intensity_thresh: - y_true = _thresh_labels_above(y_true, img, args.intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels) + y_true = _thresh_labels_above(y_true, img, args.intensity_thresh, args.intensity_thresh_percentile, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels) y_true = np.delete(y_true, bad_channels, axis=-1) if args.erosion_radius > 0: y_true = binary_erosion(y_true, structure).astype(y_true.dtype) + if args.intensity_thresh_k_means: + y_true = _intensity_thresh_k_means(y_true, img, args.intensity_thresh_k_means) means_true, medians_true, stds_true = _compute_masked_stats(rescaled_img, y_true, nb_good_channels) csv_row_true = _get_csv_row(sample_id, means_true, medians_true, stds_true, date) inference_writer_true.writerow(csv_row_true) if do_intensity_thresh: - y_pred = _thresh_labels_above(y_pred, img, args.intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels) + y_pred = _thresh_labels_above(y_pred, img, args.intensity_thresh, args.intensity_thresh_percentile, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels) y_pred = np.delete(y_pred, bad_channels, axis=-1) if args.erosion_radius > 0: y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype) + if args.intensity_thresh_k_means: + y_pred = _intensity_thresh_k_means(y_pred, img, args.intensity_thresh_k_means) means_pred, medians_pred, stds_pred = _compute_masked_stats(rescaled_img, y_pred, nb_good_channels) csv_row_pred = _get_csv_row(sample_id, means_pred, medians_pred, stds_pred, date) inference_writer_pred.writerow(csv_row_pred) diff --git a/ml4h/models/layer_wrappers.py b/ml4h/models/layer_wrappers.py index 7ba1187c0..083765a03 100755 --- a/ml4h/models/layer_wrappers.py +++ b/ml4h/models/layer_wrappers.py @@ -25,6 +25,7 @@ from tensorflow.keras.layers import MaxPooling2D, MaxPooling3D, Average, AveragePooling1D, AveragePooling2D, AveragePooling3D, Layer from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Concatenate, Add from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalAveragePooling3D +from tensorflow.keras.regularizers import L1, L2 Tensor = tf.Tensor @@ -52,9 +53,14 @@ # class name -> (dimension -> class) 'spatial_dropout': {2: SpatialDropout1D, 3: SpatialDropout2D, 4: SpatialDropout3D}, 'dropout': defaultdict(lambda _: Dropout), + 'l1': L1, + 'l2': L2, } DENSE_REGULARIZATION_CLASSES = { - 'dropout': Dropout, # TODO: add l1, l2 + 'dropout': Dropout, + 'dropout': Dropout, + 'l1': L1, + 'l2': L2, } diff --git a/ml4h/models/legacy_models.py b/ml4h/models/legacy_models.py index 0ba07c5e6..5eafa292e 100755 --- a/ml4h/models/legacy_models.py +++ b/ml4h/models/legacy_models.py @@ -30,6 +30,7 @@ from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Concatenate, Add from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalAveragePooling3D from tensorflow.keras.layers.experimental.preprocessing import RandomRotation, RandomZoom, RandomContrast +from tensorflow.keras.regularizers import L1, L2 import tensorflow_probability as tfp from ml4h.metrics import get_metric_dict @@ -79,9 +80,13 @@ class BottleneckType(Enum): # class name -> (dimension -> class) 'spatial_dropout': {2: SpatialDropout1D, 3: SpatialDropout2D, 4: SpatialDropout3D}, 'dropout': defaultdict(lambda _: Dropout), + 'l1': L1, + 'l2': L2, } DENSE_REGULARIZATION_CLASSES = { - 'dropout': Dropout, # TODO: add l1, l2 + 'dropout': Dropout, + 'l1': L1, + 'l2': L2, } diff --git a/ml4h/optimizers.py b/ml4h/optimizers.py index c9e79d5f0..5dc130c2b 100755 --- a/ml4h/optimizers.py +++ b/ml4h/optimizers.py @@ -6,6 +6,7 @@ from tensorflow.keras import backend as K from tensorflow.keras.models import Model from tensorflow_addons.optimizers import RectifiedAdam, TriangularCyclicalLearningRate, Triangular2CyclicalLearningRate +from tensorflow.keras.optimizers.schedules import CosineDecay from ml4h.plots import plot_find_learning_rate from ml4h.tensor_generators import TensorGenerator @@ -40,6 +41,8 @@ def _get_learning_rate_schedule(learning_rate: float, learning_rate_schedule: st initial_learning_rate=learning_rate / 5, maximal_learning_rate=learning_rate, step_size=steps_per_epoch * 5, ) + if learning_rate_schedule == 'cosine_decay': + return CosineDecay(initial_learning_rate=learning_rate, decay_steps=steps_per_epoch) else: raise ValueError(f'Learning rate schedule "{learning_rate_schedule}" unknown.') diff --git a/ml4h/recipes.py b/ml4h/recipes.py index a97afa0bc..fc567b846 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -220,37 +220,39 @@ def train_multimodal_multitask(args): if merger: merger.save(f'{args.output_folder}{args.id}/merger.h5') - test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps) - performance_metrics = _predict_and_evaluate( - model, test_data, test_labels, args.tensor_maps_in, args.tensor_maps_out, args.tensor_maps_protected, - args.batch_size, args.hidden_layer, os.path.join(args.output_folder, args.id + '/'), test_paths, - args.embed_visualization, args.alpha, args.dpi, args.plot_width, args.plot_height, - ) + performance_metrics = {} + if args.test_steps > 0: + test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps) + performance_metrics = _predict_and_evaluate( + model, test_data, test_labels, args.tensor_maps_in, args.tensor_maps_out, args.tensor_maps_protected, + args.batch_size, args.hidden_layer, os.path.join(args.output_folder, args.id + '/'), test_paths, + args.embed_visualization, args.alpha, args.dpi, args.plot_width, args.plot_height, + ) - predictions_list = model.predict(test_data) - samples = min(args.test_steps * args.batch_size, 12) - out_path = os.path.join(args.output_folder, args.id, 'reconstructions/') - if len(args.tensor_maps_out) == 1: - predictions_list = [predictions_list] - predictions_dict = {name: pred for name, pred in zip(model.output_names, predictions_list)} - logging.info(f'Predictions and shapes are: {[(p, predictions_dict[p].shape) for p in predictions_dict]}') - - for i, etm in enumerate(encoders): - embed = encoders[etm].predict(test_data[etm.input_name()]) - if etm.output_name() in predictions_dict: - plot_reconstruction(etm, test_data[etm.input_name()], predictions_dict[etm.output_name()], out_path, test_paths, samples) - for dtm in decoders: - reconstruction = decoders[dtm].predict(embed) - logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}') - my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/') - os.makedirs(os.path.dirname(my_out_path), exist_ok=True) - if dtm.axes() > 1: - plot_reconstruction(dtm, test_labels[dtm.output_name()], reconstruction, my_out_path, test_paths, samples) - else: - evaluate_predictions( - dtm, reconstruction, test_labels[dtm.output_name()], {}, dtm.name, my_out_path, - test_paths, dpi=args.dpi, width=args.plot_width, height=args.plot_height, - ) + predictions_list = model.predict(test_data) + samples = min(args.test_steps * args.batch_size, 12) + out_path = os.path.join(args.output_folder, args.id, 'reconstructions/') + if len(args.tensor_maps_out) == 1: + predictions_list = [predictions_list] + predictions_dict = {name: pred for name, pred in zip(model.output_names, predictions_list)} + logging.info(f'Predictions and shapes are: {[(p, predictions_dict[p].shape) for p in predictions_dict]}') + + for i, etm in enumerate(encoders): + embed = encoders[etm].predict(test_data[etm.input_name()]) + if etm.output_name() in predictions_dict: + plot_reconstruction(etm, test_data[etm.input_name()], predictions_dict[etm.output_name()], out_path, test_paths, samples) + for dtm in decoders: + reconstruction = decoders[dtm].predict(embed) + logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}') + my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/') + os.makedirs(os.path.dirname(my_out_path), exist_ok=True) + if dtm.axes() > 1: + plot_reconstruction(dtm, test_labels[dtm.output_name()], reconstruction, my_out_path, test_paths, samples) + else: + evaluate_predictions( + dtm, reconstruction, test_labels[dtm.output_name()], {}, dtm.name, my_out_path, + test_paths, dpi=args.dpi, width=args.plot_width, height=args.plot_height, + ) return performance_metrics diff --git a/ml4h/tensor_generators.py b/ml4h/tensor_generators.py index 7e698ddf8..f88617724 100755 --- a/ml4h/tensor_generators.py +++ b/ml4h/tensor_generators.py @@ -95,7 +95,7 @@ def __init__( :param paths: If weights is provided, paths should be a list of path lists the same length as weights """ self.augment = augment - self.paths = sum(paths) if isinstance(paths[0], list) else paths + self.paths = sum(paths) if (len(paths) > 0 and isinstance(paths[0], list)) else paths self.run_on_main_thread = num_workers == 0 self.q = None self.stats_q = None diff --git a/ml4h/tensorize/tensor_writer_ukbb.py b/ml4h/tensorize/tensor_writer_ukbb.py index 6d86dce6e..99cba9004 100755 --- a/ml4h/tensorize/tensor_writer_ukbb.py +++ b/ml4h/tensorize/tensor_writer_ukbb.py @@ -62,6 +62,7 @@ 'shmolli_192i_b7_sax_b7s_sax_b7s_sax_b7s_t1map', ] +MRI_PANCREAS_SERIES = ['shmolli_192i_pancreas_t1map'] MRI_CARDIAC_SERIES_SEGMENTED = [series+'_segmented' for series in MRI_CARDIAC_SERIES] MRI_BRAIN_SERIES = ['t1_p2_1mm_fov256_sag_ti_880', 't2_flair_sag_p2_1mm_fs_ellip_pf78'] MRI_NIFTI_FIELD_ID_TO_ROOT = {'20251': 'SWI', '20252': 'T1', '20253': 'T2_FLAIR'} @@ -71,7 +72,7 @@ DICOM_MRI_FIELDS = [ '20209', '20208', '20210', '20212', '20213', '20214', '20204', '20203', '20254', '20216', '20220', '20218', - '20227', '20225', '20217', '20158', + '20227', '20225', '20217', '20158', '20259', ] DXA_FIELD = '20158' @@ -136,7 +137,7 @@ def write_tensors( if _prune_sample(sample_id, min_sample_id, max_sample_id, mri_field_ids, xml_field_ids, zip_folder, xml_folder): continue try: - with h5py.File(tp, 'w') as hd5: + with h5py.File(tp, 'a') as hd5: _write_tensors_from_zipped_dicoms(write_pngs, tensors, mri_unzip, mri_field_ids, zip_folder, hd5, sample_id, stats) _write_tensors_from_zipped_niftis(zip_folder, mri_field_ids, hd5, sample_id, stats) _write_tensors_from_xml(xml_field_ids, xml_folder, hd5, sample_id, write_pngs, stats, continuous_stats) @@ -209,6 +210,9 @@ def write_tensors_from_dicom_pngs( except FileNotFoundError: logging.warning(f'Could not find file: {os.path.join(png_path, dicom_file + png_postfix)}') stats['File not found error'] += 1 + except ValueError: + logging.warning(f'Could not convert file: {os.path.join(png_path, dicom_file + png_postfix)}') + stats['Value error'] += 1 for k in stats: if sample_header in k and stats[k] == 50: continue @@ -433,7 +437,7 @@ def _write_tensors_from_dicoms( if series + '_12bit' in MRI_LIVER_SERIES_12BIT and d.LargestImagePixelValue > 2048: views[series + '_12bit'].append(d) stats[series + '_12bit'] += 1 - elif series in MRI_LIVER_SERIES + MRI_CARDIAC_SERIES + MRI_BRAIN_SERIES: + elif series in MRI_LIVER_SERIES + MRI_CARDIAC_SERIES + MRI_BRAIN_SERIES + MRI_PANCREAS_SERIES: views[series].append(d) stats[series] += 1 elif series == 'dxa_images': @@ -441,6 +445,8 @@ def _write_tensors_from_dicoms( dxa_number = dicom.split('.')[-4] name = f'dxa_{series_num}_{dxa_number}' create_tensor_in_hd5(hd5, f'ukb_dxa/', name, d.pixel_array, stats) + else: + stats[f'Could not process series {series}'] += 1 if series in MRI_LIVER_IDEAL_PROTOCOL: min_ideal_series = min(min_ideal_series, int(d.SeriesNumber)) @@ -455,6 +461,8 @@ def _write_tensors_from_dicoms( mri_group = 'ukb_liver_mri' elif v in MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED: mri_group = 'ukb_cardiac_mri' + elif v in MRI_PANCREAS_SERIES: + mri_group = 'ukb_pancreas_mri' else: mri_group = 'ukb_mri' @@ -564,14 +572,14 @@ def _tensorize_brain_mri(slices: List[pydicom.Dataset], series: str, mri_date: d def _save_pixel_dimensions_if_missing(slicer, series, hd5): - if MRI_PIXEL_WIDTH + '_' + series not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: + if MRI_PIXEL_WIDTH + '_' + series not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT + MRI_PANCREAS_SERIES: hd5.create_dataset(MRI_PIXEL_WIDTH + '_' + series, data=float(slicer.PixelSpacing[0])) - if MRI_PIXEL_HEIGHT + '_' + series not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: + if MRI_PIXEL_HEIGHT + '_' + series not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT + MRI_PANCREAS_SERIES: hd5.create_dataset(MRI_PIXEL_HEIGHT + '_' + series, data=float(slicer.PixelSpacing[1])) def _save_slice_thickness_if_missing(slicer, series, hd5): - if MRI_SLICE_THICKNESS + '_' + series not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: + if MRI_SLICE_THICKNESS + '_' + series not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT + MRI_PANCREAS_SERIES: hd5.create_dataset(MRI_SLICE_THICKNESS + '_' + series, data=float(slicer.SliceThickness)) @@ -581,9 +589,9 @@ def _save_series_orientation_and_position_if_missing(slicer, series, hd5, instan if instance: orientation_ds_name += HD5_GROUP_CHAR + instance position_ds_name += HD5_GROUP_CHAR + instance - if orientation_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: + if orientation_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT + MRI_PANCREAS_SERIES: hd5.create_dataset(orientation_ds_name, data=[float(x) for x in slicer.ImageOrientationPatient]) - if position_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT: + if position_ds_name not in hd5 and series in MRI_BRAIN_SERIES + MRI_CARDIAC_SERIES + MRI_CARDIAC_SERIES_SEGMENTED + MRI_LIVER_SERIES + MRI_LIVER_SERIES_12BIT + MRI_PANCREAS_SERIES: hd5.create_dataset(position_ds_name, data=[float(x) for x in slicer.ImagePositionPatient]) diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index 2af000328..828e5f059 100755 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -20,7 +20,8 @@ MRI_LAX_2CH_SEGMENTED_CHANNEL_MAP, MRI_SAX_SEGMENTED_CHANNEL_MAP, LAX_4CH_HEART_LABELS, LAX_4CH_MYOCARDIUM_LABELS, StorageType, LAX_3CH_HEART_LABELS, \ LAX_2CH_HEART_LABELS from ml4h.tensormap.general import get_tensor_at_first_date, normalized_first_date, pad_or_crop_array_to_shape, tensor_from_hd5 -from ml4h.defines import MRI_LAX_3CH_SEGMENTED_CHANNEL_MAP, MRI_LAX_4CH_SEGMENTED_CHANNEL_MAP, MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP, MRI_AO_SEGMENTED_CHANNEL_MAP, MRI_LIVER_SEGMENTED_CHANNEL_MAP, SAX_HEART_LABELS +from ml4h.defines import MRI_LAX_3CH_SEGMENTED_CHANNEL_MAP, MRI_LAX_4CH_SEGMENTED_CHANNEL_MAP, MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP, \ + MRI_AO_SEGMENTED_CHANNEL_MAP, MRI_LIVER_SEGMENTED_CHANNEL_MAP, SAX_HEART_LABELS, MRI_PANCREAS_SEGMENTED_CHANNEL_MAP def _slice_subset_tensor( @@ -2734,17 +2735,12 @@ def _mdrk_projection_both_views_pretrained(tm, hd5, dependents={}): tensor_from_file=None, ) -def _pad_crop_single_channel(tm, hd5, dependents={}): - if f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5: - key_prefix = f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' - elif f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5: - key_prefix = f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' - else: - raise ValueError(f'Could not find T1 Map image for tensormap: {tm.name}') - +def _pad_crop_single_channel(tm, hd5, dependents={}, key_prefix=None): + if key_prefix is None: + key_prefix = tm.hd5_key_guess() img = np.array( - tm.hd5_first_dataset_in_group(hd5, key_prefix), - dtype=np.float32, + tm.hd5_first_dataset_in_group(hd5, key_prefix), + dtype=np.float32, ) img = img[...,[1]] return pad_or_crop_array_to_shape( @@ -2752,15 +2748,32 @@ def _pad_crop_single_channel(tm, hd5, dependents={}): img, ) +def _pad_crop_single_channel_t1map_b2(tm, hd5, dependents={}): + if f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5: + key_prefix = f'/{tm.path_prefix}/shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' + elif f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' in hd5: + key_prefix = f'/{tm.path_prefix}/shmolli_192i_b2_sax_b2s_sax_b2s_sax_b2s_t1map/instance_2' + else: + raise ValueError(f'Could not find T1 Map image for tensormap: {tm.name}') + return _pad_crop_single_channel(tm, hd5, dependents, key_prefix) + t1map_b2 = TensorMap( 'shmolli_192i_sax_b2s_sax_b2s_sax_b2s_t1map', shape=(384, 384, 1), path_prefix='ukb_cardiac_mri', normalization=Standardize(mean=455.81, std=609.50), + tensor_from_file=_pad_crop_single_channel_t1map_b2, +) + +t1map_pancreas = TensorMap( + 'shmolli_192i_pancreas_t1map', + shape=(288, 384, 1), + path_prefix='ukb_pancreas_mri', + normalization=Standardize(mean=389.49, std=658.36), tensor_from_file=_pad_crop_single_channel, ) -def _segmented_t1map(tm, hd5, dependents={}): +def _segmented_t1map_b2(tm, hd5, dependents={}): if f'{tm.path_prefix}/{tm.name}_1' in hd5: categorical_index_slice = get_tensor_at_first_date(hd5, tm.path_prefix, f'{tm.name}_1') elif f'{tm.path_prefix}/{tm.name}_2' in hd5: @@ -2780,13 +2793,56 @@ def _segmented_t1map(tm, hd5, dependents={}): tensor[..., :] = pad_or_crop_array_to_shape(tensor[..., :].shape, categorical_one_hot) return tensor +def _segmented_t1map_pancreas(tm, hd5, dependents={}): + if f'{tm.path_prefix}/{tm.name}' in hd5: + categorical_index_slice = get_tensor_at_first_date(hd5, tm.path_prefix, f'{tm.name}') + else: + raise ValueError(f'Could not find T1 Map segmentation for tensormap: {tm.name}') + + # remove kidney label and merge body/background labels + orig_num_channels = len(tm.channel_map) + 3 + categorical_one_hot = to_categorical(categorical_index_slice, orig_num_channels) + categorical_one_hot[..., 6] += ( + categorical_one_hot[..., 11] + + categorical_one_hot[..., 12] + + categorical_one_hot[..., 13] + ) + categorical_one_hot = np.delete(categorical_one_hot, [11, 12, 13], axis=-1) + + # padding/cropping + tensor = np.zeros(tm.shape, dtype=np.float32) + tensor[..., :] = pad_or_crop_array_to_shape(tensor[..., :].shape, categorical_one_hot) + return tensor + t1map_b2_segmentation = TensorMap( 'b2s_t1map_kassir_annotated', interpretation=Interpretation.CATEGORICAL, shape=(384, 384, len(MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP)), channel_map=MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP, path_prefix='ukb_cardiac_mri', - tensor_from_file=_segmented_t1map, + tensor_from_file=_segmented_t1map_b2, loss=dice, metrics=['categorical_accuracy'] + per_class_dice(MRI_SAX_PAP_SEGMENTED_CHANNEL_MAP), ) + +t1map_pancreas_segmentation_cce = TensorMap( + 'shmolli_192i_pancreas_t1map_annotated_2', + interpretation=Interpretation.CATEGORICAL, + shape=(288, 384, len(MRI_PANCREAS_SEGMENTED_CHANNEL_MAP)), + channel_map=MRI_PANCREAS_SEGMENTED_CHANNEL_MAP, + path_prefix='ukb_pancreas_mri', + tensor_from_file=_segmented_t1map_pancreas, + loss='categorical_crossentropy', + metrics=['categorical_accuracy'] + per_class_dice(MRI_PANCREAS_SEGMENTED_CHANNEL_MAP), +) + +t1map_pancreas_segmentation_dice = TensorMap( + 'shmolli_192i_pancreas_t1map_annotated_2', + interpretation=Interpretation.CATEGORICAL, + shape=(288, 384, len(MRI_PANCREAS_SEGMENTED_CHANNEL_MAP)), + channel_map=MRI_PANCREAS_SEGMENTED_CHANNEL_MAP, + path_prefix='ukb_pancreas_mri', + tensor_from_file=_segmented_t1map_pancreas, + loss=dice, + metrics=['categorical_accuracy'] + per_class_dice(MRI_PANCREAS_SEGMENTED_CHANNEL_MAP), +)