From ae7848beabcf3b6d36daecdf6dd4a27781131821 Mon Sep 17 00:00:00 2001 From: Mostafa Amer Al-Alusi Date: Fri, 20 Dec 2024 10:02:07 -0500 Subject: [PATCH] DROID-MVP and DROID-RV (#578) DROID-MVP and DROID-RV in the model zoo, updates and cleanup to DROID code base --- .gitattributes | 15 + .../DROID-MVP/droid_mvp_checkpoint/checkpoint | 3 + .../chkp.data-00000-of-00001 | 3 + .../DROID-MVP/droid_mvp_checkpoint/chkp.index | 3 + model_zoo/DROID-MVP/droid_mvp_inference.py | 40 ++ .../DROID-MVP/droid_mvp_model_description.py | 66 +++ .../DROID-MVP/movinet_a2_base/checkpoint | 3 + .../ckpt-1.data-00000-of-00001 | 3 + .../DROID-MVP/movinet_a2_base/ckpt-1.index | 3 + model_zoo/DROID-MVP/readme.md | 36 ++ .../DROID-RV/droid_rv_checkpoint/checkpoint | 3 + .../chkp.data-00000-of-00001 | 3 + .../DROID-RV/droid_rv_checkpoint/chkp.index | 3 + model_zoo/DROID-RV/droid_rv_inference.py | 58 ++ .../DROID-RV/droid_rv_model_description.py | 78 +++ .../DROID-RV/droid_rvef_checkpoint/checkpoint | 3 + .../chkp.data-00000-of-00001 | 3 + .../DROID-RV/droid_rvef_checkpoint/chkp.index | 3 + model_zoo/DROID-RV/movinet_a2_base/checkpoint | 3 + .../ckpt-1.data-00000-of-00001 | 3 + .../DROID-RV/movinet_a2_base/ckpt-1.index | 3 + model_zoo/DROID-RV/readme.md | 45 ++ model_zoo/DROID/README.md | 4 +- model_zoo/DROID/data_descriptions/echo.py | 25 +- .../DROID/data_descriptions/wide_file.py | 96 ++++ .../DROID/echo_supervised_inference_recipe.py | 193 +++++-- .../DROID/echo_supervised_training_recipe.py | 511 ++++++++++++++++++ model_zoo/DROID/model_descriptions/echo.py | 83 +++ 28 files changed, 1237 insertions(+), 58 deletions(-) create mode 100755 model_zoo/DROID-MVP/droid_mvp_checkpoint/checkpoint create mode 100755 model_zoo/DROID-MVP/droid_mvp_checkpoint/chkp.data-00000-of-00001 create mode 100755 model_zoo/DROID-MVP/droid_mvp_checkpoint/chkp.index create mode 100644 model_zoo/DROID-MVP/droid_mvp_inference.py create mode 100755 model_zoo/DROID-MVP/droid_mvp_model_description.py create mode 100755 model_zoo/DROID-MVP/movinet_a2_base/checkpoint create mode 100755 model_zoo/DROID-MVP/movinet_a2_base/ckpt-1.data-00000-of-00001 create mode 100755 model_zoo/DROID-MVP/movinet_a2_base/ckpt-1.index create mode 100644 model_zoo/DROID-MVP/readme.md create mode 100755 model_zoo/DROID-RV/droid_rv_checkpoint/checkpoint create mode 100755 model_zoo/DROID-RV/droid_rv_checkpoint/chkp.data-00000-of-00001 create mode 100755 model_zoo/DROID-RV/droid_rv_checkpoint/chkp.index create mode 100644 model_zoo/DROID-RV/droid_rv_inference.py create mode 100755 model_zoo/DROID-RV/droid_rv_model_description.py create mode 100755 model_zoo/DROID-RV/droid_rvef_checkpoint/checkpoint create mode 100755 model_zoo/DROID-RV/droid_rvef_checkpoint/chkp.data-00000-of-00001 create mode 100755 model_zoo/DROID-RV/droid_rvef_checkpoint/chkp.index create mode 100755 model_zoo/DROID-RV/movinet_a2_base/checkpoint create mode 100755 model_zoo/DROID-RV/movinet_a2_base/ckpt-1.data-00000-of-00001 create mode 100755 model_zoo/DROID-RV/movinet_a2_base/ckpt-1.index create mode 100644 model_zoo/DROID-RV/readme.md create mode 100644 model_zoo/DROID/data_descriptions/wide_file.py create mode 100644 model_zoo/DROID/echo_supervised_training_recipe.py diff --git a/.gitattributes b/.gitattributes index e056953b3..2ec4f5c56 100644 --- a/.gitattributes +++ b/.gitattributes @@ -8,6 +8,21 @@ notebooks/**.ipynb filter=nbstripout *.genes filter=lfs diff=lfs merge=lfs -text model_zoo/ECG_PheWAS/*.h5 filter=lfs diff=lfs merge=lfs -text model_zoo/DROID/encoders/**/* filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-MVP/movinet_a2_base/ckpt-1.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-MVP/movinet_a2_base/ckpt-1.index filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-MVP/movinet_a2_base/checkpoint filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-MVP/droid_mvp_checkpoint/checkpoint filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-MVP/droid_mvp_checkpoint/chkp.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-MVP/droid_mvp_checkpoint/chkp.index filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-RV/droid_rv_checkpoint/checkpoint filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-RV/droid_rv_checkpoint/chkp.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-RV/droid_rv_checkpoint/chkp.index filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-RV/droid_rvef_checkpoint/checkpoint filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-RV/droid_rvef_checkpoint/chkp.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-RV/droid_rvef_checkpoint/chkp.index filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-RV/movinet_a2_base/checkpoint filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-RV/movinet_a2_base/ckpt-1.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text +model_zoo/DROID-RV/movinet_a2_base/ckpt-1.index filter=lfs diff=lfs merge=lfs -text model_zoo/ECG2AF/ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5 filter=lfs diff=lfs merge=lfs -text model_zoo/ECG2AF/strip_I_survival_curve_af_v2021_06_15.h5 filter=lfs diff=lfs merge=lfs -text model_zoo/ECG2AF/strip_II_survival_curve_af_v2021_06_15.h5 filter=lfs diff=lfs merge=lfs -text diff --git a/model_zoo/DROID-MVP/droid_mvp_checkpoint/checkpoint b/model_zoo/DROID-MVP/droid_mvp_checkpoint/checkpoint new file mode 100755 index 000000000..68c9e0245 --- /dev/null +++ b/model_zoo/DROID-MVP/droid_mvp_checkpoint/checkpoint @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb5a38d5763dc7f41acdb9468670ec94f1b31846285558ab69e302bd01917962 +size 65 diff --git a/model_zoo/DROID-MVP/droid_mvp_checkpoint/chkp.data-00000-of-00001 b/model_zoo/DROID-MVP/droid_mvp_checkpoint/chkp.data-00000-of-00001 new file mode 100755 index 000000000..5bffdcef0 --- /dev/null +++ b/model_zoo/DROID-MVP/droid_mvp_checkpoint/chkp.data-00000-of-00001 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2196f91e9ae5e932b9c3dd546e3ce31e1a509fae5c3c13291bb0e496fb4e33a +size 34826365 diff --git a/model_zoo/DROID-MVP/droid_mvp_checkpoint/chkp.index b/model_zoo/DROID-MVP/droid_mvp_checkpoint/chkp.index new file mode 100755 index 000000000..dc60179e3 --- /dev/null +++ b/model_zoo/DROID-MVP/droid_mvp_checkpoint/chkp.index @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0af27bb4a0285fadf47c471f4f9047d86d767e6f697d9ec40cccd246b20f9cb7 +size 99789 diff --git a/model_zoo/DROID-MVP/droid_mvp_inference.py b/model_zoo/DROID-MVP/droid_mvp_inference.py new file mode 100644 index 000000000..5b792d3ba --- /dev/null +++ b/model_zoo/DROID-MVP/droid_mvp_inference.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# coding: utf-8 + +import numpy as np +import tensorflow as tf +from droid_mvp_model_description import create_movinet_classifier, create_regressor_classifier +import logging +tf.get_logger().setLevel(logging.ERROR) + +pretrained_chkp_dir = "droid_mvp_checkpoint/chkp" +movinet_chkp_dir = 'movinet_a2_base/' + +movinet_model, backbone = create_movinet_classifier( + n_input_frames=16, + batch_size=16, + num_classes=600, + checkpoint_dir=movinet_chkp_dir, +) + +backbone_output = backbone.layers[-1].output[0] +flatten = tf.keras.layers.Flatten()(backbone_output) +encoder = tf.keras.Model(inputs=[backbone.input], outputs=[flatten]) + +func_args = { + 'input_shape': (16, 224, 224, 3), + 'n_output_features': 0, # number of regression features + 'categories': {"mvp_status_binary":2, "mvp_status_detailed":6}, + 'category_order': ["mvp_status_binary", "mvp_status_detailed"], +} + +model_plus_head = create_regressor_classifier(encoder, **func_args) + +model_plus_head.load_weights(pretrained_chkp_dir) + +random_video = np.random.random((1, 16, 224, 224, 3)) + +print(f""" +DROID-MVP Predictions: +{model_plus_head.predict(random_video)} +""") \ No newline at end of file diff --git a/model_zoo/DROID-MVP/droid_mvp_model_description.py b/model_zoo/DROID-MVP/droid_mvp_model_description.py new file mode 100755 index 000000000..eb78d95e8 --- /dev/null +++ b/model_zoo/DROID-MVP/droid_mvp_model_description.py @@ -0,0 +1,66 @@ +import numpy as np +import tensorflow as tf +from official.vision.beta.projects.movinet.modeling import movinet, movinet_model + +hidden_units = 256 +dropout_rate = 0.5 + +def create_movinet_classifier( + n_input_frames, + batch_size, + checkpoint_dir, + num_classes, + freeze_backbone=False +): + backbone = movinet.Movinet(model_id='a2') + model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600) + model.build([1, 1, 1, 1, 3]) + checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) + checkpoint = tf.train.Checkpoint(model=model) + status = checkpoint.restore(checkpoint_path) + status.assert_existing_objects_matched() + + model = movinet_model.MovinetClassifier( + backbone=backbone, + num_classes=num_classes + ) + model.build([batch_size, n_input_frames, 224, 224, 3]) + + if freeze_backbone: + for layer in model.layers[:-1]: + layer.trainable = False + model.layers[-1].trainable = True + + return model, backbone + +def create_regressor_classifier(encoder, trainable=True, input_shape=(224, 224, 3), n_output_features=0, categories={}, + category_order=None, add_dense={'regressor': False, 'classifier': False}): + for layer in encoder.layers: + layer.trainable = trainable + + inputs = tf.keras.Input(shape=input_shape, name='image') + features = encoder(inputs) + features = tf.keras.layers.Dropout(dropout_rate)(features) + features = tf.keras.layers.Dense(hidden_units, activation="relu")(features) + features = tf.keras.layers.Dropout(dropout_rate)(features) + + outputs = [] + if n_output_features > 0: + if add_dense['regressor']: + features_reg = tf.keras.layers.Dense(hidden_units, activation="relu")(features) + features_reg = tf.keras.layers.Dropout(dropout_rate)(features_reg) + outputs.append(tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features_reg)) + else: + outputs.append(tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features)) + if len(categories.keys()) > 0: + if add_dense['classifier']: + features = tf.keras.layers.Dense(hidden_units, activation="relu")(features) + features = tf.keras.layers.Dropout(dropout_rate)(features) + for category in category_order: + activation = 'softmax' + n_classes = categories[category] + outputs.append(tf.keras.layers.Dense(n_classes, name='cls_'+category, activation=activation)(features)) + + model = tf.keras.Model(inputs=inputs, outputs=outputs, name="regressor_classifier") + + return model diff --git a/model_zoo/DROID-MVP/movinet_a2_base/checkpoint b/model_zoo/DROID-MVP/movinet_a2_base/checkpoint new file mode 100755 index 000000000..6f975fdec --- /dev/null +++ b/model_zoo/DROID-MVP/movinet_a2_base/checkpoint @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:78fdb1e081e9fc8d4e10e3bca4fe00117a236ddc4726bbf75594db19ae1be665 +size 69 diff --git a/model_zoo/DROID-MVP/movinet_a2_base/ckpt-1.data-00000-of-00001 b/model_zoo/DROID-MVP/movinet_a2_base/ckpt-1.data-00000-of-00001 new file mode 100755 index 000000000..cb729253e --- /dev/null +++ b/model_zoo/DROID-MVP/movinet_a2_base/ckpt-1.data-00000-of-00001 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f393b7ef377ffaf59bd8bf081c72d05e74a576c5bba0d4bc180315432e49e557 +size 21240182 diff --git a/model_zoo/DROID-MVP/movinet_a2_base/ckpt-1.index b/model_zoo/DROID-MVP/movinet_a2_base/ckpt-1.index new file mode 100755 index 000000000..7bcd8b097 --- /dev/null +++ b/model_zoo/DROID-MVP/movinet_a2_base/ckpt-1.index @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d801f29eace2f39bcc7b268ecf0d1bd117d9b3881cbcd810721c8c1b1f6c161 +size 10102 diff --git a/model_zoo/DROID-MVP/readme.md b/model_zoo/DROID-MVP/readme.md new file mode 100644 index 000000000..31f4037a6 --- /dev/null +++ b/model_zoo/DROID-MVP/readme.md @@ -0,0 +1,36 @@ +# DROID-MVP Inference Example + +This is a simple example script demonstrated how to load and run the DROID-MVP model. Model training and inference was performed using the code provided in the ML4H [model zoo](https://github.com/broadinstitute/ml4h/tree/master/model_zoo/DROID). The example below was adapted from the DROID inference code. + +1. Download DROID docker image. Note: docker image is not compatible with Apple Silicon. + +`docker pull alalusim/droid:latest` + +2. Pull github repo, including DROID-MVP model checkpoints stored using git lfs. + +``` +github clone https://github.com/broadinstitute/ml4h.git +git lfs pull --include ml4h/model_zoo/DROID-MVP/droid_mvp_checkpoint/* +git lfs pull --include ml4h/model_zoo/DROID-MVP/movinet_a2_base/* +``` + +3. Run docker image while mounting ml4h directory and run example inference script. + +`docker run -it -v {PATH TO CLONED ML4H DIRECTORY}:/ml4h/ alalusim/droid:latest` + +``` +cd /ml4h/model_zoo/DROID-MVP/ +python droid_mvp_inference.py +``` + +To use with your own data, format echocardiogram videos as tensors with shape (16, 224, 224, 3) before passing to the model. Code for data preprocessing, storage, loading, training, and inference can be found in the ml4h [model zoo](https://github.com/broadinstitute/ml4h/tree/master/model_zoo/DROID). + +Model outputs for DROID-MVP take the form: +``` +[ + [["MVP", "Not MVP"]], + [["Anterior ", "Bileaflet", "Not MVP", "Posterior", "Superior Displacement", "MVP not otherwise specified"]], +] +``` + +Note that the model was optimized for predicting binary MVP status (the primary task) and that detailed MVP status was used as an auxiliary task to improve performance on the primary classification task. \ No newline at end of file diff --git a/model_zoo/DROID-RV/droid_rv_checkpoint/checkpoint b/model_zoo/DROID-RV/droid_rv_checkpoint/checkpoint new file mode 100755 index 000000000..68c9e0245 --- /dev/null +++ b/model_zoo/DROID-RV/droid_rv_checkpoint/checkpoint @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb5a38d5763dc7f41acdb9468670ec94f1b31846285558ab69e302bd01917962 +size 65 diff --git a/model_zoo/DROID-RV/droid_rv_checkpoint/chkp.data-00000-of-00001 b/model_zoo/DROID-RV/droid_rv_checkpoint/chkp.data-00000-of-00001 new file mode 100755 index 000000000..7a3271b58 --- /dev/null +++ b/model_zoo/DROID-RV/droid_rv_checkpoint/chkp.data-00000-of-00001 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2af9efac9cec47cdd0bb4ca0c539b153657583fc6f261ac42bf5ef01031792f0 +size 34827706 diff --git a/model_zoo/DROID-RV/droid_rv_checkpoint/chkp.index b/model_zoo/DROID-RV/droid_rv_checkpoint/chkp.index new file mode 100755 index 000000000..7ecb4fb16 --- /dev/null +++ b/model_zoo/DROID-RV/droid_rv_checkpoint/chkp.index @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a6c10fc7bfb8667a75ae8faeb19ed01278ca6d3fbf67488dd35834209921d17 +size 100586 diff --git a/model_zoo/DROID-RV/droid_rv_inference.py b/model_zoo/DROID-RV/droid_rv_inference.py new file mode 100644 index 000000000..20c5617a3 --- /dev/null +++ b/model_zoo/DROID-RV/droid_rv_inference.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# coding: utf-8 + +import numpy as np +import tensorflow as tf +from droid_rv_model_description import create_movinet_classifier, create_regressor_classifier, rescale_droid_rv_outputs, rescale_droid_rvef_outputs +import logging +tf.get_logger().setLevel(logging.ERROR) + +droid_rv_checkpoint = "droid_rv_checkpoint/chkp" +droid_rvef_checkpoint = "droid_rvef_checkpoint/chkp" +movinet_chkp_dir = 'movinet_a2_base/' + +movinet_model, backbone = create_movinet_classifier( + n_input_frames=16, + batch_size=16, + num_classes=600, + checkpoint_dir=movinet_chkp_dir, +) + +backbone_output = backbone.layers[-1].output[0] +flatten = tf.keras.layers.Flatten()(backbone_output) +encoder = tf.keras.Model(inputs=[backbone.input], outputs=[flatten]) + +droid_rv_func_args = { + 'input_shape': (16, 224, 224, 3), + 'n_output_features': 2, # number of regression features + 'categories': {"RV_size":2, "RV_function":2, "Sex":2}, + 'category_order': ["RV_size", "RV_function", "Sex"], +} + +droid_rvef_func_args = { + 'input_shape': (16, 224, 224, 3), + 'n_output_features': 4, # number of regression features + 'categories': {"Sex":2}, + 'category_order': ["Sex"], +} + +droid_rv_model = create_regressor_classifier(encoder, **droid_rv_func_args) +droid_rv_model.load_weights(droid_rv_checkpoint) + +droid_rvef_model = create_regressor_classifier(encoder, **droid_rvef_func_args) +droid_rvef_model.load_weights(droid_rvef_checkpoint) + +random_video = np.random.random((1, 16, 224, 224, 3)) + +droid_rv_pred = droid_rv_model.predict(random_video) +droid_rvef_pred = droid_rvef_model.predict(random_video) + +print(f""" + +DROID-RV Predictions: +{rescale_droid_rv_outputs(droid_rv_pred)} + +DROID-RVEF Predictions: +{rescale_droid_rvef_outputs(droid_rvef_pred)} + +""") \ No newline at end of file diff --git a/model_zoo/DROID-RV/droid_rv_model_description.py b/model_zoo/DROID-RV/droid_rv_model_description.py new file mode 100755 index 000000000..19b583f42 --- /dev/null +++ b/model_zoo/DROID-RV/droid_rv_model_description.py @@ -0,0 +1,78 @@ +import numpy as np +import tensorflow as tf +from official.vision.beta.projects.movinet.modeling import movinet, movinet_model + +hidden_units = 256 +dropout_rate = 0.5 + +def create_movinet_classifier( + n_input_frames, + batch_size, + checkpoint_dir, + num_classes, + freeze_backbone=False +): + backbone = movinet.Movinet(model_id='a2') + model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600) + model.build([1, 1, 1, 1, 3]) + checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) + checkpoint = tf.train.Checkpoint(model=model) + status = checkpoint.restore(checkpoint_path) + status.assert_existing_objects_matched() + + model = movinet_model.MovinetClassifier( + backbone=backbone, + num_classes=num_classes + ) + model.build([batch_size, n_input_frames, 224, 224, 3]) + + if freeze_backbone: + for layer in model.layers[:-1]: + layer.trainable = False + model.layers[-1].trainable = True + + return model, backbone + +def create_regressor_classifier(encoder, trainable=True, input_shape=(224, 224, 3), n_output_features=0, categories={}, + category_order=None, add_dense={'regressor': False, 'classifier': False}): + for layer in encoder.layers: + layer.trainable = trainable + + inputs = tf.keras.Input(shape=input_shape, name='image') + features = encoder(inputs) + features = tf.keras.layers.Dropout(dropout_rate)(features) + features = tf.keras.layers.Dense(hidden_units, activation="relu")(features) + features = tf.keras.layers.Dropout(dropout_rate)(features) + + outputs = [] + if n_output_features > 0: + if add_dense['regressor']: + features_reg = tf.keras.layers.Dense(hidden_units, activation="relu")(features) + features_reg = tf.keras.layers.Dropout(dropout_rate)(features_reg) + outputs.append(tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features_reg)) + else: + outputs.append(tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features)) + if len(categories.keys()) > 0: + if add_dense['classifier']: + features = tf.keras.layers.Dense(hidden_units, activation="relu")(features) + features = tf.keras.layers.Dropout(dropout_rate)(features) + for category in category_order: + activation = 'softmax' + n_classes = categories[category] + outputs.append(tf.keras.layers.Dense(n_classes, name='cls_'+category, activation=activation)(features)) + + model = tf.keras.Model(inputs=inputs, outputs=outputs, name="regressor_classifier") + + return model + +def rescale_droid_rv_outputs(droid_rv_output): + droid_rv_output[0][0,0] = droid_rv_output[0][0,0] * 15.51761856 + 64.43979878 + droid_rv_output[0][0,1] = droid_rv_output[0][0,1] * 6.88963822 + 42.52320993 + return droid_rv_output + +def rescale_droid_rvef_outputs(droid_rvef_output): + droid_rvef_output[0][0,0] = droid_rvef_output[0][0,0] * 8.658711 + 53.40699 + droid_rvef_output[0][0,1] = droid_rvef_output[0][0,1] * 46.5734 + 130.8913 + droid_rvef_output[0][0,2] = droid_rvef_output[0][0,2] * 31.6643 + 62.87321 + droid_rvef_output[0][0,3] = droid_rvef_output[0][0,3] * 22.99643 + 47.18989 + return droid_rvef_output diff --git a/model_zoo/DROID-RV/droid_rvef_checkpoint/checkpoint b/model_zoo/DROID-RV/droid_rvef_checkpoint/checkpoint new file mode 100755 index 000000000..68c9e0245 --- /dev/null +++ b/model_zoo/DROID-RV/droid_rvef_checkpoint/checkpoint @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb5a38d5763dc7f41acdb9468670ec94f1b31846285558ab69e302bd01917962 +size 65 diff --git a/model_zoo/DROID-RV/droid_rvef_checkpoint/chkp.data-00000-of-00001 b/model_zoo/DROID-RV/droid_rvef_checkpoint/chkp.data-00000-of-00001 new file mode 100755 index 000000000..c124a0a0c --- /dev/null +++ b/model_zoo/DROID-RV/droid_rvef_checkpoint/chkp.data-00000-of-00001 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b635997716fc47e4d7694ed4dffb1fb19d32a3e8731bb250ec9bca9dc3283eba +size 34820197 diff --git a/model_zoo/DROID-RV/droid_rvef_checkpoint/chkp.index b/model_zoo/DROID-RV/droid_rvef_checkpoint/chkp.index new file mode 100755 index 000000000..78259797c --- /dev/null +++ b/model_zoo/DROID-RV/droid_rvef_checkpoint/chkp.index @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5cae8c8fadd814be6d5d2fd967d6479d08edcae3f40f842f107dda63effd935 +size 99789 diff --git a/model_zoo/DROID-RV/movinet_a2_base/checkpoint b/model_zoo/DROID-RV/movinet_a2_base/checkpoint new file mode 100755 index 000000000..6f975fdec --- /dev/null +++ b/model_zoo/DROID-RV/movinet_a2_base/checkpoint @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:78fdb1e081e9fc8d4e10e3bca4fe00117a236ddc4726bbf75594db19ae1be665 +size 69 diff --git a/model_zoo/DROID-RV/movinet_a2_base/ckpt-1.data-00000-of-00001 b/model_zoo/DROID-RV/movinet_a2_base/ckpt-1.data-00000-of-00001 new file mode 100755 index 000000000..cb729253e --- /dev/null +++ b/model_zoo/DROID-RV/movinet_a2_base/ckpt-1.data-00000-of-00001 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f393b7ef377ffaf59bd8bf081c72d05e74a576c5bba0d4bc180315432e49e557 +size 21240182 diff --git a/model_zoo/DROID-RV/movinet_a2_base/ckpt-1.index b/model_zoo/DROID-RV/movinet_a2_base/ckpt-1.index new file mode 100755 index 000000000..7bcd8b097 --- /dev/null +++ b/model_zoo/DROID-RV/movinet_a2_base/ckpt-1.index @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d801f29eace2f39bcc7b268ecf0d1bd117d9b3881cbcd810721c8c1b1f6c161 +size 10102 diff --git a/model_zoo/DROID-RV/readme.md b/model_zoo/DROID-RV/readme.md new file mode 100644 index 000000000..004d3e585 --- /dev/null +++ b/model_zoo/DROID-RV/readme.md @@ -0,0 +1,45 @@ +# DROID-RV Inference Example + +This is a simple example script demonstrated how to load and run the DROID-RV and DROID-RVEF models. Model training and inference was performed using the code provided in the ML4H [model zoo](https://github.com/broadinstitute/ml4h/tree/master/model_zoo/DROID). The example below was adapted from the DROID inference code. + +1. Download DROID docker image. Note: docker image is not compatible with Apple Silicon. + +`docker pull alalusim/droid:latest` + +2. Pull github repo, including DROID-RV model checkpoints stored using git lfs. + +``` +github clone https://github.com/broadinstitute/ml4h.git +git lfs pull --include ml4h/model_zoo/DROID-RV/droid_rv_checkpoint/* +git lfs pull --include ml4h/model_zoo/DROID-RV/droid_rvef_checkpoint/* +git lfs pull --include ml4h/model_zoo/DROID-RV/movinet_a2_base/* +``` + +3. Run docker image while mounting ml4h directory and run example inference script. + +`docker run -it -v {PATH TO CLONED ML4H DIRECTORY}:/ml4h/ alalusim/droid:latest` + +``` +cd /ml4h/model_zoo/DROID-RV/ +python droid_rv_inference.py +``` + +To use with your own data, format echocardiogram videos as tensors with shape (16, 224, 224, 3) before passing to the model. Code for data preprocessing, storage, loading, training, and inference can be found in the ML4H [model zoo](https://github.com/broadinstitute/ml4h/tree/master/model_zoo/DROID). + +Model outputs for DROID-RV take the form: +``` +[ + [["Age", "RVEDD"]], + [["Dilated", "Not Dilated"]], + [["Hypokinetic", "Not Hypokinetic"]], + [["Female", "Male"]] +] +``` + +Model outputs for DROID-RVEF take the form: +``` +[ + [["RVEF", "RV End-Diastolic Volume, "RV End-Systolic Volume", "Age"]], + [["Female", "Male"]] +] +``` \ No newline at end of file diff --git a/model_zoo/DROID/README.md b/model_zoo/DROID/README.md index 3129ba731..21b00d191 100644 --- a/model_zoo/DROID/README.md +++ b/model_zoo/DROID/README.md @@ -63,7 +63,7 @@ python echo_supervised_inference_recipe.py \ --wide_file {WIDE_FILE_PATH} \ --splits_file {SPLITS_JSON} \ --lmdb_folder {LMDB_DIRECTORY_PATH} \ - --pretrained_ckpt_dir {SPECIALIZED_ENCODER_PATH} \ - --movinet_ckpt_dir {MoViNet-A2-Base_PATH} \ + --pretrained_chkp_dir {SPECIALIZED_ENCODER_PATH} \ + --movinet_chkp_dir {MoViNet-A2-Base_PATH} \ --output_dir {WHERE_TO_STORE_PREDICTIONS} ``` \ No newline at end of file diff --git a/model_zoo/DROID/data_descriptions/echo.py b/model_zoo/DROID/data_descriptions/echo.py index 864c77d94..a78145050 100644 --- a/model_zoo/DROID/data_descriptions/echo.py +++ b/model_zoo/DROID/data_descriptions/echo.py @@ -40,18 +40,19 @@ def __init__( transforms=None, nframes: int = None, skip_modulo: int = 1, - start_beat=0, + start_frame=0, + randomize_start_frame = False ): self.local_lmdb_dir = local_lmdb_dir self._name = name + self.start_frame = start_frame self.nframes = nframes - self.nframes = (nframes + start_beat) * skip_modulo - self.start_beat = start_beat # transformations self.transforms = transforms or [] self.skip_modulo = skip_modulo - + self.randomize_start_frame = randomize_start_frame + def get_loading_options(self, sample_id): _, study, view = sample_id.split('_') lmdb_folder = os.path.join(self.local_lmdb_dir, f"{study}.lmdb") @@ -82,13 +83,23 @@ def get_raw_data(self, sample_id, loading_option=None): in_mem_bytes_io = io.BytesIO(txn.get(view.encode('utf-8'))) video_container = av.open(in_mem_bytes_io, metadata_errors="ignore") video_frames = itertools.cycle(video_container.decode(video=0)) + + total_frames = len(list(video_container.decode(video=0))) + video_container.seek(0) + + if self.randomize_start_frame: + frame_range = total_frames - (self.nframes * self.skip_modulo) + if frame_range > 0: + self.start_frame = np.random.randint(frame_range) + end_frame = self.start_frame + (self.nframes * self.skip_modulo) + frames =[] for i, frame in enumerate(video_frames): - if i == nframes: + if len(frames) == self.nframes: break - if i < (self.start_beat * self.skip_modulo): + if i < (self.start_frame): continue if self.skip_modulo > 1: - if (i % self.skip_modulo) != 0: + if ((i - self.start_frame) % self.skip_modulo) != 0: continue frame = np.array(frame.to_image()) for transform in self.transforms: diff --git a/model_zoo/DROID/data_descriptions/wide_file.py b/model_zoo/DROID/data_descriptions/wide_file.py new file mode 100644 index 000000000..c8fe1c7ef --- /dev/null +++ b/model_zoo/DROID/data_descriptions/wide_file.py @@ -0,0 +1,96 @@ +from typing import Dict + +import numpy as np +import pandas as pd +import tensorflow as tf + +from ml4ht.data.data_description import DataDescription + +from data_descriptions.echo import VIEW_OPTION_KEY + + +class EcholabDataDescription(DataDescription): + # DataDescription for a wide file + + def __init__( + self, + wide_df: pd.DataFrame, + sample_id_column: str, + column_names: str, + name: str, + categories: Dict = None, + cls_categories_map: Dict = None, + transforms=None, + ): + """ + """ + self.wide_df = wide_df + self._name = name + self.sample_id_column = sample_id_column + self.column_names = column_names + self.categories = categories + self.prep_df() + self.transforms = transforms or [] + self.cls_categories_map = cls_categories_map + + def prep_df(self): + self.wide_df.index = self.wide_df[self.sample_id_column] + self.wide_df = self.wide_df.drop_duplicates() + + def get_loading_options(self, sample_id): + row = self.wide_df.loc[sample_id] + + # a loading option is a dictionary of options to use at loading time + # we use DATE_OPTION_KEY to make the date selection utilities work + loading_options = [{VIEW_OPTION_KEY: row}] + + # it's get_loading_options, not get loading_option, so we return a list + return loading_options + + def get_raw_data(self, sample_id, loading_option=None): + try: + if sample_id.shape[0] > 1: + sample_id = sample_id[0] + except AttributeError: + pass + try: + sample_id = sample_id.decode('UTF-8') + except (UnicodeDecodeError, AttributeError): + pass + row = self.wide_df.loc[sample_id] + data = row[self.column_names].values + label_noise = np.zeros(len(self.column_names)) + for transform in self.transforms: + label_noise += transform() + if self.categories: + output_data = np.zeros(len(self.categories), dtype=np.float32) + output_data[self.categories[data[0]]['index']] = 1.0 + return output_data + # ---------- Adaptation for regression + classification ---------- # + if self.cls_categories_map: + # If training include classification tasks: + data = [] + reg_data = row[self.column_names].drop(self.cls_categories_map['cls_output_order']).values + if len(reg_data) > 0: + data.append(np.squeeze(np.array(reg_data, dtype=np.float32))) + + for k in self.cls_categories_map['cls_output_order']: + # Changing values to class labels: + row_cls_lbl = self.cls_categories_map[k][row[k]] + # Changing class indices to one hot vectors + cls_one_hot = tf.keras.utils.to_categorical(row_cls_lbl, + num_classes=len(self.cls_categories_map[k])) + data.append(cls_one_hot) + + if len(data) == 1: + data = data[0] + + return data + # ---------------------------------------------------------------- # + return np.squeeze(np.array(data, dtype=np.float32)) + + @property + def name(self): + # if we have multiple wide file DataDescriptions at the same time, + # this will allow us to differentiate between them + return self._name diff --git a/model_zoo/DROID/echo_supervised_inference_recipe.py b/model_zoo/DROID/echo_supervised_inference_recipe.py index c8fb83d1f..5bcd20adf 100644 --- a/model_zoo/DROID/echo_supervised_inference_recipe.py +++ b/model_zoo/DROID/echo_supervised_inference_recipe.py @@ -2,6 +2,7 @@ import json import logging import os +import sys import numpy as np import pandas as pd @@ -9,11 +10,12 @@ from data_descriptions.echo import LmdbEchoStudyVideoDataDescription from echo_defines import category_dictionaries -from model_descriptions.echo import DDGenerator, create_movinet_classifier, create_regressor +from model_descriptions.echo import DDGenerator, create_movinet_classifier, create_regressor, create_regressor_classifier logging.basicConfig(level=logging.INFO) tf.get_logger().setLevel(logging.ERROR) +SAVE_ONEHOT_DF_FOR_EACH_CLASS = True def main( n_input_frames, @@ -30,12 +32,52 @@ def main( batch_size, skip_modulo, lmdb_folder, - pretrained_ckpt_dir, - movinet_ckpt_dir, + pretrained_chkp_dir, + movinet_chkp_dir, output_dir, extract_embeddings, start_beat, ): + # Loading information on saved model: + model_param_path = os.path.join(os.path.split(os.path.dirname(pretrained_chkp_dir))[0], 'model_params.json') + with open(model_param_path, 'r') as json_file: + model_params = json.load(json_file) + + output_labels = model_params['output_labels'] if not output_labels else output_labels + selected_views = model_params['selected_views'] if not selected_views else selected_views + selected_doppler = model_params['selected_doppler'] if not selected_doppler else selected_doppler + selected_quality = model_params['selected_quality'] if not selected_quality else selected_quality + selected_canonical = model_params['selected_canonical'] if not selected_canonical else selected_canonical + logging.info(f'Loaded model with output labels: {output_labels}, views: {selected_views}, doppler: {selected_doppler}, quality: {selected_quality}, canonical: {selected_canonical}') + + # ---------- Adaptation for regression + classification ---------- # + if ('output_labels_types' in model_params.keys()) and ('c' in model_params['output_labels_types'].lower()): + cls_lbl_map_path = os.path.join(os.path.split(os.path.dirname(pretrained_chkp_dir))[0], + 'classification_class_label_mapping_per_output.json') + with open(cls_lbl_map_path, 'r') as json_file: + cls_category_map_dicts = json.load(json_file) + cls_category_len_dict = {} + for c_lbl in cls_category_map_dicts['cls_output_order']: + cls_category_len_dict[c_lbl] = len(cls_category_map_dicts[c_lbl]) + # Reordering output labels to fit the regression-classification output order during training (assuming correct + # output_labels that include all saved classification output names - if not, the classification output names + # are added next anyway): + output_labels = ([i for i in output_labels if i not in cls_category_map_dicts['cls_output_order']] + + cls_category_map_dicts['cls_output_order']) + logging.info(f'Loaded model contains classification heads. Updated output_label_order: {output_labels}, with classification heads for: {cls_category_map_dicts["cls_output_order"]}') + output_reg_len = len(output_labels) - len(cls_category_map_dicts['cls_output_order']) + add_separate_dense_reg = cls_category_map_dicts['add_separate_dense_reg'] + add_separate_dense_cls = cls_category_map_dicts['add_separate_dense_cls'] + else: + logging.info(f'Loaded model contains only regression variables.') + output_reg_len = len(output_labels) + cls_category_len_dict = {} + add_separate_dense_reg = model_params[ + 'add_separate_dense_reg'] if 'add_separate_dense_reg' in model_params.keys() else False + add_separate_dense_cls = model_params[ + 'add_separate_dense_cls'] if 'add_separate_dense_cls' in model_params.keys() else False + # ---------------------------------------------------------------- # + # Hide devices based on split physical_devices = tf.config.list_physical_devices('GPU') tf.config.set_visible_devices([physical_devices[split_idx % 4]], 'GPU') @@ -77,7 +119,18 @@ def main( else: patient_inference = splits['patient_test'] + # Testing a random subset of IDs (for speed) to see if there are matching ids in working_ids and chosen split + random_ids_for_test = np.random.permutation(len(working_ids))[:min(len(working_ids), 500)] + working_ids_subset_test = [working_ids[i] for i in random_ids_for_test] + inference_ids_match_test = [t for t in working_ids_subset_test if int(t.split('_')[0]) in patient_inference] + if len(inference_ids_match_test) == 0: + logging.warning( + f'A random test of indices showed no match between {wide_file} indices and {splits_file} indices. It is possible that there are still matches, but please verify file names. This process might take a long time to break if there are no matches, consider forcing it to stop.') + inference_ids = sorted([t for t in working_ids if int(t.split('_')[0]) in patient_inference]) + if len(inference_ids) == 0: + logging.error(f'No matches found between {wide_file} indices and the {splits_file} indices!') + sys.exit() INPUT_DD = LmdbEchoStudyVideoDataDescription( lmdb_folder, @@ -85,7 +138,7 @@ def main( [], n_input_frames, skip_modulo, - start_beat=start_beat, + start_frame=start_beat ) inference_ids_split = np.array_split(inference_ids, n_splits)[split_idx] @@ -101,41 +154,62 @@ def main( output_signature=( tf.TensorSpec(shape=(None, n_input_frames, 224, 224, 3), dtype=tf.float32), ), - args=(sample_ids,), + args=(sample_ids,) ), - ) + cycle_length = 2, + num_parallel_calls = 2 + ).prefetch(8) model, backbone = create_movinet_classifier( n_input_frames, batch_size, num_classes=600, - checkpoint_dir=movinet_ckpt_dir, + checkpoint_dir=movinet_chkp_dir, ) backbone_output = backbone.layers[-1].output[0] flatten = tf.keras.layers.Flatten()(backbone_output) encoder = tf.keras.Model(inputs=[backbone.input], outputs=[flatten]) - model_plus_head = create_regressor( - encoder, - input_shape=(n_input_frames, 224, 224, 3), - n_output_features=len(output_labels), - ) - model_plus_head.load_weights(pretrained_ckpt_dir) + + # ---------- Adaptation for regression + classification ---------- # + # Organize regressor/classifier inputs: + func_args = {'input_shape': (n_input_frames, 224, 224, 3), + 'n_output_features': output_reg_len, + 'categories': cls_category_len_dict, + 'category_order': cls_category_map_dicts['cls_output_order'] if cls_category_len_dict else None, + 'add_dense': {'regressor': add_separate_dense_reg, 'classifier': add_separate_dense_cls}} + + model_plus_head = create_regressor_classifier(encoder, **func_args) + # ---------------------------------------------------------------- # + model_plus_head.load_weights(pretrained_chkp_dir) vois = '_'.join(selected_views) ufm = 'conv7' if extract_embeddings: - output_folder = os.path.join( - output_dir, - f'inference_embeddings_{vois}_{ufm}_{lmdb_folder.split("/")[-1]}_{splits_file}_{start_beat}', - ) + output_folder = os.path.join(output_dir, + f'inference_embeddings_{vois}_{ufm}_{lmdb_folder.split("/")[-1]}_{splits_file.split("/")[-1]}_{start_beat}') else: - output_folder = os.path.join( - output_dir, - f'inference_{vois}_{ufm}_{lmdb_folder.split("/")[-1]}_{splits_file}_{start_beat}', - ) + output_folder = os.path.join(output_dir, + f'inference_{vois}_{ufm}_{lmdb_folder.split("/")[-1]}_{splits_file.split("/")[-1]}_{start_beat}') os.makedirs(output_folder, exist_ok=True) + wide_df_selected.to_csv(f'{output_folder}/wide_df_selected.csv') + + def save_model_pred_as_df(pred, fname_suffix='', pred_col_names=[]): + save_df = pd.DataFrame() + save_df['sample_id'] = inference_ids_split + if len(pred_col_names) == pred.shape[1]: + use_pred_col_names = True + else: + use_pred_col_names = False + for i_p in range(pred.shape[1]): + if use_pred_col_names: + save_df[pred_col_names[i_p]] = pred[:, i_p] + else: + save_df[f'prediction_{i_p}'] = pred[:, i_p] + + save_df.to_parquet(os.path.join(output_folder, f'prediction_{split_idx}' + fname_suffix + '.pq')) + if extract_embeddings: embeddings = encoder.predict(io_inference_ds, steps=n_inference_steps, verbose=1) df = pd.DataFrame() @@ -146,45 +220,66 @@ def main( df.to_parquet(os.path.join(output_folder, f'prediction_{split_idx}.pq')) else: predictions = model_plus_head.predict(io_inference_ds, steps=n_inference_steps, verbose=1) - df = pd.DataFrame() - df['sample_id'] = inference_ids_split - for j, _ in enumerate(range(predictions.shape[1])): - df[f'prediction_{j}'] = predictions[:, j] - - df.to_parquet(os.path.join(output_folder, f'prediction_{split_idx}.pq')) + # predictions is a list of length = number of outputs in list, where all regression variables are in a single + # list element and each classification task has a separate list element. + # Each list element is of size: + # len(inference_ids_split) X number of output variables (total number of regression vars or number of classes) + if len(cls_category_len_dict) > 0: + # Case: regression + classification or classification only + # Currently saving actual class predictions jointly with the regression variables if exist + # and for each class one-hot predictions are saved in a separate pq file (flag dependent) + if output_reg_len > 0: + reg_pred = predictions[0] + cls_pred = predictions[1:] + else: + reg_pred = np.zeros((0, 0)) + cls_pred = predictions + if len(cls_category_len_dict) == 1: + cls_pred = [predictions] + df = pd.DataFrame() + df['sample_id'] = inference_ids_split + for i_p in range(reg_pred.shape[1]): + df[f'prediction_{i_p}'] = reg_pred[:, i_p] + for i in range(len(cls_pred)): + curr_cls_name = cls_category_map_dicts['cls_output_order'][i] + if SAVE_ONEHOT_DF_FOR_EACH_CLASS: + save_model_pred_as_df(cls_pred[i], fname_suffix='_one_hot_' + curr_cls_name) + cls_pred_vals_curr = cls_pred[i].argmax(axis=1) + cls_map_inv = {v: k for k, v in zip(cls_category_map_dicts[curr_cls_name].keys(), + cls_category_map_dicts[curr_cls_name].values())} + df[cls_category_map_dicts['cls_output_order'][i]] = cls_pred_vals_curr + df.replace({curr_cls_name: cls_map_inv}, inplace=True) + df.to_parquet(os.path.join(output_folder, f'prediction_{split_idx}.pq')) + else: + # Case: regression only + save_model_pred_as_df(predictions) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--n_input_frames', type=int, default=50) - parser.add_argument('-o', '--output_labels', action='append') + parser.add_argument('-o', '--output_labels', action='append', required=False) parser.add_argument('--wide_file', type=str) parser.add_argument('--splits_file') - parser.add_argument( - '-v', '--selected_views', action='append', choices=category_dictionaries['view'].keys(), - required=True, - ) - parser.add_argument( - '-d', '--selected_doppler', action='append', choices=category_dictionaries['doppler'].keys(), - required=True, - ) - parser.add_argument( - '-q', '--selected_quality', action='append', choices=category_dictionaries['quality'].keys(), - required=True, - ) - parser.add_argument( - '-c', '--selected_canonical', action='append', - choices=category_dictionaries['canonical'].keys(), required=True, - ) + + parser.add_argument('-v', '--selected_views', action='append', choices=category_dictionaries['view'].keys(), + required=False) + parser.add_argument('-d', '--selected_doppler', action='append', choices=category_dictionaries['doppler'].keys(), + required=False) + parser.add_argument('-q', '--selected_quality', action='append', choices=category_dictionaries['quality'].keys(), + required=False) + parser.add_argument('-c', '--selected_canonical', action='append', + choices=category_dictionaries['canonical'].keys(), required=False) + parser.add_argument('-n', '--n_train_patients', default='all') parser.add_argument('--split_idx', type=int, choices=range(4)) parser.add_argument('--n_splits', type=int, default=4) parser.add_argument('--batch_size', default=16, type=int) parser.add_argument('--skip_modulo', type=int, default=1) - parser.add_argument('--lmdb_folder') - parser.add_argument('--pretrained_ckpt_dir', type=str) - parser.add_argument('--movinet_ckpt_dir', type=str) + parser.add_argument('--lmdb_folder', type=str) + parser.add_argument('--pretrained_chkp_dir', type=str) + parser.add_argument('--movinet_chkp_dir', type=str) parser.add_argument('--output_dir', type=str) parser.add_argument('--extract_embeddings', action='store_true') parser.add_argument('--start_beat', type=int, default=0) @@ -211,8 +306,8 @@ def main( batch_size=args.batch_size, skip_modulo=args.skip_modulo, lmdb_folder=args.lmdb_folder, - pretrained_ckpt_dir=args.pretrained_ckpt_dir, - movinet_ckpt_dir=args.movinet_ckpt_dir, + pretrained_chkp_dir=args.pretrained_chkp_dir, + movinet_chkp_dir=args.movinet_chkp_dir, output_dir=args.output_dir, extract_embeddings=args.extract_embeddings, start_beat=args.start_beat, diff --git a/model_zoo/DROID/echo_supervised_training_recipe.py b/model_zoo/DROID/echo_supervised_training_recipe.py new file mode 100644 index 000000000..b0d3b646a --- /dev/null +++ b/model_zoo/DROID/echo_supervised_training_recipe.py @@ -0,0 +1,511 @@ +import argparse +import datetime +import json +import logging +import os + +import numpy as np +import pandas as pd +import tensorflow as tf + +from data_descriptions.echo import LmdbEchoStudyVideoDataDescription +from data_descriptions.wide_file import EcholabDataDescription +from echo_defines import category_dictionaries +from model_descriptions.echo import create_movinet_classifier, create_regressor, create_regressor_classifier, train_model, DDGenerator + +logging.basicConfig(level=logging.INFO) +tf.get_logger().setLevel(logging.ERROR) + +USER = os.getenv('USER') + +def main( + n_input_frames, + output_labels, + wide_file, + splits_file, + selected_views, + selected_doppler, + selected_quality, + selected_canonical, + n_train_patients, + batch_size, + epochs, + skip_modulo, + lmdb_folder, + fine_tune, + model_params, + pretrained_chkp_dir, + movinet_chkp_dir, + output_dir, + adam, + scale_outputs, + es_patience, + es_loss2monitor, + output_labels_types, + add_separate_dense_reg, + add_separate_dense_cls, + loss_weights, + randomize_start_frame +): + lmdb_vois = '_'.join(selected_views) + olabels = '_'.join(output_labels) + + # ---------- Adaptation for regression + classification ---------- # + def process_labels_types(o_lbls, o_lbls_types, var_type='output_labels'): + # ---- Processing input values and handling incorrect inputs ----- # + # Specify parameters for regression and classification heads: + if len(o_lbls_types) == len(o_lbls): + # Number of task types labels (regression/classification) is equal to the number of output variables + unq_lbl_types = set([ch for ch in o_lbls_types.lower()]) + elif len(o_lbls_types) == 1: + # Only one task type label (regression/classification) is given for all output variables + unq_lbl_types = o_lbls_types.lower() + else: + # A wrong number of task type labels was given (empty or different from 1 or 'len(output_labels)') + raise TypeError( + f"The lengths of '{var_type}' and '{var_type}_types' do not match (should be equal or 'len({var_type}_types)=1').") + if not set(unq_lbl_types) <= {'r', 'c'}: + # Wrong task type labels were given (letters other than 'r' for regression and 'c' for classification) + raise TypeError(f"'{var_type}_types' contains unrecognized letters (should include 'r' and/or 'c' only).") + + # Grouping regression tasks together and classification tasks together + # and computing output lengths (number of regression variables and number of classification tasks) + if len(unq_lbl_types) > 1: + # Both regression and classification tasks were specified + output_label_types_int = [0 if (ch == 'r') else 1 for ch in o_lbls_types.lower()] + o_reg_len = len(output_label_types_int) - sum(output_label_types_int) + cls_o_names = [o_lbls[i_c] for i_c, c in enumerate(output_label_types_int) if c == 1] + output_order = np.argsort(output_label_types_int) + o_lbls = [o_lbls[i] for i in output_order] + if var_type == 'output_labels': + logging.info('Training with regression and classification heads') + else: + logging.info('Loaded model has regression and classification heads') + logging.info(f'Updated {var_type} order: {o_lbls}') + elif 'r' in unq_lbl_types: + # Only one task type specified - regression + o_reg_len = len(o_lbls) + cls_o_names = [] + if var_type == 'output_labels': + logging.info('Training only with a regression head') + else: + logging.info('Loaded model has only a regression head') + else: + # Only one task type - classification + o_reg_len = 0 + cls_o_names = o_lbls + if var_type == 'output_labels': + logging.info('Training only with a classification head') + else: + logging.info('Loaded model has only a classification head') + + return o_lbls, o_reg_len, cls_o_names + + def process_class_categories(df, cls_o_names, var_type='output_labels'): + # Creating dictionaries specifying number of classes for each output_label name + # and mapping between wide_file values to class labels: + clsc_map_dicts = {} + clsc_len_dict = {} + for c_lbl in cls_o_names: + all_cls_vals = np.sort(df[c_lbl].drop_duplicates().tolist()) + val2clsind_map_dict = {val: c_ind for val, c_ind in zip(all_cls_vals, range(len(all_cls_vals)))} + clsc_map_dicts[c_lbl] = val2clsind_map_dict + clsc_len_dict[c_lbl] = len(df[c_lbl].drop_duplicates()) + if clsc_len_dict[c_lbl] < 2: + logging.error( + f'Error: Output variable {c_lbl} has a constant value in the train and validation sets - might cause errors in the classifier. Error raised when processing {var_type} related classification variables.') + clsc_map_dicts['cls_output_order'] = cls_o_names + return clsc_map_dicts, clsc_len_dict + + # ---------------------------------------------------------------- # + output_labels, output_reg_len, cls_output_names = process_labels_types(output_labels, output_labels_types, + var_type='output_labels') + # ---------------------------------------------------------------- # + wide_df = pd.read_parquet(wide_file) + + # Select only view(s) of interest + selected_views_idx = [category_dictionaries['view'][v] for v in selected_views] + selected_doppler_idx = [category_dictionaries['doppler'][v] for v in selected_doppler] + selected_quality_idx = [category_dictionaries['quality'][v] for v in selected_quality] + selected_canonical_idx = [category_dictionaries['canonical'][v] for v in selected_canonical] + wide_df_selected = wide_df[ + (wide_df['view_prediction'].isin(selected_views_idx)) & + (wide_df['doppler_prediction'].isin(selected_doppler_idx)) & + (wide_df['quality_prediction'].isin(selected_quality_idx)) & + (wide_df['canonical_prediction'].isin(selected_canonical_idx)) + ] + + # Drop entries without echolab measurements and get all sample_ids + wide_df_selected = wide_df_selected.dropna(subset=output_labels) + working_ids = wide_df_selected['sample_id'].values.tolist() + + # Read splits and partition dataset + with open(splits_file, 'r') as json_file: + splits = json.load(json_file) + + patient_train = splits['patient_train'] + patient_valid = splits['patient_valid'] + + if n_train_patients != 'all': + patient_train = patient_train[:int(int(n_train_patients) * 0.9)] + patient_valid = patient_valid[:int(int(n_train_patients) * 0.1)] + + train_ids = [t for t in working_ids if int(t.split('_')[0]) in patient_train] + valid_ids = [t for t in working_ids if int(t.split('_')[0]) in patient_valid] + print(f"train_ids: {len(train_ids)}") + print(f"valid_ids: {len(valid_ids)}") + + # If scale_outputs, normalize by summary stats of training set + if scale_outputs: + wide_df_train = wide_df_selected[wide_df_selected['sample_id'].isin(train_ids)] + output_labels_to_scale = np.array([l for l in output_labels if l not in cls_output_names]) + output_labels_to_scale = list(output_labels_to_scale[ + np.logical_and(wide_df_train[output_labels_to_scale].dtypes != 'object', + wide_df_train[output_labels_to_scale].dtypes != 'string')]) + logging.info( + f'Not scaling classification columns and columns containing strings/objects, unscaled columns: {[l for l in output_labels if l not in output_labels_to_scale]}') + mean_outputs = np.mean(wide_df_train[output_labels_to_scale].values, axis=0) + std_outputs = np.std(wide_df_train[output_labels_to_scale].values, axis=0) + wide_df_selected.loc[:, output_labels_to_scale] = (wide_df_selected[output_labels_to_scale].values - mean_outputs) / std_outputs + logging.info(mean_outputs) + logging.info(std_outputs) + + valid_ids = list(set(valid_ids).intersection(set(working_ids))) + print(f"valid_ids: {len(valid_ids)}") + + # ---------- Adaptation for regression + classification ---------- # + cls_category_map_dicts, cls_category_len_dict = process_class_categories(wide_df_selected, cls_output_names, + var_type='output_labels') + + if pretrained_chkp_dir: + cls_lbl_map_path = os.path.join(os.path.split(os.path.dirname(pretrained_chkp_dir))[0], + 'classification_class_label_mapping_per_output.json') + define_new_heads = False + if os.path.isfile(cls_lbl_map_path): + with open(cls_lbl_map_path, 'r') as json_file: + cls_category_signature_map_dicts = json.load(json_file) + similar_cls = [c for c in cls_output_names if c in cls_category_signature_map_dicts.keys()] + for c in similar_cls: + if (len(cls_category_signature_map_dicts[c]) > len(cls_category_map_dicts[c])) and set( + cls_category_map_dicts[c].keys()).issubset(set(cls_category_signature_map_dicts[c].keys())): + cls_category_map_dicts[c] = cls_category_signature_map_dicts[c] + cls_category_len_dict[c] = len(cls_category_map_dicts[c]) + logging.info(f'Using mapping from pretrained_chkp_dir for classification task on {c}') + elif not set( + cls_category_map_dicts[c].keys()).issubset(set(cls_category_signature_map_dicts[c].keys())): + define_new_heads = True + # ---------------------------------------------------------------- # + + INPUT_DD_TRAIN = LmdbEchoStudyVideoDataDescription( + lmdb_folder, + 'image', + [], + n_input_frames, + skip_modulo, + randomize_start_frame=randomize_start_frame + ) + + INPUT_DD_VALID = LmdbEchoStudyVideoDataDescription( + lmdb_folder, + 'image', + [], + n_input_frames, + skip_modulo, + randomize_start_frame = False + ) + + OUTPUT_DD = EcholabDataDescription( + wide_df=wide_df_selected[['sample_id'] + output_labels].drop_duplicates(), + sample_id_column='sample_id', + column_names=output_labels, + name='echolab', + # ---------- Adaptation for regression + classification ---------- # + cls_categories_map=cls_category_map_dicts if cls_output_names else None + # ---------------------------------------------------------------- # + ) + + body_train_ids = tf.data.Dataset.from_tensor_slices(working_ids).shuffle(len(working_ids), + reshuffle_each_iteration=True).batch( + batch_size, drop_remainder=True) + print(f"body_train_ids: {len(body_train_ids)}") + + body_valid_ids = tf.data.Dataset.from_tensor_slices(valid_ids).shuffle(len(valid_ids), + reshuffle_each_iteration=True).batch( + batch_size, drop_remainder=True) + print(f"body_valid_ids: {len(body_valid_ids)}") + + n_train_steps = len(working_ids) // batch_size + n_valid_steps = len(valid_ids) // batch_size + print(f"n_train_steps: {n_train_steps}") + print(f"n_valid_steps: {n_valid_steps}") + + # ---------- Adaptation for regression + classification ---------- # + # Adapting tensor output sizes for classification heads + num_classes = [output_reg_len] + [cls_category_len_dict[c] for c in + cls_category_map_dicts['cls_output_order']] if output_reg_len > 0 else [ + cls_category_len_dict[c] for c in cls_category_map_dicts['cls_output_order']] + if len(num_classes) > 1: + output_signatures = ( + tf.TensorSpec(shape=(batch_size, n_input_frames, 224, 224, 3), dtype=tf.float32), + tuple([tf.TensorSpec(shape=(batch_size, n_c), dtype=tf.float32) + for n_c in num_classes]) + ) + else: + output_signatures = ( + tf.TensorSpec(shape=(batch_size, n_input_frames, 224, 224, 3), dtype=tf.float32), + tf.TensorSpec(shape=(batch_size, num_classes[0]) if num_classes[0] > 1 else (batch_size,), dtype=tf.float32) + ) + # ---------------------------------------------------------------- # + + io_train_ds = body_train_ids.interleave( + lambda sample_ids: tf.data.Dataset.from_generator( + DDGenerator( + INPUT_DD_TRAIN, + OUTPUT_DD + ), + output_signature=output_signatures, + args=(sample_ids,) + ), + num_parallel_calls=tf.data.AUTOTUNE + ).repeat(epochs).prefetch(tf.data.AUTOTUNE) + + io_valid_ds = body_valid_ids.interleave( + lambda sample_ids: tf.data.Dataset.from_generator( + DDGenerator( + INPUT_DD_VALID, + OUTPUT_DD + ), + output_signature=output_signatures, + args=(sample_ids,) + ), + num_parallel_calls=tf.data.AUTOTUNE + ).repeat(epochs).prefetch(tf.data.AUTOTUNE) + + mirrored_strategy = tf.distribute.MirroredStrategy() + with mirrored_strategy.scope(): + _, backbone = create_movinet_classifier( + n_input_frames, + batch_size, + num_classes=600, + checkpoint_dir=movinet_chkp_dir, + freeze_backbone=fine_tune + ) + backbone_output = backbone.layers[-1].output[0] + flatten = tf.keras.layers.Flatten()(backbone_output) + encoder = tf.keras.Model(inputs=[backbone.input], outputs=[flatten]) + + # ---------- Adaptation for regression + classification ---------- # + # Organize regressor/classifier inputs: + func_args = {'input_shape': (n_input_frames, 224, 224, 3), 'trainable': not fine_tune, + 'n_output_features': output_reg_len, + 'categories': cls_category_len_dict, + 'category_order': cls_category_map_dicts['cls_output_order'] if cls_category_len_dict else None, + 'add_dense': {'regressor': add_separate_dense_reg, 'classifier': add_separate_dense_cls}} + + model = create_regressor_classifier(encoder, **func_args) + # ---------------------------------------------------------------- # + + if pretrained_chkp_dir: + signature_model_param_path = os.path.join(os.path.split(os.path.dirname(pretrained_chkp_dir))[0], + 'model_params.json') + f = open(signature_model_param_path) + signature_model_params = json.load(f) + sig_add_separate_dense_reg = signature_model_params[ + 'add_separate_dense_reg'] if 'add_separate_dense_reg' in signature_model_params.keys() else False + sig_add_separate_dense_cls = signature_model_params[ + 'add_separate_dense_cls'] if 'add_separate_dense_cls' in signature_model_params.keys() else False + output_signature_labels_types = signature_model_params[ + 'output_labels_types'] if 'output_labels_types' in signature_model_params.keys() else 'r' + output_signature_labels = signature_model_params['output_labels'] + logging.info(f'output_labels of loaded model: {output_signature_labels}') + + output_signature_labels, output_signature_reg_len, cls_output_signature_names = process_labels_types( + output_signature_labels, output_signature_labels_types, var_type='output_signature_labels') + + if 'c' in output_signature_labels_types.lower(): + cls_category_signature_len_dict = {} + for c_lbl in cls_category_signature_map_dicts['cls_output_order']: + cls_category_signature_len_dict[c_lbl] = len(cls_category_signature_map_dicts[c_lbl]) + else: + cls_category_signature_len_dict = {} + + model = create_regressor_classifier( + encoder, + input_shape=(n_input_frames, 224, 224, 3), + trainable=not fine_tune, + n_output_features=output_signature_reg_len, + categories=cls_category_signature_len_dict, + category_order=cls_category_signature_map_dicts[ + 'cls_output_order'] if cls_category_signature_len_dict else None, + add_dense={'regressor': sig_add_separate_dense_reg, 'classifier': sig_add_separate_dense_cls} + ) + model.load_weights(pretrained_chkp_dir) + + if (output_labels != output_signature_labels) or (output_signature_reg_len != output_reg_len) or ( + cls_output_signature_names != cls_output_names) or define_new_heads: + logging.info('Redefining regression and/or classification heads due to differences in outputs used') + # ---------- Adaptation for regression + classification ---------- # + model = create_regressor_classifier(encoder, **func_args) + # ---------------------------------------------------------------- # + + if adam: + optimizer = tf.keras.optimizers.Adam(learning_rate=adam) + else: + initial_learning_rate = 0.00005 * batch_size + learning_rate = tf.keras.optimizers.schedules.CosineDecay( + initial_learning_rate, + decay_steps=n_train_steps * epochs, + ) + + optimizer = tf.keras.optimizers.RMSprop( + learning_rate, + rho=0.9, + momentum=0.9, + epsilon=1.0, + clipnorm=1.0 + ) + + classification_metrics = [ + tf.keras.metrics.CategoricalAccuracy(), + tf.keras.metrics.AUC(name='AUROC'), + tf.keras.metrics.AUC(curve="PR", name='AUPRC') + ] + + loss = {'cls_' + k: tf.keras.losses.CategoricalCrossentropy() for k in cls_category_len_dict.keys()} + metrics = {'cls_' + k: classification_metrics for k in cls_category_len_dict.keys()} + if output_reg_len > 0: + loss['echolab'] = tf.keras.losses.MeanSquaredError() + metrics['echolab'] = tf.keras.metrics.MeanAbsoluteError() + + model.compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + loss_weights=loss_weights if loss_weights else None + ) + + options = tf.data.Options() + options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA + + io_train_ds = io_train_ds.with_options(options) + io_valid_ds = io_valid_ds.with_options(options) + + fine_tune_string = f'_fine_tune' if fine_tune else '' + output_folder = os.path.join(output_dir, + f'{datetime.datetime.now().strftime("%Y%m%d%H%M")}_{lmdb_vois}_{olabels}_{n_input_frames}frames{fine_tune_string}_{n_train_patients}') + + os.makedirs(output_folder, exist_ok=True) + with open(f'{output_folder}/model_params.json', 'w') as json_file: + json.dump(model_params, json_file) + + wide_df_selected.to_parquet(f'{output_folder}/wide_df_selected.pq') + + # ---------- Adaptation for regression + classification ---------- # + # Record output labels new order (after possible reordering of regression and classification): + with open(f'{output_folder}/output_labels_final_ordering.json', 'w') as json_file: + json.dump(output_labels, json_file) + # Record output mapping for classification tasks (dictionary that contains column names as well): + if cls_output_names: + cls_category_map_dicts['add_separate_dense_cls'] = add_separate_dense_cls + cls_category_map_dicts['add_separate_dense_reg'] = add_separate_dense_reg + with open(f'{output_folder}/classification_class_label_mapping_per_output.json', 'w') as json_file: + json.dump(cls_category_map_dicts, json_file) + # ---------------------------------------------------------------- # + + es_flags = {'es_patience': es_patience, 'es_loss2monitor': es_loss2monitor} + + logging.info(model.summary()) + trained_model = train_model( + model, + io_train_ds, + io_valid_ds, + epochs, + n_train_steps, + n_valid_steps, + output_folder, + es_flags + ) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--n_input_frames', type=int, default=32) + parser.add_argument('-o', '--output_labels', action='append', type=str) + parser.add_argument('--wide_file', type=str) + parser.add_argument('-v', '--selected_views', action='append', choices=category_dictionaries['view'].keys(), + required=True) + parser.add_argument('-d', '--selected_doppler', action='append', choices=category_dictionaries['doppler'].keys(), + required=True) + parser.add_argument('-q', '--selected_quality', action='append', choices=category_dictionaries['quality'].keys(), + required=True) + parser.add_argument('-c', '--selected_canonical', action='append', + choices=category_dictionaries['canonical'].keys(), required=True) + parser.add_argument('-n', '--n_train_patients', type=str, required=True) + parser.add_argument('--batch_size', default=16, type=int) + parser.add_argument('--epochs', default=50, type=int) + parser.add_argument('--skip_modulo', type=int, default=2) + parser.add_argument('--lmdb_folder', type=str) + parser.add_argument('--fine_tune', action='store_true') + parser.add_argument('--pretrained_chkp_dir', type=str) + parser.add_argument('--movinet_chkp_dir', type=str) + parser.add_argument('--output_dir', type=str) + parser.add_argument('--splits_file') + parser.add_argument('--adam', default=None, type=float) + parser.add_argument('--scale_outputs', action='store_true') + parser.add_argument('--es_patience', default=3, type=int, + help='Number of epochs with no change before early stopping.') + parser.add_argument('--es_loss2monitor', default='val_loss', type=str, + help='Loss on which the early stopping will be based, options are "val_loss", "val_echolab_loss" for regression loss, or "val_cls_COLUMN-NAME_loss" for classification loss.') + parser.add_argument('--randomize_start_frame', action='store_true') + # ---------- Adaptation for regression + classification ---------- # + parser.add_argument('--output_labels_types', default='r', type=str, + help='A string indicating task types: r for regression, c for classification. Should be of length 1 or the same length of the specified output_labels variable, e.g. "r" or "rrcr".') + parser.add_argument('--add_separate_dense_reg', action='store_true', + help='Adds an additional dense layer trained separately for the regression head') + parser.add_argument('--add_separate_dense_cls', action='store_true', + help='Adds an additional dense layer trained separately for the classification head') + parser.add_argument('-lw', '--loss_weights', action='append', type=float, + help='Loss weights, number of weights to specify should be: No. classification tasks (columns) + 1 if there are regression variables. For example, for output_labels_types="rrcc", the length should be 2+1=3.') + # ---------------------------------------------------------------- # + args = parser.parse_args() + + root = logging.getLogger() + root.setLevel(logging.INFO) + + model_params_dict = {} + for arg, value in sorted(vars(args).items()): + logging.info(f"Argument {arg}: {value}") + model_params_dict[arg] = value + + main( + n_input_frames=args.n_input_frames, + output_labels=args.output_labels, + wide_file=args.wide_file, + splits_file=args.splits_file, + selected_views=args.selected_views, + selected_doppler=args.selected_doppler, + selected_quality=args.selected_quality, + selected_canonical=args.selected_canonical, + n_train_patients=args.n_train_patients, + batch_size=args.batch_size, + epochs=args.epochs, + skip_modulo=args.skip_modulo, + lmdb_folder=args.lmdb_folder, + fine_tune=args.fine_tune, + model_params=model_params_dict, + pretrained_chkp_dir=args.pretrained_chkp_dir, + movinet_chkp_dir=args.movinet_chkp_dir, + output_dir=args.output_dir, + adam=args.adam, + scale_outputs=args.scale_outputs, + es_patience=args.es_patience, + es_loss2monitor=args.es_loss2monitor, + # ---------- Adaptation for regression + classification ---------- # + output_labels_types=args.output_labels_types, + add_separate_dense_reg=args.add_separate_dense_reg, + add_separate_dense_cls=args.add_separate_dense_cls, + loss_weights=args.loss_weights, + # ---------------------------------------------------------------- # + randomize_start_frame=args.randomize_start_frame + ) diff --git a/model_zoo/DROID/model_descriptions/echo.py b/model_zoo/DROID/model_descriptions/echo.py index 6759871c9..1993130c0 100644 --- a/model_zoo/DROID/model_descriptions/echo.py +++ b/model_zoo/DROID/model_descriptions/echo.py @@ -30,6 +30,11 @@ def __call__(self, sample_ids): if self.fill_empty: ret_output.append(np.NaN) + if self.output_dd is not None and isinstance(ret_output[0], list): + ret_output = [np.vstack([ret_output[i][j] for i in range(len(sample_ids))]) + for j in range(len(ret_output[0]))] + ret_output = tuple(ret_output) + if self.output_dd is None and self.fill_empty == False: yielded = (ret_input,) else: @@ -80,3 +85,81 @@ def create_regressor(encoder, trainable=True, input_shape=(224, 224, 3), n_outpu model = tf.keras.Model(inputs=inputs, outputs=outputs, name="regressor") return model + + +# ---------- Adaptation for regression + classification ---------- # +def create_regressor_classifier(encoder, trainable=True, input_shape=(224, 224, 3), n_output_features=0, categories={}, + category_order=None, add_dense={'regressor': False, 'classifier': False}): + for layer in encoder.layers: + layer.trainable = trainable + + inputs = tf.keras.Input(shape=input_shape, name='image') + features = encoder(inputs) + features = tf.keras.layers.Dropout(dropout_rate)(features) + features = tf.keras.layers.Dense(hidden_units, activation="relu")(features) + features = tf.keras.layers.Dropout(dropout_rate)(features) + + outputs = [] + if n_output_features > 0: + if add_dense['regressor']: + features_reg = tf.keras.layers.Dense(hidden_units, activation="relu")(features) + features_reg = tf.keras.layers.Dropout(dropout_rate)(features_reg) + outputs.append(tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features_reg)) + else: + outputs.append(tf.keras.layers.Dense(n_output_features, activation=None, name='echolab')(features)) + if len(categories.keys()) > 0: + if add_dense['classifier']: + features = tf.keras.layers.Dense(hidden_units, activation="relu")(features) + features = tf.keras.layers.Dropout(dropout_rate)(features) + for category in category_order: + # added a variable - category_order to make sure the ordering is correct + # (dictionary items ordering is not necessarily consistent) + activation = 'softmax' + n_classes = categories[category] + outputs.append(tf.keras.layers.Dense(n_classes, name='cls_'+category, activation=activation)(features)) + + model = tf.keras.Model(inputs=inputs, outputs=outputs, name="regressor_classifier") + + return model +# ---------------------------------------------------------------- # + + +def train_model( + model, + train_loader, + valid_loader, + epochs, + n_train_steps, + n_valid_steps, + output_folder, + es_flags, + class_weight=None +): + tb_callback = tf.keras.callbacks.TensorBoard(f'{output_folder}/logs', profile_batch=[160, 170]) + es_callback = tf.keras.callbacks.EarlyStopping(monitor=es_flags['es_loss2monitor'], + patience=es_flags['es_patience']) + cp_callback = tf.keras.callbacks.ModelCheckpoint( + filepath=f'{output_folder}/model/chkp', + monitor=es_flags['es_loss2monitor'], + save_best_only=True, + save_weights_only=True, + mode='min' + ) + model.fit( + train_loader, + validation_data=valid_loader, + callbacks=[tb_callback, es_callback, cp_callback], + epochs=epochs, + steps_per_epoch=n_train_steps, + validation_steps=n_valid_steps, + workers=1, + max_queue_size=1, + use_multiprocessing=False, + class_weight=class_weight + ) + + model.load_weights( + f'{output_folder}/model/chkp' + ) + + return model