Softmax与Flash-Attention

本文最后更新于 2026年2月20日

Softmax

Flash-Attention的算法根基来自online-softmax。

原始的 softmax 公式:

数值稳定的 safe-softmax 公式:

safe-softmax 与原始的 softmax完全等价,但是对于计算效率来说并不友好,需要三次循环才能得到最后的结果:

  • 求最大值
  • 相减后计算指数和
  • 写回结果
1
2
3
4
5
6
def softmax(x):
x_max = x.max(dim=-1, keepdim=True)[0]
x = x - x_max
x_exp = x.exp()
x_exp_sum = x_exp.sum(dim=-1, keepdim=True)
return x_exp / x_exp_sum

幸运的是,online softmax可以将其减少至两个循环。

我们不妨先考虑前

  1. 【最大值的更新】

  2. 【指数和的更新】

我们理想中的算法是只遍历一次,就可以得到最大值和 ,即 只与 有关(这样意味着我们可以通过递推在计算最大值的循环中得到

项贡献

项的贡献

其中对于前 项的贡献,这部分指数和本来应该是基于 来计算的:

但是这跟上面前 项的贡献表示不同,所以我们要将 转换为以 为基准:

所以:

这里可以这么理解这个公式: 等于对 的修正加上最后一项

对应的 triton kernel

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
@triton.jit
def softmax_kernel(x_ptr, output_ptr, row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)

x_max = -float('inf')
x_exp_sum = 0.0
for offset in range(0, n_cols, BLOCK_SIZE):
col_range = tl.arange(0, BLOCK_SIZE)
col_mask = col_range + offset < n_cols
x = tl.load(x_ptr + pid * row_stride + col_range + offset, mask=col_mask, other=-float('inf'))
x_max_new = tl.maximum(x_max, tl.max(x, axis=-1))
x_exp_sum = x_exp_sum * tl.exp(x_max - x_max_new) + tl.sum(tl.exp(x - x_max_new), axis=-1)
x_max = x_max_new

for offset in range(0, n_cols, BLOCK_SIZE):
col_range = tl.arange(0, BLOCK_SIZE)
col_mask = col_range + offset < n_cols
x = tl.load(x_ptr + pid * row_stride + col_range + offset, mask=col_mask)
x_exp = tl.exp(x - x_max)
tl.store(output_ptr + pid * row_stride + col_range + offset, x_exp / x_exp_sum, mask=col_mask)


def triton_softmax(x):
n_rows, n_cols = x.shape
output = torch.empty_like(x)
BLOCK_SIZE = 256
softmax_kernel[(n_rows,)](
x,
output,
x.stride(0),
n_cols,
BLOCK_SIZE
)
return output

可以看到,这个 triton 算子的核心

1
x_exp_sum = x_exp_sum * tl.exp(x_max - x_max_new) + tl.sum(tl.exp(x - x_max_new), axis=-1)

就是

Flash-Attention

相信读者在看完上面的内容后能够一定程度上读懂下面的flash attn 2 论文中的 forward 算法。

Algorithm FlashAttention-2 forward pass

Require: Matrices in HBM, block sizes , .

  1. Divide into blocks of size each, and divide in to blocks and , of size each.

  2. Divide the output into blocks of size each, and divide the logsumexp into blocks of size each.

  3. for do

    • Load from HBM to on-chip SRAM.
    • On chip, initialize .
    • for do
      • Load from HBM to on-chip SRAM.
      • On chip, compute .
      • On chip, compute , (pointwise), .
      • On chip, compute .
    • end for
    • On chip, compute .
    • On chip, compute .
    • Write to HBM as the -th block of .
    • Write to HBM as the -th block of .
  4. end for

  5. Return the output and the logsumexp .

这里的forward和推理框架中的并不完全相同,比如这里的,在只做 forward 计算的时候并不需要。

triton 版:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@triton.jit
def flash_attention_kernel(q_ptr, k_ptr, v_ptr, output_ptr, n, d: tl.constexpr, BR: tl.constexpr, BC: tl.constexpr):
pid = tl.program_id(0)

q_row_range = tl.arange(0, BR) + pid * BR
q_range = q_row_range[:, None] * d + tl.arange(0, d)
q_mask = (q_row_range[:, None] < n) & (tl.arange(0, d) < d)
q = tl.load(q_ptr + q_range, mask=q_mask)

o = tl.full((BR, d), 0, dtype=tl.float32)
l = tl.full((BR,), 0, dtype=tl.float32)
m = tl.full((BR,), -float('inf'), dtype=tl.float32)

for j in range(0, n, BC):
kv_row_range = tl.arange(0, BC) + j
kv_range = kv_row_range[:, None] * d + tl.arange(0, d)
kv_mask = (kv_row_range[:, None] < n) & (tl.arange(0, d) < d)
k = tl.load(k_ptr + kv_range, mask=kv_mask, other=0)
v = tl.load(v_ptr + kv_range, mask=kv_mask, other=0)
s = tl.dot(q, tl.trans(k))
s_mask = (q_row_range[:, None] < n) & (kv_row_range < n)
s = tl.where(s_mask, s, -float('inf'))
m_new = tl.maximum(m, tl.max(s, axis=-1))
p = tl.exp(s - m_new[:, None])
l = tl.exp(m - m_new) * l + tl.sum(p, axis=-1)
o = tl.exp(m - m_new)[:, None] * o + tl.dot(p, v)
m = m_new

tl.store(output_ptr + q_range, o / l[:, None], mask=q_mask)



def flash_attention(q, k, v):
n, d = q.shape
BR = 32
BC = 64
output = torch.empty((n, d), device=q.device)
grid = lambda meta: (triton.cdiv(n, BR),)
flash_attention_kernel[grid](
q,
k,
v,
output,
n,
d,
BR,
BC,
)
return output

这里可以额外讲一点广播技巧,比如代码中的

1
q_range = q_row_range[:, None] * d + tl.arange(0, d)

q_row_range[:, None] 相当于在最后加上一维,e.g., (n,) -> (n, 1), [0,1,2] -> [[0],[1],[2]]

当与tl.arange(0, d)shape(m,))运算时,结果变为

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
range[:, None]:   广播后:
[0][0, 0, 0, 0]
[1][1, 1, 1, 1]
[2][2, 2, 2, 2]

tl.arange(0, d): 广播后:
[0,1,2,3][0,1,2,3]
[0,1,2,3]
[0,1,2,3]

最终的range
行起始地址: 列偏移: 结果:
[0*100=0] + [0,1,2,3] = [0, 1, 2, 3]
[1*100=100] + [0,1,2,3] = [100,101,102,103]
[2*100=200] + [0,1,2,3] = [200,201,202,203]

附录

本文参考