rope

本文最后更新于 2026年2月4日

什么是位置编码

我们希望给q,k加上位置信息。

举一个具体的例子,“喜欢”和“吃饭”两个词之间可能有比较高的注意力分数,但是我们不希望相隔过远的两个词之间依然有很高的注意力分数。

编码有两种方式,绝对位置和相对位置。

什么是 RoPE

RoPE 是一种位置编码算法,被广泛应用在 transformer 中

其中的维度,这里的代表一个 token 对应的 query,代表所处的位置

由于旋转矩阵的稀疏性,所以直接用矩阵乘法来实现会很浪费算力,推荐通过下述方式来实现RoPE: 在实际运算过程中,一般都是预先计算好

RoPE算子

下面这段torch代码来自flash-attn仓库,实现得非常简洁。从这里我们也可以看出,RoPE只作用在最后一维 head_dim 上,不同 head 互不影响。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import torch
import triton
import triton.language as tl
from einops import rearrange, repeat

def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)


def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat(
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
dim=-1,
)
apply_rotary_emb_compile = torch.compile(apply_rotary_emb_torch)

@triton.jit
def rotary_kernel(
x_ptr, cos_ptr, sin_ptr, out_ptr,
stride_x_bs, stride_x_s, stride_x_h, stride_x_d,
stride_c_bs, stride_c_s, stride_c_d,
stride_o_bs, stride_o_s, stride_o_h, stride_o_d,
B, S, H, D, RO_DIM,
COS_HAS_BS: tl.constexpr,
INTERLEAVED: tl.constexpr,
BLOCK_D: tl.constexpr,
):
pid = tl.program_id(0)

# -------- pid -> (b, s, h) --------
h = pid % H
tmp = pid // H
s = tmp % S
b = tmp // S

# -------- base pointers --------
x_base = x_ptr + b * stride_x_bs + s * stride_x_s + h * stride_x_h
o_base = out_ptr + b * stride_o_bs + s * stride_o_s + h * stride_o_h

if COS_HAS_BS:
cos_base = cos_ptr + b * stride_c_bs + s * stride_c_s
sin_base = sin_ptr + b * stride_c_bs + s * stride_c_s
else:
cos_base = cos_ptr + s * stride_c_s
sin_base = sin_ptr + s * stride_c_s

offs = tl.arange(0, BLOCK_D)

# -------- loop over D --------
for d0 in range(0, D, BLOCK_D):
d = d0 + offs
mask = d < D

# load x
x = tl.load(
x_base + d * stride_x_d,
mask=mask,
other=0.0,
)

rope_mask = d < RO_DIM

if INTERLEAVED:
# ======================================================
# interleaved=True
# PyTorch 等价:
# x1 = x[..., ::2]
# x2 = x[..., 1::2]
# [-x2, x1] reshape back
# ======================================================
even = (d & 1) == 0
pair_idx = d >> 1 # d // 2

cos = tl.load(
cos_base + pair_idx * stride_c_d,
mask=rope_mask,
other=0.0,
)
sin = tl.load(
sin_base + pair_idx * stride_c_d,
mask=rope_mask,
other=0.0,
)

other_d = tl.where(even, d + 1, d - 1)

x_other = tl.load(
x_base + other_d * stride_x_d,
mask=rope_mask,
other=0.0,
)

out_rope = tl.where(
even,
x * cos - x_other * sin,
x_other * sin + x * cos,
)

else:
# ======================================================
# interleaved=False (chunk-based)
# PyTorch 等价:
# x1, x2 = x.chunk(2, dim=-1)
# [-x2, x1]
# ======================================================
half = RO_DIM // 2
first_half = d < half

pair_idx = tl.where(first_half, d, d - half)

cos = tl.load(
cos_base + pair_idx * stride_c_d,
mask=rope_mask,
other=0.0,
)
sin = tl.load(
sin_base + pair_idx * stride_c_d,
mask=rope_mask,
other=0.0,
)

other_d = tl.where(first_half, d + half, d - half)

x_other = tl.load(
x_base + other_d * stride_x_d,
mask=rope_mask,
other=0.0,
)

out_rope = tl.where(
first_half,
x * cos - x_other * sin,
x_other * sin + x * cos,
)

# -------- merge rope / non-rope --------
out = tl.where(rope_mask, out_rope, x)

# -------- store --------
tl.store(
o_base + d * stride_o_d,
out,
mask=mask,
)


def apply_rotary_emb_triton(x, cos, sin, interleaved=False):
"""
x: (B, S, H, D)
cos/sin: (S, D/2) or (B, S, D/2)
"""
B, S, H, D = x.shape
ro_dim = cos.shape[-1] * 2

out = torch.empty_like(x)

grid = (B * S * H,)

COS_HAS_BS = (cos.dim() == 3)

rotary_kernel[grid](
x, cos, sin, out,
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
cos.stride(0) if COS_HAS_BS else 0,
cos.stride(-2),
cos.stride(-1),
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
B, S, H, D, ro_dim,
COS_HAS_BS=COS_HAS_BS,
INTERLEAVED=interleaved,
BLOCK_D=128,
)
return out


if __name__ == "__main__":
# simple test
B, S, H, D = 40, 4000, 64, 128
x = torch.randn(B, S, H, D, device="cuda")
cos = torch.randn(S, D // 2, device="cuda")
sin = torch.randn(S, D // 2, device="cuda")

out_torch = apply_rotary_emb_torch(x, cos, sin, interleaved=False)
out_triton = apply_rotary_emb_triton(x, cos, sin, interleaved=False)

print("Max diff:", (out_torch - out_triton).abs().max().item())

import torch.utils.benchmark as benchmark

t_torch = benchmark.Timer(
stmt="apply_rotary_emb_torch(x, cos, sin, interleaved=False)",
globals=globals(),
).blocked_autorange(min_run_time=1.0)

t_triton = benchmark.Timer(
stmt="apply_rotary_emb_triton(x, cos, sin, interleaved=False)",
globals=globals(),
).blocked_autorange(min_run_time=1.0)

t_compiled = benchmark.Timer(
stmt="apply_rotary_emb_compile(x, cos, sin, interleaved=False)",
globals=globals(),
).blocked_autorange(min_run_time=1.0)

print(t_torch)
print(t_triton)
print(t_compiled)
1
2
3
4
5
6
7
8
9
apply_rotary_emb_torch(x, cos, sin, interleaved=False)
Median: 23.93 ms
IQR: 0.04 ms (23.90 to 23.94)
apply_rotary_emb_triton(x, cos, sin, interleaved=False)
Median: 7.95 ms
IQR: 0.01 ms (7.95 to 7.96)
apply_rotary_emb_compile(x, cos, sin, interleaved=False)
Median: 2.80 ms
IQR: 0.02 ms (2.80 to 2.82)

看来这种简单算子手写还是比不过torch.compile