Skip to content

Commit

Permalink
xdl recipe
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Oct 25, 2023
1 parent 302f3db commit 57f4f38
Showing 1 changed file with 125 additions and 1 deletion.
126 changes: 125 additions & 1 deletion ml4h/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import h5py
import glob
import logging
import datetime
import numpy as np
import pandas as pd

Expand All @@ -15,6 +16,8 @@
from timeit import default_timer as timer
from collections import Counter, defaultdict

from torch.utils.data import DataLoader

from ml4h.arguments import parse_args
from ml4h.models.inspect import saliency_map
from ml4h.optimizers import find_learning_rate
Expand All @@ -26,12 +29,12 @@
from ml4h.ml4ht_integration.tensor_generator import TensorMapDataLoader2
from ml4h.explorations import test_labels_to_label_map, infer_with_pixels
from ml4h.tensor_generators import BATCH_INPUT_INDEX, BATCH_OUTPUT_INDEX, BATCH_PATHS_INDEX
from ml4h.data_descriptions import dataframe_data_description_from_tensor_map, ECGDataDescription
from ml4h.explorations import mri_dates, ecg_dates, predictions_to_pngs, sample_from_language_model
from ml4h.plots import plot_reconstruction, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp
from ml4h.explorations import plot_while_learning, plot_histograms_of_tensors_in_pdf, explore, pca_on_tsv
from ml4h.models.legacy_models import get_model_inputs_outputs, make_shallow_model, make_hidden_layer_model
from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator
from ml4h.data_descriptions import dataframe_data_description_from_tensor_map, ECGDataDescription, DataFrameDataDescription
from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients
from ml4h.plots import subplot_rocs, subplot_comparison_rocs, subplot_scatters, subplot_comparison_scatters, plot_prediction_calibrations
from ml4h.models.legacy_models import make_character_model_plus, embed_model_predict, make_siamese_model, legacy_multimodal_multitask_model
Expand Down Expand Up @@ -77,6 +80,8 @@ def run(args):
infer_with_pixels(args)
elif 'infer_encoders' == args.mode:
infer_encoders_block_multimodal_multitask(args)
elif 'infer_xdl' == args.mode:
infer_xdl(args)
elif 'test_scalar' == args.mode:
test_multimodal_scalar_tasks(args)
elif 'compare_scalar' == args.mode:
Expand Down Expand Up @@ -289,6 +294,7 @@ def standardize_by_sample_ecg(ecg, _):
"""Z Score ECG"""
return (ecg - np.mean(ecg)) / (np.std(ecg) + 1e-6)


def train_xdl(args):
mrn_df = pd.read_csv(args.app_csv)
if 'start_fu_age' in mrn_df:
Expand Down Expand Up @@ -367,6 +373,124 @@ def option_picker(sample_id, data_descriptions):
merger.save(f'{args.output_folder}{args.id}/merger.h5')


def datetime_to_float(d):
return pd.to_datetime(d, utc=True).timestamp()


def float_to_datetime(fl):
return pd.to_datetime(fl, unit='s', utc=True)


def infer_from_dataloader(dataloader, model, tensor_maps_out, max_batches=125000):
dataloader_iterator = iter(dataloader)
space_dict = defaultdict(list)
for i in range(max_batches):
try:
data, target = next(dataloader_iterator)
for k in data:
data[k] = np.array(data[k])
prediction = model.predict(data, verbose=0)
if len(model.output_names) == 1:
prediction = [prediction]
predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)}
for b in range(prediction[0].shape[0]):
for otm in tensor_maps_out:
y = predictions_dict[otm.output_name()]
if otm.is_categorical():
space_dict[f'{otm.name}_prediction'].append(y[b, 1])
elif otm.is_continuous():
space_dict[f'{otm.name}_prediction'].append(y[b, 0])
elif otm.is_survival_curve():
intervals = otm.shape[-1] // 2
days_per_bin = 1 + (2*otm.days_window) // intervals
predicted_survivals = np.cumprod(y[:, :intervals], axis=1)
space_dict[f'{otm.name}_prediction'].append(str(1 - predicted_survivals[0, -1]))
sick = np.sum(target[otm.output_name()][:, intervals:], axis=-1)
follow_up = np.cumsum(target[otm.output_name()][:, :intervals], axis=-1)[:, -1] * days_per_bin
space_dict[f'{otm.name}_event'].append(str(sick[0]))
space_dict[f'{otm.name}_follow_up'].append(str(follow_up[0]))
for k in target:
if k in ['MRN', 'linker_id', 'is_c3po', 'output_age_in_days_continuous' ]:
space_dict[f'{k}'].append(target[k][b].numpy())
elif k in ['datetime']:
space_dict[f'{k}'].append(float_to_datetime(int(target[k][b].numpy())))
else:
space_dict[f'{k}'].append(target[k][b, -1].numpy())
if i % 100 == 0:
print(f'Inferred on {i} batches, {len(space_dict[k])} rows')
except StopIteration:
print('loaded all batches')
break
return pd.DataFrame.from_dict(space_dict)


