Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High GPU memory consumption #6

Open
saareliad opened this issue Aug 8, 2019 · 4 comments
Open

High GPU memory consumption #6

saareliad opened this issue Aug 8, 2019 · 4 comments

Comments

@saareliad
Copy link

saareliad commented Aug 8, 2019

Hi,
I tried to integrate the TTLayer into transformerXL,
however I found that it consumes much more memory than usual.
Did you experience such problems? do you know anyway around this?

(BTW I also applied few fixes for multi-GPU training, e.g tensor train objects are not passed to GPU when you activate the model.to(device), therefore breaking the model in distributed training).

@AlexGrinch
Copy link
Collaborator

Hello,

It seems that this is a fundamental problem with TTLayer and how optimization in autograd frameworks is done. In addition to the memory footprint of model weights, during optimization we also store activations on GPU.

In the case of a single FC layer of size d^3 x d^3 and a batch of size B x d^3, the storage memory footprint is d^6 and the activations footprint is Bd^3. In the case of TTLayer with 3 cores, the storage footprint is d^2(r^2 + 2r) but the activation footprint is Bd^3(2r + 1). Batch size (number of tokens, B x L) used in Transformers is usually pretty big and, as a result, an increase in the activation memory footprint (2Brd^3) outweighs the win in memory footprint (d^6).

To be precise, this happens when 2Br > d^3. Common values for d^3 is ~1000-4000, common batch sizes are around 5000 and for ranks they are 8-32. As you see, there is a big activation memory overhead. Most likely, TTLayers are not applicable to FC layers used in Transformers.

@saareliad
Copy link
Author

saareliad commented Aug 14, 2019

Your explanation about activations makes sense, I also went over the math and its correct.

However, the compressed model also consumes much more memory during inference, i.e in eval model and with torch.no_grad().

Top memory consumption was 3021MB for the compressed model versus 2132MB for the normal model.
(TransfomerXL, before compression 151M params, after compression was 124M params, compressed all positional FF layers).

I also tried to write the "forward" method more efficiently (e.g with bmm or einsum) , it didn't help too.
I suspect its due to all the reshapes happening underneath the surface.
What do you think?

@KhrulkovV
Copy link
Owner

KhrulkovV commented Aug 14, 2019

Hey,
We have tried to implement a naive but more concise version of the forward pass for the TTLayer (so far for d=3) and it seems that it fixes the problem (in our example memory usage fell from 7700 MB to 1500MB, which is roughly equal to the memory occupied by the standard Fully Connected layer). The code is in the branch memory-fix, and to allow naive (more efficient) forward pass you have to supply the corresponding argument TTLayer(..., naive=True). Probably it can be optimized even further, I'll work on it.

@saareliad
Copy link
Author

Excellent, cool!
Looks very promising.

I wonder how the "native" solution would scale in terms of compute time and memory consumption.

I can prepare code for d>3, I made a working script for this yesterday for something else.
(my last implementation does the entire TT with a single einsum op, it had high memory consumption too)

That's your main change:
full = torch.einsum('abcd,defg,ghij->bcefhi', core0, core1, core2)
res = torch.einsum('abcd,bqcsdx->aqsx', input, full)

which does

  1. "decompress" , restoring the full matrix.
  2. kind of normal matmul between decompressed and input, but with more dimensions.

Need to fully understand when einsum does a reshape and if it does efficient broadcasting for scaling this.

There are several issues on pytorch repo about einsum, I understood they are working on it:

pytorch/pytorch#10661
pytorch/pytorch#15671

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants