-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
multihead attention #2375
multihead attention #2375
Conversation
c818c76
to
e51a42f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the slow response. I added a couple of comments.
d692767
to
6ff1d77
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a couple of comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think perhaps, it's better to pass three matrices (query, key, and value, concatenated into one matrix) each time. It would simplify the interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments for the Forward()
implementation.
Wow! Tests are failing. Locally they were passing. 😕 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some minor comments. I'm still looking through the implementation.
I think these memory errors aren't your fault. I verified your tests with |
Ohh right. I had to scroll up again and again to see that comment, so missed the part that |
Could you apply this patch please? restore-ordering.patch.zip It restores natural ordering. |
Wow!!! Thank you so much for working this out. |
@lozhnikov I was trying to implement using the Concat and other layers as you suggested. I'm finding problems with extending it to further decoder blocks.
If I try to extend it to other decoder blocks, how should I do it?
The input sources are totally different in this case. And if I club those inputs then use subview again to split it, it will run the whole encoder stack again for concatenating the last encoder output and the output of the first decoder. And the same for second and third and so on. |
I'm also clueless about how residual connections would be employed as residual connections are made b/w query and the output of attention block. For the encoder side, it won't be a problem but how can I do it for decoder side residual connections when we use concatenated |
It's hard to explain. Let me write a prototype for
It's easy to avoid. You just need to broadcast the encoder output |
I've implemented a rough model of the encoder block. It should answer your question about the residual connections. I'll continue thinking of the decoder block in the morning. Sequential<>* CreateEncoder() {
Sequential<>* encoder = new Sequential<>;
{
Concat<>* selfAttentionInput = new Concat<>();
selfAttentionInput->Add<IdentityLayer<>>();
selfAttentionInput->Add<IdentityLayer<>>();
selfAttentionInput->Add<IdentityLayer<>>();
Sequential<>* selfAttention = new Sequential<>();
selfAttention->Add(selfAttentionInput);
selfAttention->Add<MultiheadAttention<>>();
AddMerge<>* residualAddMerge = new AddMerge<>();
residualAddMerge->Add(selfAttention);
residualAddMerge->Add<IdentityLayer<>>();
encoder->Add(residualAddMerge);
}
encoder->Add<LayerNorm<>>();
{
Sequential<>* pointWiseFeedForwardNetwork = new Sequential<>();
// pointWiseFeedForwardNetwork->Add(......);
// pointWiseFeedForwardNetwork->Add(......);
AddMerge<>* residualAddMerge = new AddMerge<>();
residualAddMerge->Add(pointWiseFeedForwardNetwork);
residualAddMerge->Add<IdentityLayer<>>();
encoder->Add(residualAddMerge);
}
encoder->Add<LayerNorm<>>();
return encoder;
} Upd: I meant the encoder. Now I'm thinking on the decoder. |
@mrityunjay-tripathi Finally I implemented the whole draft of the transformer model. Despite the fact it's a draft, I think it's quite accurate. It supports an arbitrary number of the decoders and encoders. You just need to put the correct arguments to the layer constructors especially put the correct arguments to the https://gist.github.com/lozhnikov/aabb9231c0bb72528ff64a4f9bc19923 Tell me if you need any help with this. |
I think memory failure has something to do with skipped tests and not failed tests. It shows some .xml file is invalid. Why?? |
I think these memory issues are unrelated to your PR. Looks like there are some memory issues in other ANN tests/methods. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. I added some minor style suggestions.
Co-authored-by: Mikhail Lozhnikov <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second approval provided automatically after 24 hours. 👍
I merged the PR. The memory issues were unrelated to this PR, I checked the tests with valgrind locally. Thanks for the contribution! |
Thanks, @lozhnikov! Finally this is done 😅. Thanks for all the reviews, suggestions and helps :) |
Hi everyone,
I've worked on the implementation of multihead attention. The multihead attention layer would be required for the Transformer model. Debugging and refactoring of the code will come subsequently but this is the initial structure on which I will be working on. The implementation is mostly motivated from PyTorch and Tensorflow.