Skip to content

Commit

Permalink
set model = true
Browse files Browse the repository at this point in the history
  • Loading branch information
mrityunjay-tripathi committed Aug 26, 2020
1 parent a253a6a commit fbdd4ff
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 29 deletions.
32 changes: 16 additions & 16 deletions models/transformer/decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,18 @@ class TransformerDecoder
*/
Sequential<>* AttentionBlock()
{
Sequential<>* decoderBlockBottom = new Sequential<>(false);
Sequential<>* decoderBlockBottom = new Sequential<>();
decoderBlockBottom->Add<Subview<>>(1, 0, dModel * tgtSeqLen - 1, 0, -1);

// Broadcast the incoming input to decoder
// i.e. query into (query, key, value).
Concat<>* decoderInput = new Concat<>();
Concat<>* decoderInput = new Concat<>(true);
decoderInput->Add<IdentityLayer<>>();
decoderInput->Add<IdentityLayer<>>();
decoderInput->Add<IdentityLayer<>>();

// Masked Self attention layer.
Sequential<>* maskedSelfAttention = new Sequential<>(false);
Sequential<>* maskedSelfAttention = new Sequential<>();
maskedSelfAttention->Add(decoderInput);

MultiheadAttention<>* mha1 = new MultiheadAttention<>(tgtSeqLen,
Expand All @@ -157,7 +157,7 @@ class TransformerDecoder
maskedSelfAttention->Add(mha1);

// Residual connection.
AddMerge<>* residualAdd1 = new AddMerge<>();
AddMerge<>* residualAdd1 = new AddMerge<>(true);
residualAdd1->Add(maskedSelfAttention);
residualAdd1->Add<IdentityLayer<>>();

Expand All @@ -167,19 +167,19 @@ class TransformerDecoder
decoderBlockBottom->Add<LayerNorm<>>(dModel * tgtSeqLen);

// This layer broadcasts the output of encoder i.e. key into (key, value).
Concat<>* broadcastEncoderOutput = new Concat<>();
Concat<>* broadcastEncoderOutput = new Concat<>(true);
broadcastEncoderOutput->Add<Subview<>>(1, dModel * tgtSeqLen, -1, 0, -1);
broadcastEncoderOutput->Add<Subview<>>(1, dModel * tgtSeqLen, -1, 0, -1);

// This layer concatenates the output of the bottom decoder block (query)
// and the output of the encoder (key, value).
Concat<>* encoderDecoderAttentionInput = new Concat<>();
encoderDecoderAttentionInput->Add(decoderBlockBottom);
encoderDecoderAttentionInput->Add(broadcastEncoderOutput);
Concat<>* encDecAttnInput = new Concat<>(true);
encDecAttnInput->Add<Subview<>>(1, 0, dModel * tgtSeqLen - 1, 0, -1);
encDecAttnInput->Add(broadcastEncoderOutput);

// Encoder-decoder attention.
Sequential<>* encoderDecoderAttention = new Sequential<>(false);
encoderDecoderAttention->Add(encoderDecoderAttentionInput);
Sequential<>* encoderDecoderAttention = new Sequential<>();
encoderDecoderAttention->Add(encDecAttnInput);

MultiheadAttention<>* mha2 = new MultiheadAttention<>(tgtSeqLen,
srcSeqLen,
Expand All @@ -189,11 +189,11 @@ class TransformerDecoder
encoderDecoderAttention->Add(mha2);

// Residual connection.
AddMerge<>* residualAdd2 = new AddMerge<>();
AddMerge<>* residualAdd2 = new AddMerge<>(true);
residualAdd2->Add(encoderDecoderAttention);
residualAdd2->Add<IdentityLayer<>>();
residualAdd2->Add(decoderBlockBottom);

Sequential<>* decoderBlock = new Sequential<>(false);
Sequential<>* decoderBlock = new Sequential<>();
decoderBlock->Add(residualAdd2);
decoderBlock->Add<LayerNorm<>>(dModel * tgtSeqLen);
return decoderBlock;
Expand All @@ -204,18 +204,18 @@ class TransformerDecoder
*/
Sequential<>* PositionWiseFFNBlock()
{
Sequential<>* positionWiseFFN = new Sequential<>(false);
Sequential<>* positionWiseFFN = new Sequential<>();
positionWiseFFN->Add<Linear3D<>>(dModel, dimFFN);
positionWiseFFN->Add<ActivationFunction>();
positionWiseFFN->Add<Linear3D<>>(dimFFN, dModel);
positionWiseFFN->Add<Dropout<>>(dropout);

/* Residual connection. */
AddMerge<>* residualAdd = new AddMerge<>();
AddMerge<>* residualAdd = new AddMerge<>(true);
residualAdd->Add(positionWiseFFN);
residualAdd->Add<IdentityLayer<>>();

Sequential<>* decoderBlock = new Sequential<>(false);
Sequential<>* decoderBlock = new Sequential<>();
decoderBlock->Add(residualAdd);
return decoderBlock;
}
Expand Down
6 changes: 3 additions & 3 deletions models/transformer/decoder_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ TransformerDecoder<ActivationFunction, RegularizerType>::TransformerDecoder(
keyPaddingMask(keyPaddingMask),
ownMemory(ownMemory)
{
decoder = new Sequential<>(false);
decoder = new Sequential<>();

for (size_t n = 0; n < numLayers; ++n)
{
Expand All @@ -66,11 +66,11 @@ TransformerDecoder<ActivationFunction, RegularizerType>::TransformerDecoder(
break;
}

Sequential<>* decoderBlock = new Sequential<>(false);
Sequential<>* decoderBlock = new Sequential<>();
decoderBlock->Add(AttentionBlock());
decoderBlock->Add(PositionWiseFFNBlock());

Concat<>* concatQueryKey = new Concat<>();
Concat<>* concatQueryKey = new Concat<>(true);
concatQueryKey->Add(decoderBlock);
concatQueryKey->Add<Subview<>>(1, dModel * tgtSeqLen, -1, 0, -1);

Expand Down
10 changes: 5 additions & 5 deletions models/transformer/encoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ class TransformerEncoder
*/
void AttentionBlock()
{
Concat<>* input = new Concat<>();
Concat<>* input = new Concat<>(true);
input->Add<IdentityLayer<>>();
input->Add<IdentityLayer<>>();
input->Add<IdentityLayer<>>();

/* Self attention layer. */
Sequential<>* selfAttn = new Sequential<>(false);
Sequential<>* selfAttn = new Sequential<>();
selfAttn->Add(input);

MultiheadAttention<>* mha = new MultiheadAttention<>(srcSeqLen,
Expand All @@ -150,7 +150,7 @@ class TransformerEncoder
selfAttn->Add(mha);

/* This layer adds a residual connection. */
AddMerge<>* residualAdd = new AddMerge<>();
AddMerge<>* residualAdd = new AddMerge<>(true);
residualAdd->Add(selfAttn);
residualAdd->Add<IdentityLayer<>>();

Expand All @@ -163,14 +163,14 @@ class TransformerEncoder
*/
void PositionWiseFFNBlock()
{
Sequential<>* positionWiseFFN = new Sequential<>(false);
Sequential<>* positionWiseFFN = new Sequential<>();
positionWiseFFN->Add<Linear3D<>>(dModel, dimFFN);
positionWiseFFN->Add<ActivationFunction>();
positionWiseFFN->Add<Linear3D<>>(dimFFN, dModel);
positionWiseFFN->Add<Dropout<>>(dropout);

/* This layer adds a residual connection. */
AddMerge<>* residualAdd = new AddMerge<>();
AddMerge<>* residualAdd = new AddMerge<>(true);
residualAdd->Add(positionWiseFFN);
residualAdd->Add<IdentityLayer<>>();

Expand Down
2 changes: 1 addition & 1 deletion models/transformer/encoder_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ TransformerEncoder<ActivationFunction, RegularizerType>::TransformerEncoder(
keyPaddingMask(keyPaddingMask),
ownMemory(ownMemory)
{
encoder = new Sequential<>(false);
encoder = new Sequential<>();

for (size_t n = 0; n < numLayers; ++n)
{
Expand Down
8 changes: 4 additions & 4 deletions models/transformer/transformer_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ Transformer<ActivationFunction, RegularizerType>::Transformer(
keyPaddingMask(keyPaddingMask),
ownMemory(ownMemory)
{
transformer = new Sequential<>(false);
transformer = new Sequential<>();

Sequential<>* encoder = new Sequential<>(false);
Sequential<>* encoder = new Sequential<>();

// Pull out the sequences of source language which is stacked above in the
// input matrix. Here 'lastCol = -1' denotes upto last batch of input matrix.
Expand All @@ -69,7 +69,7 @@ Transformer<ActivationFunction, RegularizerType>::Transformer(

encoder->Add(encoderStack);

Sequential<>* decoderPE = new Sequential<>(false);
Sequential<>* decoderPE = new Sequential<>();

// Pull out the sequences of target language which is stacked below in the
// input matrix. Here 'lastRow = -1' and 'lastCol = -1' denotes upto last
Expand All @@ -78,7 +78,7 @@ Transformer<ActivationFunction, RegularizerType>::Transformer(
decoderPE->Add<Lookup<>>(tgtVocabSize, dModel);
decoderPE->Add<PositionalEncoding<>>(dModel, tgtSeqLen);

Concat<>* encoderDecoderConcat = new Concat<>();
Concat<>* encoderDecoderConcat = new Concat<>(true);
encoderDecoderConcat->Add(encoder);
encoderDecoderConcat->Add(decoderPE);

Expand Down

0 comments on commit fbdd4ff

Please sign in to comment.