1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
| import torch import triton import triton.language as tl from einops import rearrange, repeat
def rotate_half(x, interleaved=False): if not interleaved: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False): """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) """ ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") return torch.cat( [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], dim=-1, ) apply_rotary_emb_compile = torch.compile(apply_rotary_emb_torch)
@triton.jit def rotary_kernel( x_ptr, cos_ptr, sin_ptr, out_ptr, stride_x_bs, stride_x_s, stride_x_h, stride_x_d, stride_c_bs, stride_c_s, stride_c_d, stride_o_bs, stride_o_s, stride_o_h, stride_o_d, B, S, H, D, RO_DIM, COS_HAS_BS: tl.constexpr, INTERLEAVED: tl.constexpr, BLOCK_D: tl.constexpr, ): pid = tl.program_id(0)
h = pid % H tmp = pid // H s = tmp % S b = tmp // S
x_base = x_ptr + b * stride_x_bs + s * stride_x_s + h * stride_x_h o_base = out_ptr + b * stride_o_bs + s * stride_o_s + h * stride_o_h
if COS_HAS_BS: cos_base = cos_ptr + b * stride_c_bs + s * stride_c_s sin_base = sin_ptr + b * stride_c_bs + s * stride_c_s else: cos_base = cos_ptr + s * stride_c_s sin_base = sin_ptr + s * stride_c_s
offs = tl.arange(0, BLOCK_D)
for d0 in range(0, D, BLOCK_D): d = d0 + offs mask = d < D
x = tl.load( x_base + d * stride_x_d, mask=mask, other=0.0, )
rope_mask = d < RO_DIM
if INTERLEAVED: even = (d & 1) == 0 pair_idx = d >> 1
cos = tl.load( cos_base + pair_idx * stride_c_d, mask=rope_mask, other=0.0, ) sin = tl.load( sin_base + pair_idx * stride_c_d, mask=rope_mask, other=0.0, )
other_d = tl.where(even, d + 1, d - 1)
x_other = tl.load( x_base + other_d * stride_x_d, mask=rope_mask, other=0.0, )
out_rope = tl.where( even, x * cos - x_other * sin, x_other * sin + x * cos, )
else: half = RO_DIM // 2 first_half = d < half
pair_idx = tl.where(first_half, d, d - half)
cos = tl.load( cos_base + pair_idx * stride_c_d, mask=rope_mask, other=0.0, ) sin = tl.load( sin_base + pair_idx * stride_c_d, mask=rope_mask, other=0.0, )
other_d = tl.where(first_half, d + half, d - half)
x_other = tl.load( x_base + other_d * stride_x_d, mask=rope_mask, other=0.0, )
out_rope = tl.where( first_half, x * cos - x_other * sin, x_other * sin + x * cos, )
out = tl.where(rope_mask, out_rope, x)
tl.store( o_base + d * stride_o_d, out, mask=mask, )
def apply_rotary_emb_triton(x, cos, sin, interleaved=False): """ x: (B, S, H, D) cos/sin: (S, D/2) or (B, S, D/2) """ B, S, H, D = x.shape ro_dim = cos.shape[-1] * 2
out = torch.empty_like(x)
grid = (B * S * H,)
COS_HAS_BS = (cos.dim() == 3)
rotary_kernel[grid]( x, cos, sin, out, x.stride(0), x.stride(1), x.stride(2), x.stride(3), cos.stride(0) if COS_HAS_BS else 0, cos.stride(-2), cos.stride(-1), out.stride(0), out.stride(1), out.stride(2), out.stride(3), B, S, H, D, ro_dim, COS_HAS_BS=COS_HAS_BS, INTERLEAVED=interleaved, BLOCK_D=128, ) return out
if __name__ == "__main__": B, S, H, D = 40, 4000, 64, 128 x = torch.randn(B, S, H, D, device="cuda") cos = torch.randn(S, D // 2, device="cuda") sin = torch.randn(S, D // 2, device="cuda")
out_torch = apply_rotary_emb_torch(x, cos, sin, interleaved=False) out_triton = apply_rotary_emb_triton(x, cos, sin, interleaved=False)
print("Max diff:", (out_torch - out_triton).abs().max().item())
import torch.utils.benchmark as benchmark
t_torch = benchmark.Timer( stmt="apply_rotary_emb_torch(x, cos, sin, interleaved=False)", globals=globals(), ).blocked_autorange(min_run_time=1.0)
t_triton = benchmark.Timer( stmt="apply_rotary_emb_triton(x, cos, sin, interleaved=False)", globals=globals(), ).blocked_autorange(min_run_time=1.0)
t_compiled = benchmark.Timer( stmt="apply_rotary_emb_compile(x, cos, sin, interleaved=False)", globals=globals(), ).blocked_autorange(min_run_time=1.0)
print(t_torch) print(t_triton) print(t_compiled)
|