Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformer #16

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
build*
xcode*
.vscode/
.DS_Store
.idea
cmake-build-*
Expand Down
12 changes: 10 additions & 2 deletions models/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR)
project(models)

add_subdirectory(darknet)
add_subdirectory(yolo)
# Recurse into each model mlpack provides.
set(DIRS
darknet
transformer
yolo
)

foreach(dir ${DIRS})
add_subdirectory(${dir})
endforeach()

# Add directory name to sources.
set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/)
Expand Down
20 changes: 20 additions & 0 deletions models/transformer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR)
project(transformer)

set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../")

set(SOURCES
decoder.hpp
decoder_impl.hpp
encoder.hpp
encoder_impl.hpp
transformer.hpp
transformer_impl.hpp
)

foreach(file ${SOURCES})
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
endforeach()

set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE)
267 changes: 267 additions & 0 deletions models/transformer/decoder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
/**
* @file models/transformer/decoder.hpp
* @author Mikhail Lozhnikov
* @author Mrityunjay Tripathi
*
* Definition of the Transformer Decoder layer.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#ifndef MODELS_TRANSFORMER_DECODER_HPP
#define MODELS_TRANSFORMER_DECODER_HPP

#include <mlpack/prereqs.hpp>
#include <mlpack/methods/ann/layer/layer_types.hpp>
#include <mlpack/methods/ann/layer/base_layer.hpp>
#include <mlpack/methods/ann/regularizer/no_regularizer.hpp>

namespace mlpack {
namespace ann /** Artificial Neural Network. */ {

/**
* In addition to the two sub-layers in each encoder layer, the decoder inserts
* a third sub-layer, which performs multi-head attention over the output of the
* encoder stack. Similar to the encoder, we employ residual connections around
* each of the sub-layers, followed by layer normalization. We also modify the
* self-attention sub-layer in the decoder stack to prevent positions from
* attending to subsequent positions. This masking, combined with fact that the
* output embeddings are offset by one position, ensures that the predictions
* for position i can depend only on the known outputs at positions less than i.
*
* @tparam ActivationFunction The type of the activation function to be used in
* the position-wise feed forward neural network.
* @tparam RegularizerType The type of regularizer to be applied to layer
* parameters.
*/
template <
typename ActivationFunction = ReLULayer<>,
typename RegularizerType = NoRegularizer
>
class TransformerDecoder
{
public:
TransformerDecoder();

/**
* Create the TransformerDecoder object using the specified parameters.
*
* @param numLayers The number of decoder blocks.
* @param tgtSeqLen Target Sequence Length.
* @param srcSeqLen Source Sequence Length.
* @param dModel The number of features in the input. Also, same as the
* `embedDim` in `MultiheadAttention` layer.
* @param numHeads The number of attention heads.
* @param dimFFN The dimentionality of feedforward network.
* @param dropout The dropout rate.
* @param attentionMask The attention mask used to black-out future sequences.
* @param keyPaddingMask The padding mask used to black-out particular token.
* @param ownMemory Whether to delete the pointer-type decoder object.
*/
TransformerDecoder(const size_t numLayers,
const size_t tgtSeqLen,
const size_t srcSeqLen,
const size_t dModel = 512,
const size_t numHeads = 8,
const size_t dimFFN = 1024,
const double dropout = 0.1,
const arma::mat& attentionMask = arma::mat(),
const arma::mat& keyPaddingMask = arma::mat(),
const bool ownMemory = false);

/**
* Destructor.
*/
~TransformerDecoder()
{
if (ownMemory)
delete decoder;
}

/**
* Copy constructor.
*/
TransformerDecoder(const TransformerDecoder& ) = delete;

/**
* Move constructor.
*/
TransformerDecoder(TransformerDecoder&& ) = delete;

/**
* Copy assignment operator.
*/
TransformerDecoder& operator = (const TransformerDecoder& ) = delete;
mrityunjay-tripathi marked this conversation as resolved.
Show resolved Hide resolved

/**
* Move assignment operator.
*/
TransformerDecoder& operator = (TransformerDecoder&& ) = delete;

/**
* Get the Transformer Decoder model.
*/
Sequential<>* Model() { return decoder; }

/**
* Load the network from a local directory.
*
* @param filepath The location of the stored model.
*/
void LoadModel(const std::string& filepath);

/**
* Save the network locally.
*
* @param filepath The location where the model is to be saved.
*/
void SaveModel(const std::string& filepath);

//! Get the attention mask.
arma::mat const& AttentionMask() const { return attentionMask; }

//! Modify the attention mask.
arma::mat& AttentionMask() { return attentionMask; }

//! Get the key padding mask.
arma::mat const& KeyPaddingMask() const { return keyPaddingMask; }

//! Modify the key padding mask.
arma::mat& KeyPaddingMask() { return keyPaddingMask; }

private:
/**
* This method adds the attention block to the decoder.
*/
Sequential<>* AttentionBlock()
{
Sequential<>* decoderBlockBottom = new Sequential<>();
decoderBlockBottom->Add<Subview<>>(1, 0, dModel * tgtSeqLen - 1, 0, -1);

// Broadcast the incoming input to decoder
// i.e. query into (query, key, value).
Concat<>* decoderInput = new Concat<>(true);
decoderInput->Add<IdentityLayer<>>();
decoderInput->Add<IdentityLayer<>>();
decoderInput->Add<IdentityLayer<>>();

// Masked Self attention layer.
Sequential<>* maskedSelfAttention = new Sequential<>();
maskedSelfAttention->Add(decoderInput);

MultiheadAttention<>* mha1 = new MultiheadAttention<>(tgtSeqLen,
tgtSeqLen,
dModel,
numHeads);
mha1->AttentionMask() = attentionMask;

maskedSelfAttention->Add(mha1);

// Residual connection.
AddMerge<>* residualAdd1 = new AddMerge<>(true);
residualAdd1->Add(maskedSelfAttention);
residualAdd1->Add<IdentityLayer<>>();

decoderBlockBottom->Add(residualAdd1);

// Add the LayerNorm layer with required parameters.
decoderBlockBottom->Add<LayerNorm<>>(dModel * tgtSeqLen);

// This layer broadcasts the output of encoder i.e. key into (key, value).
Concat<>* broadcastEncoderOutput = new Concat<>(true);
broadcastEncoderOutput->Add<Subview<>>(1, dModel * tgtSeqLen, -1, 0, -1);
broadcastEncoderOutput->Add<Subview<>>(1, dModel * tgtSeqLen, -1, 0, -1);

// This layer concatenates the output of the bottom decoder block (query)
// and the output of the encoder (key, value).
Concat<>* encDecAttnInput = new Concat<>(true);
encDecAttnInput->Add<Subview<>>(1, 0, dModel * tgtSeqLen - 1, 0, -1);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is incorrect. It's the decoder bottom input. But the encoder-decoder attention block should receive the output of the decoder bottom.

Suggested change
encDecAttnInput->Add<Subview<>>(1, 0, dModel * tgtSeqLen - 1, 0, -1);
encDecAttnInput->Add(decoderBlockBottom);

encDecAttnInput->Add(broadcastEncoderOutput);

// Encoder-decoder attention.
Sequential<>* encoderDecoderAttention = new Sequential<>();
encoderDecoderAttention->Add(encDecAttnInput);

MultiheadAttention<>* mha2 = new MultiheadAttention<>(tgtSeqLen,
srcSeqLen,
dModel,
numHeads);
mha2->KeyPaddingMask() = keyPaddingMask;
encoderDecoderAttention->Add(mha2);

// Residual connection.
AddMerge<>* residualAdd2 = new AddMerge<>(true);
residualAdd2->Add(encoderDecoderAttention);
residualAdd2->Add(decoderBlockBottom);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't pass the same block twice (see the comment to encDecAttnInput). Looks like we need to change the model a bit. I have to go now. I'll come up with the idea in the evening.


Sequential<>* decoderBlock = new Sequential<>();
decoderBlock->Add(residualAdd2);
decoderBlock->Add<LayerNorm<>>(dModel * tgtSeqLen);
return decoderBlock;
}

/**
* This method adds the position-wise feed forward network to the decoder.
*/
Sequential<>* PositionWiseFFNBlock()
{
Sequential<>* positionWiseFFN = new Sequential<>();
positionWiseFFN->Add<Linear3D<>>(dModel, dimFFN);
positionWiseFFN->Add<ActivationFunction>();
positionWiseFFN->Add<Linear3D<>>(dimFFN, dModel);
positionWiseFFN->Add<Dropout<>>(dropout);

/* Residual connection. */
AddMerge<>* residualAdd = new AddMerge<>(true);
residualAdd->Add(positionWiseFFN);
residualAdd->Add<IdentityLayer<>>();

Sequential<>* decoderBlock = new Sequential<>();
decoderBlock->Add(residualAdd);
return decoderBlock;
}

//! Locally-stored number of decoder layers.
size_t numLayers;

//! Locally-stored target sequence length.
size_t tgtSeqLen;

//! Locally-stored source sequence length.
size_t srcSeqLen;

//! Locally-stored number of features in the input.
size_t dModel;

//! Locally-stored number of attention heads.
size_t numHeads;

//! Locally-stored dimensionality of position-wise feed forward network.
size_t dimFFN;

//! Locally-stored dropout rate.
double dropout;

//! Locally-stored attention mask.
arma::mat attentionMask;

//! Locally-stored key padding mask.
arma::mat keyPaddingMask;

//! Whether to delete pointer-type decoder object.
bool ownMemory;

//! Locally-stored complete decoder network.
Sequential<>* decoder;
}; // class TransformerDecoder

} // namespace ann
} // namespace mlpack

// Include implementation.
#include "decoder_impl.hpp"

#endif
Loading