Skip to content

Commit 49aab19

Browse files
committed
CUDA replacement implmementation for silu
Signed-off-by: elvircrn <[email protected]>
1 parent 28f350e commit 49aab19

File tree

6 files changed

+780
-186
lines changed

6 files changed

+780
-186
lines changed

benchmarks/kernels/benchmark_silu_mul_fp8_quant.py

Lines changed: 201 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,191 @@
66
import torch
77

88
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
9-
silu_mul_fp8_quant_deep_gemm,
9+
silu_mul_fp8_quant_deep_gemm_cuda,
1010
)
1111
from vllm.platforms import current_platform
12+
from vllm.triton_utils import tl, triton
13+
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
1214

1315

14-
def benchmark(E, T, H, G=128, runs=50):
16+
@triton.jit
17+
def _silu_mul_fp8_quant_deep_gemm(
18+
# Pointers ------------------------------------------------------------
19+
input_ptr, # 16-bit activations (E, T, 2*H)
20+
y_q_ptr, # fp8 quantized activations (E, T, H)
21+
y_s_ptr, # 16-bit scales (E, T, G)
22+
counts_ptr, # int32 num tokens per expert (E)
23+
# Sizes ---------------------------------------------------------------
24+
H: tl.constexpr, # hidden dimension (per output)
25+
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
26+
# Strides for input (elements) ---------------------------------------
27+
stride_i_e,
28+
stride_i_t,
29+
stride_i_h,
30+
# Strides for y_q (elements) -----------------------------------------
31+
stride_yq_e,
32+
stride_yq_t,
33+
stride_yq_h,
34+
# Strides for y_s (elements) -----------------------------------------
35+
stride_ys_e,
36+
stride_ys_t,
37+
stride_ys_g,
38+
# Stride for counts (elements)
39+
stride_counts_e,
40+
# Numeric params ------------------------------------------------------
41+
eps: tl.constexpr,
42+
fp8_min: tl.constexpr,
43+
fp8_max: tl.constexpr,
44+
use_ue8m0: tl.constexpr,
45+
# Meta ---------------------------------------------------------------
46+
BLOCK: tl.constexpr,
47+
NUM_STAGES: tl.constexpr,
48+
):
49+
G = H // GROUP_SIZE
50+
51+
# map program id -> (e, g)
52+
pid = tl.program_id(0)
53+
e = pid // G
54+
g = pid % G
55+
56+
e = e.to(tl.int64)
57+
g = g.to(tl.int64)
58+
59+
# number of valid tokens for this expert
60+
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
61+
62+
cols = tl.arange(0, BLOCK).to(tl.int64)
63+
mask = cols < BLOCK
64+
65+
base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
66+
base_gate_offset = base_input_offset + cols * stride_i_h
67+
base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
68+
base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h
69+
base_ys_offset = e * stride_ys_e + g * stride_ys_g
70+
71+
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
72+
gate = tl.load(
73+
input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0
74+
).to(tl.float32)
75+
up = tl.load(
76+
input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0
77+
).to(tl.float32)
78+
79+
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
80+
y = gate * up
81+
82+
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
83+
if use_ue8m0:
84+
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
85+
86+
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
87+
88+
tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask)
89+
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
90+
91+
92+
def silu_mul_fp8_quant_deep_gemm_triton(
93+
y: torch.Tensor, # (E, T, 2*H)
94+
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
95+
num_parallel_tokens,
96+
group_size: int = 128,
97+
eps: float = 1e-10,
98+
) -> tuple[torch.Tensor, torch.Tensor]:
99+
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
100+
101+
y has shape (E, T, 2*H). The first half of the last dimension is
102+
silu-activated, multiplied by the second half, then quantized into FP8.
103+
104+
Returns `(y_q, y_s)` where
105+
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
106+
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
107+
"""
108+
assert y.ndim == 3, "y must be (E, T, 2*H)"
109+
E, T, H2 = y.shape
110+
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
111+
H = H2 // 2
112+
G = H // group_size
113+
assert H % group_size == 0, "H must be divisible by group_size"
114+
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, (
115+
"tokens_per_expert must be shape (E,)"
116+
)
117+
tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32)
118+
119+
# allocate outputs
120+
fp8_dtype = torch.float8_e4m3fn
121+
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
122+
123+
# strides (elements)
124+
stride_i_e, stride_i_t, stride_i_h = y.stride()
125+
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
126+
127+
# desired scale strides (elements): (T*G, 1, T)
128+
stride_ys_e = T * G
129+
stride_ys_t = 1
130+
stride_ys_g = T
131+
y_s = torch.empty_strided(
132+
(E, T, G),
133+
(stride_ys_e, stride_ys_t, stride_ys_g),
134+
dtype=torch.float32,
135+
device=y.device,
136+
)
137+
138+
stride_cnt_e = tokens_per_expert.stride()[0]
139+
140+
# Static grid over experts and H-groups.
141+
# A loop inside the kernel handles the token dim
142+
grid = (E * G,)
143+
144+
f_info = torch.finfo(fp8_dtype)
145+
fp8_max = f_info.max
146+
fp8_min = f_info.min
147+
148+
_silu_mul_fp8_quant_deep_gemm[grid](
149+
y,
150+
y_q,
151+
y_s,
152+
tokens_per_expert,
153+
H,
154+
group_size,
155+
stride_i_e,
156+
stride_i_t,
157+
stride_i_h,
158+
stride_yq_e,
159+
stride_yq_t,
160+
stride_yq_h,
161+
stride_ys_e,
162+
stride_ys_t,
163+
stride_ys_g,
164+
stride_cnt_e,
165+
eps,
166+
fp8_min,
167+
fp8_max,
168+
is_deep_gemm_e8m0_used(),
169+
BLOCK=group_size,
170+
NUM_STAGES=4,
171+
num_warps=1,
172+
)
173+
174+
return y_q, y_s
175+
176+
177+
def benchmark(k, E, T, H, num_parallel_tokens, G=128, runs=100):
15178
current_platform.seed_everything(42)
16179
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
17180
tokens_per_expert = torch.randint(
18181
T // 2, T, size=(E,), dtype=torch.int32, device="cuda"
19182
)
20183

21184
# Warmup
22-
for _ in range(10):
23-
silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
185+
for _ in range(20):
186+
k(y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G)
24187
torch.cuda.synchronize()
25188

26189
# Benchmark
27190
torch.cuda.synchronize()
28191
start = time.perf_counter()
29192
for _ in range(runs):
30-
silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
193+
k(y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G)
31194
torch.cuda.synchronize()
32195

33196
avg_time = (time.perf_counter() - start) / runs * 1000
@@ -51,27 +214,37 @@ def benchmark(E, T, H, G=128, runs=50):
51214
return avg_time, gflops, memory_bw
52215

53216

54-
configs = [
55-
(8, 32, 1024),
56-
(16, 64, 2048),
57-
(32, 128, 4096),
58-
# DeepSeekV3 Configs
59-
(256, 16, 7168),
60-
(256, 32, 7168),
61-
(256, 64, 7168),
62-
(256, 128, 7168),
63-
(256, 256, 7168),
64-
(256, 512, 7168),
65-
(256, 1024, 7168),
66-
]
67-
68-
print(f"GPU: {torch.cuda.get_device_name()}")
69-
print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
70-
print("-" * 50)
71-
72-
for E, T, H in configs:
73-
try:
74-
time_ms, gflops, gbps = benchmark(E, T, H)
217+
def benchmark_full():
218+
configs = [
219+
(32, 8, 7168),
220+
(32, 16, 7168),
221+
(32, 32, 7168),
222+
(32, 64, 7168),
223+
(32, 128, 7168),
224+
(32, 256, 7168),
225+
(32, 512, 7168),
226+
(32, 1024, 7168),
227+
]
228+
229+
print(f"GPU: {torch.cuda.get_device_name()} CUDA Kernel")
230+
print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
231+
print("-" * 50)
232+
233+
for E, T, H in configs:
234+
time_ms, gflops, gbps = benchmark(
235+
silu_mul_fp8_quant_deep_gemm_cuda, E, T, H, 16
236+
)
237+
print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
238+
239+
print(f"GPU: {torch.cuda.get_device_name()} Baseline")
240+
print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
241+
print("-" * 50)
242+
243+
for E, T, H in configs:
244+
time_ms, gflops, gbps = benchmark(
245+
silu_mul_fp8_quant_deep_gemm_triton, E, T, H, 16
246+
)
75247
print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
76-
except Exception:
77-
print(f"E={E:3d},T={T:4d},H={H:4d} FAILED")
248+
249+
250+
benchmark_full()

csrc/ops.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,13 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
137137
torch::Tensor& input,
138138
torch::Tensor& input_global_scale);
139139
#endif
140+
void silu_mul_fp8_quant_deep_gemm_cuda(
141+
const at::Tensor& input, // (E, T, 2*H)
142+
const at::Tensor& counts, // (E)
143+
at::Tensor& y_q, // (E, T, H) [OUT]
144+
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
145+
int64_t group_size, double eps, double fp8_min, double fp8_max,
146+
bool use_ue8m0, int64_t num_parallel_tokens);
140147

141148
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
142149

@@ -354,4 +361,4 @@ void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
354361
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
355362
int64_t quant_level, bool cast_bf2half = false);
356363
int64_t qr_max_size();
357-
#endif
364+
#endif

0 commit comments

Comments
 (0)