Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multihead attention #2375

Merged
merged 21 commits into from
Aug 24, 2020
Merged

multihead attention #2375

merged 21 commits into from
Aug 24, 2020

Conversation

mrityunjay-tripathi
Copy link
Member

@mrityunjay-tripathi mrityunjay-tripathi commented Apr 17, 2020

Hi everyone,
I've worked on the implementation of multihead attention. The multihead attention layer would be required for the Transformer model. Debugging and refactoring of the code will come subsequently but this is the initial structure on which I will be working on. The implementation is mostly motivated from PyTorch and Tensorflow.

Copy link
Contributor

@lozhnikov lozhnikov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the slow response. I added a couple of comments.

src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention_impl.hpp Outdated Show resolved Hide resolved
@mrityunjay-tripathi mrityunjay-tripathi marked this pull request as ready for review May 17, 2020 10:56
@birm birm marked this pull request as draft May 19, 2020 16:06
Copy link
Contributor

@lozhnikov lozhnikov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a couple of comments.

src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention_impl.hpp Outdated Show resolved Hide resolved
Copy link
Contributor

@lozhnikov lozhnikov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think perhaps, it's better to pass three matrices (query, key, and value, concatenated into one matrix) each time. It would simplify the interface.

src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
Copy link
Contributor

@lozhnikov lozhnikov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments for the Forward() implementation.

src/mlpack/methods/ann/layer/multihead_attention_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention_impl.hpp Outdated Show resolved Hide resolved
@mrityunjay-tripathi mrityunjay-tripathi marked this pull request as ready for review July 8, 2020 18:12
@mrityunjay-tripathi mrityunjay-tripathi changed the title [WIP] multihead attention multihead attention Jul 8, 2020
@mrityunjay-tripathi
Copy link
Member Author

Wow! Tests are failing. Locally they were passing. 😕

Copy link
Contributor

@lozhnikov lozhnikov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments.

src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention_impl.hpp Outdated Show resolved Hide resolved
Copy link
Contributor

@lozhnikov lozhnikov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some minor comments. I'm still looking through the implementation.

src/mlpack/core/math/multiply_slices.hpp Show resolved Hide resolved
src/mlpack/core/math/multiply_slices_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/core/math/multiply_slices_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/core/math/multiply_slices_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/core/math/multiply_slices.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/tests/ann_layer_test.cpp Outdated Show resolved Hide resolved
src/mlpack/tests/ann_layer_test.cpp Outdated Show resolved Hide resolved
src/mlpack/tests/ann_layer_test.cpp Outdated Show resolved Hide resolved
src/mlpack/tests/ann_layer_test.cpp Outdated Show resolved Hide resolved
@lozhnikov
Copy link
Contributor

lozhnikov commented Aug 3, 2020

I think these memory errors aren't your fault. I verified your tests with valgrind and AddressSanitizer, they didn't show anything wrong. However, other tests in ANNLayerTest have a considerable amount of memory leaks (see leaks.zip, I used gcc-6 under ubuntu 18.04).

@mrityunjay-tripathi
Copy link
Member Author

Ohh right. I had to scroll up again and again to see that comment, so missed the part that selfAttention is actually Concat layer. Thanks for the clarification :)

@lozhnikov
Copy link
Contributor

Could you apply this patch please? restore-ordering.patch.zip It restores natural ordering.

@mrityunjay-tripathi
Copy link
Member Author

Could you apply this patch please? restore-ordering.patch.zip It restores natural ordering.

Wow!!! Thank you so much for working this out.
Looks like we are solving Rubik's cube 😅

@mrityunjay-tripathi
Copy link
Member Author

@lozhnikov I was trying to implement using the Concat and other layers as you suggested. I'm finding problems with extending it to further decoder blocks.

// This concatenates encoder output and output of first attention block inside the decoder block.
Concat<> encoderAndDecoderBottom;
encoderAndDecoderBottom.Add(encoderSequence);
encoderAndDecoderBottom.Add(decoderBottomSequence);

