Skip to content

Commit

Permalink
add gpustats callback
Browse files Browse the repository at this point in the history
  • Loading branch information
hvgazula committed Apr 1, 2024
1 parent 11347b1 commit 6e1d075
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions 1.2.0/kwyk_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# @Email: [email protected]
# @Create At: 2024-03-29 09:08:29
# @Last Modified By: Harsha
# @Last Modified At: 2024-04-01 17:44:15
# @Last Modified At: 2024-04-01 19:29:32
# @Description: This is description.

import os
Expand All @@ -23,8 +23,9 @@
import tensorflow as tf
from nobrainer.dataset import Dataset
from nobrainer.models import unet
from nobrainer.processing.segmentation import Segmentation
from nobrainer.models.bayesian_meshnet import variational_meshnet
from nobrainer.processing.segmentation import Segmentation
from nvitop.callbacks.keras import GpuStatsLogger

# tf.data.experimental.enable_debug_mode()

Expand Down Expand Up @@ -80,7 +81,6 @@ def create_filepaths(path_to_data: str, sample: bool = False) -> None:

@main_timer
def load_sample_files():

if True:
csv_path = nobrainer.utils.get_data()
filepaths = nobrainer.io.read_csv(csv_path)
Expand Down Expand Up @@ -116,7 +116,6 @@ def load_sample_tfrec(target: str = "train"):

@main_timer
def load_custom_tfrec(target: str = "train"):

if target == "train":
data_pattern = "/nese/mit/group/sig/data/kwyk/tfrecords/*train*"
data_pattern = "/om2/scratch/Fri/hgazula/kwyk_full/*train*"
Expand Down Expand Up @@ -151,6 +150,7 @@ def get_label_count():
# @main_timer
def main():
gpus = tf.config.list_physical_devices("GPU")
gpu_names = [item.name for item in gpus]
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
NUM_GPUS = len(gpus)
Expand Down Expand Up @@ -198,8 +198,10 @@ def main():
callback_backup = tf.keras.callbacks.BackupAndRestore(
backup_dir=f"output/{model_string}/backup", save_freq=save_freq
)
callback_gpustats = GpuStatsLogger(gpu_names)

callbacks = [
callback_gpustats, # gpu stats callback should be placed before tboard/csvlogger callback
callback_model_checkpoint,
callback_tensorboard,
callback_early_stopping,
Expand All @@ -220,6 +222,7 @@ def main():
dataset_validate=dataset_eval,
epochs=n_epochs,
callbacks=callbacks,
verbose=1,
)

print("Success")
Expand Down

0 comments on commit 6e1d075

Please sign in to comment.