Skip to content

Commit

Permalink
policy class for different attention mechanisms
Browse files Browse the repository at this point in the history
  • Loading branch information
mrityunjay-tripathi committed Jun 9, 2020
1 parent 8b32f11 commit d692767
Showing 1 changed file with 176 additions and 19 deletions.
195 changes: 176 additions & 19 deletions src/mlpack/methods/ann/layer/multihead_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,18 @@ namespace ann /** Artificial Neural Network. */ {
* arma::sp_mat or arma::cube).
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
* @tparam AttentionType The type of the attention module (0 for self-attention,
* 1 for encoder-decoder attention and 2 for none).
* @tparam AttentionPolicy The type of multihead attention to be used. Use
* 'SelfAttention' for self-attention and 'EncoderDecoderAttention' for
* encoder-decoder attention.
*/
template <
typename InputDataType = arma::mat,
typename OutputDataType = arma::mat,
uint8_t AttentionType = 0
class AttentionPolicy
>
class MultiheadAttention
class MultiheadAttention :
{
public:
typedef typename InputDataType::elem_type ElemType;
/**
* Default constructor.
*/
Expand All @@ -63,21 +63,26 @@ class MultiheadAttention
/**
* Create the MultiheadAttention object using the specified modules.
*
* @param tgtSeqLen The length of the target sequence.
* @param tgtSeqLen The length of target sequence.
* @param srcSeqLen The length of the source sequence.
* @param embedDim Total dimension of the model.
* @param numHeads Number of parallel attention heads.
* @param dropout The dropout rate for attention output weights.
* @param deterministic If false, dropout layer is omitted else dropout layer
* is applied with dropout rate `dropout`.
*/
MultiheadAttention(size_t tgtSeqLen,
size_t srcSeqLen,
MultiheadAttention(const size_t tgtSeqLen,
const size_t srcSeqLen,
const size_t embedDim,
const size_t numHeads,
const ElemType dropout = 0.0,
const bool deterministic = true);

/**
* Destructor for MultiheadAttention object.
*/
~MultiheadAttention();

/**
* Reset the layer parameters.
*/
Expand Down Expand Up @@ -168,10 +173,13 @@ class MultiheadAttention
OutputDataType& Deterministic() { return deterministic; }

private:
//! Locally-stored value of size of the target sequence.
//! Element Type of the input.
typedef typename InputDataType::elem_type ElemType;

//! Locally-stored value of target sequence length.
size_t tgtSeqLen;

//! Locally-stored value of size of the source sequence.
//! Locally-stored value of source sequence length.
size_t srcSeqLen;

//! Locally-stored module output size.
Expand All @@ -193,7 +201,7 @@ class MultiheadAttention
InputDataType attnMask2d;

//! Three dimensional Attention Mask.
arma::cube attnMask3d;
arma::Cube<ElemType> attnMask3d;

//! Key Padding Mask.
InputDataType keyPaddingMask;
Expand All @@ -207,17 +215,20 @@ class MultiheadAttention
//! Locally-stored attention output weight to be fed to last linear layer.
OutputDataType attnWt;

//! Linear layer for input query.
Linear<InputDataType, OutputDataType>* queryLinear;
//! Locally-stored weight matrix associated with query.
OutputDataType queryWt;

//! Locally-stored weight matrix associated with key.
OutputDataType keyWt;

//! Linear layer for input key.
Linear<InputDataType, OutputDataType>* keyLinear;
//! Locally-stored weight matrix associated with value.
OutputDataType valueWt;

//! Linear layer for input value.
Linear<InputDataType, OutputDataType>* valueLinear;
//! Locally-stored weight matrix associated with attnWt.
OutputDataType outWt;

//! Linear layer for output of attention layer.
Linear<InputDataType, OutputDataType>* outLinear;
//! Locally-stored weights.
OutputDataType weights;

//! Softmax layer to represent the probabilities of next sequence.
Softmax<InputDataType, OutputDataType>* softmaxModule;
Expand All @@ -233,8 +244,154 @@ class MultiheadAttention

//! Locally-stored output parameter.
OutputDataType outputParameter;

//! Policy Type.
AttentionPolicy attnPolicy();
}; // class MultiheadAttention

/**
* Self Attention Policy class. It deals with the constraints of the parameters.
*/
class SelfAttention
{
public:
/**
* Construct the SelfAttention policy object.
*/
SelfAttention() :
allInputsEqual(true),
keyEqualsValue(true)
{
// Nothing to do here.
}

/**
* Checks the validity of the dimensions of the input.
*/
template <typename InputType>
void checkInputDimensions(const InputType& input,
const size_t tgtSeqLen,
const size_t srcSeqLen,
const size_t embedDim)
{
Log::Assert(tgtSeqLen == srcSeqLen);
Log::Assert(input.n_rows == tgtSeqLen * embedDim);
}

/**
* Checks the validity of the dimensions of the output.
*/
template <typename OutputType>
void checkOutputDimensions(const OutputType& output,
const size_t tgtSeqLen,
const size_t srcSeqLen,
const size_t embedDim)
{
Log::Assert(tgtSeqLen == srcSeqLen);
Log::Assert(input.n_rows == tgtSeqLen * embedDim);
}

template <typename eT>
void splitInput(const arma::Mat<eT>& input,
arma::Cube<eT>& query,
arma::Cube<eT>& key,
arma::Cube<eT>& value,
const size_t tgtSeqLen,
const size_t srcSeqLen,
const size_t embedDim)
{
checkInputDimensions(input, tgtSeqLen, srcSeqLen, embedDim);
query = arma::Cube<eT>(const_cast<arma::Mat<eT>&>(input.memptr(),
embedDim, tgtSeqLen, input.n_cols, false));
key = value = query;
}

//! Get the value of allInputsEqual.
bool const AllInputsEqual() const { return allInputsEqual; }

//! Get the value of keyEqualsValue.
bool const KeyEqualsValue() const { return keyEqualsValue; }

private:
//! True if query, key and value are equal.
bool allInputsEqual;

//! True if key and value are equal.
bool keyEqualsValue;
};

/**
* EncoderDecoderAttention policy class. It deals with the constraints of the
* parameters.
*/
class EncoderDecoderAttention
{
public:
/**
* Construct the EncoderDecoderAttention policy object.
*/
EncoderDecoderAttention() :
allInputsEqual(false),
keyEqualsValue(true)
{
// Nothing to do here.
}

/**
* Checks the validity of the dimensions of the input.
*/
template <typename InputType>
void checkInputDimensions(const InputType& input,
const size_t tgtSeqLen,
const size_t srcSeqLen,
const size_t embedDim)
{
Log::Assert(input.n_rows == (tgtSeqLen + srcSeqLen) * embedDim);
}

/**
* Checks the validity of the dimensions of the output.
*/
template <typename OutputType>
void checkOutputDimensions(const OutputType& output,
const size_t tgtSeqLen,
const size_t srcSeqLen,
const size_t embedDim)
{
Log::Assert(input.n_rows == (tgtSeqLen + srcSeqLen) * embedDim);
}

template <typename eT>
void splitInput(const arma::Mat<eT>& input,
arma::Cube<eT>& query,
arma::Cube<eT>& key,
arma::Cube<eT>& value,
const size_t tgtSeqLen,
const size_t srcSeqLen,
const size_t embedDim)
{
checkInputDimensions(input, tgtSeqLen, srcSeqLen, embedDim);
query = arma::Cube<eT>(const_cast<arma::Mat<eT>&>(input).memptr(),
embedDim, tgtSeqLen, input.n_cols, false);
key = arma::Cube<eT>(const_cast<arma::Mat<eT>&>(input).memptr()
+ embedDim * tgtSeqLen, embedDim, srcSeqLen, input.n_cols, false);
value = key;
}

//! Get the value of allInputsEqual.
bool const AllInputsEqual() const { return allInputsEqual; }

//! Get the value of keyEqualsValue.
bool const KeyEqualsValue() const { return keyEqualsValue; }

private:
//! True if query, key and value are equal.
bool allInputsEqual;

//! True if key and value are equal.
bool keyEqualsValue;
};

} // namespace ann
} // namespace mlpack

Expand Down

0 comments on commit d692767

Please sign in to comment.