Skip to content

Commit

Permalink
use mutator method to set mask in mha
Browse files Browse the repository at this point in the history
  • Loading branch information
mrityunjay-tripathi committed Aug 25, 2020
1 parent 0aa3f28 commit 5e572fa
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 27 deletions.
30 changes: 15 additions & 15 deletions models/transformer/decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,14 @@ class TransformerDecoder
// Masked Self attention layer.
Sequential<>* maskedSelfAttention = new Sequential<>(false);
maskedSelfAttention->Add(decoderInput);
maskedSelfAttention->Add<MultiheadAttention<
arma::mat, arma::mat, RegularizerType>>(
tgtSeqLen,
tgtSeqLen,
dModel,
numHeads,
attentionMask);

MultiheadAttention<>* mha1 = new MultiheadAttention<>(tgtSeqLen,
tgtSeqLen,
dModel,
numHeads);
mha1->AttentionMask() = attentionMask;

maskedSelfAttention->Add(mha1);

// Residual connection.
AddMerge<>* residualAdd1 = new AddMerge<>();
Expand All @@ -179,14 +180,13 @@ class TransformerDecoder
// Encoder-decoder attention.
Sequential<>* encoderDecoderAttention = new Sequential<>(false);
encoderDecoderAttention->Add(encoderDecoderAttentionInput);
encoderDecoderAttention->Add<MultiheadAttention<
arma::mat, arma::mat, RegularizerType>>(
tgtSeqLen,
srcSeqLen,
dModel,
numHeads,
arma::mat(), // No attention mask to encoder-decoder attention.
keyPaddingMask);

MultiheadAttention<>* mha2 = new MultiheadAttention<>(tgtSeqLen,
srcSeqLen,
dModel,
numHeads);
mha2->KeyPaddingMask() = keyPaddingMask;
encoderDecoderAttention->Add(mha2);

// Residual connection.
AddMerge<>* residualAdd2 = new AddMerge<>();
Expand Down
17 changes: 9 additions & 8 deletions models/transformer/encoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class TransformerEncoder
/**
* Get the Transformer Encoder Model.
*/
Sequential<arma::mat, arma::mat, false>* Model()
Sequential<>* Model()
{
return encoder;
}
Expand Down Expand Up @@ -140,13 +140,14 @@ class TransformerEncoder
/* Self attention layer. */
Sequential<>* selfAttn = new Sequential<>(false);
selfAttn->Add(input);
selfAttn->Add<MultiheadAttention<arma::mat, arma::mat, RegularizerType>>(
srcSeqLen,
srcSeqLen,
dModel,
numHeads,
attentionMask,
keyPaddingMask);

MultiheadAttention<>* mha = new MultiheadAttention<>(srcSeqLen,
srcSeqLen,
dModel,
numHeads);
mha->AttentionMask() = attentionMask;
mha->KeyPaddingMask() = keyPaddingMask;
selfAttn->Add(mha);

/* This layer adds a residual connection. */
AddMerge<>* residualAdd = new AddMerge<>();
Expand Down
2 changes: 1 addition & 1 deletion models/transformer/encoder_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ TransformerEncoder<ActivationFunction, RegularizerType>::TransformerEncoder(
keyPaddingMask(keyPaddingMask),
ownMemory(ownMemory)
{
encoder = new Sequential<arma::mat, arma::mat, false>(false);
encoder = new Sequential<>(false);

for (size_t n = 0; n < numLayers; ++n)
{
Expand Down
6 changes: 3 additions & 3 deletions tests/ffn_model_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ BOOST_AUTO_TEST_CASE(TransformerEncoderTest)
mlpack::ann::TransformerEncoder<> encoder(numLayers, srcSeqLen,
dModel, numHeads, dimFFN, dropout);

FFN<> model;
FFN<NegativeLogLikelihood<>, XavierInitialization> model;

model.Add(encoder.Model());
model.Add<Linear<>>(dModel * srcSeqLen, vocabSize);
Expand Down Expand Up @@ -103,7 +103,7 @@ BOOST_AUTO_TEST_CASE(TransformerDecoderTest)
mlpack::ann::TransformerDecoder<> decoder(numLayers, tgtSeqLen, srcSeqLen,
dModel, numHeads, dimFFN, dropout);

FFN<> model;
FFN<NegativeLogLikelihood<>, XavierInitialization> model;

model.Add(decoder.Model());
model.Add<Linear<>>(dModel * tgtSeqLen, vocabSize);
Expand Down Expand Up @@ -148,7 +148,7 @@ BOOST_AUTO_TEST_CASE(TransformerTest)
mlpack::ann::Transformer<> transformer(numLayers, tgtSeqLen, srcSeqLen,
tgtVocabSize, srcVocabSize, dModel, numHeads, dimFFN, dropout);

FFN<> model;
FFN<NegativeLogLikelihood<>, XavierInitialization> model;

model.Add(transformer.Model());
model.Add<Linear<>>(dModel * tgtSeqLen, tgtVocabSize);
Expand Down

0 comments on commit 5e572fa

Please sign in to comment.