@@ -3656,10 +3656,33 @@ GGML_CALL static bool ggml_backend_cuda_supports_buft(ggml_backend_t backend, gg
3656
3656
}
3657
3657
3658
3658
GGML_CALL static bool ggml_backend_cuda_offload_op (ggml_backend_t backend, const ggml_tensor * op) {
3659
- const int min_batch_size = 32 ;
3660
-
3661
- return (op->ne [1 ] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
3662
- (op->ne [2 ] >= min_batch_size && (op->op == GGML_OP_MUL_MAT_ID || op->op == GGML_OP_MOE_FUSED_UP_GATE));
3659
+ constexpr int min_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
3660
+
3661
+ // Why do we want to do this? The heuristics that the batch must have more than min_batch_size tokens to be worth it
3662
+ // offloading the required model weights comes from dense models. For MoE models, the average number of tokens
3663
+ // each expert deals with in a batch is (active_experts / total_experts) * batch_size. Hence, according to the
3664
+ // learned heuristics, we need (active_experts / total_experts) * batch_size >= min_batch_size.
3665
+ // Rearranging we get
3666
+ //
3667
+ // batch_size * active_experts >= min_batch_size * total_experts
3668
+ //
3669
+ // as the condition for offloading model weights resinding in RAM to the GPU.
3670
+ // In this case, the number of tokens is not as usual in op->ne[1] but rather in op->ne[2].
3671
+ if (op->op == GGML_OP_MUL_MAT_ID || op->op == GGML_OP_MOE_FUSED_UP_GATE) {
3672
+ auto ids = op->op == GGML_OP_MUL_MAT_ID ? op->src [2 ] : op->src [3 ];
3673
+ int64_t batch_size = op->ne [2 ];
3674
+ if (batch_size < min_batch_size) return false ;
3675
+ int64_t n_experts_tot = op->src [0 ]->ne [2 ];
3676
+ int64_t n_experts_active = ids->ne [0 ];
3677
+ // printf("%s(%s): op->ne[2] = %ld, n_experts_tot = %ld, n_experts_active = %ld, ids: %s, %ld x %ld x %ld x %ld\n", __func__, op->name, op->ne[2], n_experts_tot, n_experts_active, ids->name, ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3]);
3678
+ return batch_size*n_experts_active >= min_batch_size*n_experts_tot;
3679
+ }
3680
+
3681
+ return op->ne [1 ] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
3682
+
3683
+ // Original:
3684
+ // return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
3685
+ // (op->ne[2] >= min_batch_size && (op->op == GGML_OP_MUL_MAT_ID || op->op == GGML_OP_MOE_FUSED_UP_GATE));
3663
3686
3664
3687
GGML_UNUSED (backend);
3665
3688
}
0 commit comments