Skip to content

Commit c866220

Browse files
authored
Merge pull request #17 from pmed/fair-sized-chunks-in-for_slice
Fair-sized chunks in `for_slice()` Closes #16
2 parents 797013c + 71749f6 commit c866220

File tree

1 file changed

+10
-23
lines changed

1 file changed

+10
-23
lines changed

include/fork_union.hpp

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -444,29 +444,16 @@ void for_slices( //
444444
thread_index_t const threads_count = pool.threads();
445445
if (threads_count == 1 || prongs_count == 1) return function(prong_t {0, 0}, prongs_count);
446446

447-
// Divide and round-up the workload size per thread - assuming some fuzzer may
448-
// pass an absurdly large `prongs_count` as an input, the addition may overflow,
449-
// so `(prongs_count + threads_count - 1) / threads_count` is not the safest option.
450-
// Instead, we can do: `prongs_count / threads_count + (prongs_count % threads_count != 0)`,
451-
// but avoiding the cost of the second integer division, replacing it with multiplication.
452-
index_t const tasks_per_thread_lower_bound = prongs_count / threads_count;
453-
index_t const tasks_per_thread =
454-
tasks_per_thread_lower_bound + ((tasks_per_thread_lower_bound * threads_count) < prongs_count);
455-
456-
pool.broadcast([prongs_count, tasks_per_thread, tasks_per_thread_lower_bound,
457-
function](thread_index_t const thread_index) noexcept {
458-
// Multiplying `thread_index` by `tasks_per_thread` may overflow. For an 8-bit `index_t` type:
459-
// - 254 threads,
460-
// - 255 tasks,
461-
// - each thread gets 1 or 2 tasks
462-
// In that case, both `begin` and `begin_lower_bound` will overflow, but we can use
463-
// their relative values to determine the real slice length for the thread.
464-
index_t const begin = thread_index * tasks_per_thread; // ? Handled overflow
465-
index_t const begin_lower_bound = tasks_per_thread_lower_bound * thread_index; // ? Handled overflow
466-
bool const begin_overflows = begin_lower_bound > begin;
467-
bool const begin_exceeds_n = begin >= prongs_count;
468-
if (begin_overflows || begin_exceeds_n) return;
469-
index_t const count = (std::min<index_t>)(add_sat(begin, tasks_per_thread), prongs_count) - begin;
447+
// The first (N % M) chunks have size ceil(N/M)
448+
// The remaining N - (N % M) chunks have size floor(N/M)
449+
// where N = prongs_count, M = threads_count
450+
// See https://lemire.me/blog/2025/05/22/dividing-an-array-into-fair-sized-chunks/
451+
index_t const quotient = prongs_count / threads_count;
452+
index_t const remainder = prongs_count % threads_count;
453+
454+
pool.broadcast([quotient, remainder, function](thread_index_t const thread_index) noexcept {
455+
index_t const begin = quotient * thread_index + (thread_index < remainder ? thread_index : remainder);
456+
index_t const count = quotient + (thread_index < remainder ? 1 : 0);
470457
function(prong_t {thread_index, begin}, count);
471458
});
472459
}

0 commit comments

Comments
 (0)