Skip to content

Commit

Permalink
FIX: duplicated code in TF Trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
matbun committed Oct 18, 2023
1 parent a718b7a commit b214459
Showing 1 changed file with 73 additions and 73 deletions.
146 changes: 73 additions & 73 deletions src/itwinai/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,76 +138,76 @@ def train(self, train_dataset, validation_dataset):
return history


class TensorflowTrainer2(Trainer):
def __init__(
self,
epochs,
batch_size,
callbacks,
model_dict: Dict,
compile_conf,
strategy
):
super().__init__()
self.strategy = strategy
self.epochs = epochs
self.batch_size = batch_size
self.callbacks = callbacks

# Handle the parsing
model_class = import_class(model_dict["class_path"])
parser = ArgumentParser()
parser.add_subclass_arguments(model_class, "model")
model_dict = {"model": model_dict}

# Create distributed TF vars
if self.strategy:
with self.strategy.scope():
self.model = parser.instantiate_classes(model_dict).model
print(self.model)
self.model.compile(**compile_conf)
# TODO: move loss, optimizer and metrics instantiation under
# here
# Ref:
# https://www.tensorflow.org/guide/distributed_training#use_tfdistributestrategy_with_keras_modelfit
else:
self.model = parser.instantiate_classes(model_dict).model
self.model.compile(**compile_conf)

self.num_devices = (
self.strategy.num_replicas_in_sync if self.strategy else 1)
print(f"Strategy is working with: {self.num_devices} devices")

def train(self, train_dataset, validation_dataset):
# TODO: FIX Steps sizes in model.fit
train, test = train_dataset, validation_dataset

# Set batch size to the dataset
train = train.batch(self.batch_size, drop_remainder=True)
test = test.batch(self.batch_size, drop_remainder=True)

# Number of samples
n_train = train.cardinality().numpy()
n_test = test.cardinality().numpy()

# TODO: read
# https://github.com/tensorflow/tensorflow/issues/56773#issuecomment-1188693881
# https://www.tensorflow.org/guide/distributed_training#use_tfdistributestrategy_with_keras_modelfit

# Distribute dataset
if self.strategy:
train = self.strategy.experimental_distribute_dataset(train)
test = self.strategy.experimental_distribute_dataset(test)

# train the model
history = self.model.fit(
train,
validation_data=test,
steps_per_epoch=int(n_train // self.num_devices),
validation_steps=int(n_test // self.num_devices),
epochs=self.epochs,
callbacks=self.callbacks,
)

print("Model trained")
return history
# class TensorflowTrainer2(Trainer):
# def __init__(
# self,
# epochs,
# batch_size,
# callbacks,
# model_dict: Dict,
# compile_conf,
# strategy
# ):
# super().__init__()
# self.strategy = strategy
# self.epochs = epochs
# self.batch_size = batch_size
# self.callbacks = callbacks

# # Handle the parsing
# model_class = import_class(model_dict["class_path"])
# parser = ArgumentParser()
# parser.add_subclass_arguments(model_class, "model")
# model_dict = {"model": model_dict}

# # Create distributed TF vars
# if self.strategy:
# with self.strategy.scope():
# self.model = parser.instantiate_classes(model_dict).model
# print(self.model)
# self.model.compile(**compile_conf)
# # TODO: move loss, optimizer and metrics instantiation under
# # here
# # Ref:
# # https://www.tensorflow.org/guide/distributed_training#use_tfdistributestrategy_with_keras_modelfit
# else:
# self.model = parser.instantiate_classes(model_dict).model
# self.model.compile(**compile_conf)

# self.num_devices = (
# self.strategy.num_replicas_in_sync if self.strategy else 1)
# print(f"Strategy is working with: {self.num_devices} devices")

# def train(self, train_dataset, validation_dataset):
# # TODO: FIX Steps sizes in model.fit
# train, test = train_dataset, validation_dataset

# # Set batch size to the dataset
# train = train.batch(self.batch_size, drop_remainder=True)
# test = test.batch(self.batch_size, drop_remainder=True)

# # Number of samples
# n_train = train.cardinality().numpy()
# n_test = test.cardinality().numpy()

# # TODO: read
# # https://github.com/tensorflow/tensorflow/issues/56773#issuecomment-1188693881
# # https://www.tensorflow.org/guide/distributed_training#use_tfdistributestrategy_with_keras_modelfit

# # Distribute dataset
# if self.strategy:
# train = self.strategy.experimental_distribute_dataset(train)
# test = self.strategy.experimental_distribute_dataset(test)

# # train the model
# history = self.model.fit(
# train,
# validation_data=test,
# steps_per_epoch=int(n_train // self.num_devices),
# validation_steps=int(n_test // self.num_devices),
# epochs=self.epochs,
# callbacks=self.callbacks,
# )

# print("Model trained")
# return history

0 comments on commit b214459

Please sign in to comment.