-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
complete encoder decoder and transformer model
- Loading branch information
1 parent
0a307db
commit 0c08f69
Showing
11 changed files
with
1,023 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
build* | ||
xcode* | ||
.vscode/ | ||
.DS_Store | ||
.idea | ||
cmake-build-* | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
/** | ||
* @file models.hpp | ||
* @author Mrityunjay Tripathi | ||
* | ||
* This includes various models. | ||
*/ | ||
|
||
#include "transformer/encoder.hpp" | ||
#include "transformer/decoder.hpp" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR) | ||
project(transformer) | ||
|
||
set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/) | ||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../") | ||
|
||
set(SOURCES | ||
decoder.hpp | ||
decoder_impl.hpp | ||
encoder.hpp | ||
encoder_impl.hpp | ||
transformer.hpp | ||
transformer_impl.hpp | ||
) | ||
|
||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
|
||
set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
/** | ||
* @file models/transformer/decoder.hpp | ||
* @author Mikhail Lozhnikov | ||
* @author Mrityunjay Tripathi | ||
* | ||
* Definition of the Transformer Decoder layer. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
|
||
#ifndef MODELS_TRANSFORMER_DECODER_HPP | ||
#define MODELS_TRANSFORMER_DECODER_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
#include <mlpack/methods/ann/layer/layer_types.hpp> | ||
#include <mlpack/methods/ann/layer/base_layer.hpp> | ||
#include <mlpack/methods/ann/regularizer/no_regularizer.hpp> | ||
|
||
namespace mlpack { | ||
namespace ann /** Artificial Neural Network. */ { | ||
|
||
/** | ||
* In addition to the two sub-layers in each encoder layer, the decoder inserts | ||
* a third sub-layer, which performs multi-head attention over the output of the | ||
* encoder stack. Similar to the encoder, we employ residual connections around | ||
* each of the sub-layers, followed by layer normalization. We also modify the | ||
* self-attention sub-layer in the decoder stack to prevent positions from | ||
* attending to subsequent positions. This masking, combined with fact that the | ||
* output embeddings are offset by one position, ensures that the predictions | ||
* for position i can depend only on the known outputs at positions less than i. | ||
* | ||
* @tparam ActivationFunction The type of the activation function to be used in | ||
* the position-wise feed forward neural network. | ||
* @tparam RegularizerType The type of regularizer to be applied to layer | ||
* parameters. | ||
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat, | ||
* arma::sp_mat or arma::cube). | ||
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat, | ||
* arma::sp_mat or arma::cube). | ||
*/ | ||
template < | ||
typename ActivationFunction = ReLULayer<>, | ||
typename RegularizerType = NoRegularizer, | ||
typename InputDataType = arma::mat, | ||
typename OutputDataType = arma::mat | ||
> | ||
class TransformerDecoder | ||
{ | ||
public: | ||
TransformerDecoder(); | ||
|
||
/** | ||
* Create the TransformerDecoder object using the specified parameters. | ||
* | ||
* @param numLayers The number of decoder blocks. | ||
* @param tgtSeqLen Target Sequence Length. | ||
* @param srcSeqLen Source Sequence Length. | ||
* @param memoryModule The last Encoder module. | ||
* @param dModel The number of features in the input. Also, same as the | ||
* 'embedDim' in 'MultiheadAttention' layer. | ||
* @param numHeads The number of attention heads. | ||
* @param dimFFN The dimentionality of feedforward network. | ||
* @param dropout The dropout rate. | ||
* @param attentionMask The attention mask used to black-out future sequences. | ||
* @param keyPaddingMask The padding mask used to black-out particular token. | ||
*/ | ||
TransformerDecoder(const size_t numLayers, | ||
const size_t tgtSeqLen, | ||
const size_t srcSeqLen, | ||
const size_t dModel = 512, | ||
const size_t numHeads = 8, | ||
const size_t dimFFN = 1024, | ||
const double dropout = 0.1, | ||
const InputDataType& attentionMask = InputDataType(), | ||
const InputDataType& keyPaddingMask = InputDataType()); | ||
|
||
/** | ||
* Get the Transformer Decoder model. | ||
*/ | ||
Sequential<>* Model() { return decoder; } | ||
/** | ||
* Load the network from a local directory. | ||
* | ||
* @param filepath The location of the stored model. | ||
*/ | ||
void LoadModel(const std::string& filepath); | ||
|
||
/** | ||
* Save the network locally. | ||
* | ||
* @param filepath The location where the model is to be saved. | ||
*/ | ||
void SaveModel(const std::string& filepath); | ||
|
||
//! Get the key matrix, the output of the Transformer Encoder. | ||
InputDataType const& Key() const { return key; } | ||
|
||
//! Modify the key matrix. | ||
InputDataType& Key() { return key; } | ||
|
||
private: | ||
/** | ||
* This method adds the attention block to the decoder. | ||
*/ | ||
void AttentionBlock() | ||
{ | ||
Sequential<>* decoderBlockBottom = new Sequential<>(); | ||
decoderBlockBottom->Add<Subview<>>(1, 0, dModel * tgtSeqLen - 1, 0, -1); | ||
|
||
// Broadcast the incoming input to decoder | ||
// i.e. query into (query, key, value). | ||
Concat<>* decoderInput = new Concat<>(); | ||
decoderInput->Add<IdentityLayer<>>(); | ||
decoderInput->Add<IdentityLayer<>>(); | ||
decoderInput->Add<IdentityLayer<>>(); | ||
|
||
// Masked Self attention layer. | ||
Sequential<>* maskedSelfAttention = new Sequential<>(); | ||
maskedSelfAttention->Add(decoderInput); | ||
maskedSelfAttention->Add<MultiheadAttention< | ||
InputDataType, OutputDataType, RegularizerType>>( | ||
tgtSeqLen, | ||
tgtSeqLen, | ||
dModel, | ||
numHeads, | ||
attentionMask | ||
); | ||
|
||
// Residual connection. | ||
AddMerge<>* residualAdd = new AddMerge<>(); | ||
residualAdd->Add(maskedSelfAttention); | ||
residualAdd->Add<IdentityLayer<>>(); | ||
|
||
decoderBlockBottom->Add(residualAddMerge); | ||
|
||
// Add the LayerNorm layer with required parameters. | ||
decoderBlockBottom->Add<LayerNorm<>>(dModel * tgtSeqLen); | ||
|
||
// This layer broadcasts the output of encoder i.e. key into (key, value). | ||
Concat<>* broadcastEncoderOutput = new Concat<>(); | ||
broadcastEncoderOutput->Add<Subview<>>(1, dModel * tgtSeqLen, -1, 0, -1); | ||
broadcastEncoderOutput->Add<Subview<>>(1, dModel * tgtSeqLen, -1, 0, -1); | ||
|
||
// This layer concatenates the output of the bottom decoder block (query) | ||
// and the output of the encoder (key, value). | ||
Concat<>* encoderDecoderAttentionInput = new Concat<>(); | ||
encoderDecoderAttentionInput->Add(decoderBlockBottom); | ||
encoderDecoderAttentionInput->Add(broadcastEncoderOutput); | ||
|
||
// Encoder-decoder attention. | ||
Sequential<>* encoderDecoderAttention = new Sequential<>(); | ||
encoderDecoderAttention->Add(encoderDecoderAttentionInput); | ||
encoderDecoderAttention->Add<MultiheadAttention< | ||
InputDataType, OutputDataType, RegularizerType>>( | ||
tgtSeqLen, | ||
srcSeqLen, | ||
dModel, | ||
numHeads, | ||
InputDatatype(), // No attention mask to encoder-decoder attention. | ||
keyPaddingMask); | ||
|
||
// Residual connection. | ||
AddMerge<>* residualAdd = new AddMerge<>(); | ||
residualAdd->Add(encoderDecoderAttention); | ||
residualAdd->Add<IdentityLayer<>>(); | ||
|
||
decoder->Add(residualAdd); | ||
decoder->Add<LayerNorm<>>(dModel * tgtSeqLen); | ||
} | ||
|
||
/** | ||
* This method adds the position-wise feed forward network to the decoder. | ||
*/ | ||
void PositionWiseFFNBlock() | ||
{ | ||
Sequential<>* positionWiseFFN = new Sequential<>(); | ||
positionWiseFFN->Add<Linear3D<>>(dModel, dimFFN); | ||
positionWiseFFN->Add<ActivationFunction>(); | ||
positionWiseFFN->Add<Linear3D<>>(dimFFN, dModel); | ||
positionWiseFFN->Add<Dropout<>>(dropout); | ||
|
||
/* Residual connection. */ | ||
AddMerge<>* residualAdd = new AddMerge<>(); | ||
residualAdd->Add(positionWiseFFN); | ||
residualAdd->Add<IdentityLayer<>>(); | ||
decoder->Add(residualAdd); | ||
} | ||
|
||
//! Locally-stored number of decoder layers. | ||
size_t numLayers; | ||
|
||
//! Locally-stored target sequence length. | ||
size_t tgtSeqLen; | ||
|
||
//! Locally-stored source sequence length. | ||
size_t srcSeqLen; | ||
|
||
//! Locally-stored number of input units. | ||
size_t dModel; | ||
|
||
//! Locally-stored number of output units. | ||
size_t numHeads; | ||
|
||
//! Locally-stored weight object. | ||
size_t dimFFN; | ||
|
||
//! Locally-stored weight parameters. | ||
double dropout; | ||
|
||
//! Locally-stored attention mask. | ||
InputDataType attentionMask; | ||
|
||
//! Locally-stored key padding mask. | ||
InputDataType keyPaddingMask; | ||
|
||
//! Locally-stored complete decoder network. | ||
Sequential<InputDataType, OutputDataType, false>* decoder; | ||
|
||
}; // class TransformerDecoder | ||
|
||
} // namespace ann | ||
} // namespace mlpack | ||
|
||
// Include implementation. | ||
#include "decoder_impl.hpp" | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/** | ||
* @file models/transformer/decoder_impl.hpp | ||
* @author Mikhail Lozhnikov | ||
* @author Mrityunjay Tripathi | ||
* | ||
* Implementation of the Transformer Decoder class. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
|
||
#ifndef MODELS_TRANSFORMER_DECODER_IMPL_HPP | ||
#define MODELS_TRANSFORMER_DECODER_IMPL_HPP | ||
|
||
#include "decoder.hpp" | ||
|
||
namespace mlpack { | ||
namespace ann /** Artificial Neural Network. */ { | ||
|
||
template<typename ActivationFunction, typename RegularizerType, | ||
typename InputDataType, typename OutputDataType> | ||
TransformerDecoder<ActivationFunction, RegularizerType, InputDataType, | ||
OutputDataType>::TransformerDecoder() : | ||
tgtSeqLen(0), | ||
srcSeqLen(0), | ||
memoryModule(NULL), | ||
dModel(0), | ||
numHeads(0), | ||
dimFFN(0), | ||
dropout(0) | ||
{ | ||
// Nothing to do here. | ||
} | ||
|
||
template<typename ActivationFunction, typename RegularizerType, | ||
typename InputDataType, typename OutputDataType> | ||
TransformerDecoder<ActivationFunction, RegularizerType, InputDataType, | ||
OutputDataType>::TransformerDecoder( | ||
const size_t numLayers, | ||
const size_t tgtSeqLen, | ||
const size_t srcSeqLen, | ||
const size_t dModel, | ||
const size_t numHeads, | ||
const size_t dimFFN, | ||
const double dropout, | ||
const InputDataType& attentionMask, | ||
const InputDataType& keyPaddingMask) : | ||
numLayers(numLayers), | ||
tgtSeqLen(tgtSeqLen), | ||
srcSeqLen(srcSeqLen), | ||
dModel(dModel), | ||
numHeads(numHeads), | ||
dimFFN(dimFFN), | ||
dropout(dropout), | ||
attentionMask(attentionMask), | ||
keyPaddingMask(keyPaddingMask) | ||
{ | ||
decoder = new Sequential<InputDataType, OutputDataType, false>(); | ||
|
||
for (size_t N = 0; N < numLayers; ++N) | ||
{ | ||
AttentionBlock(); | ||
PositionWiseFFNBlock(); | ||
} | ||
} | ||
|
||
template<typename ActivationFunction, typename RegularizerType, | ||
typename InputDataType, typename OutputDataType> | ||
void TransformerDecoder<ActivationFunction, RegularizerType, | ||
InputDataType, OutputDataType>::LoadModel(const std::string& filepath) | ||
{ | ||
data::Load(filepath, "TransformerDecoder", decoder); | ||
std::cout << "Loaded model" << std::endl; | ||
} | ||
|
||
template<typename ActivationFunction, typename RegularizerType, | ||
typename InputDataType, typename OutputDataType> | ||
void TransformerDecoder<ActivationFunction, RegularizerType, | ||
InputDataType, OutputDataType>::SaveModel(const std::string& filepath) | ||
{ | ||
std::cout << "Saving model" << std::endl; | ||
data::Save(filepath, "TransformerDecoder", decoder); | ||
std::cout << "Model saved in " << filepath << std::endl; | ||
} | ||
|
||
} // namespace ann | ||
} // namespace mlpack | ||
|
||
#endif |
Oops, something went wrong.