Skip to content

Commit

Permalink
ECG2AF open-source weights and notebook (#543)
Browse files Browse the repository at this point in the history
* Update ECG2AF README.md
  • Loading branch information
lucidtronix authored Jan 4, 2024
1 parent f65ae7a commit 982b56b
Show file tree
Hide file tree
Showing 15 changed files with 537 additions and 42 deletions.
1 change: 0 additions & 1 deletion docker/vm_boot_images/config/tensorflow-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,3 @@ umap-learn[plot]
neurite
voxelmorph
pystrum

2 changes: 0 additions & 2 deletions ml4h/TensorMap.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ def __init__(
elif self.activation is None and (self.is_survival_curve() or self.is_time_to_event()):
self.activation = 'sigmoid'



if self.channel_map is None and self.is_time_to_event():
self.channel_map = DEFAULT_TIME_TO_EVENT_CHANNELS

Expand Down
4 changes: 2 additions & 2 deletions ml4h/data_descriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, List, Union, Optional, Tuple, Dict, Any

import h5py
import datetime
import numcodecs
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -331,10 +332,9 @@ def __init__(
):
"""
Gets data from a column of the provided DataFrame.
:param df: Must be multi-indexed with sample_id, loading_option
# TODO: allow multiple loading options
:param col: The column name to get data from
:param process_col: Function to turn the column value into Tensor
:param name: Optional overwrite of the df column name
"""
self.process_col = process_col or self._default_process_call
self.df = df
Expand Down
4 changes: 2 additions & 2 deletions ml4h/models/legacy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def make_hidden_layer_model(parent_model: Model, tensor_maps_in: List[TensorMap]
dummy_input = {tm.input_name(): np.zeros((1,) + parent_model.get_layer(tm.input_name()).input_shape[0][1:]) for tm in tensor_maps_in}
intermediate_layer_model = Model(inputs=parent_inputs, outputs=target_layer.output)
# If we do not predict here then the graph is disconnected, I do not know why?!
intermediate_layer_model.predict(dummy_input)
intermediate_layer_model.predict(dummy_input, verbose=0)
return intermediate_layer_model


Expand Down Expand Up @@ -1344,7 +1344,7 @@ def make_paired_autoencoder_model(
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def embed_model_predict(model, tensor_maps_in, embed_layer, test_data, batch_size):
embed_model = make_hidden_layer_model(model, tensor_maps_in, embed_layer)
return embed_model.predict(test_data, batch_size=batch_size)
return embed_model.predict(test_data, batch_size=batch_size, verbose=0)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions ml4h/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def plot_scatter(

ax1.set_xlabel("Predictions")
ax1.set_ylabel("Actual")
ax1.set_title(title)
ax1.set_title(f'{title} N = {len(prediction)}' )
ax1.legend(loc="lower right")

sns.distplot(prediction, label="Predicted", color="r", ax=ax2)
Expand Down Expand Up @@ -2253,7 +2253,7 @@ def plot_ecg_rest(
tensor_paths: List[str],
rows: List[int],
out_folder: str,
is_blind: bool,
is_blind: bool
) -> None:
"""Plots resting ECGs including annotations and LVH criteria
Expand Down
240 changes: 234 additions & 6 deletions ml4h/recipes.py

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions ml4h/tensormap/mgb/xdl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Dict

import h5py
import numpy as np
from ml4h.TensorMap import TensorMap, Interpretation

ecg_5000_std = TensorMap('ecg_5000_std', Interpretation.CONTINUOUS, shape=(5000, 12))

hypertension_icd_only = TensorMap(name='hypertension_icd_only', interpretation=Interpretation.CATEGORICAL,
channel_map={'no_hypertension_icd_only': 0, 'hypertension_icd_only': 1})
hypertension_icd_bp = TensorMap(name='hypertension_icd_bp', interpretation=Interpretation.CATEGORICAL,
channel_map={'no_hypertension_icd_bp': 0, 'hypertension_icd_bp': 1})
hypertension_icd_bp_med = TensorMap(name='hypertension_icd_bp_med', interpretation=Interpretation.CATEGORICAL,
channel_map={'no_hypertension_icd_bp_med': 0, 'hypertension_icd_bp_med': 1})
hypertension_med = TensorMap(name='start_fu_hypertension_med', interpretation=Interpretation.CATEGORICAL,
channel_map={'no_hypertension_medication': 0, 'hypertension_medication': 1})

lvef = TensorMap(name='LVEF', interpretation=Interpretation.CONTINUOUS, channel_map={'LVEF': 0})

age = TensorMap(name='age_in_days', interpretation=Interpretation.CONTINUOUS, channel_map={'age_in_days': 0})
sex = TensorMap(name='sex', interpretation=Interpretation.CATEGORICAL, channel_map={'Female': 0, 'Male': 1})

cad = TensorMap(name='cad', interpretation=Interpretation.CATEGORICAL, channel_map={'no_cad': 0, 'cad': 1})
dm = TensorMap(name='dm', interpretation=Interpretation.CATEGORICAL, channel_map={'no_dm': 0, 'dm': 1})
hypercholesterolemia = TensorMap(name='hypercholesterolemia', interpretation=Interpretation.CATEGORICAL,
channel_map={'no_hypercholesterolemia': 0, 'hypercholesterolemia': 1})


def ecg_median_biosppy(tm: TensorMap, hd5: h5py.File, dependents: Dict = {}) -> np.ndarray:
tensor = np.zeros(tm.shape, dtype=np.float32)
for lead in tm.channel_map:
tensor[:, tm.channel_map[lead]] = hd5[f'{tm.path_prefix}{lead}']
tensor = np.nan_to_num(tensor)
return tensor

ecg_channel_map = {
'I': 0, 'II': 1, 'III': 2, 'aVR': 3, 'aVL': 4, 'aVF': 5,
'V1': 6, 'V2': 7, 'V3': 8, 'V4': 9, 'V5': 10, 'V6': 11,
}

ecg_biosppy_median_60bpm = TensorMap(
'median', Interpretation.CONTINUOUS, path_prefix='median_60bpm_', shape=(600, 12),
tensor_from_file=ecg_median_biosppy,
channel_map=ecg_channel_map,
)
15 changes: 10 additions & 5 deletions ml4h/tensormap/ukb/demographics.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,11 @@ def alcohol_from_file(tm, hd5, dependents={}):
path_prefix='categorical', annotation_units=2,
channel_map={'Sex_Female_0_0': 0, 'Sex_Male_0_0': 1}, loss='categorical_crossentropy',
)
# sex = TensorMap(
# 'Sex_Male_0_0', Interpretation.CATEGORICAL, storage_type=StorageType.CATEGORICAL_FLAG, path_prefix='categorical', annotation_units=2,
# channel_map={'Sex_Female_0_0': 0, 'Sex_Male_0_0': 1}, loss='categorical_crossentropy',
# )
sex_dummy1 = TensorMap(
'sex', Interpretation.CATEGORICAL, storage_type=StorageType.CATEGORICAL_FLAG,
path_prefix='categorical', annotation_units=2,
channel_map={'Sex_Female_0_0': 0, 'Sex_Male_0_0': 1}, loss='categorical_crossentropy',
)
af_dummy2 = TensorMap(
'af_in_read', Interpretation.CATEGORICAL, path_prefix='categorical', storage_type=StorageType.CATEGORICAL_FLAG,
channel_map={'no_atrial_fibrillation': 0, 'atrial_fibrillation': 1},
Expand All @@ -354,7 +355,11 @@ def alcohol_from_file(tm, hd5, dependents={}):
path_prefix='categorical', annotation_units=2,
channel_map={'Sex_Female_0_0': 0, 'Sex_Male_0_0': 1}, loss='categorical_crossentropy',
)

sex_dummy3 = TensorMap(
'sex_from_wide', Interpretation.CATEGORICAL, storage_type=StorageType.CATEGORICAL_FLAG,
path_prefix='categorical', annotation_units=2,
channel_map={'female': 0, 'male': 1}, loss='categorical_crossentropy',
)
brain_volume = TensorMap(
'25010_Volume-of-brain-greywhite-matter_2_0', Interpretation.CONTINUOUS, path_prefix='continuous', normalization={'mean': 1165940.0, 'std': 111511.0},
channel_map={'25010_Volume-of-brain-greywhite-matter_2_0': 0}, loss='logcosh', loss_weight=0.1,
Expand Down
2 changes: 1 addition & 1 deletion ml4h/tensormap/ukb/dxa.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def dxa_background_erase(tm, hd5, dependents={}):
)
dxa_11 = TensorMap(
'dxa_1_11',
shape=(896, 352, 1),
shape=(896, 384, 1),
path_prefix='ukb_dxa',
tensor_from_file=dxa_background_erase,
normalization=ZeroMeanStd1(),
Expand Down
7 changes: 5 additions & 2 deletions ml4h/tensormap/ukb/ecg.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ def ecg_rest_section_to_segment(tm, hd5, dependents={}):
metrics=['mse', 'mae'], channel_map=ECG_REST_MEDIAN_LEADS, normalization=Standardize(mean=0, std=10),
)


ecg_rest_median_576 = TensorMap(
'ecg_rest_median_576', Interpretation.CONTINUOUS, path_prefix='ukb_ecg_rest', shape=(576, 12), loss='logcosh',
activation='linear', tensor_from_file=_make_ecg_rest(), channel_map=ECG_REST_MEDIAN_LEADS,
Expand All @@ -595,8 +596,10 @@ def ecg_rest_section_to_segment(tm, hd5, dependents={}):
)

ecg_rest_median_raw_10_prediction = TensorMap(
'ecg_rest_median_raw_10', Interpretation.CONTINUOUS, shape=(600, 12), loss='logcosh', activation='linear', normalization=ZeroMeanStd1(),
tensor_from_file=named_tensor_from_hd5('ecg_rest_median_raw_10_prediction'), metrics=['mse', 'mae'], channel_map=ECG_REST_MEDIAN_LEADS,
'ecg_rest_median_raw_10', Interpretation.CONTINUOUS, shape=(600, 12), loss='logcosh', activation='linear',
normalization=ZeroMeanStd1(),
tensor_from_file=named_tensor_from_hd5('ecg_rest_median_raw_10_prediction'), metrics=['mse', 'mae'],
channel_map=ECG_REST_MEDIAN_LEADS,
)


Expand Down
65 changes: 65 additions & 0 deletions ml4h/tensormap/ukb/mri.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,25 @@ def _slice_tensor_from_file(tm, hd5, dependents={}):
return _slice_tensor_from_file


def _random_slice_tensor(tensor_key, max_random=50):
def _slice_tensor_from_file(tm, hd5, dependents={}):
slice_index = np.random.randint(max_random)
if tm.shape[-1] == 1:
t = pad_or_crop_array_to_shape(
tm.shape[:-1],
np.array(hd5[tensor_key][..., slice_index], dtype=np.float32),
)
tensor = np.expand_dims(t, axis=-1)
else:
tensor = pad_or_crop_array_to_shape(
tm.shape,
np.array(hd5[tensor_key][..., slice_index], dtype=np.float32),
)
return tensor

return _slice_tensor_from_file


def _segmented_dicom_slices(dicom_key_prefix, path_prefix='ukb_cardiac_mri', step=1, total_slices=50):
def _segmented_dicom_tensor_from_file(tm, hd5, dependents={}):
tensor = np.zeros(tm.shape, dtype=np.float32)
Expand Down Expand Up @@ -389,6 +408,12 @@ def _mri_slice_blackout_tensor_from_file(tm, hd5, dependents={}):
tensor_from_file=_slice_tensor('ukb_cardiac_mri/cine_segmented_lax_4ch/2/instance_0', 0),
)

lax_4ch_random_slice_3d = TensorMap(
'lax_4ch_random_slice_3d', Interpretation.CONTINUOUS, shape=(160, 224, 1),
normalization=ZeroMeanStd1(),
tensor_from_file=_random_slice_tensor('ukb_cardiac_mri/cine_segmented_lax_4ch/2/instance_0'),
)

lax_4ch_diastole_slice0_224_3d_augmented = TensorMap(
'lax_4ch_diastole_slice0_224_3d_augmented', Interpretation.CONTINUOUS, shape=(160, 224, 1),
normalization=ZeroMeanStd1(), augmentations=[_gaussian_noise, _make_rotate(-15, 15)],
Expand All @@ -415,6 +440,36 @@ def _mri_slice_blackout_tensor_from_file(tm, hd5, dependents={}):
'ukb_cardiac_mri/cine_segmented_lax_2ch/2/instance_0', 0,
),
)
lax_2ch_diastole_slice_224_160_3d = TensorMap(
'lax_2ch_diastole_slice_224_160_3d',
Interpretation.CONTINUOUS,
shape=(224, 160, 1),
loss='logcosh',
normalization=ZeroMeanStd1(),
tensor_from_file=_slice_tensor(
'ukb_cardiac_mri/cine_segmented_lax_2ch/2/instance_0', 0,
),
)
lax_2ch_diastole_slice_224_192_3d = TensorMap(
'lax_2ch_diastole_slice_224_192_3d',
Interpretation.CONTINUOUS,
shape=(224, 192, 1),
loss='logcosh',
normalization=ZeroMeanStd1(),
tensor_from_file=_slice_tensor(
'ukb_cardiac_mri/cine_segmented_lax_2ch/2/instance_0', 0,
),
)
lax_2ch_diastole_slice_224_224_3d = TensorMap(
'lax_2ch_diastole_slice_224_224_3d',
Interpretation.CONTINUOUS,
shape=(224, 224, 1),
loss='logcosh',
normalization=ZeroMeanStd1(),
tensor_from_file=_slice_tensor(
'ukb_cardiac_mri/cine_segmented_lax_2ch/2/instance_0', 0,
),
)
lax_3ch_diastole_slice0_3d = TensorMap(
'lax_3ch_diastole_slice0_3d',
Interpretation.CONTINUOUS,
Expand All @@ -425,6 +480,16 @@ def _mri_slice_blackout_tensor_from_file(tm, hd5, dependents={}):
'ukb_cardiac_mri/cine_segmented_lax_3ch/2/instance_0', 0,
),
)
lax_3ch_diastole_slice_224_160_3d = TensorMap(
'lax_3ch_diastole_slice_224_160_3d',
Interpretation.CONTINUOUS,
shape=(224, 160, 1),
loss='logcosh',
normalization=ZeroMeanStd1(),
tensor_from_file=_slice_tensor(
'ukb_cardiac_mri/cine_segmented_lax_3ch/2/instance_0', 0,
),
)
cine_segmented_ao_dist_slice0_3d = TensorMap(
'cine_segmented_ao_dist_slice0_3d',
Interpretation.CONTINUOUS,
Expand Down
56 changes: 49 additions & 7 deletions model_zoo/ECG2AF/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,44 @@
This directory contains models and code for predicting incident atrial fibrillation from 12 lead resting ECGs, as described in our
[Circulation paper](https://www.ahajournals.org/doi/full/10.1161/CIRCULATIONAHA.121.057480).

To perform inference with this model run:
The raw model files are stored using `git lfs` so you must have it installed and localize the full ~200MB files with:
```bash
git lfs pull --include model_zoo/ECG2AF/ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5
git lfs pull --include model_zoo/ECG2AF/strip_*
```

To load the 12 lead model in a jupyter notebook (running with the ml4h docker or python library installed) see the [example](./ecg2af_infer.ipynb) or run:

```python
import numpy as np
from tensorflow.keras.models import load_model
from ml4h.models.model_factory import get_custom_objects
from ml4h.tensormap.ukb.survival import mgb_afib_wrt_instance2
from ml4h.tensormap.ukb.demographics import age_2_wide, af_dummy, sex_dummy3

output_tensormaps = {tm.output_name(): tm for tm in [mgb_afib_wrt_instance2, age_2_wide, af_dummy, sex_dummy3]}
custom_dict = get_custom_objects(list(output_tensormaps.values()))
model = load_model('./ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5', custom_objects=custom_dict)
ecg = np.random.random((1, 5000, 12))
prediction = model(ecg)
```
If above does not work you may need to use an absolute path in `load_model`.

The model has 4 output heads: the survival curve prediction for incident atrial fibrillation, the classification of atrial fibrillation at the time of ECG, sex, and age regression. Those outputs can be accessed with:
```python
for name, pred in zip(model.output_names, prediction):
otm = output_tensormaps[name]
if otm.is_survival_curve():
intervals = otm.shape[-1] // 2
days_per_bin = 1 + otm.days_window // intervals
predicted_survivals = np.cumprod(pred[:, :intervals], axis=1)
print(f'AF Risk {otm} prediction is: {str(1 - predicted_survivals[0, -1])}')
else:
print(f'{otm} prediction is {pred}')
```


To perform command line inference with this model run:
```bash
python /path/to/ml4h/ml4h/recipes.py \
--mode infer \
Expand All @@ -20,18 +57,23 @@ The model weights for the main model which performs incident atrial fibrillation
age regression, sex classification and prevalent (at the time of ECG) atrial fibrillation:
[ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5](./ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5)

We also include single lead models for lead strip I:[strip_I_survival_curve_af_v2021_06_15.h5](./strip_I_survival_curve_af_v2021_06_15.h5)
We also include single lead models for lead/strip I: [strip_I_survival_curve_af_v2021_06_15.h5](./strip_I_survival_curve_af_v2021_06_15.h5)
and II: [strip_II_survival_curve_af_v2021_06_15.h5](./strip_II_survival_curve_af_v2021_06_15.h5)

### Study Design
Flow chart of study design
![Flow chart of study design](./study_design.jpg)
### Study design
<div style="padding: 10px; background-color: white; display: inline-block;">
<img src="./study_design.jpg" alt="Flow chart of study design" />
</div>

### Performance
Risk stratification model comparison
![Risk stratification model comparison](./km.jpg)
<div style="padding: 10px; background-color: white; display: inline-block;">
<img src="./km.jpg" alt="Risk stratification model comparison" />
</div>

### Salience
Salience and Median waveforms from predicted risk extremes.
![Salience and Median waveforms](./salience.jpg)
### Architecture
1D Convolutional neural net architecture
![Convolutional neural net architecture](./architecture.png)
![Convolutional neural net architecture](./architecture.png)
Loading

0 comments on commit 982b56b

Please sign in to comment.