@@ -444,29 +444,16 @@ void for_slices( //
444
444
thread_index_t const threads_count = pool.threads ();
445
445
if (threads_count == 1 || prongs_count == 1 ) return function (prong_t {0 , 0 }, prongs_count);
446
446
447
- // Divide and round-up the workload size per thread - assuming some fuzzer may
448
- // pass an absurdly large `prongs_count` as an input, the addition may overflow,
449
- // so `(prongs_count + threads_count - 1) / threads_count` is not the safest option.
450
- // Instead, we can do: `prongs_count / threads_count + (prongs_count % threads_count != 0)`,
451
- // but avoiding the cost of the second integer division, replacing it with multiplication.
452
- index_t const tasks_per_thread_lower_bound = prongs_count / threads_count;
453
- index_t const tasks_per_thread =
454
- tasks_per_thread_lower_bound + ((tasks_per_thread_lower_bound * threads_count) < prongs_count);
455
-
456
- pool.broadcast ([prongs_count, tasks_per_thread, tasks_per_thread_lower_bound,
457
- function](thread_index_t const thread_index) noexcept {
458
- // Multiplying `thread_index` by `tasks_per_thread` may overflow. For an 8-bit `index_t` type:
459
- // - 254 threads,
460
- // - 255 tasks,
461
- // - each thread gets 1 or 2 tasks
462
- // In that case, both `begin` and `begin_lower_bound` will overflow, but we can use
463
- // their relative values to determine the real slice length for the thread.
464
- index_t const begin = thread_index * tasks_per_thread; // ? Handled overflow
465
- index_t const begin_lower_bound = tasks_per_thread_lower_bound * thread_index; // ? Handled overflow
466
- bool const begin_overflows = begin_lower_bound > begin;
467
- bool const begin_exceeds_n = begin >= prongs_count;
468
- if (begin_overflows || begin_exceeds_n) return ;
469
- index_t const count = (std::min<index_t >)(add_sat (begin, tasks_per_thread), prongs_count) - begin;
447
+ // The first (N % M) chunks have size ceil(N/M)
448
+ // The remaining N - (N % M) chunks have size floor(N/M)
449
+ // where N = prongs_count, M = threads_count
450
+ // See https://lemire.me/blog/2025/05/22/dividing-an-array-into-fair-sized-chunks/
451
+ index_t const quotient = prongs_count / threads_count;
452
+ index_t const remainder = prongs_count % threads_count;
453
+
454
+ pool.broadcast ([quotient, remainder, function](thread_index_t const thread_index) noexcept {
455
+ index_t const begin = quotient * thread_index + (thread_index < remainder ? thread_index : remainder);
456
+ index_t const count = quotient + (thread_index < remainder ? 1 : 0 );
470
457
function (prong_t {thread_index, begin}, count);
471
458
});
472
459
}
0 commit comments