def infer_xdl(args):
mrn_df = pd.read_csv(args.app_csv)
if 'start_fu_age' in mrn_df:
mrn_df['age_in_days'] = pd.to_timedelta(mrn_df.start_fu_age).dt.days
elif 'start_fu' in mrn_df:
mrn_df['age_in_days'] = pd.to_timedelta(mrn_df.start_fu).dt.days
mrn_df = mrn_df.rename(columns={'Dem.Gender.no_filter_x': 'sex'})
for ot in args.tensor_maps_out:
mrn_df = mrn_df[mrn_df[ot.name].notna()]
mrn_df = mrn_df.set_index('sample_id')

output_dds = [dataframe_data_description_from_tensor_map(tmap, mrn_df) for tmap in args.tensor_maps_out]

output_dds.append(DataFrameDataDescription(mrn_df, col="MRN"))
output_dds.append(DataFrameDataDescription(mrn_df, col="linker_id"))
output_dds.append(DataFrameDataDescription(mrn_df, col="is_c3po"))
output_dds.append(DataFrameDataDescription(mrn_df, col="datetime", process_col=datetime_to_float))

ecg_dd = ECGDataDescription(
args.tensors,
name=args.tensor_maps_in[0].input_name(),
ecg_len=5000, # all ECGs will be linearly interpolated to be this length
transforms=[standardize_by_sample_ecg], # these will be applied in order
)

def test_option_picker(sample_id, data_descriptions):
ecg_dts = ecg_dd.get_loading_options(sample_id)
htn_dt = output_dds[0].get_loading_options(sample_id)[0]['start_fu_datetime']
min_ecg_dt = htn_dt - pd.to_timedelta("1095d")
max_ecg_dt = htn_dt - pd.to_timedelta("1d")
dates = []
for dt in ecg_dts:
if min_ecg_dt <= dt[DATE_OPTION_KEY] <= max_ecg_dt:
dates.append(dt)
if len(dates) == 0:
raise ValueError('No matching dates')
chosen_dt = dates[-1]
return {dd: chosen_dt for dd in data_descriptions}

sg = DataDescriptionSampleGetter(
input_data_descriptions=[ecg_dd], # what we want a model to use as input data
output_data_descriptions=output_dds, # what we want a model to predict from the input data
option_picker=test_option_picker,
)

dataset = SampleGetterIterableDataset(sample_ids=list(mrn_df.index), sample_getter=sg, get_epoch=shuffle_get_epoch)
dataloader = DataLoader(dataset, num_workers=0, batch_size=args.batch_size)

model, _, _, _ = make_multimodal_multitask_model(**args.__dict__)
infer_df = infer_from_dataloader(dataloader, model, args.tensor_maps_out)
if 'mgh' in args.tensors:
hospital = 'mgh'
df_ecg_2_disease = infer_df.rename(columns={'MRN': 'MGH_MRN'})
df_ecg_2_disease.MGH_MRN = infer_df.MGH_MRN.astype(int)
else:
hospital = 'bwh'
df_ecg_2_disease = infer_df.rename(columns={'MRN': 'BWH_MRN'})
df_ecg_2_disease.BWH_MRN = infer_df.BWH_MRN.astype(int)

df_ecg_2_disease.linker_id = df_ecg_2_disease.linker_id.astype(int)
names = '_'.join([otm.name for otm in args.tensor_maps_out])
now_string = datetime.datetime.now().strftime('%Y_%m_%d')
out_file = f'./ecg_{names}_{hospital}_inference_v{now_string}.tsv'
df_ecg_2_disease.to_csv(out_file, sep='\t', index=False)
print(f'Saved inferences to: {out_file}')


def _make_tmap_nan_on_fail(tmap):
"""
Builds a copy TensorMap with a tensor_from_file that returns nans on errors instead of raising an error
Expand Down

0 comments on commit 57f4f38

Please sign in to comment.