Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add write function to State #45

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 65 additions & 27 deletions packages/iocraft/src/hooks/use_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use core::{
pin::Pin,
task::{Context, Poll, Waker},
};
use generational_box::{AnyStorage, GenerationalBox, Owner, SyncStorage};
use generational_box::{AnyStorage, BorrowError, GenerationalBox, Owner, SyncStorage};

mod private {
pub trait Sealed {}
Expand Down Expand Up @@ -115,6 +115,38 @@ impl<'a, T: 'static> ops::Deref for StateRef<'a, T> {
}
}

/// A mutable reference to the value of a [`State`].
pub struct StateMutRef<'a, T: 'static> {
inner: <SyncStorage as AnyStorage>::Mut<'a, StateValue<T>>,
did_deref_mut: bool,
}

impl<'a, T: 'static> ops::Deref for StateMutRef<'a, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
&self.inner.value
}
}

impl<'a, T: 'static> ops::DerefMut for StateMutRef<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.did_deref_mut = true;
&mut self.inner.value
}
}

impl<'a, T: 'static> Drop for StateMutRef<'a, T> {
fn drop(&mut self) {
if self.did_deref_mut {
self.inner.did_change = true;
if let Some(waker) = self.inner.waker.take() {
waker.wake();
}
}
}
}

/// `State` is a copyable wrapper for a value that can be observed for changes. States used by a
/// component will cause the component to be re-rendered when its value changes.
pub struct State<T: Send + Sync + 'static> {
Expand All @@ -132,14 +164,14 @@ impl<T: Sync + Send + 'static> Copy for State<T> {}
impl<T: Copy + Sync + Send + 'static> State<T> {
/// Gets a copy of the current value of the state.
pub fn get(&self) -> T {
self.inner.read().value
*self.read()
}
}

impl<T: Sync + Send + 'static> State<T> {
/// Sets the value of the state.
pub fn set(&mut self, value: T) {
self.modify(|v| *v = value);
*self.write() = value;
}

/// Returns a reference to the state's value.
Expand All @@ -148,33 +180,37 @@ impl<T: Sync + Send + 'static> State<T> {
/// multiple copies of the same state, writes to one will be blocked for as long as any
/// reference returned by this method exists.</div>
pub fn read(&self) -> StateRef<T> {
StateRef {
inner: self.inner.read(),
loop {
match self.inner.try_read() {
Ok(inner) => break StateRef { inner },
Err(BorrowError::AlreadyBorrowedMut(_)) => self.inner.write(),
Err(BorrowError::Dropped(_)) => panic!("state was read after owner was dropped"),
};
}
}

fn modify<F>(&mut self, f: F)
where
F: FnOnce(&mut T),
{
let mut inner = self.inner.write();
f(&mut inner.value);
inner.did_change = true;
if let Some(waker) = inner.waker.take() {
waker.wake();
/// Returns a mutable reference to the state's value.
///
/// <div class="warning">It is possible to create a deadlock using this method. If you have
/// multiple copies of the same state, operations on one will be blocked for as long as any
/// reference returned by this method exists.</div>
pub fn write(&mut self) -> StateMutRef<T> {
StateMutRef {
inner: self.inner.write(),
did_deref_mut: false,
}
}
}

impl<T: Debug + Sync + Send + 'static> Debug for State<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.inner.read().value.fmt(f)
self.read().fmt(f)
}
}

impl<T: Display + Sync + Send + 'static> Display for State<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.inner.read().value.fmt(f)
self.read().fmt(f)
}
}

Expand All @@ -188,7 +224,7 @@ impl<T: ops::Add<Output = T> + Copy + Sync + Send + 'static> ops::Add<T> for Sta

impl<T: ops::AddAssign<T> + Copy + Sync + Send + 'static> ops::AddAssign<T> for State<T> {
fn add_assign(&mut self, rhs: T) {
self.modify(|v| *v += rhs);
*self.write() += rhs;
}
}

Expand All @@ -202,7 +238,7 @@ impl<T: ops::Sub<Output = T> + Copy + Sync + Send + 'static> ops::Sub<T> for Sta

impl<T: ops::SubAssign<T> + Copy + Sync + Send + 'static> ops::SubAssign<T> for State<T> {
fn sub_assign(&mut self, rhs: T) {
self.modify(|v| *v -= rhs);
*self.write() -= rhs;
}
}

Expand All @@ -216,7 +252,7 @@ impl<T: ops::Mul<Output = T> + Copy + Sync + Send + 'static> ops::Mul<T> for Sta

impl<T: ops::MulAssign<T> + Copy + Sync + Send + 'static> ops::MulAssign<T> for State<T> {
fn mul_assign(&mut self, rhs: T) {
self.modify(|v| *v *= rhs);
*self.write() *= rhs;
}
}

Expand All @@ -230,34 +266,31 @@ impl<T: ops::Div<Output = T> + Copy + Sync + Send + 'static> ops::Div<T> for Sta

impl<T: ops::DivAssign<T> + Copy + Sync + Send + 'static> ops::DivAssign<T> for State<T> {
fn div_assign(&mut self, rhs: T) {
self.modify(|v| *v /= rhs);
*self.write() /= rhs;
}
}

impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<T> for State<T> {
fn eq(&self, other: &T) -> bool {
self.inner.read().value == *other
*self.read() == *other
}
}

impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<T> for State<T> {
fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
self.inner.read().value.partial_cmp(other)
self.read().partial_cmp(other)
}
}

impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<State<T>> for State<T> {
fn eq(&self, other: &State<T>) -> bool {
self.inner.read().value == other.inner.read().value
*self.read() == *other.read()
}
}

impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<State<T>> for State<T> {
fn partial_cmp(&self, other: &State<T>) -> Option<cmp::Ordering> {
self.inner
.read()
.value
.partial_cmp(&other.inner.read().value)
self.read().partial_cmp(&other.read())
}
}

Expand Down Expand Up @@ -307,5 +340,10 @@ mod tests {
assert!(state > 42);
assert!(state >= 43);
assert!(state < 44);

assert_eq!(*state.write(), 43);

let state_copy = state.clone();
assert_eq!(*state.read(), *state_copy.read());
}
}
Loading