diff --git a/model/train_nmt.py b/model/train_nmt.py index c03f184..a1a4057 100644 --- a/model/train_nmt.py +++ b/model/train_nmt.py @@ -3,7 +3,12 @@ from tensorflow.python.keras.utils import to_categorical import numpy as np -import os +import os, sys + +project_path = os.path.sep.join(os.path.abspath(__file__).split(os.path.sep)[:-2]) +if project_path not in sys.path: + sys.path.append(project_path) + from utils.data_helper import read_data, sents2sequences from model.nmt import define_nmt import matplotlib.pyplot as plt @@ -12,8 +17,8 @@ def get_data(train_size, random_seed=100): """ Getting randomly shuffled training / testing data """ - en_text = read_data(os.path.join('..', 'data', 'small_vocab_en.txt')) - fr_text = read_data(os.path.join('..', 'data', 'small_vocab_fr.txt')) + en_text = read_data(os.path.join(project_path, 'data', 'small_vocab_en.txt')) + fr_text = read_data(os.path.join(project_path, 'data', 'small_vocab_fr.txt')) print('Length of text: {}'.format(len(en_text))) fr_text = ['sos ' + sent[:-1] + 'eos .' if sent.endswith('.') else 'sos ' + sent + ' eos .' for sent in fr_text] @@ -184,7 +189,6 @@ def plot_attention_weights(encoder_inputs, attention_weights, en_id2word, fr_id2 print('Translating: {}'.format(test_en)) test_en_seq = sents2sequences(en_tokenizer, [test_en], pad_length=en_timesteps) - print(test_en_seq) test_fr, attn_weights = infer_nmt( encoder_model=infer_enc_model, decoder_model=infer_dec_model, test_en_seq=test_en_seq, en_vsize=en_vsize, fr_vsize=fr_vsize)