From ddf90c8d5eaff48a4e640cd939ad0c2cdecbc398 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Sun, 8 Jun 2025 15:07:41 -0400 Subject: [PATCH] fixes #946 -- return both the old and new values in `Atomic::compare_exchange` --- crossbeam-epoch/src/atomic.rs | 175 +++++++++++++++++++++++++++++++--- crossbeam-epoch/src/lib.rs | 4 +- 2 files changed, 164 insertions(+), 15 deletions(-) diff --git a/crossbeam-epoch/src/atomic.rs b/crossbeam-epoch/src/atomic.rs index 5320af6fb..50c26092d 100644 --- a/crossbeam-epoch/src/atomic.rs +++ b/crossbeam-epoch/src/atomic.rs @@ -28,6 +28,24 @@ fn strongest_failure_ordering(order: Ordering) -> Ordering { } } +/// The value returned from a compare-and-swap operation. +pub struct CompareExchangeValue<'g, T: ?Sized + Pointable> { + /// The previous value that was in the atomic pointer. + pub old: Shared<'g, T>, + + /// The new value that was stored. + pub new: Shared<'g, T>, +} + +impl fmt::Debug for CompareExchangeValue<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CompareExchangeValue") + .field("old", &self.old) + .field("new", &self.new) + .finish() + } +} + /// The error returned on failed compare-and-swap operation. pub struct CompareExchangeError<'g, T: ?Sized + Pointable, P: Pointer> { /// The value in the atomic pointer at the time of the failed operation. @@ -420,9 +438,10 @@ impl Atomic { /// value is the same as `current`. The tag is also taken into account, so two pointers to the /// same object, but with different tags, will not be considered equal. /// - /// The return value is a result indicating whether the new pointer was written. On success the - /// pointer that was written is returned. On failure the actual current value and `new` are - /// returned. + /// The return value contains both the previous value and the new value that was written. + /// On success, `old` contains the previous value (which equals `current`) and `new` contains + /// the value that was stored. On failure, `old` contains the actual current value and `new` + /// contains the value that was attempted to be stored. /// /// This method takes two `Ordering` arguments to describe the memory /// ordering of this operation. `success` describes the required ordering for the @@ -454,14 +473,19 @@ impl Atomic { success: Ordering, failure: Ordering, _: &'g Guard, - ) -> Result, CompareExchangeError<'g, T, P>> + ) -> Result, CompareExchangeError<'g, T, P>> where P: Pointer, { let new = new.into_ptr(); self.data .compare_exchange(current.into_ptr(), new, success, failure) - .map(|_| unsafe { Shared::from_ptr(new) }) + .map(|old| unsafe { + CompareExchangeValue { + old: Shared::from_ptr(old), + new: Shared::from_ptr(new), + } + }) .map_err(|current| unsafe { CompareExchangeError { current: Shared::from_ptr(current), @@ -475,9 +499,11 @@ impl Atomic { /// same object, but with different tags, will not be considered equal. /// /// Unlike [`compare_exchange`], this method is allowed to spuriously fail even when comparison - /// succeeds, which can result in more efficient code on some platforms. The return value is a - /// result indicating whether the new pointer was written. On success the pointer that was - /// written is returned. On failure the actual current value and `new` are returned. + /// succeeds, which can result in more efficient code on some platforms. The return value + /// contains both the previous value and the new value that was written. On success, `old` + /// contains the previous value (which equals `current`) and `new` contains the value that was + /// stored. On failure, `old` contains the actual current value and `new` contains the value + /// that was attempted to be stored. /// /// This method takes two `Ordering` arguments to describe the memory /// ordering of this operation. `success` describes the required ordering for the @@ -504,8 +530,8 @@ impl Atomic { /// # unsafe { drop(a.load(SeqCst, guard).into_owned()); } // avoid leak /// loop { /// match a.compare_exchange_weak(ptr, new, SeqCst, SeqCst, guard) { - /// Ok(p) => { - /// ptr = p; + /// Ok(result) => { + /// ptr = result.old; /// break; /// } /// Err(err) => { @@ -531,14 +557,19 @@ impl Atomic { success: Ordering, failure: Ordering, _: &'g Guard, - ) -> Result, CompareExchangeError<'g, T, P>> + ) -> Result, CompareExchangeError<'g, T, P>> where P: Pointer, { let new = new.into_ptr(); self.data .compare_exchange_weak(current.into_ptr(), new, success, failure) - .map(|_| unsafe { Shared::from_ptr(new) }) + .map(|old| unsafe { + CompareExchangeValue { + old: Shared::from_ptr(old), + new: Shared::from_ptr(new), + } + }) .map_err(|current| unsafe { CompareExchangeError { current: Shared::from_ptr(current), @@ -600,7 +631,7 @@ impl Atomic { let mut prev = self.load(fail_order, guard); while let Some(next) = func(prev) { match self.compare_exchange_weak(prev, next, set_order, fail_order, guard) { - Ok(shared) => return Ok(shared), + Ok(result) => return Ok(result.old), Err(next_prev) => prev = next_prev.current, } } @@ -1574,8 +1605,10 @@ impl Default for Shared<'_, T> { #[cfg(all(test, not(crossbeam_loom)))] mod tests { - use super::{Owned, Shared}; + use super::{Atomic, Owned, Shared}; + use crate::pin; use std::mem::MaybeUninit; + use std::sync::atomic::Ordering; #[test] fn valid_tag_i8() { @@ -1599,4 +1632,118 @@ mod tests { let arr: &[MaybeUninit] = &owned; assert_eq!(arr.len(), 10); } + + #[test] + fn compare_exchange_success() { + let atomic = Atomic::new(42); + let guard = &pin(); + + let current = atomic.load(Ordering::SeqCst, guard); + let new_value = Owned::new(100); + + let result = atomic + .compare_exchange( + current, + new_value, + Ordering::SeqCst, + Ordering::SeqCst, + guard, + ) + .unwrap(); + // On success, `old` should equal the current value we loaded + assert_eq!(unsafe { result.old.deref() }, &42); + // `new` should equal the value we stored + assert_eq!(unsafe { result.new.deref() }, &100); + + // Verify the atomic actually contains the new value + let current = atomic.load(Ordering::SeqCst, guard); + assert_eq!(unsafe { current.deref() }, &100); + + unsafe { + drop(result.old.into_owned()); + drop(atomic.into_owned()); + } + } + + #[test] + fn compare_exchange_failure() { + let atomic = Atomic::new(42); + let guard = &pin(); + + // Load the current value + let current = atomic.load(Ordering::SeqCst, guard); + + let old_value = atomic.swap(Owned::new(200), Ordering::SeqCst, guard); + unsafe { + drop(old_value.into_owned()); + } + + // Now try to compare_exchange with the old current value - this should fail + let new_value = Owned::new(300); + let error = atomic + .compare_exchange( + current, + new_value, + Ordering::SeqCst, + Ordering::SeqCst, + guard, + ) + .unwrap_err(); + // On failure, `current` should contain the actual current value (200) + assert_eq!(unsafe { error.current.deref() }, &200); + // `new` should contain the value we tried to store (300) + assert_eq!(&*error.new, &300); + + // Verify the atomic still contains the value we set (200), not the failed attempt (300) + let current = atomic.load(Ordering::SeqCst, guard); + assert_eq!(unsafe { current.deref() }, &200); + + unsafe { + drop(atomic.into_owned()); + } + } + + #[test] + fn compare_exchange_weak_success() { + let atomic = Atomic::new(42); + let guard = &pin(); + + let mut current = atomic.load(Ordering::SeqCst, guard); + let mut new_value = Owned::new(100); + + loop { + match atomic.compare_exchange_weak( + current, + new_value, + Ordering::SeqCst, + Ordering::SeqCst, + guard, + ) { + Ok(result) => { + // On success, `old` should equal the current value we loaded + assert_eq!(unsafe { result.old.deref() }, &42); + // `new` should equal the value we stored + assert_eq!(unsafe { result.new.deref() }, &100); + + // Verify the atomic actually contains the new value + let current = atomic.load(Ordering::SeqCst, guard); + assert_eq!(unsafe { current.deref() }, &100); + + unsafe { + drop(result.old.into_owned()); + } + break; + } + Err(e) => { + current = e.current; + new_value = e.new; + assert_eq!(unsafe { current.deref() }, &42); + } + } + } + + unsafe { + drop(atomic.into_owned()); + } + } } diff --git a/crossbeam-epoch/src/lib.rs b/crossbeam-epoch/src/lib.rs index 6ecba1b8c..eaed52512 100644 --- a/crossbeam-epoch/src/lib.rs +++ b/crossbeam-epoch/src/lib.rs @@ -145,7 +145,9 @@ mod sync; #[cfg(all(feature = "alloc", target_has_atomic = "ptr"))] pub use crate::{ - atomic::{Atomic, CompareExchangeError, Owned, Pointable, Pointer, Shared}, + atomic::{ + Atomic, CompareExchangeError, CompareExchangeValue, Owned, Pointable, Pointer, Shared, + }, collector::{Collector, LocalHandle}, guard::{unprotected, Guard}, };