diff --git a/docker/vm_boot_images/config/tensorflow-requirements.txt b/docker/vm_boot_images/config/tensorflow-requirements.txt index 77233b2e0..88ca7bf7a 100644 --- a/docker/vm_boot_images/config/tensorflow-requirements.txt +++ b/docker/vm_boot_images/config/tensorflow-requirements.txt @@ -44,4 +44,3 @@ umap-learn[plot] neurite voxelmorph pystrum - diff --git a/ml4h/TensorMap.py b/ml4h/TensorMap.py index bd4e505f2..cc179031b 100755 --- a/ml4h/TensorMap.py +++ b/ml4h/TensorMap.py @@ -204,8 +204,6 @@ def __init__( elif self.activation is None and (self.is_survival_curve() or self.is_time_to_event()): self.activation = 'sigmoid' - - if self.channel_map is None and self.is_time_to_event(): self.channel_map = DEFAULT_TIME_TO_EVENT_CHANNELS diff --git a/ml4h/data_descriptions.py b/ml4h/data_descriptions.py index 86dbbcc72..d9b41e484 100644 --- a/ml4h/data_descriptions.py +++ b/ml4h/data_descriptions.py @@ -5,6 +5,7 @@ from typing import Callable, List, Union, Optional, Tuple, Dict, Any import h5py +import datetime import numcodecs import numpy as np import pandas as pd @@ -331,10 +332,9 @@ def __init__( ): """ Gets data from a column of the provided DataFrame. - :param df: Must be multi-indexed with sample_id, loading_option - # TODO: allow multiple loading options :param col: The column name to get data from :param process_col: Function to turn the column value into Tensor + :param name: Optional overwrite of the df column name """ self.process_col = process_col or self._default_process_call self.df = df diff --git a/ml4h/models/legacy_models.py b/ml4h/models/legacy_models.py index 3a9aab20c..0ba07c5e6 100755 --- a/ml4h/models/legacy_models.py +++ b/ml4h/models/legacy_models.py @@ -350,7 +350,7 @@ def make_hidden_layer_model(parent_model: Model, tensor_maps_in: List[TensorMap] dummy_input = {tm.input_name(): np.zeros((1,) + parent_model.get_layer(tm.input_name()).input_shape[0][1:]) for tm in tensor_maps_in} intermediate_layer_model = Model(inputs=parent_inputs, outputs=target_layer.output) # If we do not predict here then the graph is disconnected, I do not know why?! - intermediate_layer_model.predict(dummy_input) + intermediate_layer_model.predict(dummy_input, verbose=0) return intermediate_layer_model @@ -1344,7 +1344,7 @@ def make_paired_autoencoder_model( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def embed_model_predict(model, tensor_maps_in, embed_layer, test_data, batch_size): embed_model = make_hidden_layer_model(model, tensor_maps_in, embed_layer) - return embed_model.predict(test_data, batch_size=batch_size) + return embed_model.predict(test_data, batch_size=batch_size, verbose=0) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/ml4h/plots.py b/ml4h/plots.py index 5f756eb3b..af5add9bc 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -770,7 +770,7 @@ def plot_scatter( ax1.set_xlabel("Predictions") ax1.set_ylabel("Actual") - ax1.set_title(title) + ax1.set_title(f'{title} N = {len(prediction)}' ) ax1.legend(loc="lower right") sns.distplot(prediction, label="Predicted", color="r", ax=ax2) @@ -2253,7 +2253,7 @@ def plot_ecg_rest( tensor_paths: List[str], rows: List[int], out_folder: str, - is_blind: bool, + is_blind: bool ) -> None: """Plots resting ECGs including annotations and LVH criteria diff --git a/ml4h/recipes.py b/ml4h/recipes.py index 33ca54b7f..58eab9fac 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -7,7 +7,9 @@ import h5py import glob import logging +import datetime import numpy as np +import pandas as pd from functools import reduce from google.cloud import storage @@ -22,19 +24,32 @@ from ml4h.tensormap.tensor_map_maker import write_tensor_maps from ml4h.tensorize.tensor_writer_mgb import write_tensors_mgb from ml4h.models.model_factory import make_multimodal_multitask_model +from ml4h.ml4ht_integration.tensor_generator import TensorMapDataLoader2 +from ml4h.explorations import test_labels_to_label_map, infer_with_pixels from ml4h.tensor_generators import BATCH_INPUT_INDEX, BATCH_OUTPUT_INDEX, BATCH_PATHS_INDEX + from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, latent_space_dataframe, infer_stats_from_segmented_regions from ml4h.explorations import mri_dates, ecg_dates, predictions_to_pngs, sample_from_language_model +from ml4h.plots import plot_reconstruction, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp from ml4h.explorations import plot_while_learning, plot_histograms_of_tensors_in_pdf, explore, pca_on_tsv +from ml4h.models.legacy_models import get_model_inputs_outputs, make_shallow_model, make_hidden_layer_model from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator +from ml4h.data_descriptions import dataframe_data_description_from_tensor_map, ECGDataDescription, DataFrameDataDescription from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients + from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival, plot_dice from ml4h.plots import plot_reconstruction, plot_hit_to_miss_transforms, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp + from ml4h.plots import subplot_rocs, subplot_comparison_rocs, subplot_scatters, subplot_comparison_scatters, plot_prediction_calibrations from ml4h.models.legacy_models import make_character_model_plus, embed_model_predict, make_siamese_model, legacy_multimodal_multitask_model -from ml4h.models.legacy_models import get_model_inputs_outputs, make_shallow_model, make_hidden_layer_model, make_paired_autoencoder_model +from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival from ml4h.tensorize.tensor_writer_ukbb import write_tensors, append_fields_from_csv, append_gene_csv, write_tensors_from_dicom_pngs, write_tensors_from_ecg_pngs +from ml4ht.data.util.date_selector import DATE_OPTION_KEY +from ml4ht.data.sample_getter import DataDescriptionSampleGetter +from ml4ht.data.data_loader import SampleGetterIterableDataset, shuffle_get_epoch + +from torch.utils.data import DataLoader def run(args): start_time = timer() # Keep track of elapsed execution time @@ -56,6 +71,8 @@ def run(args): train_multimodal_multitask(args) elif 'train_legacy' == args.mode: train_legacy(args) + elif 'train_xdl' == args.mode: + train_xdl(args) elif 'test' == args.mode: test_multimodal_multitask(args) elif 'compare' == args.mode: @@ -70,6 +87,8 @@ def run(args): infer_stats_from_segmented_regions(args) elif 'infer_encoders' == args.mode: infer_encoders_block_multimodal_multitask(args) + elif 'infer_xdl' == args.mode: + infer_xdl(args) elif 'test_scalar' == args.mode: test_multimodal_scalar_tasks(args) elif 'compare_scalar' == args.mode: @@ -123,7 +142,6 @@ def run(args): except Exception as e: logging.exception(e) - if args.gcs_cloud_bucket is not None: save_to_google_cloud(args) @@ -278,6 +296,208 @@ def compare_multimodal_scalar_task_models(args): _calculate_and_plot_prediction_stats(args, predictions, labels, paths) +def standardize_by_sample_ecg(ecg, _): + """Z Score ECG""" + return (ecg - np.mean(ecg)) / (np.std(ecg) + 1e-6) + + +def train_xdl(args): + mrn_df = pd.read_csv(args.app_csv) + if 'start_fu_age' in mrn_df: + mrn_df['age_in_days'] = pd.to_timedelta(mrn_df.start_fu_age).dt.days + elif 'start_fu' in mrn_df: + mrn_df['age_in_days'] = pd.to_timedelta(mrn_df.start_fu).dt.days + for ot in args.tensor_maps_out: + mrn_df = mrn_df[mrn_df[ot.name].notna()] + mrn_df = mrn_df.set_index('sample_id') + + output_dds = [dataframe_data_description_from_tensor_map(tmap, mrn_df) for tmap in args.tensor_maps_out] + + ecg_dd = ECGDataDescription( + args.tensors, + name=args.tensor_maps_in[0].input_name(), + ecg_len=5000, # all ECGs will be linearly interpolated to be this length + transforms=[standardize_by_sample_ecg], # these will be applied in order + # data will be automatically localized from s3 + ) + + def option_picker(sample_id, data_descriptions): + ecg_dts = ecg_dd.get_loading_options(sample_id) + htn_dt = output_dds[0].get_loading_options(sample_id)[0]['start_fu_datetime'] + min_ecg_dt = htn_dt - pd.to_timedelta("1095d") + max_ecg_dt = htn_dt + pd.to_timedelta("1095d") + dates = [] + for dt in ecg_dts: + if min_ecg_dt <= dt[DATE_OPTION_KEY] <= max_ecg_dt: + dates.append(dt) + if len(dates) == 0: + raise ValueError('No matching dates') + chosen_dt = np.random.choice(dates) + chosen_dt['day_delta'] = (htn_dt - chosen_dt[DATE_OPTION_KEY]).days + return {dd: chosen_dt for dd in data_descriptions} + + sg = DataDescriptionSampleGetter( + input_data_descriptions=[ecg_dd], # what we want a model to use as input data + output_data_descriptions=output_dds, # what we want a model to predict from the input data + option_picker=option_picker, + ) + model, encoders, decoders, merger = make_multimodal_multitask_model(**args.__dict__) + + train_ids = list(mrn_df[mrn_df.split == 'train'].index) + valid_ids = list(mrn_df[mrn_df.split == 'valid'].index) + test_ids = list(mrn_df[mrn_df.split == 'test'].index) + + train_dataset = SampleGetterIterableDataset(sample_ids=list(train_ids), sample_getter=sg, + get_epoch=shuffle_get_epoch) + valid_dataset = SampleGetterIterableDataset(sample_ids=list(valid_ids), sample_getter=sg, + get_epoch=shuffle_get_epoch) + + num_train_workers = int(args.training_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0) + num_valid_workers = int(args.validation_steps / (args.training_steps + args.validation_steps) * args.num_workers) or (1 if args.num_workers else 0) + + generate_train = TensorMapDataLoader2( + batch_size=args.batch_size, input_maps=args.tensor_maps_in, output_maps=args.tensor_maps_out, + dataset=train_dataset, + num_workers=num_train_workers, + ) + generate_valid = TensorMapDataLoader2( + batch_size=args.batch_size, input_maps=args.tensor_maps_in, output_maps=args.tensor_maps_out, + dataset=valid_dataset, + num_workers=num_valid_workers, + ) + + model = train_model_from_generators( + model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size, args.epochs, + args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels, args.tensor_maps_out, + save_last_model=args.save_last_model, + ) + for tm in encoders: + encoders[tm].save(f'{args.output_folder}{args.id}/encoder_{tm.name}.h5') + for tm in decoders: + decoders[tm].save(f'{args.output_folder}{args.id}/decoder_{tm.name}.h5') + if merger: + merger.save(f'{args.output_folder}{args.id}/merger.h5') + + +def datetime_to_float(d): + return pd.to_datetime(d, utc=True).timestamp() + + +def float_to_datetime(fl): + return pd.to_datetime(fl, unit='s', utc=True) + + +def infer_from_dataloader(dataloader, model, tensor_maps_out, max_batches=125000): + dataloader_iterator = iter(dataloader) + space_dict = defaultdict(list) + for i in range(max_batches): + try: + data, target = next(dataloader_iterator) + for k in data: + data[k] = np.array(data[k]) + prediction = model.predict(data, verbose=0) + if len(model.output_names) == 1: + prediction = [prediction] + predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)} + for b in range(prediction[0].shape[0]): + for otm in tensor_maps_out: + y = predictions_dict[otm.output_name()] + if otm.is_categorical(): + space_dict[f'{otm.name}_prediction'].append(y[b, 1]) + elif otm.is_continuous(): + space_dict[f'{otm.name}_prediction'].append(y[b, 0]) + elif otm.is_survival_curve(): + intervals = otm.shape[-1] // 2 + days_per_bin = 1 + (2*otm.days_window) // intervals + predicted_survivals = np.cumprod(y[:, :intervals], axis=1) + space_dict[f'{otm.name}_prediction'].append(str(1 - predicted_survivals[0, -1])) + sick = np.sum(target[otm.output_name()][:, intervals:], axis=-1) + follow_up = np.cumsum(target[otm.output_name()][:, :intervals], axis=-1)[:, -1] * days_per_bin + space_dict[f'{otm.name}_event'].append(str(sick[0])) + space_dict[f'{otm.name}_follow_up'].append(str(follow_up[0])) + for k in target: + if k in ['MRN', 'linker_id', 'is_c3po', 'output_age_in_days_continuous' ]: + space_dict[f'{k}'].append(target[k][b].numpy()) + elif k in ['datetime']: + space_dict[f'{k}'].append(float_to_datetime(int(target[k][b].numpy()))) + else: + space_dict[f'{k}'].append(target[k][b, -1].numpy()) + if i % 100 == 0: + logging.info(f'Inferred on {i} batches, {len(space_dict[k])} rows') + except StopIteration: + logging.info(f'Inferred on all {i} batches.') + break + return pd.DataFrame.from_dict(space_dict) + + +def infer_xdl(args): + mrn_df = pd.read_csv(args.app_csv) + if 'start_fu_age' in mrn_df: + mrn_df['age_in_days'] = pd.to_timedelta(mrn_df.start_fu_age).dt.days + elif 'start_fu' in mrn_df: + mrn_df['age_in_days'] = pd.to_timedelta(mrn_df.start_fu).dt.days + mrn_df = mrn_df.rename(columns={'Dem.Gender.no_filter_x': 'sex', 'Dem.Gender.no_filter': 'sex'}) + mrn_df['is_c3po'] = mrn_df.cohort == 'c3po' + for ot in args.tensor_maps_out: + mrn_df = mrn_df[mrn_df[ot.name].notna()] + mrn_df = mrn_df.set_index('sample_id') + + output_dds = [dataframe_data_description_from_tensor_map(tmap, mrn_df) for tmap in args.tensor_maps_out] + + output_dds.append(DataFrameDataDescription(mrn_df, col="MRN")) + output_dds.append(DataFrameDataDescription(mrn_df, col="linker_id")) + output_dds.append(DataFrameDataDescription(mrn_df, col="is_c3po")) + output_dds.append(DataFrameDataDescription(mrn_df, col="datetime", process_col=datetime_to_float)) + + ecg_dd = ECGDataDescription( + args.tensors, + name=args.tensor_maps_in[0].input_name(), + ecg_len=5000, # all ECGs will be linearly interpolated to be this length + transforms=[standardize_by_sample_ecg], # these will be applied in order + ) + + def test_option_picker(sample_id, data_descriptions): + ecg_dts = ecg_dd.get_loading_options(sample_id) + htn_dt = output_dds[0].get_loading_options(sample_id)[0]['start_fu_datetime'] + min_ecg_dt = htn_dt - pd.to_timedelta("1095d") + max_ecg_dt = htn_dt - pd.to_timedelta("1d") + dates = [] + for dt in ecg_dts: + if min_ecg_dt <= dt[DATE_OPTION_KEY] <= max_ecg_dt: + dates.append(dt) + if len(dates) == 0: + raise ValueError('No matching dates') + chosen_dt = dates[-1] + return {dd: chosen_dt for dd in data_descriptions} + + sg = DataDescriptionSampleGetter( + input_data_descriptions=[ecg_dd], # what we want a model to use as input data + output_data_descriptions=output_dds, # what we want a model to predict from the input data + option_picker=test_option_picker, + ) + + dataset = SampleGetterIterableDataset(sample_ids=list(mrn_df.index), sample_getter=sg, get_epoch=shuffle_get_epoch) + dataloader = DataLoader(dataset, num_workers=0, batch_size=args.batch_size) + + model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) + infer_df = infer_from_dataloader(dataloader, model, args.tensor_maps_out) + if 'mgh' in args.tensors: + hospital = 'mgh' + infer_df = infer_df.rename(columns={'MRN': 'MGH_MRN'}) + infer_df.MGH_MRN = infer_df.MGH_MRN.astype(int) + else: + hospital = 'bwh' + infer_df = infer_df.rename(columns={'MRN': 'BWH_MRN'}) + infer_df.BWH_MRN = infer_df.BWH_MRN.astype(int) + + infer_df.linker_id = infer_df.linker_id.astype(int) + names = '_'.join([otm.name for otm in args.tensor_maps_out]) + now_string = datetime.datetime.now().strftime('%Y_%m_%d') + out_file = f'{args.output_folder}/{args.id}/infer_{names}_{hospital}_v{now_string}.tsv' + infer_df.to_csv(out_file, sep='\t', index=False) + logging.info(f'Infer dataframe head: {infer_df.head()} \n\n Saved inferences to: {out_file}') + + def _make_tmap_nan_on_fail(tmap): """ Builds a copy TensorMap with a tensor_from_file that returns nans on errors instead of raising an error @@ -384,8 +604,13 @@ def infer_multimodal_multitask(args): hd5_path = os.path.join(args.output_folder, args.id, 'inferred_hd5s', f'{sample_id}{TENSOR_EXT}') os.makedirs(os.path.dirname(hd5_path), exist_ok=True) with h5py.File(hd5_path, 'a') as hd5: - hd5.create_dataset(f'{otm.name}_truth', data=otm.rescale(output_data[otm.output_name()][0]), compression='gzip') - hd5.create_dataset(f'{otm.name}_prediction', data=otm.rescale(y[0]), compression='gzip') + hd5.create_dataset(f'{otm.name}_truth', data=otm.rescale(output_data[otm.output_name()][0]), + compression='gzip') + if otm.path_prefix == 'ukb_ecg_rest': + for lead in otm.channel_map: + hd5.create_dataset(f'/ukb_ecg_rest/{lead}/instance_0', + data=otm.rescale(y[0, otm.channel_map[lead]]), + compression='gzip') inference_writer.writerow(csv_row) tensor_paths_inferred.add(tensor_paths[0]) stats['count'] += 1 @@ -619,7 +844,7 @@ def _predict_and_evaluate( scatters = [] rocs = [] - y_predictions = model.predict(test_data, batch_size=batch_size) + y_predictions = model.predict(test_data, batch_size=batch_size, verbose=0) protected_data = {tm: test_labels[tm.output_name()] for tm in tensor_maps_protected} for y, tm in zip(y_predictions, tensor_maps_out): if tm.output_name() not in layer_names: @@ -668,7 +893,7 @@ def _predict_scalars_and_evaluate_from_generator( for i in range(steps): batch = next(generate_test) input_data, output_data, tensor_paths = batch[BATCH_INPUT_INDEX], batch[BATCH_OUTPUT_INDEX], batch[BATCH_PATHS_INDEX] - y_predictions = model.predict(input_data) + y_predictions = model.predict(input_data, verbose=0) test_paths.extend(tensor_paths) if hidden_layer in layer_names: x_embed = embed_model_predict(model, tensor_maps_in, hidden_layer, input_data, 2) @@ -685,6 +910,9 @@ def _predict_scalars_and_evaluate_from_generator( if tm_output_name in scalar_predictions: scalar_predictions[tm_output_name].extend(np.copy(y)) + if i % 100 == 0: + logging.info(f'Processed {i} batches, {len(test_paths)} tensors.') + performance_metrics = {} scatters = [] rocs = [] diff --git a/ml4h/tensormap/mgb/xdl.py b/ml4h/tensormap/mgb/xdl.py new file mode 100644 index 000000000..fe54db2cd --- /dev/null +++ b/ml4h/tensormap/mgb/xdl.py @@ -0,0 +1,45 @@ +from typing import Dict + +import h5py +import numpy as np +from ml4h.TensorMap import TensorMap, Interpretation + +ecg_5000_std = TensorMap('ecg_5000_std', Interpretation.CONTINUOUS, shape=(5000, 12)) + +hypertension_icd_only = TensorMap(name='hypertension_icd_only', interpretation=Interpretation.CATEGORICAL, + channel_map={'no_hypertension_icd_only': 0, 'hypertension_icd_only': 1}) +hypertension_icd_bp = TensorMap(name='hypertension_icd_bp', interpretation=Interpretation.CATEGORICAL, + channel_map={'no_hypertension_icd_bp': 0, 'hypertension_icd_bp': 1}) +hypertension_icd_bp_med = TensorMap(name='hypertension_icd_bp_med', interpretation=Interpretation.CATEGORICAL, + channel_map={'no_hypertension_icd_bp_med': 0, 'hypertension_icd_bp_med': 1}) +hypertension_med = TensorMap(name='start_fu_hypertension_med', interpretation=Interpretation.CATEGORICAL, + channel_map={'no_hypertension_medication': 0, 'hypertension_medication': 1}) + +lvef = TensorMap(name='LVEF', interpretation=Interpretation.CONTINUOUS, channel_map={'LVEF': 0}) + +age = TensorMap(name='age_in_days', interpretation=Interpretation.CONTINUOUS, channel_map={'age_in_days': 0}) +sex = TensorMap(name='sex', interpretation=Interpretation.CATEGORICAL, channel_map={'Female': 0, 'Male': 1}) + +cad = TensorMap(name='cad', interpretation=Interpretation.CATEGORICAL, channel_map={'no_cad': 0, 'cad': 1}) +dm = TensorMap(name='dm', interpretation=Interpretation.CATEGORICAL, channel_map={'no_dm': 0, 'dm': 1}) +hypercholesterolemia = TensorMap(name='hypercholesterolemia', interpretation=Interpretation.CATEGORICAL, + channel_map={'no_hypercholesterolemia': 0, 'hypercholesterolemia': 1}) + + +def ecg_median_biosppy(tm: TensorMap, hd5: h5py.File, dependents: Dict = {}) -> np.ndarray: + tensor = np.zeros(tm.shape, dtype=np.float32) + for lead in tm.channel_map: + tensor[:, tm.channel_map[lead]] = hd5[f'{tm.path_prefix}{lead}'] + tensor = np.nan_to_num(tensor) + return tensor + +ecg_channel_map = { + 'I': 0, 'II': 1, 'III': 2, 'aVR': 3, 'aVL': 4, 'aVF': 5, + 'V1': 6, 'V2': 7, 'V3': 8, 'V4': 9, 'V5': 10, 'V6': 11, +} + +ecg_biosppy_median_60bpm = TensorMap( + 'median', Interpretation.CONTINUOUS, path_prefix='median_60bpm_', shape=(600, 12), + tensor_from_file=ecg_median_biosppy, + channel_map=ecg_channel_map, +) \ No newline at end of file diff --git a/ml4h/tensormap/ukb/demographics.py b/ml4h/tensormap/ukb/demographics.py index 82bf167c2..479a59942 100755 --- a/ml4h/tensormap/ukb/demographics.py +++ b/ml4h/tensormap/ukb/demographics.py @@ -341,10 +341,11 @@ def alcohol_from_file(tm, hd5, dependents={}): path_prefix='categorical', annotation_units=2, channel_map={'Sex_Female_0_0': 0, 'Sex_Male_0_0': 1}, loss='categorical_crossentropy', ) -# sex = TensorMap( -# 'Sex_Male_0_0', Interpretation.CATEGORICAL, storage_type=StorageType.CATEGORICAL_FLAG, path_prefix='categorical', annotation_units=2, -# channel_map={'Sex_Female_0_0': 0, 'Sex_Male_0_0': 1}, loss='categorical_crossentropy', -# ) +sex_dummy1 = TensorMap( + 'sex', Interpretation.CATEGORICAL, storage_type=StorageType.CATEGORICAL_FLAG, + path_prefix='categorical', annotation_units=2, + channel_map={'Sex_Female_0_0': 0, 'Sex_Male_0_0': 1}, loss='categorical_crossentropy', +) af_dummy2 = TensorMap( 'af_in_read', Interpretation.CATEGORICAL, path_prefix='categorical', storage_type=StorageType.CATEGORICAL_FLAG, channel_map={'no_atrial_fibrillation': 0, 'atrial_fibrillation': 1}, @@ -354,7 +355,11 @@ def alcohol_from_file(tm, hd5, dependents={}): path_prefix='categorical', annotation_units=2, channel_map={'Sex_Female_0_0': 0, 'Sex_Male_0_0': 1}, loss='categorical_crossentropy', ) - +sex_dummy3 = TensorMap( + 'sex_from_wide', Interpretation.CATEGORICAL, storage_type=StorageType.CATEGORICAL_FLAG, + path_prefix='categorical', annotation_units=2, + channel_map={'female': 0, 'male': 1}, loss='categorical_crossentropy', +) brain_volume = TensorMap( '25010_Volume-of-brain-greywhite-matter_2_0', Interpretation.CONTINUOUS, path_prefix='continuous', normalization={'mean': 1165940.0, 'std': 111511.0}, channel_map={'25010_Volume-of-brain-greywhite-matter_2_0': 0}, loss='logcosh', loss_weight=0.1, diff --git a/ml4h/tensormap/ukb/dxa.py b/ml4h/tensormap/ukb/dxa.py index 7a6cf73fb..462723c46 100755 --- a/ml4h/tensormap/ukb/dxa.py +++ b/ml4h/tensormap/ukb/dxa.py @@ -98,7 +98,7 @@ def dxa_background_erase(tm, hd5, dependents={}): ) dxa_11 = TensorMap( 'dxa_1_11', - shape=(896, 352, 1), + shape=(896, 384, 1), path_prefix='ukb_dxa', tensor_from_file=dxa_background_erase, normalization=ZeroMeanStd1(), diff --git a/ml4h/tensormap/ukb/ecg.py b/ml4h/tensormap/ukb/ecg.py index 5305de968..c71735f65 100755 --- a/ml4h/tensormap/ukb/ecg.py +++ b/ml4h/tensormap/ukb/ecg.py @@ -574,6 +574,7 @@ def ecg_rest_section_to_segment(tm, hd5, dependents={}): metrics=['mse', 'mae'], channel_map=ECG_REST_MEDIAN_LEADS, normalization=Standardize(mean=0, std=10), ) + ecg_rest_median_576 = TensorMap( 'ecg_rest_median_576', Interpretation.CONTINUOUS, path_prefix='ukb_ecg_rest', shape=(576, 12), loss='logcosh', activation='linear', tensor_from_file=_make_ecg_rest(), channel_map=ECG_REST_MEDIAN_LEADS, @@ -595,8 +596,10 @@ def ecg_rest_section_to_segment(tm, hd5, dependents={}): ) ecg_rest_median_raw_10_prediction = TensorMap( - 'ecg_rest_median_raw_10', Interpretation.CONTINUOUS, shape=(600, 12), loss='logcosh', activation='linear', normalization=ZeroMeanStd1(), - tensor_from_file=named_tensor_from_hd5('ecg_rest_median_raw_10_prediction'), metrics=['mse', 'mae'], channel_map=ECG_REST_MEDIAN_LEADS, + 'ecg_rest_median_raw_10', Interpretation.CONTINUOUS, shape=(600, 12), loss='logcosh', activation='linear', + normalization=ZeroMeanStd1(), + tensor_from_file=named_tensor_from_hd5('ecg_rest_median_raw_10_prediction'), metrics=['mse', 'mae'], + channel_map=ECG_REST_MEDIAN_LEADS, ) diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index 788b768ba..2af000328 100755 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -253,6 +253,25 @@ def _slice_tensor_from_file(tm, hd5, dependents={}): return _slice_tensor_from_file +def _random_slice_tensor(tensor_key, max_random=50): + def _slice_tensor_from_file(tm, hd5, dependents={}): + slice_index = np.random.randint(max_random) + if tm.shape[-1] == 1: + t = pad_or_crop_array_to_shape( + tm.shape[:-1], + np.array(hd5[tensor_key][..., slice_index], dtype=np.float32), + ) + tensor = np.expand_dims(t, axis=-1) + else: + tensor = pad_or_crop_array_to_shape( + tm.shape, + np.array(hd5[tensor_key][..., slice_index], dtype=np.float32), + ) + return tensor + + return _slice_tensor_from_file + + def _segmented_dicom_slices(dicom_key_prefix, path_prefix='ukb_cardiac_mri', step=1, total_slices=50): def _segmented_dicom_tensor_from_file(tm, hd5, dependents={}): tensor = np.zeros(tm.shape, dtype=np.float32) @@ -389,6 +408,12 @@ def _mri_slice_blackout_tensor_from_file(tm, hd5, dependents={}): tensor_from_file=_slice_tensor('ukb_cardiac_mri/cine_segmented_lax_4ch/2/instance_0', 0), ) +lax_4ch_random_slice_3d = TensorMap( + 'lax_4ch_random_slice_3d', Interpretation.CONTINUOUS, shape=(160, 224, 1), + normalization=ZeroMeanStd1(), + tensor_from_file=_random_slice_tensor('ukb_cardiac_mri/cine_segmented_lax_4ch/2/instance_0'), +) + lax_4ch_diastole_slice0_224_3d_augmented = TensorMap( 'lax_4ch_diastole_slice0_224_3d_augmented', Interpretation.CONTINUOUS, shape=(160, 224, 1), normalization=ZeroMeanStd1(), augmentations=[_gaussian_noise, _make_rotate(-15, 15)], @@ -415,6 +440,36 @@ def _mri_slice_blackout_tensor_from_file(tm, hd5, dependents={}): 'ukb_cardiac_mri/cine_segmented_lax_2ch/2/instance_0', 0, ), ) +lax_2ch_diastole_slice_224_160_3d = TensorMap( + 'lax_2ch_diastole_slice_224_160_3d', + Interpretation.CONTINUOUS, + shape=(224, 160, 1), + loss='logcosh', + normalization=ZeroMeanStd1(), + tensor_from_file=_slice_tensor( + 'ukb_cardiac_mri/cine_segmented_lax_2ch/2/instance_0', 0, + ), +) +lax_2ch_diastole_slice_224_192_3d = TensorMap( + 'lax_2ch_diastole_slice_224_192_3d', + Interpretation.CONTINUOUS, + shape=(224, 192, 1), + loss='logcosh', + normalization=ZeroMeanStd1(), + tensor_from_file=_slice_tensor( + 'ukb_cardiac_mri/cine_segmented_lax_2ch/2/instance_0', 0, + ), +) +lax_2ch_diastole_slice_224_224_3d = TensorMap( + 'lax_2ch_diastole_slice_224_224_3d', + Interpretation.CONTINUOUS, + shape=(224, 224, 1), + loss='logcosh', + normalization=ZeroMeanStd1(), + tensor_from_file=_slice_tensor( + 'ukb_cardiac_mri/cine_segmented_lax_2ch/2/instance_0', 0, + ), +) lax_3ch_diastole_slice0_3d = TensorMap( 'lax_3ch_diastole_slice0_3d', Interpretation.CONTINUOUS, @@ -425,6 +480,16 @@ def _mri_slice_blackout_tensor_from_file(tm, hd5, dependents={}): 'ukb_cardiac_mri/cine_segmented_lax_3ch/2/instance_0', 0, ), ) +lax_3ch_diastole_slice_224_160_3d = TensorMap( + 'lax_3ch_diastole_slice_224_160_3d', + Interpretation.CONTINUOUS, + shape=(224, 160, 1), + loss='logcosh', + normalization=ZeroMeanStd1(), + tensor_from_file=_slice_tensor( + 'ukb_cardiac_mri/cine_segmented_lax_3ch/2/instance_0', 0, + ), +) cine_segmented_ao_dist_slice0_3d = TensorMap( 'cine_segmented_ao_dist_slice0_3d', Interpretation.CONTINUOUS, diff --git a/model_zoo/ECG2AF/README.md b/model_zoo/ECG2AF/README.md index 02f3fae67..99bd724e4 100644 --- a/model_zoo/ECG2AF/README.md +++ b/model_zoo/ECG2AF/README.md @@ -2,7 +2,44 @@ This directory contains models and code for predicting incident atrial fibrillation from 12 lead resting ECGs, as described in our [Circulation paper](https://www.ahajournals.org/doi/full/10.1161/CIRCULATIONAHA.121.057480). -To perform inference with this model run: +The raw model files are stored using `git lfs` so you must have it installed and localize the full ~200MB files with: +```bash +git lfs pull --include model_zoo/ECG2AF/ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5 +git lfs pull --include model_zoo/ECG2AF/strip_* +``` + +To load the 12 lead model in a jupyter notebook (running with the ml4h docker or python library installed) see the [example](./ecg2af_infer.ipynb) or run: + +```python +import numpy as np +from tensorflow.keras.models import load_model +from ml4h.models.model_factory import get_custom_objects +from ml4h.tensormap.ukb.survival import mgb_afib_wrt_instance2 +from ml4h.tensormap.ukb.demographics import age_2_wide, af_dummy, sex_dummy3 + +output_tensormaps = {tm.output_name(): tm for tm in [mgb_afib_wrt_instance2, age_2_wide, af_dummy, sex_dummy3]} +custom_dict = get_custom_objects(list(output_tensormaps.values())) +model = load_model('./ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5', custom_objects=custom_dict) +ecg = np.random.random((1, 5000, 12)) +prediction = model(ecg) +``` +If above does not work you may need to use an absolute path in `load_model`. + +The model has 4 output heads: the survival curve prediction for incident atrial fibrillation, the classification of atrial fibrillation at the time of ECG, sex, and age regression. Those outputs can be accessed with: +```python +for name, pred in zip(model.output_names, prediction): + otm = output_tensormaps[name] + if otm.is_survival_curve(): + intervals = otm.shape[-1] // 2 + days_per_bin = 1 + otm.days_window // intervals + predicted_survivals = np.cumprod(pred[:, :intervals], axis=1) + print(f'AF Risk {otm} prediction is: {str(1 - predicted_survivals[0, -1])}') + else: + print(f'{otm} prediction is {pred}') +``` + + +To perform command line inference with this model run: ```bash python /path/to/ml4h/ml4h/recipes.py \ --mode infer \ @@ -20,18 +57,23 @@ The model weights for the main model which performs incident atrial fibrillation age regression, sex classification and prevalent (at the time of ECG) atrial fibrillation: [ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5](./ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5) -We also include single lead models for lead strip I:[strip_I_survival_curve_af_v2021_06_15.h5](./strip_I_survival_curve_af_v2021_06_15.h5) +We also include single lead models for lead/strip I: [strip_I_survival_curve_af_v2021_06_15.h5](./strip_I_survival_curve_af_v2021_06_15.h5) and II: [strip_II_survival_curve_af_v2021_06_15.h5](./strip_II_survival_curve_af_v2021_06_15.h5) -### Study Design -Flow chart of study design -![Flow chart of study design](./study_design.jpg) +### Study design +