Skip to content

Commit

Permalink
DROID-MVP and DROID-RV (#578)
Browse files Browse the repository at this point in the history
DROID-MVP and DROID-RV in the model zoo, updates and cleanup to DROID code base
  • Loading branch information
alalusim authored Dec 20, 2024
1 parent 3ced397 commit ae7848b
Show file tree
Hide file tree
Showing 28 changed files with 1,237 additions and 58 deletions.
15 changes: 15 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@ notebooks/**.ipynb filter=nbstripout
*.genes filter=lfs diff=lfs merge=lfs -text
model_zoo/ECG_PheWAS/*.h5 filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID/encoders/**/* filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-MVP/movinet_a2_base/ckpt-1.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-MVP/movinet_a2_base/ckpt-1.index filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-MVP/movinet_a2_base/checkpoint filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-MVP/droid_mvp_checkpoint/checkpoint filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-MVP/droid_mvp_checkpoint/chkp.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-MVP/droid_mvp_checkpoint/chkp.index filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-RV/droid_rv_checkpoint/checkpoint filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-RV/droid_rv_checkpoint/chkp.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-RV/droid_rv_checkpoint/chkp.index filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-RV/droid_rvef_checkpoint/checkpoint filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-RV/droid_rvef_checkpoint/chkp.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-RV/droid_rvef_checkpoint/chkp.index filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-RV/movinet_a2_base/checkpoint filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-RV/movinet_a2_base/ckpt-1.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
model_zoo/DROID-RV/movinet_a2_base/ckpt-1.index filter=lfs diff=lfs merge=lfs -text
model_zoo/ECG2AF/ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5 filter=lfs diff=lfs merge=lfs -text
model_zoo/ECG2AF/strip_I_survival_curve_af_v2021_06_15.h5 filter=lfs diff=lfs merge=lfs -text
model_zoo/ECG2AF/strip_II_survival_curve_af_v2021_06_15.h5 filter=lfs diff=lfs merge=lfs -text
Expand Down
3 changes: 3 additions & 0 deletions model_zoo/DROID-MVP/droid_mvp_checkpoint/checkpoint
Git LFS file not shown
Git LFS file not shown
3 changes: 3 additions & 0 deletions model_zoo/DROID-MVP/droid_mvp_checkpoint/chkp.index
Git LFS file not shown
40 changes: 40 additions & 0 deletions model_zoo/DROID-MVP/droid_mvp_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env python
# coding: utf-8

import numpy as np
import tensorflow as tf
from droid_mvp_model_description import create_movinet_classifier, create_regressor_classifier
import logging
tf.get_logger().setLevel(logging.ERROR)

pretrained_chkp_dir = "droid_mvp_checkpoint/chkp"
movinet_chkp_dir = 'movinet_a2_base/'

movinet_model, backbone = create_movinet_classifier(
n_input_frames=16,
batch_size=16,
num_classes=600,
checkpoint_dir=movinet_chkp_dir,
)

backbone_output = backbone.layers[-1].output[0]
flatten = tf.keras.layers.Flatten()(backbone_output)
encoder = tf.keras.Model(inputs=[backbone.input], outputs=[flatten])

func_args = {
'input_shape': (16, 224, 224, 3),
'n_output_features': 0, # number of regression features
'categories': {"mvp_status_binary":2, "mvp_status_detailed":6},
'category_order': ["mvp_status_binary", "mvp_status_detailed"],
}

model_plus_head = create_regressor_classifier(encoder, **func_args)

model_plus_head.load_weights(pretrained_chkp_dir)

random_video = np.random.random((1, 16, 224, 224, 3))

print(f"""
DROID-MVP Predictions:
{model_plus_head.predict(random_video)}
""")
66 changes: 66 additions & 0 deletions model_zoo/DROID-MVP/droid_mvp_model_description.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.movinet.modeling import movinet, movinet_model

hidden_units = 256
dropout_rate = 0.5

def create_movinet_classifier(
n_input_frames,
batch_size,
checkpoint_dir,
num_classes,
freeze_backbone=False
):
backbone = movinet.Movinet(model_id='a2')
model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600)
model.build([1, 1, 1, 1, 3])
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore(checkpoint_path)
status.assert_existing_objects_matched()

model = movinet_model.MovinetClassifier(
backbone=backbone,
num_classes=num_classes
)
model.build([batch_size, n_input_frames, 224, 224, 3])

if freeze_backbone:
for layer in model.layers[:-1]:
layer.trainable = False
model.layers[-1].trainable = True

return model, backbone

def create_regressor_classifier(encoder, trainable=True, input_shape=(224, 224, 3), n_output_features=0, categories={},
category_order=None, add_dense={'regressor': False, 'classifier': False}):
for layer in encoder.layers:
layer.trainable = trainable

inputs = tf.keras.Input(shape=input_shape, name='image')
features = encoder(inputs)
features = tf.keras.layers.Dropout(dropout_rate)(features)
features = tf.keras.layers.Dense(hidden_units, activation="relu")(features)
features = tf.keras.layers.Dropout(dropout_rate)(features)

outputs = []
if n_output_features > 0:
if add_dense['regressor']:
features_reg = tf.keras.layers.Dense(hidden_units, activation="relu")(features)
features_reg = tf.keras.layers.Dropout(dropout_rate)(features_reg)
outputs.append(tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features_reg))
else:
outputs.append(tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features))
if len(categories.keys()) > 0:
if add_dense['classifier']:
features = tf.keras.layers.Dense(hidden_units, activation="relu")(features)
features = tf.keras.layers.Dropout(dropout_rate)(features)
for category in category_order:
activation = 'softmax'
n_classes = categories[category]
outputs.append(tf.keras.layers.Dense(n_classes, name='cls_'+category, activation=activation)(features))

model = tf.keras.Model(inputs=inputs, outputs=outputs, name="regressor_classifier")

return model
3 changes: 3 additions & 0 deletions model_zoo/DROID-MVP/movinet_a2_base/checkpoint
Git LFS file not shown
Git LFS file not shown
3 changes: 3 additions & 0 deletions model_zoo/DROID-MVP/movinet_a2_base/ckpt-1.index
Git LFS file not shown
36 changes: 36 additions & 0 deletions model_zoo/DROID-MVP/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# DROID-MVP Inference Example

This is a simple example script demonstrated how to load and run the DROID-MVP model. Model training and inference was performed using the code provided in the ML4H [model zoo](https://github.com/broadinstitute/ml4h/tree/master/model_zoo/DROID). The example below was adapted from the DROID inference code.

1. Download DROID docker image. Note: docker image is not compatible with Apple Silicon.

`docker pull alalusim/droid:latest`

2. Pull github repo, including DROID-MVP model checkpoints stored using git lfs.

```
github clone https://github.com/broadinstitute/ml4h.git
git lfs pull --include ml4h/model_zoo/DROID-MVP/droid_mvp_checkpoint/*
git lfs pull --include ml4h/model_zoo/DROID-MVP/movinet_a2_base/*
```

3. Run docker image while mounting ml4h directory and run example inference script.

`docker run -it -v {PATH TO CLONED ML4H DIRECTORY}:/ml4h/ alalusim/droid:latest`

```
cd /ml4h/model_zoo/DROID-MVP/
python droid_mvp_inference.py
```

To use with your own data, format echocardiogram videos as tensors with shape (16, 224, 224, 3) before passing to the model. Code for data preprocessing, storage, loading, training, and inference can be found in the ml4h [model zoo](https://github.com/broadinstitute/ml4h/tree/master/model_zoo/DROID).

Model outputs for DROID-MVP take the form:
```
[
[["MVP", "Not MVP"]],
[["Anterior ", "Bileaflet", "Not MVP", "Posterior", "Superior Displacement", "MVP not otherwise specified"]],
]
```

Note that the model was optimized for predicting binary MVP status (the primary task) and that detailed MVP status was used as an auxiliary task to improve performance on the primary classification task.
3 changes: 3 additions & 0 deletions model_zoo/DROID-RV/droid_rv_checkpoint/checkpoint
Git LFS file not shown
Git LFS file not shown
3 changes: 3 additions & 0 deletions model_zoo/DROID-RV/droid_rv_checkpoint/chkp.index
Git LFS file not shown
58 changes: 58 additions & 0 deletions model_zoo/DROID-RV/droid_rv_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python
# coding: utf-8

import numpy as np
import tensorflow as tf
from droid_rv_model_description import create_movinet_classifier, create_regressor_classifier, rescale_droid_rv_outputs, rescale_droid_rvef_outputs
import logging
tf.get_logger().setLevel(logging.ERROR)

droid_rv_checkpoint = "droid_rv_checkpoint/chkp"
droid_rvef_checkpoint = "droid_rvef_checkpoint/chkp"
movinet_chkp_dir = 'movinet_a2_base/'

movinet_model, backbone = create_movinet_classifier(
n_input_frames=16,
batch_size=16,
num_classes=600,
checkpoint_dir=movinet_chkp_dir,
)

backbone_output = backbone.layers[-1].output[0]
flatten = tf.keras.layers.Flatten()(backbone_output)
encoder = tf.keras.Model(inputs=[backbone.input], outputs=[flatten])

droid_rv_func_args = {
'input_shape': (16, 224, 224, 3),
'n_output_features': 2, # number of regression features
'categories': {"RV_size":2, "RV_function":2, "Sex":2},
'category_order': ["RV_size", "RV_function", "Sex"],
}

droid_rvef_func_args = {
'input_shape': (16, 224, 224, 3),
'n_output_features': 4, # number of regression features
'categories': {"Sex":2},
'category_order': ["Sex"],
}

droid_rv_model = create_regressor_classifier(encoder, **droid_rv_func_args)
droid_rv_model.load_weights(droid_rv_checkpoint)

droid_rvef_model = create_regressor_classifier(encoder, **droid_rvef_func_args)
droid_rvef_model.load_weights(droid_rvef_checkpoint)

random_video = np.random.random((1, 16, 224, 224, 3))

droid_rv_pred = droid_rv_model.predict(random_video)
droid_rvef_pred = droid_rvef_model.predict(random_video)

print(f"""
DROID-RV Predictions:
{rescale_droid_rv_outputs(droid_rv_pred)}
DROID-RVEF Predictions:
{rescale_droid_rvef_outputs(droid_rvef_pred)}
""")
78 changes: 78 additions & 0 deletions model_zoo/DROID-RV/droid_rv_model_description.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.movinet.modeling import movinet, movinet_model

hidden_units = 256
dropout_rate = 0.5

def create_movinet_classifier(
n_input_frames,
batch_size,
checkpoint_dir,
num_classes,
freeze_backbone=False
):
backbone = movinet.Movinet(model_id='a2')
model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600)
model.build([1, 1, 1, 1, 3])
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore(checkpoint_path)
status.assert_existing_objects_matched()

model = movinet_model.MovinetClassifier(
backbone=backbone,
num_classes=num_classes
)
model.build([batch_size, n_input_frames, 224, 224, 3])

if freeze_backbone:
for layer in model.layers[:-1]:
layer.trainable = False
model.layers[-1].trainable = True

return model, backbone

def create_regressor_classifier(encoder, trainable=True, input_shape=(224, 224, 3), n_output_features=0, categories={},
category_order=None, add_dense={'regressor': False, 'classifier': False}):
for layer in encoder.layers:
layer.trainable = trainable

inputs = tf.keras.Input(shape=input_shape, name='image')
features = encoder(inputs)
features = tf.keras.layers.Dropout(dropout_rate)(features)
features = tf.keras.layers.Dense(hidden_units, activation="relu")(features)
features = tf.keras.layers.Dropout(dropout_rate)(features)

outputs = []
if n_output_features > 0:
if add_dense['regressor']:
features_reg = tf.keras.layers.Dense(hidden_units, activation="relu")(features)
features_reg = tf.keras.layers.Dropout(dropout_rate)(features_reg)
outputs.append(tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features_reg))
else:
outputs.append(tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features))
if len(categories.keys()) > 0:
if add_dense['classifier']:
features = tf.keras.layers.Dense(hidden_units, activation="relu")(features)
features = tf.keras.layers.Dropout(dropout_rate)(features)
for category in category_order:
activation = 'softmax'
n_classes = categories[category]
outputs.append(tf.keras.layers.Dense(n_classes, name='cls_'+category, activation=activation)(features))

