Softmax与Flash-Attention
本文最后更新于 2026年2月20日
Softmax
Flash-Attention的算法根基来自online-softmax。
原始的 softmax 公式:
数值稳定的 safe-softmax 公式:
safe-softmax 与原始的 softmax完全等价,但是对于计算效率来说并不友好,需要三次循环才能得到最后的结果:
- 求最大值
- 相减后计算指数和
- 写回结果
1 | |
幸运的是,online softmax可以将其减少至两个循环。
我们不妨先考虑前
【最大值的更新】 【指数和的更新】
我们理想中的算法是只遍历一次,就可以得到最大值和
前
第
其中对于前
但是这跟上面前
所以:
这里可以这么理解这个公式:
对应的 triton kernel
1 | |
可以看到,这个 triton 算子的核心 1
x_exp_sum = x_exp_sum * tl.exp(x_max - x_max_new) + tl.sum(tl.exp(x - x_max_new), axis=-1)
就是
Flash-Attention
相信读者在看完上面的内容后能够一定程度上读懂下面的flash attn 2 论文中的 forward 算法。
Algorithm FlashAttention-2 forward pass
Require: Matrices
Divide
into blocks of size each, and divide in to blocks and , of size each. Divide the output
into blocks of size each, and divide the logsumexp into blocks of size each. for
do - Load
from HBM to on-chip SRAM. - On chip, initialize
. - for
do - Load
from HBM to on-chip SRAM. - On chip, compute
. - On chip, compute
, (pointwise), . - On chip, compute
.
- Load
- end for
- On chip, compute
. - On chip, compute
. - Write
to HBM as the -th block of . - Write
to HBM as the -th block of .
- Load
end for
Return the output
and the logsumexp .
这里的forward和推理框架中的并不完全相同,比如这里的
triton 版: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49@triton.jit
def flash_attention_kernel(q_ptr, k_ptr, v_ptr, output_ptr, n, d: tl.constexpr, BR: tl.constexpr, BC: tl.constexpr):
pid = tl.program_id(0)
q_row_range = tl.arange(0, BR) + pid * BR
q_range = q_row_range[:, None] * d + tl.arange(0, d)
q_mask = (q_row_range[:, None] < n) & (tl.arange(0, d) < d)
q = tl.load(q_ptr + q_range, mask=q_mask)
o = tl.full((BR, d), 0, dtype=tl.float32)
l = tl.full((BR,), 0, dtype=tl.float32)
m = tl.full((BR,), -float('inf'), dtype=tl.float32)
for j in range(0, n, BC):
kv_row_range = tl.arange(0, BC) + j
kv_range = kv_row_range[:, None] * d + tl.arange(0, d)
kv_mask = (kv_row_range[:, None] < n) & (tl.arange(0, d) < d)
k = tl.load(k_ptr + kv_range, mask=kv_mask, other=0)
v = tl.load(v_ptr + kv_range, mask=kv_mask, other=0)
s = tl.dot(q, tl.trans(k))
s_mask = (q_row_range[:, None] < n) & (kv_row_range < n)
s = tl.where(s_mask, s, -float('inf'))
m_new = tl.maximum(m, tl.max(s, axis=-1))
p = tl.exp(s - m_new[:, None])
l = tl.exp(m - m_new) * l + tl.sum(p, axis=-1)
o = tl.exp(m - m_new)[:, None] * o + tl.dot(p, v)
m = m_new
tl.store(output_ptr + q_range, o / l[:, None], mask=q_mask)
def flash_attention(q, k, v):
n, d = q.shape
BR = 32
BC = 64
output = torch.empty((n, d), device=q.device)
grid = lambda meta: (triton.cdiv(n, BR),)
flash_attention_kernel[grid](
q,
k,
v,
output,
n,
d,
BR,
BC,
)
return output
这里可以额外讲一点广播技巧,比如代码中的 1
q_range = q_row_range[:, None] * d + tl.arange(0, d)
q_row_range[:, None] 相当于在最后加上一维,e.g.,
(n,) -> (n, 1),
[0,1,2] -> [[0],[1],[2]]
当与tl.arange(0, d)(shape为(m,))运算时,结果变为
1 | |
附录
本文参考