If I try to extend it to other decoder blocks, how should I do it?

Concat<> encoderAndDecoderSecond;
encoderAndDecoderSecond.Add(encoderSequence); ??
encoderAndDecoderSecond.Add(bottomDecoder);   ??

The input sources are totally different in this case. And if I club those inputs then use subview again to split it, it will run the whole encoder stack again for concatenating the last encoder output and the output of the first decoder. And the same for second and third and so on.
Am I doing it the wrong way?? Do I need to use some other method for concatenating encoder output to further decoder blocks?

@mrityunjay-tripathi
Copy link
Member Author

I'm also clueless about how residual connections would be employed as residual connections are made b/w query and the output of attention block. For the encoder side, it won't be a problem but how can I do it for decoder side residual connections when we use concatenated [query key value].

@lozhnikov
Copy link
Contributor

If I try to extend it to other decoder blocks, how should I do it?

It's hard to explain. Let me write a prototype for N = 2 decoder blocks. I think it'll be easy to do the same for an arbitrary N. I'll try to do this in the evening.

it will run the whole encoder stack again for concatenating the last encoder output and the output of the first decoder

It's easy to avoid. You just need to broadcast the encoder output 2N times, where N is the number of Decoder blocks.

@lozhnikov
Copy link
Contributor

lozhnikov commented Aug 19, 2020

I've implemented a rough model of the encoder block. It should answer your question about the residual connections. I'll continue thinking of the decoder block in the morning.

Sequential<>* CreateEncoder() {
  Sequential<>* encoder = new Sequential<>;

  {
    Concat<>* selfAttentionInput = new Concat<>();
    selfAttentionInput->Add<IdentityLayer<>>();
    selfAttentionInput->Add<IdentityLayer<>>();
    selfAttentionInput->Add<IdentityLayer<>>();

    Sequential<>* selfAttention = new Sequential<>();
    selfAttention->Add(selfAttentionInput);
    selfAttention->Add<MultiheadAttention<>>();

    AddMerge<>* residualAddMerge = new AddMerge<>();
    residualAddMerge->Add(selfAttention);
    residualAddMerge->Add<IdentityLayer<>>();

    encoder->Add(residualAddMerge);
  }
  
  encoder->Add<LayerNorm<>>();
  
  {
    Sequential<>* pointWiseFeedForwardNetwork = new Sequential<>();
//    pointWiseFeedForwardNetwork->Add(......);
//    pointWiseFeedForwardNetwork->Add(......);

    AddMerge<>* residualAddMerge = new AddMerge<>();
    residualAddMerge->Add(pointWiseFeedForwardNetwork);
    residualAddMerge->Add<IdentityLayer<>>();
    encoder->Add(residualAddMerge);
  }

  encoder->Add<LayerNorm<>>();

  return encoder;
}

Upd: I meant the encoder. Now I'm thinking on the decoder.

@lozhnikov
Copy link
Contributor

@mrityunjay-tripathi Finally I implemented the whole draft of the transformer model. Despite the fact it's a draft, I think it's quite accurate. It supports an arbitrary number of the decoders and encoders. You just need to put the correct arguments to the layer constructors especially put the correct arguments to the Subview<> layers.

https://gist.github.com/lozhnikov/aabb9231c0bb72528ff64a4f9bc19923

Tell me if you need any help with this.

@mrityunjay-tripathi
Copy link
Member Author