model = tf.keras.Model(inputs=inputs, outputs=outputs, name="regressor_classifier")

return model

def rescale_droid_rv_outputs(droid_rv_output):
droid_rv_output[0][0,0] = droid_rv_output[0][0,0] * 15.51761856 + 64.43979878
droid_rv_output[0][0,1] = droid_rv_output[0][0,1] * 6.88963822 + 42.52320993
return droid_rv_output

def rescale_droid_rvef_outputs(droid_rvef_output):
droid_rvef_output[0][0,0] = droid_rvef_output[0][0,0] * 8.658711 + 53.40699
droid_rvef_output[0][0,1] = droid_rvef_output[0][0,1] * 46.5734 + 130.8913
droid_rvef_output[0][0,2] = droid_rvef_output[0][0,2] * 31.6643 + 62.87321
droid_rvef_output[0][0,3] = droid_rvef_output[0][0,3] * 22.99643 + 47.18989
return droid_rvef_output
3 changes: 3 additions & 0 deletions model_zoo/DROID-RV/droid_rvef_checkpoint/checkpoint
Git LFS file not shown
Git LFS file not shown
3 changes: 3 additions & 0 deletions model_zoo/DROID-RV/droid_rvef_checkpoint/chkp.index
Git LFS file not shown
3 changes: 3 additions & 0 deletions model_zoo/DROID-RV/movinet_a2_base/checkpoint
Git LFS file not shown
3 changes: 3 additions & 0 deletions model_zoo/DROID-RV/movinet_a2_base/ckpt-1.data-00000-of-00001
Git LFS file not shown
3 changes: 3 additions & 0 deletions model_zoo/DROID-RV/movinet_a2_base/ckpt-1.index
Git LFS file not shown
Loading

0 comments on commit ae7848b

Please sign in to comment.