From bb875b5abf690877a32baf86481c81705e182a73 Mon Sep 17 00:00:00 2001 From: Mano Aravindhan Date: Wed, 29 May 2024 10:14:42 +0530 Subject: [PATCH] Updated README, requirements, and code. --- README.md | 2 +- model/build_datasets.py | 4 ++++ model/classes/dataset/Dataset.py | 2 +- model/classes/dataset/Generator.py | 2 +- model/classes/model/pix2code.py | 4 ++-- model/train.py | 4 ++-- requirements.txt | 10 +++++----- 7 files changed, 16 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 101ee2a..b148cc7 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ Both the source code and the datasets are provided to foster future research in ## Setup ### Prerequisites -- Python 2 or 3 +- Python 3.11.x - pip ### Install dependencies diff --git a/model/build_datasets.py b/model/build_datasets.py index 513d883..e269972 100755 --- a/model/build_datasets.py +++ b/model/build_datasets.py @@ -37,6 +37,10 @@ evaluation_samples_number = len(paths) / (distribution + 1) training_samples_number = evaluation_samples_number * distribution +evaluation_samples_number = int(evaluation_samples_number+0.5) +training_samples_number = int(training_samples_number+0.5) + + assert training_samples_number + evaluation_samples_number == len(paths) print("Splitting datasets, training samples: {}, evaluation samples: {}".format(training_samples_number, evaluation_samples_number)) diff --git a/model/classes/dataset/Dataset.py b/model/classes/dataset/Dataset.py index af470b8..444569a 100644 --- a/model/classes/dataset/Dataset.py +++ b/model/classes/dataset/Dataset.py @@ -141,4 +141,4 @@ def sparsify_labels(next_words, voc): return temp def save_metadata(self, path): - np.save("{}/meta_dataset".format(path), np.array([self.input_shape, self.output_size, self.size])) + np.save("{}/meta_dataset".format(path), np.array([self.input_shape, self.output_size, self.size], dtype=object)) diff --git a/model/classes/dataset/Generator.py b/model/classes/dataset/Generator.py index 2fc8e43..73f7107 100644 --- a/model/classes/dataset/Generator.py +++ b/model/classes/dataset/Generator.py @@ -65,7 +65,7 @@ def data_generator(voc, gui_paths, img_paths, batch_size, generate_binary_sequen if verbose: print("Yield batch") - yield ([batch_input_images, batch_partial_sequences], batch_next_words) + yield ((batch_input_images, batch_partial_sequences), batch_next_words) batch_input_images = [] batch_partial_sequences = [] diff --git a/model/classes/model/pix2code.py b/model/classes/model/pix2code.py index e8b2bcd..e44e7d8 100644 --- a/model/classes/model/pix2code.py +++ b/model/classes/model/pix2code.py @@ -58,7 +58,7 @@ def __init__(self, input_shape, output_size, output_path): self.model = Model(inputs=[visual_input, textual_input], outputs=decoder) - optimizer = RMSprop(lr=0.0001, clipvalue=1.0) + optimizer = RMSprop(learning_rate=0.0001, clipvalue=1.0) self.model.compile(loss='categorical_crossentropy', optimizer=optimizer) def fit(self, images, partial_captions, next_words): @@ -66,7 +66,7 @@ def fit(self, images, partial_captions, next_words): self.save() def fit_generator(self, generator, steps_per_epoch): - self.model.fit_generator(generator, steps_per_epoch=steps_per_epoch, epochs=EPOCHS, verbose=1) + self.model.fit(generator, steps_per_epoch=steps_per_epoch, epochs=EPOCHS, verbose=1) self.save() def predict(self, image, partial_caption): diff --git a/model/train.py b/model/train.py index afa4424..6700d9d 100755 --- a/model/train.py +++ b/model/train.py @@ -4,7 +4,7 @@ __author__ = 'Tony Beltramelli - www.tonybeltramelli.com' import tensorflow as tf -sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) +sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(log_device_placement=True)) import sys @@ -33,7 +33,7 @@ def run(input_path, output_path, is_memory_intensive=False, pretrained_model=Non input_shape = dataset.input_shape output_size = dataset.output_size - steps_per_epoch = dataset.size / BATCH_SIZE + steps_per_epoch = int(dataset.size / BATCH_SIZE) voc = Vocabulary() voc.retrieve(output_path) diff --git a/requirements.txt b/requirements.txt index 3071c5e..45cb617 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -Keras==2.1.2 -numpy==1.13.3 -opencv-python==3.3.0.10 -h5py==2.7.1 -tensorflow==1.4.0 +Keras==3.0.0 +numpy==1.23.5 +opencv-python==4.9.0.80 +h5py==3.10.0 +tensorflow==2.16.1