From 57f4f38c3533f5286c7f03b4b82842f2be14d1e7 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 25 Oct 2023 08:25:31 -0400 Subject: [PATCH] xdl recipe --- ml4h/recipes.py | 126 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 125 insertions(+), 1 deletion(-) diff --git a/ml4h/recipes.py b/ml4h/recipes.py index 6bd44e2af..1f37634c9 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -7,6 +7,7 @@ import h5py import glob import logging +import datetime import numpy as np import pandas as pd @@ -15,6 +16,8 @@ from timeit import default_timer as timer from collections import Counter, defaultdict +from torch.utils.data import DataLoader + from ml4h.arguments import parse_args from ml4h.models.inspect import saliency_map from ml4h.optimizers import find_learning_rate @@ -26,12 +29,12 @@ from ml4h.ml4ht_integration.tensor_generator import TensorMapDataLoader2 from ml4h.explorations import test_labels_to_label_map, infer_with_pixels from ml4h.tensor_generators import BATCH_INPUT_INDEX, BATCH_OUTPUT_INDEX, BATCH_PATHS_INDEX -from ml4h.data_descriptions import dataframe_data_description_from_tensor_map, ECGDataDescription from ml4h.explorations import mri_dates, ecg_dates, predictions_to_pngs, sample_from_language_model from ml4h.plots import plot_reconstruction, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp from ml4h.explorations import plot_while_learning, plot_histograms_of_tensors_in_pdf, explore, pca_on_tsv from ml4h.models.legacy_models import get_model_inputs_outputs, make_shallow_model, make_hidden_layer_model from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator +from ml4h.data_descriptions import dataframe_data_description_from_tensor_map, ECGDataDescription, DataFrameDataDescription from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients from ml4h.plots import subplot_rocs, subplot_comparison_rocs, subplot_scatters, subplot_comparison_scatters, plot_prediction_calibrations from ml4h.models.legacy_models import make_character_model_plus, embed_model_predict, make_siamese_model, legacy_multimodal_multitask_model @@ -77,6 +80,8 @@ def run(args): infer_with_pixels(args) elif 'infer_encoders' == args.mode: infer_encoders_block_multimodal_multitask(args) + elif 'infer_xdl' == args.mode: + infer_xdl(args) elif 'test_scalar' == args.mode: test_multimodal_scalar_tasks(args) elif 'compare_scalar' == args.mode: @@ -289,6 +294,7 @@ def standardize_by_sample_ecg(ecg, _): """Z Score ECG""" return (ecg - np.mean(ecg)) / (np.std(ecg) + 1e-6) + def train_xdl(args): mrn_df = pd.read_csv(args.app_csv) if 'start_fu_age' in mrn_df: @@ -367,6 +373,124 @@ def option_picker(sample_id, data_descriptions): merger.save(f'{args.output_folder}{args.id}/merger.h5') +def datetime_to_float(d): + return pd.to_datetime(d, utc=True).timestamp() + + +def float_to_datetime(fl): + return pd.to_datetime(fl, unit='s', utc=True) + + +def infer_from_dataloader(dataloader, model, tensor_maps_out, max_batches=125000): + dataloader_iterator = iter(dataloader) + space_dict = defaultdict(list) + for i in range(max_batches): + try: + data, target = next(dataloader_iterator) + for k in data: + data[k] = np.array(data[k]) + prediction = model.predict(data, verbose=0) + if len(model.output_names) == 1: + prediction = [prediction] + predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)} + for b in range(prediction[0].shape[0]): + for otm in tensor_maps_out: + y = predictions_dict[otm.output_name()] + if otm.is_categorical(): + space_dict[f'{otm.name}_prediction'].append(y[b, 1]) + elif otm.is_continuous(): + space_dict[f'{otm.name}_prediction'].append(y[b, 0]) + elif otm.is_survival_curve(): + intervals = otm.shape[-1] // 2 + days_per_bin = 1 + (2*otm.days_window) // intervals + predicted_survivals = np.cumprod(y[:, :intervals], axis=1) + space_dict[f'{otm.name}_prediction'].append(str(1 - predicted_survivals[0, -1])) + sick = np.sum(target[otm.output_name()][:, intervals:], axis=-1) + follow_up = np.cumsum(target[otm.output_name()][:, :intervals], axis=-1)[:, -1] * days_per_bin + space_dict[f'{otm.name}_event'].append(str(sick[0])) + space_dict[f'{otm.name}_follow_up'].append(str(follow_up[0])) + for k in target: + if k in ['MRN', 'linker_id', 'is_c3po', 'output_age_in_days_continuous' ]: + space_dict[f'{k}'].append(target[k][b].numpy()) + elif k in ['datetime']: + space_dict[f'{k}'].append(float_to_datetime(int(target[k][b].numpy()))) + else: + space_dict[f'{k}'].append(target[k][b, -1].numpy()) + if i % 100 == 0: + print(f'Inferred on {i} batches, {len(space_dict[k])} rows') + except StopIteration: + print('loaded all batches') + break + return pd.DataFrame.from_dict(space_dict) + + +def infer_xdl(args): + mrn_df = pd.read_csv(args.app_csv) + if 'start_fu_age' in mrn_df: + mrn_df['age_in_days'] = pd.to_timedelta(mrn_df.start_fu_age).dt.days + elif 'start_fu' in mrn_df: + mrn_df['age_in_days'] = pd.to_timedelta(mrn_df.start_fu).dt.days + mrn_df = mrn_df.rename(columns={'Dem.Gender.no_filter_x': 'sex'}) + for ot in args.tensor_maps_out: + mrn_df = mrn_df[mrn_df[ot.name].notna()] + mrn_df = mrn_df.set_index('sample_id') + + output_dds = [dataframe_data_description_from_tensor_map(tmap, mrn_df) for tmap in args.tensor_maps_out] + + output_dds.append(DataFrameDataDescription(mrn_df, col="MRN")) + output_dds.append(DataFrameDataDescription(mrn_df, col="linker_id")) + output_dds.append(DataFrameDataDescription(mrn_df, col="is_c3po")) + output_dds.append(DataFrameDataDescription(mrn_df, col="datetime", process_col=datetime_to_float)) + + ecg_dd = ECGDataDescription( + args.tensors, + name=args.tensor_maps_in[0].input_name(), + ecg_len=5000, # all ECGs will be linearly interpolated to be this length + transforms=[standardize_by_sample_ecg], # these will be applied in order + ) + + def test_option_picker(sample_id, data_descriptions): + ecg_dts = ecg_dd.get_loading_options(sample_id) + htn_dt = output_dds[0].get_loading_options(sample_id)[0]['start_fu_datetime'] + min_ecg_dt = htn_dt - pd.to_timedelta("1095d") + max_ecg_dt = htn_dt - pd.to_timedelta("1d") + dates = [] + for dt in ecg_dts: + if min_ecg_dt <= dt[DATE_OPTION_KEY] <= max_ecg_dt: + dates.append(dt) + if len(dates) == 0: + raise ValueError('No matching dates') + chosen_dt = dates[-1] + return {dd: chosen_dt for dd in data_descriptions} + + sg = DataDescriptionSampleGetter( + input_data_descriptions=[ecg_dd], # what we want a model to use as input data + output_data_descriptions=output_dds, # what we want a model to predict from the input data + option_picker=test_option_picker, + ) + + dataset = SampleGetterIterableDataset(sample_ids=list(mrn_df.index), sample_getter=sg, get_epoch=shuffle_get_epoch) + dataloader = DataLoader(dataset, num_workers=0, batch_size=args.batch_size) + + model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) + infer_df = infer_from_dataloader(dataloader, model, args.tensor_maps_out) + if 'mgh' in args.tensors: + hospital = 'mgh' + df_ecg_2_disease = infer_df.rename(columns={'MRN': 'MGH_MRN'}) + df_ecg_2_disease.MGH_MRN = infer_df.MGH_MRN.astype(int) + else: + hospital = 'bwh' + df_ecg_2_disease = infer_df.rename(columns={'MRN': 'BWH_MRN'}) + df_ecg_2_disease.BWH_MRN = infer_df.BWH_MRN.astype(int) + + df_ecg_2_disease.linker_id = df_ecg_2_disease.linker_id.astype(int) + names = '_'.join([otm.name for otm in args.tensor_maps_out]) + now_string = datetime.datetime.now().strftime('%Y_%m_%d') + out_file = f'./ecg_{names}_{hospital}_inference_v{now_string}.tsv' + df_ecg_2_disease.to_csv(out_file, sep='\t', index=False) + print(f'Saved inferences to: {out_file}') + + def _make_tmap_nan_on_fail(tmap): """ Builds a copy TensorMap with a tensor_from_file that returns nans on errors instead of raising an error