Skip to content

Commit

Permalink
Added project path to system path
Browse files Browse the repository at this point in the history
  • Loading branch information
thushv89 committed Mar 18, 2019
1 parent d25aaaa commit 45a7dd9
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions model/train_nmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

from tensorflow.python.keras.utils import to_categorical
import numpy as np
import os
import os, sys

project_path = os.path.sep.join(os.path.abspath(__file__).split(os.path.sep)[:-2])
if project_path not in sys.path:
sys.path.append(project_path)

from utils.data_helper import read_data, sents2sequences
from model.nmt import define_nmt
import matplotlib.pyplot as plt
Expand All @@ -12,8 +17,8 @@
def get_data(train_size, random_seed=100):

""" Getting randomly shuffled training / testing data """
en_text = read_data(os.path.join('..', 'data', 'small_vocab_en.txt'))
fr_text = read_data(os.path.join('..', 'data', 'small_vocab_fr.txt'))
en_text = read_data(os.path.join(project_path, 'data', 'small_vocab_en.txt'))
fr_text = read_data(os.path.join(project_path, 'data', 'small_vocab_fr.txt'))
print('Length of text: {}'.format(len(en_text)))

fr_text = ['sos ' + sent[:-1] + 'eos .' if sent.endswith('.') else 'sos ' + sent + ' eos .' for sent in fr_text]
Expand Down Expand Up @@ -184,7 +189,6 @@ def plot_attention_weights(encoder_inputs, attention_weights, en_id2word, fr_id2
print('Translating: {}'.format(test_en))

test_en_seq = sents2sequences(en_tokenizer, [test_en], pad_length=en_timesteps)
print(test_en_seq)
test_fr, attn_weights = infer_nmt(
encoder_model=infer_enc_model, decoder_model=infer_dec_model,
test_en_seq=test_en_seq, en_vsize=en_vsize, fr_vsize=fr_vsize)
Expand Down

0 comments on commit 45a7dd9

Please sign in to comment.