Skip to content

Commit

Permalink
feat: add write function to State
Browse files Browse the repository at this point in the history
  • Loading branch information
ccbrown committed Dec 20, 2024
1 parent e4fe60b commit 4594098
Showing 1 changed file with 65 additions and 27 deletions.
92 changes: 65 additions & 27 deletions packages/iocraft/src/hooks/use_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use core::{
pin::Pin,
task::{Context, Poll, Waker},
};
use generational_box::{AnyStorage, GenerationalBox, Owner, SyncStorage};
use generational_box::{AnyStorage, BorrowError, GenerationalBox, Owner, SyncStorage};

mod private {
pub trait Sealed {}
Expand Down Expand Up @@ -115,6 +115,38 @@ impl<'a, T: 'static> ops::Deref for StateRef<'a, T> {
}
}

/// A mutable reference to the value of a [`State`].
pub struct StateMutRef<'a, T: 'static> {
inner: <SyncStorage as AnyStorage>::Mut<'a, StateValue<T>>,
did_deref_mut: bool,
}

impl<'a, T: 'static> ops::Deref for StateMutRef<'a, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
&self.inner.value
}
}

impl<'a, T: 'static> ops::DerefMut for StateMutRef<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.did_deref_mut = true;
&mut self.inner.value
}
}

impl<'a, T: 'static> Drop for StateMutRef<'a, T> {
fn drop(&mut self) {
if self.did_deref_mut {
self.inner.did_change = true;
if let Some(waker) = self.inner.waker.take() {
waker.wake();
}
}
}
}

/// `State` is a copyable wrapper for a value that can be observed for changes. States used by a
/// component will cause the component to be re-rendered when its value changes.
pub struct State<T: Send + Sync + 'static> {
Expand All @@ -132,14 +164,14 @@ impl<T: Sync + Send + 'static> Copy for State<T> {}
impl<T: Copy + Sync + Send + 'static> State<T> {
/// Gets a copy of the current value of the state.
pub fn get(&self) -> T {
self.inner.read().value
*self.read()
}
}

impl<T: Sync + Send + 'static> State<T> {
/// Sets the value of the state.
pub fn set(&mut self, value: T) {
self.modify(|v| *v = value);
*self.write() = value;
}

/// Returns a reference to the state's value.
Expand All @@ -148,33 +180,37 @@ impl<T: Sync + Send + 'static> State<T> {
/// multiple copies of the same state, writes to one will be blocked for as long as any
/// reference returned by this method exists.</div>
pub fn read(&self) -> StateRef<T> {
StateRef {
inner: self.inner.read(),
loop {
match self.inner.try_read() {
Ok(inner) => break StateRef { inner },
Err(BorrowError::AlreadyBorrowedMut(_)) => self.inner.write(),
Err(BorrowError::Dropped(_)) => panic!("state was read after owner was dropped"),
};
}
}

fn modify<F>(&mut self, f: F)
where
F: FnOnce(&mut T),
{
let mut inner = self.inner.write();
f(&mut inner.value);
inner.did_change = true;
if let Some(waker) = inner.waker.take() {
waker.wake();
/// Returns a mutable reference to the state's value.
///
/// <div class="warning">It is possible to create a deadlock using this method. If you have
/// multiple copies of the same state, operations on one will be blocked for as long as any
/// reference returned by this method exists.</div>
pub fn write(&mut self) -> StateMutRef<T> {
StateMutRef {
inner: self.inner.write(),
did_deref_mut: false,
}
}
}

impl<T: Debug + Sync + Send + 'static> Debug for State<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.inner.read().value.fmt(f)
self.read().fmt(f)
}
}

impl<T: Display + Sync + Send + 'static> Display for State<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.inner.read().value.fmt(f)
self.read().fmt(f)
}
}

Expand All @@ -188,7 +224,7 @@ impl<T: ops::Add<Output = T> + Copy + Sync + Send + 'static> ops::Add<T> for Sta

impl<T: ops::AddAssign<T> + Copy + Sync + Send + 'static> ops::AddAssign<T> for State<T> {
fn add_assign(&mut self, rhs: T) {
self.modify(|v| *v += rhs);
*self.write() += rhs;
}
}

Expand All @@ -202,7 +238,7 @@ impl<T: ops::Sub<Output = T> + Copy + Sync + Send + 'static> ops::Sub<T> for Sta

impl<T: ops::SubAssign<T> + Copy + Sync + Send + 'static> ops::SubAssign<T> for State<T> {
fn sub_assign(&mut self, rhs: T) {
self.modify(|v| *v -= rhs);
*self.write() -= rhs;
}
}

Expand All @@ -216,7 +252,7 @@ impl<T: ops::Mul<Output = T> + Copy + Sync + Send + 'static> ops::Mul<T> for Sta

impl<T: ops::MulAssign<T> + Copy + Sync + Send + 'static> ops::MulAssign<T> for State<T> {
fn mul_assign(&mut self, rhs: T) {
self.modify(|v| *v *= rhs);
*self.write() *= rhs;
}
}

Expand All @@ -230,34 +266,31 @@ impl<T: ops::Div<Output = T> + Copy + Sync + Send + 'static> ops::Div<T> for Sta

impl<T: ops::DivAssign<T> + Copy + Sync + Send + 'static> ops::DivAssign<T> for State<T> {
fn div_assign(&mut self, rhs: T) {
self.modify(|v| *v /= rhs);
*self.write() /= rhs;
}
}

impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<T> for State<T> {
fn eq(&self, other: &T) -> bool {
self.inner.read().value == *other
*self.read() == *other
}
}

impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<T> for State<T> {
fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
self.inner.read().value.partial_cmp(other)
self.read().partial_cmp(other)
}
}

impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<State<T>> for State<T> {
fn eq(&self, other: &State<T>) -> bool {
self.inner.read().value == other.inner.read().value
*self.read() == *other.read()
}
}

impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<State<T>> for State<T> {
fn partial_cmp(&self, other: &State<T>) -> Option<cmp::Ordering> {
self.inner
.read()
.value
.partial_cmp(&other.inner.read().value)
self.read().partial_cmp(&other.read())
}
}

Expand Down Expand Up @@ -307,5 +340,10 @@ mod tests {
assert!(state > 42);
assert!(state >= 43);
assert!(state < 44);

assert_eq!(*state.write(), 43);

let state_copy = state.clone();
assert_eq!(*state.read(), *state_copy.read());
}
}

0 comments on commit 4594098

Please sign in to comment.