-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
DROID-MVP and DROID-RV in the model zoo, updates and cleanup to DROID code base
- Loading branch information
Showing
28 changed files
with
1,237 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
model_zoo/DROID-MVP/droid_mvp_checkpoint/chkp.data-00000-of-00001
Git LFS file not shown
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#!/usr/bin/env python | ||
# coding: utf-8 | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from droid_mvp_model_description import create_movinet_classifier, create_regressor_classifier | ||
import logging | ||
tf.get_logger().setLevel(logging.ERROR) | ||
|
||
pretrained_chkp_dir = "droid_mvp_checkpoint/chkp" | ||
movinet_chkp_dir = 'movinet_a2_base/' | ||
|
||
movinet_model, backbone = create_movinet_classifier( | ||
n_input_frames=16, | ||
batch_size=16, | ||
num_classes=600, | ||
checkpoint_dir=movinet_chkp_dir, | ||
) | ||
|
||
backbone_output = backbone.layers[-1].output[0] | ||
flatten = tf.keras.layers.Flatten()(backbone_output) | ||
encoder = tf.keras.Model(inputs=[backbone.input], outputs=[flatten]) | ||
|
||
func_args = { | ||
'input_shape': (16, 224, 224, 3), | ||
'n_output_features': 0, # number of regression features | ||
'categories': {"mvp_status_binary":2, "mvp_status_detailed":6}, | ||
'category_order': ["mvp_status_binary", "mvp_status_detailed"], | ||
} | ||
|
||
model_plus_head = create_regressor_classifier(encoder, **func_args) | ||
|
||
model_plus_head.load_weights(pretrained_chkp_dir) | ||
|
||
random_video = np.random.random((1, 16, 224, 224, 3)) | ||
|
||
print(f""" | ||
DROID-MVP Predictions: | ||
{model_plus_head.predict(random_video)} | ||
""") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
from official.vision.beta.projects.movinet.modeling import movinet, movinet_model | ||
|
||
hidden_units = 256 | ||
dropout_rate = 0.5 | ||
|
||
def create_movinet_classifier( | ||
n_input_frames, | ||
batch_size, | ||
checkpoint_dir, | ||
num_classes, | ||
freeze_backbone=False | ||
): | ||
backbone = movinet.Movinet(model_id='a2') | ||
model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600) | ||
model.build([1, 1, 1, 1, 3]) | ||
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) | ||
checkpoint = tf.train.Checkpoint(model=model) | ||
status = checkpoint.restore(checkpoint_path) | ||
status.assert_existing_objects_matched() | ||
|
||
model = movinet_model.MovinetClassifier( | ||
backbone=backbone, | ||
num_classes=num_classes | ||
) | ||
model.build([batch_size, n_input_frames, 224, 224, 3]) | ||
|
||
if freeze_backbone: | ||
for layer in model.layers[:-1]: | ||
layer.trainable = False | ||
model.layers[-1].trainable = True | ||
|
||
return model, backbone | ||
|
||
def create_regressor_classifier(encoder, trainable=True, input_shape=(224, 224, 3), n_output_features=0, categories={}, | ||
category_order=None, add_dense={'regressor': False, 'classifier': False}): | ||
for layer in encoder.layers: | ||
layer.trainable = trainable | ||
|
||
inputs = tf.keras.Input(shape=input_shape, name='image') | ||
features = encoder(inputs) | ||
features = tf.keras.layers.Dropout(dropout_rate)(features) | ||
features = tf.keras.layers.Dense(hidden_units, activation="relu")(features) | ||
features = tf.keras.layers.Dropout(dropout_rate)(features) | ||
|
||
outputs = [] | ||
if n_output_features > 0: | ||
if add_dense['regressor']: | ||
features_reg = tf.keras.layers.Dense(hidden_units, activation="relu")(features) | ||
features_reg = tf.keras.layers.Dropout(dropout_rate)(features_reg) | ||
outputs.append(tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features_reg)) | ||
else: | ||
outputs.append(tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features)) | ||
if len(categories.keys()) > 0: | ||
if add_dense['classifier']: | ||
features = tf.keras.layers.Dense(hidden_units, activation="relu")(features) | ||
features = tf.keras.layers.Dropout(dropout_rate)(features) | ||
for category in category_order: | ||
activation = 'softmax' | ||
n_classes = categories[category] | ||
outputs.append(tf.keras.layers.Dense(n_classes, name='cls_'+category, activation=activation)(features)) | ||
|
||
model = tf.keras.Model(inputs=inputs, outputs=outputs, name="regressor_classifier") | ||
|
||
return model |
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
model_zoo/DROID-MVP/movinet_a2_base/ckpt-1.data-00000-of-00001
Git LFS file not shown
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# DROID-MVP Inference Example | ||
|
||
This is a simple example script demonstrated how to load and run the DROID-MVP model. Model training and inference was performed using the code provided in the ML4H [model zoo](https://github.com/broadinstitute/ml4h/tree/master/model_zoo/DROID). The example below was adapted from the DROID inference code. | ||
|
||
1. Download DROID docker image. Note: docker image is not compatible with Apple Silicon. | ||
|
||
`docker pull alalusim/droid:latest` | ||
|
||
2. Pull github repo, including DROID-MVP model checkpoints stored using git lfs. | ||
|
||
``` | ||
github clone https://github.com/broadinstitute/ml4h.git | ||
git lfs pull --include ml4h/model_zoo/DROID-MVP/droid_mvp_checkpoint/* | ||
git lfs pull --include ml4h/model_zoo/DROID-MVP/movinet_a2_base/* | ||
``` | ||
|
||
3. Run docker image while mounting ml4h directory and run example inference script. | ||
|
||
`docker run -it -v {PATH TO CLONED ML4H DIRECTORY}:/ml4h/ alalusim/droid:latest` | ||
|
||
``` | ||
cd /ml4h/model_zoo/DROID-MVP/ | ||
python droid_mvp_inference.py | ||
``` | ||
|
||
To use with your own data, format echocardiogram videos as tensors with shape (16, 224, 224, 3) before passing to the model. Code for data preprocessing, storage, loading, training, and inference can be found in the ml4h [model zoo](https://github.com/broadinstitute/ml4h/tree/master/model_zoo/DROID). | ||
|
||
Model outputs for DROID-MVP take the form: | ||
``` | ||
[ | ||
[["MVP", "Not MVP"]], | ||
[["Anterior ", "Bileaflet", "Not MVP", "Posterior", "Superior Displacement", "MVP not otherwise specified"]], | ||
] | ||
``` | ||
|
||
Note that the model was optimized for predicting binary MVP status (the primary task) and that detailed MVP status was used as an auxiliary task to improve performance on the primary classification task. |
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
model_zoo/DROID-RV/droid_rv_checkpoint/chkp.data-00000-of-00001
Git LFS file not shown
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
#!/usr/bin/env python | ||
# coding: utf-8 | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from droid_rv_model_description import create_movinet_classifier, create_regressor_classifier, rescale_droid_rv_outputs, rescale_droid_rvef_outputs | ||
import logging | ||
tf.get_logger().setLevel(logging.ERROR) | ||
|
||
droid_rv_checkpoint = "droid_rv_checkpoint/chkp" | ||
droid_rvef_checkpoint = "droid_rvef_checkpoint/chkp" | ||
movinet_chkp_dir = 'movinet_a2_base/' | ||
|
||
movinet_model, backbone = create_movinet_classifier( | ||
n_input_frames=16, | ||
batch_size=16, | ||
num_classes=600, | ||
checkpoint_dir=movinet_chkp_dir, | ||
) | ||
|
||
backbone_output = backbone.layers[-1].output[0] | ||
flatten = tf.keras.layers.Flatten()(backbone_output) | ||
encoder = tf.keras.Model(inputs=[backbone.input], outputs=[flatten]) | ||
|
||
droid_rv_func_args = { | ||
'input_shape': (16, 224, 224, 3), | ||
'n_output_features': 2, # number of regression features | ||
'categories': {"RV_size":2, "RV_function":2, "Sex":2}, | ||
'category_order': ["RV_size", "RV_function", "Sex"], | ||
} | ||
|
||
droid_rvef_func_args = { | ||
'input_shape': (16, 224, 224, 3), | ||
'n_output_features': 4, # number of regression features | ||
'categories': {"Sex":2}, | ||
'category_order': ["Sex"], | ||
} | ||
|
||
droid_rv_model = create_regressor_classifier(encoder, **droid_rv_func_args) | ||
droid_rv_model.load_weights(droid_rv_checkpoint) | ||
|
||
droid_rvef_model = create_regressor_classifier(encoder, **droid_rvef_func_args) | ||
droid_rvef_model.load_weights(droid_rvef_checkpoint) | ||
|
||
random_video = np.random.random((1, 16, 224, 224, 3)) | ||
|
||
droid_rv_pred = droid_rv_model.predict(random_video) | ||
droid_rvef_pred = droid_rvef_model.predict(random_video) | ||
|
||
print(f""" | ||
DROID-RV Predictions: | ||
{rescale_droid_rv_outputs(droid_rv_pred)} | ||
DROID-RVEF Predictions: | ||
{rescale_droid_rvef_outputs(droid_rvef_pred)} | ||
""") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
from official.vision.beta.projects.movinet.modeling import movinet, movinet_model | ||
|
||
hidden_units = 256 | ||
dropout_rate = 0.5 | ||
|
||
def create_movinet_classifier( | ||
n_input_frames, | ||
batch_size, | ||
checkpoint_dir, | ||
num_classes, | ||
freeze_backbone=False | ||
): | ||
backbone = movinet.Movinet(model_id='a2') | ||
model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600) | ||
model.build([1, 1, 1, 1, 3]) | ||
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) | ||
checkpoint = tf.train.Checkpoint(model=model) | ||
status = checkpoint.restore(checkpoint_path) | ||
status.assert_existing_objects_matched() | ||
|
||
model = movinet_model.MovinetClassifier( | ||
backbone=backbone, | ||
num_classes=num_classes | ||
) | ||
model.build([batch_size, n_input_frames, 224, 224, 3]) | ||
|
||
if freeze_backbone: | ||
for layer in model.layers[:-1]: | ||
layer.trainable = False | ||
model.layers[-1].trainable = True | ||
|
||
return model, backbone | ||
|
||
def create_regressor_classifier(encoder, trainable=True, input_shape=(224, 224, 3), n_output_features=0, categories={}, | ||
category_order=None, add_dense={'regressor': False, 'classifier': False}): | ||
for layer in encoder.layers: | ||
layer.trainable = trainable | ||
|
||
inputs = tf.keras.Input(shape=input_shape, name='image') | ||
features = encoder(inputs) | ||
features = tf.keras.layers.Dropout(dropout_rate)(features) | ||
features = tf.keras.layers.Dense(hidden_units, activation="relu")(features) | ||
features = tf.keras.layers.Dropout(dropout_rate)(features) | ||
|
||
outputs = [] | ||
if n_output_features > 0: | ||
if add_dense['regressor']: | ||
features_reg = tf.keras.layers.Dense(hidden_units, activation="relu")(features) | ||
features_reg = tf.keras.layers.Dropout(dropout_rate)(features_reg) | ||
outputs.append(tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features_reg)) | ||
else: | ||
outputs.append(tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features)) | ||
if len(categories.keys()) > 0: | ||
if add_dense['classifier']: | ||
features = tf.keras.layers.Dense(hidden_units, activation="relu")(features) | ||
features = tf.keras.layers.Dropout(dropout_rate)(features) | ||
for category in category_order: | ||
activation = 'softmax' | ||
n_classes = categories[category] | ||
outputs.append(tf.keras.layers.Dense(n_classes, name='cls_'+category, activation=activation)(features)) | ||
|
||
model = tf.keras.Model(inputs=inputs, outputs=outputs, name="regressor_classifier") | ||
|
||
return model | ||
|
||
def rescale_droid_rv_outputs(droid_rv_output): | ||
droid_rv_output[0][0,0] = droid_rv_output[0][0,0] * 15.51761856 + 64.43979878 | ||
droid_rv_output[0][0,1] = droid_rv_output[0][0,1] * 6.88963822 + 42.52320993 | ||
return droid_rv_output | ||
|
||
def rescale_droid_rvef_outputs(droid_rvef_output): | ||
droid_rvef_output[0][0,0] = droid_rvef_output[0][0,0] * 8.658711 + 53.40699 | ||
droid_rvef_output[0][0,1] = droid_rvef_output[0][0,1] * 46.5734 + 130.8913 | ||
droid_rvef_output[0][0,2] = droid_rvef_output[0][0,2] * 31.6643 + 62.87321 | ||
droid_rvef_output[0][0,3] = droid_rvef_output[0][0,3] * 22.99643 + 47.18989 | ||
return droid_rvef_output |
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
model_zoo/DROID-RV/droid_rvef_checkpoint/chkp.data-00000-of-00001
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
model_zoo/DROID-RV/movinet_a2_base/ckpt-1.data-00000-of-00001
Git LFS file not shown
Git LFS file not shown
Oops, something went wrong.