Skip to content

Commit

Permalink
bidirectional rnn
Browse files Browse the repository at this point in the history
Signed-off-by: Anand Avati <[email protected]>
  • Loading branch information
avati committed Apr 27, 2016
1 parent b5ee9e7 commit c77ae22
Show file tree
Hide file tree
Showing 4 changed files with 612 additions and 9 deletions.
8 changes: 4 additions & 4 deletions nlc.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def train():

# Get a batch and make a step.
start_time = time.time()
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
encoder_inputs, decoder_inputs, target_weights, sequence_length = model.get_batch(
train_set, bucket_id)
_, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
target_weights, bucket_id, False)
Expand Down Expand Up @@ -202,7 +202,7 @@ def train():
if len(dev_set[bucket_id]) == 0:
print(" eval: empty bucket %d" % (bucket_id))
continue
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
encoder_inputs, decoder_inputs, target_weights, sequence_length = model.get_batch(
dev_set, bucket_id)
_, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
target_weights, bucket_id, True)
Expand Down Expand Up @@ -236,7 +236,7 @@ def decode():
bucket_id = min([b for b in xrange(len(_buckets))
if _buckets[b][0] > len(token_ids)])
# Get a 1-element batch to feed the sentence to the model.
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
encoder_inputs, decoder_inputs, target_weights, sequence_length = model.get_batch(
{bucket_id: [(token_ids, [])]}, bucket_id)
# Get output logits for the sentence.
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
Expand Down Expand Up @@ -267,7 +267,7 @@ def self_test():
[([1, 1, 1, 1, 1], [2, 2, 2, 2, 2]), ([3, 3, 3], [5, 6])])
for _ in xrange(5): # Train the fake model for 5 steps.
bucket_id = random.choice([0, 1])
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
encoder_inputs, decoder_inputs, target_weights, sequence_length = model.get_batch(
data_set, bucket_id)
model.step(sess, encoder_inputs, decoder_inputs, target_weights,
bucket_id, False)
Expand Down
Loading

0 comments on commit c77ae22

Please sign in to comment.