Skip to content

Commit 874c8dd

Browse files
mgoinbbartels
authored andcommitted
[CI Failure] Fix test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe (vllm-project#24750)
Signed-off-by: mgoin <[email protected]> Signed-off-by: bbartels <[email protected]>
1 parent d9b07a7 commit 874c8dd

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

csrc/attention/mla/sm100_cutlass_mla_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ void sm100_cutlass_mla_decode(
4343
torch::Tensor const& seq_lens,
4444
torch::Tensor const& page_table,
4545
torch::Tensor const& workspace,
46+
double sm_scale,
4647
int64_t num_kv_splits) {
4748
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
4849
}

tests/kernels/moe/test_mxfp4_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -771,11 +771,11 @@ def dequant_mxfp4_batches(mat_fp4: torch.Tensor,
771771
w13_ref = dequant_mxfp4_batches(
772772
w13_q.view(torch.uint8),
773773
w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
774-
num_experts, 2 * intermediate_size, hidden_size)
774+
num_experts, 2 * intermediate_size, hidden_size).to(device)
775775
w2_ref = dequant_mxfp4_batches(
776776
w2_q.view(torch.uint8),
777777
w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
778-
num_experts, hidden_size, intermediate_size)
778+
num_experts, hidden_size, intermediate_size).to(device)
779779

780780
# Quantize activations for SM100 path and dequantize for reference
781781
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)

0 commit comments

Comments
 (0)