INFO: Starting to record.
INFO: Processing BoostTest-1.x (default)
INFO: [BoostTest-1.x (default)] - 147 test report file(s) were found with the pattern 'reports/tests/*.boost_test.xml' relative to '/home/jenkins/workspace/pull-requests mlpack memory@2' for the testing framework 'BoostTest-1.x (default)'.
WARNING: The file '/home/jenkins/workspace/pull-requests mlpack memory@2/reports/tests/AsyncLearningTest_OneStepSarsaTest.boost_test.xml' is an invalid file.
WARNING: At line 1 of file:/home/jenkins/workspace/pull-requests%20mlpack%20memory@2/reports/tests/AsyncLearningTest_OneStepSarsaTest.boost_test.xml:XML document structures must start and end within the same entity.
WARNING: Technical validation:XML document structures must start and end within the same entity.
WARNING: The result file '/home/jenkins/workspace/pull-requests mlpack memory@2/reports/tests/AsyncLearningTest_OneStepSarsaTest.boost_test.xml' for the metric 'BoostTest' is not valid. The result file has been skipped.
WARNING: The result file '/home/jenkins/workspace/pull-requests mlpack memory@2/reports/tests/DCGANNetworkTest_DCGANCelebATest.boost_test.xml' for the metric 'BoostTest' is empty. The result file has been skipped.
WARNING: The file '/home/jenkins/workspace/pull-requests mlpack memory@2/reports/tests/InitRulesTest_KathirvalavakumarSubavathiInitTest.boost_test.xml' is an invalid file.
WARNING: At line 1 of file:/home/jenkins/workspace/pull-requests%20mlpack%20memory@2/reports/tests/InitRulesTest_KathirvalavakumarSubavathiInitTest.boost_test.xml:XML document structures must start and end within the same entity.
WARNING: Technical validation:XML document structures must start and end within the same entity.
WARNING: The result file '/home/jenkins/workspace/pull-requests mlpack memory@2/reports/tests/InitRulesTest_KathirvalavakumarSubavathiInitTest.boost_test.xml' for the metric 'BoostTest' is not valid. The result file has been skipped.
WARNING: The result file '/home/jenkins/workspace/pull-requests mlpack memory@2/reports/tests/empty.boost_test.xml' for the metric 'BoostTest' is empty. The result file has been skipped.
INFO: Check 'Failed Tests' threshold.
INFO: Check 'Skipped Tests' threshold.
INFO: Setting the build status to ABORTED
INFO: Stopping recording.
Setting status of 6bbbd94430cc57e4384b1354f3d7caf35b6513f6 to FAILURE with url http://ci.mlpack.org/job/pull-requests%20mlpack%20memory/6288/ and message: 'Build finished. '
Using context: Memory Checks
Finished: ABORTED

I think memory failure has something to do with skipped tests and not failed tests. It shows some .xml file is invalid. Why??

@lozhnikov
Copy link
Contributor

I think memory failure has something to do with skipped tests and not failed tests. It shows some .xml file is invalid. Why??

I think these memory issues are unrelated to your PR. Looks like there are some memory issues in other ANN tests/methods.

Copy link
Contributor

@lozhnikov lozhnikov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. I added some minor style suggestions.

src/mlpack/core/math/multiply_slices_impl.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/methods/ann/layer/multihead_attention.hpp Outdated Show resolved Hide resolved
src/mlpack/tests/ann_test_tools.hpp Outdated Show resolved Hide resolved
src/mlpack/tests/ann_test_tools.hpp Outdated Show resolved Hide resolved
src/mlpack/tests/ann_test_tools.hpp Outdated Show resolved Hide resolved
src/mlpack/tests/ann_test_tools.hpp Outdated Show resolved Hide resolved
src/mlpack/tests/ann_test_tools.hpp Outdated Show resolved Hide resolved
Copy link

@mlpack-bot mlpack-bot bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second approval provided automatically after 24 hours. 👍

@lozhnikov lozhnikov merged commit 944f2b5 into mlpack:master Aug 24, 2020
@lozhnikov
Copy link
Contributor

I merged the PR. The memory issues were unrelated to this PR, I checked the tests with valgrind locally. Thanks for the contribution!

@mrityunjay-tripathi
Copy link
Member Author

Thanks, @lozhnikov! Finally this is done 😅. Thanks for all the reviews, suggestions and helps :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants