From 459409815ebb3ae72bbfd36d2e71aa2bd4bc6a37 Mon Sep 17 00:00:00 2001 From: Chris Brown <1731074+ccbrown@users.noreply.github.com> Date: Fri, 20 Dec 2024 00:44:54 -0500 Subject: [PATCH] feat: add write function to State --- packages/iocraft/src/hooks/use_state.rs | 92 +++++++++++++++++-------- 1 file changed, 65 insertions(+), 27 deletions(-) diff --git a/packages/iocraft/src/hooks/use_state.rs b/packages/iocraft/src/hooks/use_state.rs index 3beab4f..2209c68 100644 --- a/packages/iocraft/src/hooks/use_state.rs +++ b/packages/iocraft/src/hooks/use_state.rs @@ -6,7 +6,7 @@ use core::{ pin::Pin, task::{Context, Poll, Waker}, }; -use generational_box::{AnyStorage, GenerationalBox, Owner, SyncStorage}; +use generational_box::{AnyStorage, BorrowError, GenerationalBox, Owner, SyncStorage}; mod private { pub trait Sealed {} @@ -115,6 +115,38 @@ impl<'a, T: 'static> ops::Deref for StateRef<'a, T> { } } +/// A mutable reference to the value of a [`State`]. +pub struct StateMutRef<'a, T: 'static> { + inner: ::Mut<'a, StateValue>, + did_deref_mut: bool, +} + +impl<'a, T: 'static> ops::Deref for StateMutRef<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.inner.value + } +} + +impl<'a, T: 'static> ops::DerefMut for StateMutRef<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.did_deref_mut = true; + &mut self.inner.value + } +} + +impl<'a, T: 'static> Drop for StateMutRef<'a, T> { + fn drop(&mut self) { + if self.did_deref_mut { + self.inner.did_change = true; + if let Some(waker) = self.inner.waker.take() { + waker.wake(); + } + } + } +} + /// `State` is a copyable wrapper for a value that can be observed for changes. States used by a /// component will cause the component to be re-rendered when its value changes. pub struct State { @@ -132,14 +164,14 @@ impl Copy for State {} impl State { /// Gets a copy of the current value of the state. pub fn get(&self) -> T { - self.inner.read().value + *self.read() } } impl State { /// Sets the value of the state. pub fn set(&mut self, value: T) { - self.modify(|v| *v = value); + *self.write() = value; } /// Returns a reference to the state's value. @@ -148,33 +180,37 @@ impl State { /// multiple copies of the same state, writes to one will be blocked for as long as any /// reference returned by this method exists. pub fn read(&self) -> StateRef { - StateRef { - inner: self.inner.read(), + loop { + match self.inner.try_read() { + Ok(inner) => break StateRef { inner }, + Err(BorrowError::AlreadyBorrowedMut(_)) => self.inner.write(), + Err(BorrowError::Dropped(_)) => panic!("state was read after owner was dropped"), + }; } } - fn modify(&mut self, f: F) - where - F: FnOnce(&mut T), - { - let mut inner = self.inner.write(); - f(&mut inner.value); - inner.did_change = true; - if let Some(waker) = inner.waker.take() { - waker.wake(); + /// Returns a mutable reference to the state's value. + /// + ///
It is possible to create a deadlock using this method. If you have + /// multiple copies of the same state, operations on one will be blocked for as long as any + /// reference returned by this method exists.
+ pub fn write(&mut self) -> StateMutRef { + StateMutRef { + inner: self.inner.write(), + did_deref_mut: false, } } } impl Debug for State { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.inner.read().value.fmt(f) + self.read().fmt(f) } } impl Display for State { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.inner.read().value.fmt(f) + self.read().fmt(f) } } @@ -188,7 +224,7 @@ impl + Copy + Sync + Send + 'static> ops::Add for Sta impl + Copy + Sync + Send + 'static> ops::AddAssign for State { fn add_assign(&mut self, rhs: T) { - self.modify(|v| *v += rhs); + *self.write() += rhs; } } @@ -202,7 +238,7 @@ impl + Copy + Sync + Send + 'static> ops::Sub for Sta impl + Copy + Sync + Send + 'static> ops::SubAssign for State { fn sub_assign(&mut self, rhs: T) { - self.modify(|v| *v -= rhs); + *self.write() -= rhs; } } @@ -216,7 +252,7 @@ impl + Copy + Sync + Send + 'static> ops::Mul for Sta impl + Copy + Sync + Send + 'static> ops::MulAssign for State { fn mul_assign(&mut self, rhs: T) { - self.modify(|v| *v *= rhs); + *self.write() *= rhs; } } @@ -230,34 +266,31 @@ impl + Copy + Sync + Send + 'static> ops::Div for Sta impl + Copy + Sync + Send + 'static> ops::DivAssign for State { fn div_assign(&mut self, rhs: T) { - self.modify(|v| *v /= rhs); + *self.write() /= rhs; } } impl + Sync + Send + 'static> cmp::PartialEq for State { fn eq(&self, other: &T) -> bool { - self.inner.read().value == *other + *self.read() == *other } } impl + Sync + Send + 'static> cmp::PartialOrd for State { fn partial_cmp(&self, other: &T) -> Option { - self.inner.read().value.partial_cmp(other) + self.read().partial_cmp(other) } } impl + Sync + Send + 'static> cmp::PartialEq> for State { fn eq(&self, other: &State) -> bool { - self.inner.read().value == other.inner.read().value + *self.read() == *other.read() } } impl + Sync + Send + 'static> cmp::PartialOrd> for State { fn partial_cmp(&self, other: &State) -> Option { - self.inner - .read() - .value - .partial_cmp(&other.inner.read().value) + self.read().partial_cmp(&other.read()) } } @@ -307,5 +340,10 @@ mod tests { assert!(state > 42); assert!(state >= 43); assert!(state < 44); + + assert_eq!(*state.write(), 43); + + let state_copy = state.clone(); + assert_eq!(*state.read(), *state_copy.read()); } }