diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index c60f1823b8a1..d1874515cc8f 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -43,6 +43,7 @@ void sm100_cutlass_mla_decode( torch::Tensor const& seq_lens, torch::Tensor const& page_table, torch::Tensor const& workspace, + double sm_scale, int64_t num_kv_splits) { TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode"); } diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 9fd72ee152b5..a3b8f07638d9 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -771,11 +771,11 @@ def dequant_mxfp4_batches(mat_fp4: torch.Tensor, w13_ref = dequant_mxfp4_batches( w13_q.view(torch.uint8), w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( - num_experts, 2 * intermediate_size, hidden_size) + num_experts, 2 * intermediate_size, hidden_size).to(device) w2_ref = dequant_mxfp4_batches( w2_q.view(torch.uint8), w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( - num_experts, hidden_size, intermediate_size) + num_experts, hidden_size, intermediate_size).to(device) # Quantize activations for SM100 path and dequantize for reference hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)