From b214459bc9ee76b5ceb4ea754a7af9ca127de122 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Wed, 18 Oct 2023 17:49:13 +0200 Subject: [PATCH] FIX: duplicated code in TF Trainer --- src/itwinai/tensorflow/trainer.py | 146 +++++++++++++++--------------- 1 file changed, 73 insertions(+), 73 deletions(-) diff --git a/src/itwinai/tensorflow/trainer.py b/src/itwinai/tensorflow/trainer.py index ec0f4e09..3f51f000 100644 --- a/src/itwinai/tensorflow/trainer.py +++ b/src/itwinai/tensorflow/trainer.py @@ -138,76 +138,76 @@ def train(self, train_dataset, validation_dataset): return history -class TensorflowTrainer2(Trainer): - def __init__( - self, - epochs, - batch_size, - callbacks, - model_dict: Dict, - compile_conf, - strategy - ): - super().__init__() - self.strategy = strategy - self.epochs = epochs - self.batch_size = batch_size - self.callbacks = callbacks - - # Handle the parsing - model_class = import_class(model_dict["class_path"]) - parser = ArgumentParser() - parser.add_subclass_arguments(model_class, "model") - model_dict = {"model": model_dict} - - # Create distributed TF vars - if self.strategy: - with self.strategy.scope(): - self.model = parser.instantiate_classes(model_dict).model - print(self.model) - self.model.compile(**compile_conf) - # TODO: move loss, optimizer and metrics instantiation under - # here - # Ref: - # https://www.tensorflow.org/guide/distributed_training#use_tfdistributestrategy_with_keras_modelfit - else: - self.model = parser.instantiate_classes(model_dict).model - self.model.compile(**compile_conf) - - self.num_devices = ( - self.strategy.num_replicas_in_sync if self.strategy else 1) - print(f"Strategy is working with: {self.num_devices} devices") - - def train(self, train_dataset, validation_dataset): - # TODO: FIX Steps sizes in model.fit - train, test = train_dataset, validation_dataset - - # Set batch size to the dataset - train = train.batch(self.batch_size, drop_remainder=True) - test = test.batch(self.batch_size, drop_remainder=True) - - # Number of samples - n_train = train.cardinality().numpy() - n_test = test.cardinality().numpy() - - # TODO: read - # https://github.com/tensorflow/tensorflow/issues/56773#issuecomment-1188693881 - # https://www.tensorflow.org/guide/distributed_training#use_tfdistributestrategy_with_keras_modelfit - - # Distribute dataset - if self.strategy: - train = self.strategy.experimental_distribute_dataset(train) - test = self.strategy.experimental_distribute_dataset(test) - - # train the model - history = self.model.fit( - train, - validation_data=test, - steps_per_epoch=int(n_train // self.num_devices), - validation_steps=int(n_test // self.num_devices), - epochs=self.epochs, - callbacks=self.callbacks, - ) - - print("Model trained") - return history +# class TensorflowTrainer2(Trainer): +# def __init__( +# self, +# epochs, +# batch_size, +# callbacks, +# model_dict: Dict, +# compile_conf, +# strategy +# ): +# super().__init__() +# self.strategy = strategy +# self.epochs = epochs +# self.batch_size = batch_size +# self.callbacks = callbacks + +# # Handle the parsing +# model_class = import_class(model_dict["class_path"]) +# parser = ArgumentParser() +# parser.add_subclass_arguments(model_class, "model") +# model_dict = {"model": model_dict} + +# # Create distributed TF vars +# if self.strategy: +# with self.strategy.scope(): +# self.model = parser.instantiate_classes(model_dict).model +# print(self.model) +# self.model.compile(**compile_conf) +# # TODO: move loss, optimizer and metrics instantiation under +# # here +# # Ref: +# # https://www.tensorflow.org/guide/distributed_training#use_tfdistributestrategy_with_keras_modelfit +# else: +# self.model = parser.instantiate_classes(model_dict).model +# self.model.compile(**compile_conf) + +# self.num_devices = ( +# self.strategy.num_replicas_in_sync if self.strategy else 1) +# print(f"Strategy is working with: {self.num_devices} devices") + +# def train(self, train_dataset, validation_dataset): +# # TODO: FIX Steps sizes in model.fit +# train, test = train_dataset, validation_dataset + +# # Set batch size to the dataset +# train = train.batch(self.batch_size, drop_remainder=True) +# test = test.batch(self.batch_size, drop_remainder=True) + +# # Number of samples +# n_train = train.cardinality().numpy() +# n_test = test.cardinality().numpy() + +# # TODO: read +# # https://github.com/tensorflow/tensorflow/issues/56773#issuecomment-1188693881 +# # https://www.tensorflow.org/guide/distributed_training#use_tfdistributestrategy_with_keras_modelfit + +# # Distribute dataset +# if self.strategy: +# train = self.strategy.experimental_distribute_dataset(train) +# test = self.strategy.experimental_distribute_dataset(test) + +# # train the model +# history = self.model.fit( +# train, +# validation_data=test, +# steps_per_epoch=int(n_train // self.num_devices), +# validation_steps=int(n_test // self.num_devices), +# epochs=self.epochs, +# callbacks=self.callbacks, +# ) + +# print("Model trained") +# return history