From 5e572fa7923d0721fbe53a7a09b00709d08448ef Mon Sep 17 00:00:00 2001 From: Mrityunjay Tripathi Date: Tue, 25 Aug 2020 23:20:03 +0530 Subject: [PATCH] use mutator method to set mask in mha --- models/transformer/decoder.hpp | 30 ++++++++++++++--------------- models/transformer/encoder.hpp | 17 ++++++++-------- models/transformer/encoder_impl.hpp | 2 +- tests/ffn_model_tests.cpp | 6 +++--- 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/models/transformer/decoder.hpp b/models/transformer/decoder.hpp index 012d7f3d..d486c6f7 100644 --- a/models/transformer/decoder.hpp +++ b/models/transformer/decoder.hpp @@ -147,13 +147,14 @@ class TransformerDecoder // Masked Self attention layer. Sequential<>* maskedSelfAttention = new Sequential<>(false); maskedSelfAttention->Add(decoderInput); - maskedSelfAttention->Add>( - tgtSeqLen, - tgtSeqLen, - dModel, - numHeads, - attentionMask); + + MultiheadAttention<>* mha1 = new MultiheadAttention<>(tgtSeqLen, + tgtSeqLen, + dModel, + numHeads); + mha1->AttentionMask() = attentionMask; + + maskedSelfAttention->Add(mha1); // Residual connection. AddMerge<>* residualAdd1 = new AddMerge<>(); @@ -179,14 +180,13 @@ class TransformerDecoder // Encoder-decoder attention. Sequential<>* encoderDecoderAttention = new Sequential<>(false); encoderDecoderAttention->Add(encoderDecoderAttentionInput); - encoderDecoderAttention->Add>( - tgtSeqLen, - srcSeqLen, - dModel, - numHeads, - arma::mat(), // No attention mask to encoder-decoder attention. - keyPaddingMask); + + MultiheadAttention<>* mha2 = new MultiheadAttention<>(tgtSeqLen, + srcSeqLen, + dModel, + numHeads); + mha2->KeyPaddingMask() = keyPaddingMask; + encoderDecoderAttention->Add(mha2); // Residual connection. AddMerge<>* residualAdd2 = new AddMerge<>(); diff --git a/models/transformer/encoder.hpp b/models/transformer/encoder.hpp index 50db3b70..a54d98b9 100644 --- a/models/transformer/encoder.hpp +++ b/models/transformer/encoder.hpp @@ -95,7 +95,7 @@ class TransformerEncoder /** * Get the Transformer Encoder Model. */ - Sequential* Model() + Sequential<>* Model() { return encoder; } @@ -140,13 +140,14 @@ class TransformerEncoder /* Self attention layer. */ Sequential<>* selfAttn = new Sequential<>(false); selfAttn->Add(input); - selfAttn->Add>( - srcSeqLen, - srcSeqLen, - dModel, - numHeads, - attentionMask, - keyPaddingMask); + + MultiheadAttention<>* mha = new MultiheadAttention<>(srcSeqLen, + srcSeqLen, + dModel, + numHeads); + mha->AttentionMask() = attentionMask; + mha->KeyPaddingMask() = keyPaddingMask; + selfAttn->Add(mha); /* This layer adds a residual connection. */ AddMerge<>* residualAdd = new AddMerge<>(); diff --git a/models/transformer/encoder_impl.hpp b/models/transformer/encoder_impl.hpp index 9da43258..05bfb058 100644 --- a/models/transformer/encoder_impl.hpp +++ b/models/transformer/encoder_impl.hpp @@ -40,7 +40,7 @@ TransformerEncoder::TransformerEncoder( keyPaddingMask(keyPaddingMask), ownMemory(ownMemory) { - encoder = new Sequential(false); + encoder = new Sequential<>(false); for (size_t n = 0; n < numLayers; ++n) { diff --git a/tests/ffn_model_tests.cpp b/tests/ffn_model_tests.cpp index 3f4601fd..f2c529d9 100644 --- a/tests/ffn_model_tests.cpp +++ b/tests/ffn_model_tests.cpp @@ -68,7 +68,7 @@ BOOST_AUTO_TEST_CASE(TransformerEncoderTest) mlpack::ann::TransformerEncoder<> encoder(numLayers, srcSeqLen, dModel, numHeads, dimFFN, dropout); - FFN<> model; + FFN, XavierInitialization> model; model.Add(encoder.Model()); model.Add>(dModel * srcSeqLen, vocabSize); @@ -103,7 +103,7 @@ BOOST_AUTO_TEST_CASE(TransformerDecoderTest) mlpack::ann::TransformerDecoder<> decoder(numLayers, tgtSeqLen, srcSeqLen, dModel, numHeads, dimFFN, dropout); - FFN<> model; + FFN, XavierInitialization> model; model.Add(decoder.Model()); model.Add>(dModel * tgtSeqLen, vocabSize); @@ -148,7 +148,7 @@ BOOST_AUTO_TEST_CASE(TransformerTest) mlpack::ann::Transformer<> transformer(numLayers, tgtSeqLen, srcSeqLen, tgtVocabSize, srcVocabSize, dModel, numHeads, dimFFN, dropout); - FFN<> model; + FFN, XavierInitialization> model; model.Add(transformer.Model()); model.Add>(dModel * tgtSeqLen, tgtVocabSize);