-
Notifications
You must be signed in to change notification settings - Fork 17
/
flash_attention.py
182 lines (136 loc) · 5.99 KB
/
flash_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import torch
import torch.nn as nn
import numpy as np
import sys
import time
from einops import rearrange
BLOCK_SIZE = 1024
NEG_INF = -1e10 # -infinity
EPSILON = 1e-10
def normal_attention(Q, K, V, mask=None):
scale = 1 / np.sqrt(Q.shape[-1])
Q = Q * scale
QKt = torch.einsum('... i d, ... j d -> ... i j', Q, K)
key_mask = rearrange(mask, 'b j -> b 1 1 j')
QKt = torch.where(key_mask > 0, QKt, NEG_INF)
attn = nn.functional.softmax(QKt, dim=-1)
return attn @ V
def flash_attention_forward(Q, K, V, mask=None):
O = torch.zeros_like(Q, requires_grad=True)
l = torch.zeros(Q.shape[:-1])[...,None]
m = torch.ones(Q.shape[:-1])[...,None] * NEG_INF
O = O.to(device='cuda')
l = l.to(device='cuda')
m = m.to(device='cuda')
Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1])
KV_BLOCK_SIZE = BLOCK_SIZE
Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2)
K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2)
V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2)
mask_BLOCKS = list(torch.split(mask, KV_BLOCK_SIZE, dim=1))
Tr = len(Q_BLOCKS)
Tc = len(K_BLOCKS)
O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))
l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))
m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))
for j in range(Tc):
Kj = K_BLOCKS[j]
Vj = V_BLOCKS[j]
maskj = mask_BLOCKS[j]
for i in range(Tr):
Qi = Q_BLOCKS[i]
Oi = O_BLOCKS[i]
li = l_BLOCKS[i]
mi = m_BLOCKS[i]
scale = 1 / np.sqrt(Q.shape[-1])
Qi_scaled = Qi * scale
S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi_scaled, Kj)
# Masking
maskj_temp = rearrange(maskj, 'b j -> b 1 1 j')
S_ij = torch.where(maskj_temp > 0, S_ij, NEG_INF)
m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True)
P_ij = torch.exp(S_ij - m_block_ij)
# Masking
P_ij = torch.where(maskj_temp > 0, P_ij, 0.)
l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True) + EPSILON
P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj)
mi_new = torch.maximum(m_block_ij, mi)
li_new = torch.exp(mi - mi_new) * li + torch.exp(m_block_ij - mi_new) * l_block_ij
O_BLOCKS[i] = (li/li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj
l_BLOCKS[i] = li_new
m_BLOCKS[i] = mi_new
O = torch.cat(O_BLOCKS, dim=2)
l = torch.cat(l_BLOCKS, dim=2)
m = torch.cat(m_BLOCKS, dim=2)
return O, l, m
def flash_attention_backward(Q, K, V, mask, O, l, m, dO):
Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1])
KV_BLOCK_SIZE = BLOCK_SIZE
Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2)
K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2)
V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2)
mask_BLOCKS = list(torch.split(mask, KV_BLOCK_SIZE, dim=1))
Tr = len(Q_BLOCKS)
Tc = len(K_BLOCKS)
O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))
dO_BLOCKS = list(torch.split(dO, Q_BLOCK_SIZE, dim=2))
l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))
m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))
dQ = torch.zeros_like(Q, requires_grad=True).to(device='cuda')
dK = torch.zeros_like(K, requires_grad=True).to(device='cuda')
dV = torch.zeros_like(V, requires_grad=True).to(device='cuda')
dQ_BLOCKS = list(torch.split(dQ, Q_BLOCK_SIZE, dim=2))
dK_BLOCKS = list(torch.split(dK, KV_BLOCK_SIZE, dim=2))
dV_BLOCKS = list(torch.split(dV, KV_BLOCK_SIZE, dim=2))
for j in range(Tc):
Kj = K_BLOCKS[j]
Vj = V_BLOCKS[j]
maskj = mask_BLOCKS[j]
dKj_block = torch.zeros_like(dK_BLOCKS[j], requires_grad=True).to(device='cuda')
dVj_block = torch.zeros_like(dV_BLOCKS[j], requires_grad=True).to(device='cuda')
for i in range(Tr):
Qi = Q_BLOCKS[i]
Oi = O_BLOCKS[i]
dOi = dO_BLOCKS[i]
li = l_BLOCKS[i]
mi = m_BLOCKS[i]
scale = 1 / np.sqrt(Q.shape[-1])
Qi_scaled = Qi * scale
S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi_scaled, Kj)
# Masking
maskj_temp = rearrange(maskj, 'b j -> b 1 1 j')
S_ij = torch.where(maskj_temp > 0, S_ij, NEG_INF)
P_ij = (1/li) * torch.exp(S_ij - mi)
# Masking
P_ij = torch.where(maskj_temp > 0, P_ij, 0.)
dVj_block = dVj_block + torch.einsum('... r c, ... r d -> ... c d', P_ij, dOi)
dP_ij = torch.einsum('... r d, ... c d -> ... r c', dOi, Vj)
Di = torch.sum(dOi * Oi, dim=-1, keepdims=True)
dS_ij = P_ij * (dP_ij - Di)
dQ_BLOCKS[i] = dQ_BLOCKS[i] + scale * torch.einsum('... r c, ... c d -> ... r d', dS_ij, Kj)
dKj_block = dKj_block + scale * torch.einsum('... r c, ... r d -> ... c d', dS_ij, Qi)
dK_BLOCKS[j] = dKj_block
dV_BLOCKS[j] = dVj_block
dQ = torch.cat(dQ_BLOCKS, dim=2)
dK = torch.cat(dK_BLOCKS, dim=2)
dV = torch.cat(dV_BLOCKS, dim=2)
return dQ, dK, dV
def flash_attention(Q, K, V, mask):
out = flash_attention_forward(Q, K, V, mask)
return out[0]
if __name__ == "__main__":
Q = torch.randn(1, 2, 4096, 1024, requires_grad=True).to(device='cuda')
K = torch.randn(1, 2, 4096, 1024, requires_grad=True).to(device='cuda')
V = torch.randn(1, 2, 4096, 1024, requires_grad=True).to(device='cuda')
mask = torch.randint(0, 2, (1, 4096)).to(device='cuda')
for i in range(10):
start1 = time.time_ns()
out1 = flash_attention(Q, K, V, mask)
end1 = time.time_ns()
start2 = time.time_ns()
out2 = normal_attention(Q, K, V, mask)
end2 = time.time_ns()
t1 = (end1 - start1) / 1000000
t2 = (end2 - start2) / 1000000
print(f'{t1}ms, {t2}ms')
print(torch.allclose(out1, out2, atol=1e-5))