6
6
import torch
7
7
8
8
from vllm .model_executor .layers .fused_moe .batched_deep_gemm_moe import (
9
- silu_mul_fp8_quant_deep_gemm ,
9
+ silu_mul_fp8_quant_deep_gemm_cuda ,
10
10
)
11
11
from vllm .platforms import current_platform
12
+ from vllm .triton_utils import tl , triton
13
+ from vllm .utils .deep_gemm import is_deep_gemm_e8m0_used
12
14
13
15
14
- def benchmark (E , T , H , G = 128 , runs = 50 ):
16
+ @triton .jit
17
+ def _silu_mul_fp8_quant_deep_gemm (
18
+ # Pointers ------------------------------------------------------------
19
+ input_ptr , # 16-bit activations (E, T, 2*H)
20
+ y_q_ptr , # fp8 quantized activations (E, T, H)
21
+ y_s_ptr , # 16-bit scales (E, T, G)
22
+ counts_ptr , # int32 num tokens per expert (E)
23
+ # Sizes ---------------------------------------------------------------
24
+ H : tl .constexpr , # hidden dimension (per output)
25
+ GROUP_SIZE : tl .constexpr , # elements per group (usually 128)
26
+ # Strides for input (elements) ---------------------------------------
27
+ stride_i_e ,
28
+ stride_i_t ,
29
+ stride_i_h ,
30
+ # Strides for y_q (elements) -----------------------------------------
31
+ stride_yq_e ,
32
+ stride_yq_t ,
33
+ stride_yq_h ,
34
+ # Strides for y_s (elements) -----------------------------------------
35
+ stride_ys_e ,
36
+ stride_ys_t ,
37
+ stride_ys_g ,
38
+ # Stride for counts (elements)
39
+ stride_counts_e ,
40
+ # Numeric params ------------------------------------------------------
41
+ eps : tl .constexpr ,
42
+ fp8_min : tl .constexpr ,
43
+ fp8_max : tl .constexpr ,
44
+ use_ue8m0 : tl .constexpr ,
45
+ # Meta ---------------------------------------------------------------
46
+ BLOCK : tl .constexpr ,
47
+ NUM_STAGES : tl .constexpr ,
48
+ ):
49
+ G = H // GROUP_SIZE
50
+
51
+ # map program id -> (e, g)
52
+ pid = tl .program_id (0 )
53
+ e = pid // G
54
+ g = pid % G
55
+
56
+ e = e .to (tl .int64 )
57
+ g = g .to (tl .int64 )
58
+
59
+ # number of valid tokens for this expert
60
+ n_tokens = tl .load (counts_ptr + e * stride_counts_e ).to (tl .int64 )
61
+
62
+ cols = tl .arange (0 , BLOCK ).to (tl .int64 )
63
+ mask = cols < BLOCK
64
+
65
+ base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
66
+ base_gate_offset = base_input_offset + cols * stride_i_h
67
+ base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
68
+ base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h
69
+ base_ys_offset = e * stride_ys_e + g * stride_ys_g
70
+
71
+ for t in tl .range (0 , n_tokens , num_stages = NUM_STAGES ):
72
+ gate = tl .load (
73
+ input_ptr + base_gate_offset + t * stride_i_t , mask = mask , other = 0.0
74
+ ).to (tl .float32 )
75
+ up = tl .load (
76
+ input_ptr + base_up_offset + t * stride_i_t , mask = mask , other = 0.0
77
+ ).to (tl .float32 )
78
+
79
+ gate = gate * (1.0 / (1.0 + tl .exp (- gate )))
80
+ y = gate * up
81
+
82
+ y_s = tl .maximum (tl .max (tl .abs (y )), eps ) / fp8_max
83
+ if use_ue8m0 :
84
+ y_s = tl .exp2 (tl .ceil (tl .log2 (y_s )))
85
+
86
+ y_q = tl .clamp (y / y_s , fp8_min , fp8_max ).to (y_q_ptr .dtype .element_ty )
87
+
88
+ tl .store (y_q_ptr + base_yq_offset + t * stride_yq_t , y_q , mask = mask )
89
+ tl .store (y_s_ptr + base_ys_offset + t * stride_ys_t , y_s )
90
+
91
+
92
+ def silu_mul_fp8_quant_deep_gemm_triton (
93
+ y : torch .Tensor , # (E, T, 2*H)
94
+ tokens_per_expert : torch .Tensor , # (E,) number of valid tokens per expert
95
+ num_parallel_tokens ,
96
+ group_size : int = 128 ,
97
+ eps : float = 1e-10 ,
98
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
99
+ """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
100
+
101
+ y has shape (E, T, 2*H). The first half of the last dimension is
102
+ silu-activated, multiplied by the second half, then quantized into FP8.
103
+
104
+ Returns `(y_q, y_s)` where
105
+ * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
106
+ * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
107
+ """
108
+ assert y .ndim == 3 , "y must be (E, T, 2*H)"
109
+ E , T , H2 = y .shape
110
+ assert H2 % 2 == 0 , "last dim of y must be even (2*H)"
111
+ H = H2 // 2
112
+ G = H // group_size
113
+ assert H % group_size == 0 , "H must be divisible by group_size"
114
+ assert tokens_per_expert .ndim == 1 and tokens_per_expert .shape [0 ] == E , (
115
+ "tokens_per_expert must be shape (E,)"
116
+ )
117
+ tokens_per_expert = tokens_per_expert .to (device = y .device , dtype = torch .int32 )
118
+
119
+ # allocate outputs
120
+ fp8_dtype = torch .float8_e4m3fn
121
+ y_q = torch .empty ((E , T , H ), dtype = fp8_dtype , device = y .device )
122
+
123
+ # strides (elements)
124
+ stride_i_e , stride_i_t , stride_i_h = y .stride ()
125
+ stride_yq_e , stride_yq_t , stride_yq_h = y_q .stride ()
126
+
127
+ # desired scale strides (elements): (T*G, 1, T)
128
+ stride_ys_e = T * G
129
+ stride_ys_t = 1
130
+ stride_ys_g = T
131
+ y_s = torch .empty_strided (
132
+ (E , T , G ),
133
+ (stride_ys_e , stride_ys_t , stride_ys_g ),
134
+ dtype = torch .float32 ,
135
+ device = y .device ,
136
+ )
137
+
138
+ stride_cnt_e = tokens_per_expert .stride ()[0 ]
139
+
140
+ # Static grid over experts and H-groups.
141
+ # A loop inside the kernel handles the token dim
142
+ grid = (E * G ,)
143
+
144
+ f_info = torch .finfo (fp8_dtype )
145
+ fp8_max = f_info .max
146
+ fp8_min = f_info .min
147
+
148
+ _silu_mul_fp8_quant_deep_gemm [grid ](
149
+ y ,
150
+ y_q ,
151
+ y_s ,
152
+ tokens_per_expert ,
153
+ H ,
154
+ group_size ,
155
+ stride_i_e ,
156
+ stride_i_t ,
157
+ stride_i_h ,
158
+ stride_yq_e ,
159
+ stride_yq_t ,
160
+ stride_yq_h ,
161
+ stride_ys_e ,
162
+ stride_ys_t ,
163
+ stride_ys_g ,
164
+ stride_cnt_e ,
165
+ eps ,
166
+ fp8_min ,
167
+ fp8_max ,
168
+ is_deep_gemm_e8m0_used (),
169
+ BLOCK = group_size ,
170
+ NUM_STAGES = 4 ,
171
+ num_warps = 1 ,
172
+ )
173
+
174
+ return y_q , y_s
175
+
176
+
177
+ def benchmark (k , E , T , H , num_parallel_tokens , G = 128 , runs = 100 ):
15
178
current_platform .seed_everything (42 )
16
179
y = torch .randn ((E , T , 2 * H ), dtype = torch .bfloat16 , device = "cuda" )
17
180
tokens_per_expert = torch .randint (
18
181
T // 2 , T , size = (E ,), dtype = torch .int32 , device = "cuda"
19
182
)
20
183
21
184
# Warmup
22
- for _ in range (10 ):
23
- silu_mul_fp8_quant_deep_gemm (y , tokens_per_expert , group_size = G )
185
+ for _ in range (20 ):
186
+ k (y , tokens_per_expert , num_parallel_tokens = num_parallel_tokens , group_size = G )
24
187
torch .cuda .synchronize ()
25
188
26
189
# Benchmark
27
190
torch .cuda .synchronize ()
28
191
start = time .perf_counter ()
29
192
for _ in range (runs ):
30
- silu_mul_fp8_quant_deep_gemm (y , tokens_per_expert , group_size = G )
193
+ k (y , tokens_per_expert , num_parallel_tokens = num_parallel_tokens , group_size = G )
31
194
torch .cuda .synchronize ()
32
195
33
196
avg_time = (time .perf_counter () - start ) / runs * 1000
@@ -51,27 +214,37 @@ def benchmark(E, T, H, G=128, runs=50):
51
214
return avg_time , gflops , memory_bw
52
215
53
216
54
- configs = [
55
- (8 , 32 , 1024 ),
56
- (16 , 64 , 2048 ),
57
- (32 , 128 , 4096 ),
58
- # DeepSeekV3 Configs
59
- (256 , 16 , 7168 ),
60
- (256 , 32 , 7168 ),
61
- (256 , 64 , 7168 ),
62
- (256 , 128 , 7168 ),
63
- (256 , 256 , 7168 ),
64
- (256 , 512 , 7168 ),
65
- (256 , 1024 , 7168 ),
66
- ]
67
-
68
- print (f"GPU: { torch .cuda .get_device_name ()} " )
69
- print (f"{ 'Config' :<20} { 'Time(ms)' :<10} { 'GFLOPS' :<10} { 'GB/s' :<10} " )
70
- print ("-" * 50 )
71
-
72
- for E , T , H in configs :
73
- try :
74
- time_ms , gflops , gbps = benchmark (E , T , H )
217
+ def benchmark_full ():
218
+ configs = [
219
+ (32 , 8 , 7168 ),
220
+ (32 , 16 , 7168 ),
221
+ (32 , 32 , 7168 ),
222
+ (32 , 64 , 7168 ),
223
+ (32 , 128 , 7168 ),
224
+ (32 , 256 , 7168 ),
225
+ (32 , 512 , 7168 ),
226
+ (32 , 1024 , 7168 ),
227
+ ]
228
+
229
+ print (f"GPU: { torch .cuda .get_device_name ()} CUDA Kernel" )
230
+ print (f"{ 'Config' :<20} { 'Time(ms)' :<10} { 'GFLOPS' :<10} { 'GB/s' :<10} " )
231
+ print ("-" * 50 )
232
+
233
+ for E , T , H in configs :
234
+ time_ms , gflops , gbps = benchmark (
235
+ silu_mul_fp8_quant_deep_gemm_cuda , E , T , H , 16
236
+ )
237
+ print (f"E={ E :3d} ,T={ T :4d} ,H={ H :4d} { time_ms :8.3f} { gflops :8.1f} { gbps :8.1f} " )
238
+
239
+ print (f"GPU: { torch .cuda .get_device_name ()} Baseline" )
240
+ print (f"{ 'Config' :<20} { 'Time(ms)' :<10} { 'GFLOPS' :<10} { 'GB/s' :<10} " )
241
+ print ("-" * 50 )
242
+
243
+ for E , T , H in configs :
244
+ time_ms , gflops , gbps = benchmark (
245
+ silu_mul_fp8_quant_deep_gemm_triton , E , T , H , 16
246
+ )
75
247
print (f"E={ E :3d} ,T={ T :4d} ,H={ H :4d} { time_ms :8.3f} { gflops :8.1f} { gbps :8.1f} " )
76
- except Exception :
77
- print (f"E={ E :3d} ,T={ T :4d} ,H={ H :4d} FAILED" )
248
+
249
+
250
+ benchmark_full ()
0 commit comments