Skip to content

Commit 4089a45

Browse files
authored
Fix: Safe broadcast in Rust (#15)
* Change the api of `ThreadPool::broadcast` to accept `&mut self` instead of `self`. This for getting compile time exclusivity, so we can prevent concurrent broadcasts at compile time not for mutability * Fix `nbody` test
1 parent 7658145 commit 4089a45

File tree

2 files changed

+39
-39
lines changed

2 files changed

+39
-39
lines changed

fork_union.rs

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ impl Inner {
157157
///
158158
/// ```no_run
159159
/// use fork_union as fu;
160-
/// let pool = fu::spawn(4); // ! Unsafe shortcut, see below
160+
/// let mut pool = fu::spawn(4); // ! Unsafe shortcut, see below
161161
/// pool.broadcast(|thread_index| {
162162
/// println!("Hello from thread {thread_index}!");
163163
/// });
@@ -175,8 +175,8 @@ impl Inner {
175175
/// fn heavy_math(_: usize) {}
176176
///
177177
/// fn main() -> Result<(), Box<dyn Error>> {
178-
/// let pool = fu::ThreadPool::try_spawn_in(4, Global)?;
179-
/// fu::for_n_dynamic(&pool, 400, |prong| {
178+
/// let mut pool = fu::ThreadPool::try_spawn_in(4, Global)?;
179+
/// fu::for_n_dynamic(&mut pool, 400, |prong| {
180180
/// heavy_math(prong.task_index);
181181
/// });
182182
/// Ok(())
@@ -260,7 +260,7 @@ impl<A: Allocator + Clone> ThreadPool<A> {
260260
}
261261

262262
/// Executes a function on each thread of the pool.
263-
pub fn broadcast<F>(&self, function: F)
263+
pub fn broadcast<F>(&mut self, function: F)
264264
where
265265
F: Fn(usize) + Sync,
266266
{
@@ -358,7 +358,7 @@ pub struct Prong {
358358
}
359359

360360
/// Distributes `n` similar duration calls between threads in slices.
361-
pub fn for_slices<A, F>(pool: &ThreadPool<A>, n: usize, function: F)
361+
pub fn for_slices<A, F>(pool: &mut ThreadPool<A>, n: usize, function: F)
362362
where
363363
A: Allocator + Clone,
364364
F: Fn(Prong, usize) + Sync,
@@ -403,7 +403,7 @@ where
403403
}
404404

405405
/// Distributes `n` similar duration calls between threads by individual indices.
406-
pub fn for_n<A, F>(pool: &ThreadPool<A>, n: usize, function: F)
406+
pub fn for_n<A, F>(pool: &mut ThreadPool<A>, n: usize, function: F)
407407
where
408408
A: Allocator + Clone,
409409
F: Fn(Prong) + Sync,
@@ -419,7 +419,7 @@ where
419419
}
420420

421421
/// Executes `n` uneven tasks on all threads, greedily stealing work.
422-
pub fn for_n_dynamic<A, F>(pool: &ThreadPool<A>, n: usize, function: F)
422+
pub fn for_n_dynamic<A, F>(pool: &mut ThreadPool<A>, n: usize, function: F)
423423
where
424424
A: Allocator + Clone,
425425
F: Fn(Prong) + Sync,
@@ -540,9 +540,9 @@ impl<F> SyncConstPtr<F> {
540540
///
541541
/// ```no_run
542542
/// use fork_union as fu;
543-
/// let pool = fu::spawn(1);
543+
/// let mut pool = fu::spawn(1);
544544
/// let mut data = vec![0u64; 1_000_000];
545-
/// fu::for_each_prong_mut(&pool, &mut data, |x, prong| {
545+
/// fu::for_each_prong_mut(&mut pool, &mut data, |x, prong| {
546546
/// *x = prong.task_index as u64 * 2;
547547
/// });
548548
/// ```
@@ -551,7 +551,7 @@ impl<F> SyncConstPtr<F> {
551551
/// set of elements in parallel, so this API serves as a shortcut.
552552
///
553553
/// Similar to Rayon's `par_chunks_mut`.
554-
pub fn for_each_prong_mut<A, T, F>(pool: &ThreadPool<A>, data: &mut [T], function: F)
554+
pub fn for_each_prong_mut<A, T, F>(pool: &mut ThreadPool<A>, data: &mut [T], function: F)
555555
where
556556
A: Allocator + Clone,
557557
T: Send,
@@ -576,9 +576,9 @@ where
576576
///
577577
/// ```no_run
578578
/// use fork_union as fu;
579-
/// let pool = fu::spawn(1);
579+
/// let mut pool = fu::spawn(1);
580580
/// let mut strings = vec![String::new(); 1_000];
581-
/// fu::for_each_prong_mut_dynamic(&pool, &mut strings, |s, prong| {
581+
/// fu::for_each_prong_mut_dynamic(&mut pool, &mut strings, |s, prong| {
582582
/// s.push_str(&format!("hello from thread {}", prong.thread_index));
583583
/// });
584584
/// ```
@@ -587,7 +587,7 @@ where
587587
/// set of elements in parallel, so this API serves as a shortcut.
588588
///
589589
/// Similar to Rayon's `par_iter_mut`.
590-
pub fn for_each_prong_mut_dynamic<A, T, F>(pool: &ThreadPool<A>, data: &mut [T], function: F)
590+
pub fn for_each_prong_mut_dynamic<A, T, F>(pool: &mut ThreadPool<A>, data: &mut [T], function: F)
591591
where
592592
A: Allocator + Clone,
593593
T: Send,
@@ -633,7 +633,7 @@ mod tests {
633633
#[test]
634634
fn for_each_thread_dispatch() {
635635
let count_threads = hw_threads();
636-
let pool = spawn(count_threads);
636+
let mut pool = spawn(count_threads);
637637

638638
let visited = Arc::new(
639639
(0..count_threads)
@@ -657,11 +657,11 @@ mod tests {
657657
#[test]
658658
fn for_each_static_uncomfortable_input_size() {
659659
let count_threads = hw_threads();
660-
let pool = spawn(count_threads);
660+
let mut pool = spawn(count_threads);
661661

662662
for input_size in 0..count_threads {
663663
let out_of_bounds = AtomicBool::new(false);
664-
for_n(&pool, input_size, |prong| {
664+
for_n(&mut pool, input_size, |prong| {
665665
let task_index = prong.task_index;
666666
if task_index >= count_threads {
667667
out_of_bounds.store(true, Ordering::Relaxed);
@@ -677,7 +677,7 @@ mod tests {
677677
#[test]
678678
fn for_each_static_static_scheduling() {
679679
const EXPECTED_PARTS: usize = 10_000_000;
680-
let pool = spawn(hw_threads());
680+
let mut pool = spawn(hw_threads());
681681

682682
let visited = Arc::new(
683683
(0..EXPECTED_PARTS)
@@ -688,7 +688,7 @@ mod tests {
688688
let visited_ref = Arc::clone(&visited);
689689
let duplicate_ref = Arc::clone(&duplicate);
690690

691-
for_n(&pool, EXPECTED_PARTS, move |prong| {
691+
for_n(&mut pool, EXPECTED_PARTS, move |prong| {
692692
let task_index = prong.task_index;
693693
if visited_ref[task_index].swap(true, Ordering::Relaxed) {
694694
duplicate_ref.store(true, Ordering::Relaxed);
@@ -707,7 +707,7 @@ mod tests {
707707
#[test]
708708
fn for_each_dynamic_dynamic_scheduling() {
709709
const EXPECTED_PARTS: usize = 10_000_000;
710-
let pool = spawn(hw_threads());
710+
let mut pool = spawn(hw_threads());
711711

712712
let visited = Arc::new(
713713
(0..EXPECTED_PARTS)
@@ -718,7 +718,7 @@ mod tests {
718718
let visited_ref = Arc::clone(&visited);
719719
let duplicate_ref = Arc::clone(&duplicate);
720720

721-
for_n_dynamic(&pool, EXPECTED_PARTS, move |prong| {
721+
for_n_dynamic(&mut pool, EXPECTED_PARTS, move |prong| {
722722
let task_index = prong.task_index;
723723
if visited_ref[task_index].swap(true, Ordering::Relaxed) {
724724
duplicate_ref.store(true, Ordering::Relaxed);
@@ -739,7 +739,7 @@ mod tests {
739739
const EXPECTED_PARTS: usize = 10_000_000;
740740
const OVERSUBSCRIPTION: usize = 7;
741741
let threads = hw_threads() * OVERSUBSCRIPTION;
742-
let pool = spawn(threads);
742+
let mut pool = spawn(threads);
743743

744744
let visited = Arc::new(
745745
(0..EXPECTED_PARTS)
@@ -752,7 +752,7 @@ mod tests {
752752

753753
thread_local! { static LOCAL_WORK: std::cell::Cell<usize> = std::cell::Cell::new(0); }
754754

755-
for_n_dynamic(&pool, EXPECTED_PARTS, move |prong| {
755+
for_n_dynamic(&mut pool, EXPECTED_PARTS, move |prong| {
756756
let task_index = prong.task_index;
757757
// Mildly unbalanced CPU burn
758758
LOCAL_WORK.with(|cell| {
@@ -787,8 +787,8 @@ mod tests {
787787
}
788788

789789
TASK_COUNTER.store(0, Ordering::Relaxed);
790-
let pool = spawn(hw_threads());
791-
for_n_dynamic(&pool, EXPECTED_PARTS, |prong| tally(prong.task_index));
790+
let mut pool = spawn(hw_threads());
791+
for_n_dynamic(&mut pool, EXPECTED_PARTS, |prong| tally(prong.task_index));
792792

793793
assert_eq!(
794794
TASK_COUNTER.load(Ordering::Relaxed),
@@ -801,7 +801,7 @@ mod tests {
801801
fn concurrent_histogram_array() {
802802
const HIST_SIZE: usize = 16;
803803
const ELEMENTS: usize = 1_000_000;
804-
let pool = spawn(hw_threads());
804+
let mut pool = spawn(hw_threads());
805805

806806
let values: Vec<usize> = (0..ELEMENTS).map(|i| i % HIST_SIZE).collect();
807807
let histogram = Arc::new(
@@ -811,7 +811,7 @@ mod tests {
811811
);
812812
let hist_ref = Arc::clone(&histogram);
813813

814-
for_n_dynamic(&pool, ELEMENTS, |prong| {
814+
for_n_dynamic(&mut pool, ELEMENTS, |prong| {
815815
let task_index = prong.task_index;
816816
let value = values[task_index];
817817
hist_ref[value].fetch_add(1, Ordering::Relaxed);
@@ -826,7 +826,7 @@ mod tests {
826826
}
827827
}
828828

829-
fn increment_all(pool: &ThreadPool, data: &[AtomicUsize]) {
829+
fn increment_all(pool: &mut ThreadPool, data: &[AtomicUsize]) {
830830
for_n(pool, data.len(), |prong| {
831831
data[prong.task_index].fetch_add(1, Ordering::Relaxed);
832832
});
@@ -835,14 +835,14 @@ mod tests {
835835
#[test]
836836
fn pass_pool_and_reuse() {
837837
const ELEMENTS: usize = 128;
838-
let pool = spawn(hw_threads());
838+
let mut pool = spawn(hw_threads());
839839

840840
let data = (0..ELEMENTS)
841841
.map(|_| AtomicUsize::new(0))
842842
.collect::<Vec<_>>();
843843

844-
increment_all(&pool, &data);
845-
increment_all(&pool, &data);
844+
increment_all(&mut pool, &data);
845+
increment_all(&mut pool, &data);
846846

847847
for counter in data.iter() {
848848
assert_eq!(counter.load(Ordering::Relaxed), 2);
@@ -854,7 +854,7 @@ mod tests {
854854
let mut pool = spawn(hw_threads());
855855
static COUNTER: AtomicUsize = AtomicUsize::new(0);
856856

857-
for_n(&pool, 1000, |_| {
857+
for_n(&mut pool, 1000, |_| {
858858
COUNTER.fetch_add(1, Ordering::Relaxed);
859859
});
860860

@@ -968,7 +968,7 @@ mod tests {
968968
);
969969

970970
let large_allocator = CountingAllocator::new(Some(1024 * 1024));
971-
let pool = ThreadPool::try_spawn_in(hw_threads(), large_allocator.clone())
971+
let mut pool = ThreadPool::try_spawn_in(hw_threads(), large_allocator.clone())
972972
.expect("We should have enough memory for this!");
973973

974974
let visited = Arc::new(
@@ -980,7 +980,7 @@ mod tests {
980980
let visited_ref = Arc::clone(&visited);
981981
let duplicate_ref = Arc::clone(&duplicate);
982982

983-
for_n_dynamic(&pool, EXPECTED_PARTS, move |prong| {
983+
for_n_dynamic(&mut pool, EXPECTED_PARTS, move |prong| {
984984
let task_index = prong.task_index;
985985
if visited_ref[task_index].swap(true, Ordering::Relaxed) {
986986
duplicate_ref.store(true, Ordering::Relaxed);

scripts/nbody.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ fn hw_threads() -> usize {
110110
// ────────────────────────────────────────────────────────────────────────────
111111
// Fork-Union kernels
112112
// ────────────────────────────────────────────────────────────────────────────
113-
fn iteration_fu_static(pool: &fu::ThreadPool, bodies: &mut [Body], forces: &mut [Vector3]) {
113+
fn iteration_fu_static(pool: &mut fu::ThreadPool, bodies: &mut [Body], forces: &mut [Vector3]) {
114114
let n = bodies.len();
115115
let bodies_ptr = fu::SyncConstPtr::new(bodies.as_ptr());
116116

@@ -129,7 +129,7 @@ fn iteration_fu_static(pool: &fu::ThreadPool, bodies: &mut [Body], forces: &mut
129129
});
130130
}
131131

132-
fn iteration_fu_dynamic(pool: &fu::ThreadPool, bodies: &mut [Body], forces: &mut [Vector3]) {
132+
fn iteration_fu_dynamic(pool: &mut fu::ThreadPool, bodies: &mut [Body], forces: &mut [Vector3]) {
133133
let n = bodies.len();
134134
let bodies_ptr = fu::SyncConstPtr::new(bodies.as_ptr());
135135

@@ -254,17 +254,17 @@ fn main() -> Result<(), Box<dyn Error>> {
254254
// Run the chosen backend
255255
match backend.as_str() {
256256
"fork_union_static" => {
257-
let pool = fu::ThreadPool::try_spawn(threads)
257+
let mut pool = fu::ThreadPool::try_spawn(threads)
258258
.unwrap_or_else(|e| panic!("Failed to start Fork-Union pool: {e}"));
259259
for _ in 0..iters {
260-
iteration_fu_static(&pool, &mut bodies, &mut forces);
260+
iteration_fu_static(&mut pool, &mut bodies, &mut forces);
261261
}
262262
}
263263
"fork_union_dynamic" => {
264-
let pool = fu::ThreadPool::try_spawn(threads)
264+
let mut pool = fu::ThreadPool::try_spawn(threads)
265265
.unwrap_or_else(|e| panic!("Failed to start Fork-Union pool: {e}"));
266266
for _ in 0..iters {
267-
iteration_fu_dynamic(&pool, &mut bodies, &mut forces);
267+
iteration_fu_dynamic(&mut pool, &mut bodies, &mut forces);
268268
}
269269
}
270270
"rayon_static" => {

0 commit comments

Comments
 (0)