diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index 3ae23eac9..4af9bcc7e 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -105,8 +105,7 @@ macro_rules! setup_input_struct { }) } - pub fn ingredient_mut(db: &mut dyn $zalsa::Database) -> (&mut $zalsa_struct::IngredientImpl, &mut $zalsa::Runtime) { - let zalsa_mut = db.zalsa_mut(); + pub fn ingredient_mut(zalsa_mut: &mut $zalsa::Zalsa) -> (&mut $zalsa_struct::IngredientImpl, &mut $zalsa::Runtime) { zalsa_mut.new_revision(); let index = zalsa_mut.add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>(); let (ingredient, runtime) = zalsa_mut.lookup_ingredient_mut(index); @@ -185,8 +184,10 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let fields = $Configuration::ingredient_(db.zalsa()).field( - db.as_dyn_database(), + let (zalsa, zalsa_local) = db.zalsas(); + let fields = $Configuration::ingredient_(zalsa).field( + zalsa, + zalsa_local, self, $field_index, ); @@ -205,7 +206,8 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let (ingredient, revision) = $Configuration::ingredient_mut(db.as_dyn_database_mut()); + let zalsa = db.zalsa_mut(); + let (ingredient, revision) = $Configuration::ingredient_mut(zalsa); $zalsa::input::SetterImpl::new( revision, self, @@ -244,7 +246,8 @@ macro_rules! setup_input_struct { $(for<'__trivial_bounds> $field_ty: std::fmt::Debug),* { $zalsa::with_attached_database(|db| { - let fields = $Configuration::ingredient(db).leak_fields(db, this); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient_(zalsa).leak_fields(zalsa, this); let mut f = f.debug_struct(stringify!($Struct)); let f = f.field("[salsa id]", &$zalsa::AsId::as_id(&this)); $( @@ -273,11 +276,11 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + salsa::Database { - let zalsa = db.zalsa(); + let (zalsa, zalsa_local) = db.zalsas(); let current_revision = zalsa.current_revision(); let ingredient = $Configuration::ingredient_(zalsa); let (fields, revision, durabilities) = builder::builder_into_inner(self, current_revision); - ingredient.new_input(db.as_dyn_database(), fields, revision, durabilities) + ingredient.new_input(zalsa, zalsa_local, fields, revision, durabilities) } } diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index 7c6381fbf..764823f20 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -140,14 +140,11 @@ macro_rules! setup_interned_struct { } impl $Configuration { - pub fn ingredient(db: &Db) -> &$zalsa_struct::IngredientImpl - where - Db: ?Sized + $zalsa::Database, + pub fn ingredient(zalsa: &$zalsa::Zalsa) -> &$zalsa_struct::IngredientImpl { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); - let zalsa = db.zalsa(); CACHE.get_or_create(zalsa, || { zalsa.add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() }) @@ -215,7 +212,8 @@ macro_rules! setup_interned_struct { $field_ty: $zalsa::interned::HashEqLike<$indexed_ty>, )* { - $Configuration::ingredient(db).intern(db.as_dyn_database(), + let (zalsa, zalsa_local) = db.zalsas(); + $Configuration::ingredient(zalsa).intern(zalsa, zalsa_local, StructKey::<$db_lt>($($field_id,)* std::marker::PhantomData::default()), |_, data| ($($zalsa::interned::Lookup::into_owned(data.$field_index),)*)) } @@ -226,7 +224,8 @@ macro_rules! setup_interned_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let fields = $Configuration::ingredient(db).fields(db.as_dyn_database(), self); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient(zalsa).fields(zalsa, self); $zalsa::return_mode_expression!( $field_option, $field_ty, @@ -238,7 +237,8 @@ macro_rules! setup_interned_struct { /// Default debug formatting for this struct (may be useful if you define your own `Debug` impl) pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { $zalsa::with_attached_database(|db| { - let fields = $Configuration::ingredient(db).fields(db.as_dyn_database(), this); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient(zalsa).fields(zalsa, this); let mut f = f.debug_struct(stringify!($Struct)); $( let f = f.field(stringify!($field_id), &fields.$field_index); diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 37cf48e46..e15fc3895 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -150,14 +150,18 @@ macro_rules! setup_tracked_fn { impl $Configuration { fn fn_ingredient(db: &dyn $Db) -> &$zalsa::function::IngredientImpl<$Configuration> { let zalsa = db.zalsa(); + Self::fn_ingredient_(db, zalsa) + } + + fn fn_ingredient_<'z>(db: &dyn $Db, zalsa: &'z $zalsa::Zalsa) -> &'z $zalsa::function::IngredientImpl<$Configuration> { $FN_CACHE.get_or_create(zalsa, || { - ::zalsa_register_downcaster(db); + ::zalsa_register_upcaster(db); zalsa.add_or_lookup_jar_by_type::<$Configuration>() }) } pub fn fn_ingredient_mut(db: &mut dyn $Db) -> &mut $zalsa::function::IngredientImpl { - ::zalsa_register_downcaster(db); + ::zalsa_register_upcaster(db); let zalsa_mut = db.zalsa_mut(); let index = zalsa_mut.add_or_lookup_jar_by_type::<$Configuration>(); let (ingredient, _) = zalsa_mut.lookup_ingredient_mut(index); @@ -169,8 +173,14 @@ macro_rules! setup_tracked_fn { db: &dyn $Db, ) -> &$zalsa::interned::IngredientImpl<$Configuration> { let zalsa = db.zalsa(); + Self::intern_ingredient_(db, zalsa) + } + fn intern_ingredient_<'z>( + db: &dyn $Db, + zalsa: &'z $zalsa::Zalsa + ) -> &'z $zalsa::interned::IngredientImpl<$Configuration> { $INTERN_CACHE.get_or_create(zalsa, || { - ::zalsa_register_downcaster(db); + ::zalsa_register_upcaster(db); zalsa.add_or_lookup_jar_by_type::<$Configuration>().successor(0) }) } @@ -218,11 +228,12 @@ macro_rules! setup_tracked_fn { } fn id_to_input<$db_lt>(db: &$db_lt Self::DbView, key: salsa::Id) -> Self::Input<$db_lt> { + let zalsa = db.zalsa(); $zalsa::macro_if! { if $needs_interner { - $Configuration::intern_ingredient(db).data(db.as_dyn_database(), key).clone() + $Configuration::intern_ingredient_(db, zalsa).data(zalsa, key).clone() } else { - $zalsa::FromIdWithDb::from_id(key, db.zalsa()) + $zalsa::FromIdWithDb::from_id(key, zalsa) } } } @@ -280,10 +291,10 @@ macro_rules! setup_tracked_fn { }; let fn_ingredient = <$zalsa::function::IngredientImpl<$Configuration>>::new( + zalsa, first_index, memo_ingredient_indices, $lru, - zalsa.views().downcaster_for::(), ); $zalsa::macro_if! { if $needs_interner { @@ -312,9 +323,10 @@ macro_rules! setup_tracked_fn { ) -> Vec<&$db_lt A> { use salsa::plumbing as $zalsa; let key = $zalsa::macro_if! { - if $needs_interner { - $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*), |_, data| data) - } else { + if $needs_interner {{ + let (zalsa, zalsa_local) = $db.zalsas(); + $Configuration::intern_ingredient($db).intern_id(zalsa, zalsa_local, ($($input_id),*), |_, data| data) + }} else { $zalsa::AsId::as_id(&($($input_id),*)) } }; @@ -355,11 +367,15 @@ macro_rules! setup_tracked_fn { let result = $zalsa::macro_if! { if $needs_interner { { - let key = $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*), |_, data| data); - $Configuration::fn_ingredient($db).fetch($db, key) + let (zalsa, zalsa_local) = $db.zalsas(); + let key = $Configuration::intern_ingredient_($db, zalsa).intern_id(zalsa, zalsa_local, ($($input_id),*), |_, data| data); + $Configuration::fn_ingredient_($db, zalsa).fetch($db, zalsa, zalsa_local, key) } } else { - $Configuration::fn_ingredient($db).fetch($db, $zalsa::AsId::as_id(&($($input_id),*))) + { + let (zalsa, zalsa_local) = $db.zalsas(); + $Configuration::fn_ingredient_($db, zalsa).fetch($db, zalsa, zalsa_local, $zalsa::AsId::as_id(&($($input_id),*))) + } } }; diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index 58945eac6..6e6d7fb59 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -259,8 +259,9 @@ macro_rules! setup_tracked_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - $Configuration::ingredient(db.as_dyn_database()).new_struct( - db.as_dyn_database(), + let (zalsa, zalsa_local) = db.zalsas(); + $Configuration::ingredient_(zalsa).new_struct( + zalsa,zalsa_local, ($($field_id,)*) ) } @@ -272,8 +273,8 @@ macro_rules! setup_tracked_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let db = db.as_dyn_database(); - let fields = $Configuration::ingredient(db).tracked_field(db, self, $relative_tracked_index); + let (zalsa, zalsa_local) = db.zalsas(); + let fields = $Configuration::ingredient_(zalsa).tracked_field(zalsa, zalsa_local, self, $relative_tracked_index); $crate::return_mode_expression!( $tracked_option, $tracked_ty, @@ -289,8 +290,8 @@ macro_rules! setup_tracked_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let db = db.as_dyn_database(); - let fields = $Configuration::ingredient(db).untracked_field(db, self); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient_(zalsa).untracked_field(zalsa, self); $crate::return_mode_expression!( $untracked_option, $untracked_ty, @@ -312,7 +313,8 @@ macro_rules! setup_tracked_struct { $(for<$db_lt> $field_ty: std::fmt::Debug),* { $zalsa::with_attached_database(|db| { - let fields = $Configuration::ingredient(db).leak_fields(db, this); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient_(zalsa).leak_fields(zalsa, this); let mut f = f.debug_struct(stringify!($Struct)); let f = f.field("[salsa id]", &$zalsa::AsId::as_id(&this)); $( diff --git a/components/salsa-macros/src/db.rs b/components/salsa-macros/src/db.rs index 478ebea5d..44fb7f109 100644 --- a/components/salsa-macros/src/db.rs +++ b/components/salsa-macros/src/db.rs @@ -110,18 +110,14 @@ impl DbMacro { let trait_name = &input.ident; input.items.push(parse_quote! { #[doc(hidden)] - fn zalsa_register_downcaster(&self); + fn zalsa_register_upcaster(&self); }); - let comment = format!(" Downcast a [`dyn Database`] to a [`dyn {trait_name}`]"); + let comment = format!(" upcast `Self` to a [`dyn {trait_name}`]"); input.items.push(parse_quote! { #[doc = #comment] - /// - /// # Safety - /// - /// The input database must be of type `Self`. #[doc(hidden)] - unsafe fn downcast(db: &dyn salsa::plumbing::Database) -> &dyn #trait_name where Self: Sized; + fn upcast(&self) -> &dyn #trait_name where Self: Sized; }); Ok(()) } @@ -137,17 +133,15 @@ impl DbMacro { input.items.push(parse_quote! { #[doc(hidden)] #[inline(always)] - fn zalsa_register_downcaster(&self) { - salsa::plumbing::views(self).add(::downcast); + fn zalsa_register_upcaster(&self) { + salsa::plumbing::views(self).add(::upcast); } }); input.items.push(parse_quote! { #[doc(hidden)] #[inline(always)] - unsafe fn downcast(db: &dyn salsa::plumbing::Database) -> &dyn #TraitPath where Self: Sized { - debug_assert_eq!(db.type_id(), ::core::any::TypeId::of::()); - // SAFETY: The input database must be of type `Self`. - unsafe { &*salsa::plumbing::transmute_data_ptr::(db) } + fn upcast(&self) -> &dyn #TraitPath where Self: Sized { + self } }); Ok(()) diff --git a/src/accumulator.rs b/src/accumulator.rs index 7542bc7e0..c78a6018d 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -103,7 +103,8 @@ impl Ingredient for IngredientImpl { unsafe fn maybe_changed_after( &self, - _db: &dyn Database, + _zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabasePointer<'_>, _input: Id, _revision: Revision, _cycle_heads: &mut CycleHeads, diff --git a/src/attach.rs b/src/attach.rs index 671933b50..ac4918694 100644 --- a/src/attach.rs +++ b/src/attach.rs @@ -45,7 +45,7 @@ impl Attached { impl<'s> DbGuard<'s> { #[inline] - fn new(attached: &'s Attached, db: &dyn Database) -> Self { + fn new(attached: &'s Attached, db: &Db) -> Self { match attached.database.get() { Some(current_db) => { let new_db = NonNull::from(db); @@ -56,7 +56,9 @@ impl Attached { } None => { // Otherwise, set the database. - attached.database.set(Some(NonNull::from(db))); + attached + .database + .set(Some(NonNull::from(db.dyn_database()))); Self { state: Some(attached), } @@ -75,7 +77,7 @@ impl Attached { } } - let _guard = DbGuard::new(self, db.as_dyn_database()); + let _guard = DbGuard::new(self, db); op() } diff --git a/src/database.rs b/src/database.rs index 72204a582..a63569847 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,12 +1,38 @@ -use std::any::Any; use std::borrow::Cow; +use std::ptr::NonNull; use crate::zalsa::{IngredientIndex, ZalsaDatabase}; use crate::{Durability, Revision}; +#[derive(Copy, Clone)] +pub struct RawDatabasePointer<'db> { + pub(crate) ptr: NonNull<()>, + _marker: std::marker::PhantomData<&'db dyn Database>, +} + +impl<'db, Db: Database + ?Sized> From<&'db Db> for RawDatabasePointer<'db> { + #[inline] + fn from(db: &'db Db) -> Self { + RawDatabasePointer { + ptr: NonNull::from(db).cast(), + _marker: std::marker::PhantomData, + } + } +} + +impl<'db, Db: Database + ?Sized> From<&'db mut Db> for RawDatabasePointer<'db> { + #[inline] + fn from(db: &'db mut Db) -> Self { + RawDatabasePointer { + ptr: NonNull::from(db).cast(), + _marker: std::marker::PhantomData, + } + } +} + /// The trait implemented by all Salsa databases. /// You can create your own subtraits of this trait using the `#[salsa::db]`(`crate::db`) procedural macro. -pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { +pub trait Database: Send + ZalsaDatabase { /// Enforces current LRU limits, evicting entries if necessary. /// /// **WARNING:** Just like an ordinary write, this method triggers @@ -82,37 +108,28 @@ pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { #[doc(hidden)] #[inline(always)] - fn zalsa_register_downcaster(&self) { - // The no-op downcaster is special cased in view caster construction. + fn dyn_database(&self) -> &dyn Database { + // SAFETY: The upcaster is derived from its argument + unsafe { + self.zalsa() + .views() + .base_database_upcaster() + .upcast_unchecked(self.into()) + } } #[doc(hidden)] #[inline(always)] - unsafe fn downcast(db: &dyn Database) -> &dyn Database - where - Self: Sized, - { - // No-op - db - } -} - -/// Upcast to a `dyn Database`. -/// -/// Only required because upcasts not yet stabilized (*grr*). -pub trait AsDynDatabase { - fn as_dyn_database(&self) -> &dyn Database; - fn as_dyn_database_mut(&mut self) -> &mut dyn Database; -} - -impl AsDynDatabase for T { - #[inline(always)] - fn as_dyn_database(&self) -> &dyn Database { - self + fn zalsa_register_upcaster(&self) { + // The no-op upcaster is special cased in view caster construction. } + #[doc(hidden)] #[inline(always)] - fn as_dyn_database_mut(&mut self) -> &mut dyn Database { + fn upcast(&self) -> &dyn Database + where + Self: Sized, + { self } } @@ -120,16 +137,3 @@ impl AsDynDatabase for T { pub fn current_revision(db: &Db) -> Revision { db.zalsa().current_revision() } - -impl dyn Database { - /// Upcasts `self` to the given view. - /// - /// # Panics - /// - /// If the view has not been added to the database (see [`crate::views::Views`]). - #[track_caller] - pub fn as_view(&self) -> &DbView { - let views = self.zalsa().views(); - views.downcaster_for().downcast(self) - } -} diff --git a/src/function.rs b/src/function.rs index be88f7e39..6403378b4 100644 --- a/src/function.rs +++ b/src/function.rs @@ -9,6 +9,7 @@ use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues use crate::cycle::{ empty_cycle_heads, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, ProvisionalStatus, }; +use crate::database::RawDatabasePointer; use crate::function::delete::DeletedEntries; use crate::function::sync::{ClaimResult, SyncTable}; use crate::ingredient::{Ingredient, WaitForResult}; @@ -18,10 +19,10 @@ use crate::salsa_struct::SalsaStructInDb; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::table::Table; -use crate::views::DatabaseDownCaster; +use crate::views::DatabaseUpCaster; use crate::zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}; use crate::zalsa_local::QueryOriginRef; -use crate::{Database, Id, Revision}; +use crate::{Id, Revision}; mod accumulated; mod backdate; @@ -118,13 +119,13 @@ pub struct IngredientImpl { /// Used to find memos to throw out when we have too many memoized values. lru: lru::Lru, - /// A downcaster from `dyn Database` to `C::DbView`. + /// An upcaster to `C::DbView`. /// /// # Safety /// /// The supplied database must be be the same as the database used to construct the [`Views`] - /// instances that this downcaster was derived from. - view_caster: DatabaseDownCaster, + /// instances that this upcaster was derived from. + view_caster: DatabaseUpCaster, sync_table: SyncTable, @@ -148,17 +149,17 @@ where C: Configuration, { pub fn new( + zalsa: &Zalsa, index: IngredientIndex, memo_ingredient_indices: as SalsaStructInDb>::MemoIngredientMap, lru: usize, - view_caster: DatabaseDownCaster, ) -> Self { Self { index, memo_ingredient_indices, lru: lru::Lru::new(lru), deleted_entries: Default::default(), - view_caster, + view_caster: zalsa.views().upcaster_for::().clone(), sync_table: SyncTable::new(index), } } @@ -237,13 +238,14 @@ where unsafe fn maybe_changed_after( &self, - db: &dyn Database, + _zalsa: &crate::zalsa::Zalsa, + db: RawDatabasePointer<'_>, input: Id, revision: Revision, cycle_heads: &mut CycleHeads, ) -> VerifyResult { // SAFETY: The `db` belongs to the ingredient as per caller invariant - let db = unsafe { self.view_caster.downcast_unchecked(db) }; + let db = unsafe { self.view_caster.upcast_unchecked(db) }; self.maybe_changed_after(db, input, revision, cycle_heads) } @@ -342,12 +344,13 @@ where C::CYCLE_STRATEGY } - fn accumulated<'db>( + unsafe fn accumulated<'db>( &'db self, - db: &'db dyn Database, + db: RawDatabasePointer<'db>, key_index: Id, ) -> (Option<&'db AccumulatedMap>, InputAccumulatedValues) { - let db = self.view_caster.downcast(db); + // SAFETY: The `db` belongs to the ingredient as per caller invariant + let db = unsafe { self.view_caster.upcast_unchecked(db) }; self.accumulated_map(db, key_index) } } diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 47fe09a84..a65804e64 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -4,7 +4,7 @@ use crate::function::{Configuration, IngredientImpl}; use crate::hash::FxHashSet; use crate::zalsa::ZalsaDatabase; use crate::zalsa_local::QueryOriginRef; -use crate::{AsDynDatabase, DatabaseKeyIndex, Id}; +use crate::{DatabaseKeyIndex, Id}; impl IngredientImpl where @@ -37,9 +37,8 @@ where let mut output = vec![]; // First ensure the result is up to date - self.fetch(db, key); + self.fetch(db, zalsa, zalsa_local, key); - let db = db.as_dyn_database(); let db_key = self.database_key_index(key); let mut visited: FxHashSet = FxHashSet::default(); let mut stack: Vec = vec![db_key]; @@ -54,7 +53,9 @@ where let ingredient = zalsa.lookup_ingredient(k.ingredient_index()); // Extend `output` with any values accumulated by `k`. - let (accumulated_map, input) = ingredient.accumulated(db, k.key_index()); + // SAFETY: `db` owns the `ingredient` + let (accumulated_map, input) = + unsafe { ingredient.accumulated(db.into(), k.key_index()) }; if let Some(accumulated_map) = accumulated_map { accumulated_map.extend_with_accumulated(accumulator.index(), &mut output); } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index bfd5ffedc..60ff8de92 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -2,7 +2,7 @@ use crate::cycle::{CycleHeads, CycleRecoveryStrategy, IterationCount}; use crate::function::memo::Memo; use crate::function::sync::ClaimResult; use crate::function::{Configuration, IngredientImpl, VerifyResult}; -use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; +use crate::zalsa::{MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{QueryRevisions, ZalsaLocal}; use crate::Id; @@ -10,8 +10,13 @@ impl IngredientImpl where C: Configuration, { - pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> &'db C::Output<'db> { - let (zalsa, zalsa_local) = db.zalsas(); + pub fn fetch<'db>( + &'db self, + db: &'db C::DbView, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, + id: Id, + ) -> &'db C::Output<'db> { zalsa.unwind_if_revision_cancelled(zalsa_local); let database_key_index = self.database_key_index(id); diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 96f5eae5f..0b745ea38 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -7,7 +7,7 @@ use crate::key::DatabaseKeyIndex; use crate::sync::atomic::Ordering; use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; use crate::zalsa_local::{QueryEdgeKind, QueryOriginRef, ZalsaLocal}; -use crate::{AsDynDatabase as _, Id, Revision}; +use crate::{Id, Revision}; /// Result of memo validation. pub enum VerifyResult { @@ -432,8 +432,6 @@ where return VerifyResult::Changed; } - let dyn_db = db.as_dyn_database(); - let mut inputs = InputAccumulatedValues::Empty; // Fully tracked inputs? Iterate over the inputs and check them, one by one. // @@ -445,7 +443,7 @@ where match edge.kind() { QueryEdgeKind::Input(dependency_index) => { match dependency_index.maybe_changed_after( - dyn_db, + db.into(), zalsa, old_memo.verified_at.load(), cycle_heads, diff --git a/src/ingredient.rs b/src/ingredient.rs index bfbcc2d30..bf9903894 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -5,6 +5,7 @@ use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues use crate::cycle::{ empty_cycle_heads, CycleHeads, CycleRecoveryStrategy, IterationCount, ProvisionalStatus, }; +use crate::database::RawDatabasePointer; use crate::function::VerifyResult; use crate::plumbing::IngredientIndices; use crate::runtime::Running; @@ -13,7 +14,7 @@ use crate::table::memo::MemoTableTypes; use crate::table::Table; use crate::zalsa::{transmute_data_mut_ptr, transmute_data_ptr, IngredientIndex, Zalsa}; use crate::zalsa_local::QueryOriginRef; -use crate::{Database, DatabaseKeyIndex, Id, Revision}; +use crate::{DatabaseKeyIndex, Id, Revision}; /// A "jar" is a group of ingredients that are added atomically. /// Each type implementing jar can be added to the database at most once. @@ -61,9 +62,10 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// # Safety /// /// The passed in database needs to be the same one that the ingredient was created with. - unsafe fn maybe_changed_after<'db>( - &'db self, - db: &'db dyn Database, + unsafe fn maybe_changed_after( + &self, + zalsa: &crate::zalsa::Zalsa, + db: crate::database::RawDatabasePointer<'_>, input: Id, revision: Revision, cycle_heads: &mut CycleHeads, @@ -173,9 +175,13 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// What values were accumulated during the creation of the value at `key_index` /// (if any). - fn accumulated<'db>( + /// + /// # Safety + /// + /// The passed in database needs to be the same one that the ingredient was created with. + unsafe fn accumulated<'db>( &'db self, - db: &'db dyn Database, + db: RawDatabasePointer<'db>, key_index: Id, ) -> (Option<&'db AccumulatedMap>, InputAccumulatedValues) { let _ = (db, key_index); diff --git a/src/input.rs b/src/input.rs index f209fe0b0..1a9511379 100644 --- a/src/input.rs +++ b/src/input.rs @@ -19,7 +19,7 @@ use crate::sync::Arc; use crate::table::memo::{MemoTable, MemoTableTypes}; use crate::table::{Slot, Table}; use crate::zalsa::{IngredientIndex, Zalsa}; -use crate::{Database, Durability, Id, Revision, Runtime}; +use crate::{zalsa_local, Durability, Id, Revision, Runtime}; pub trait Configuration: Any { const DEBUG_NAME: &'static str; @@ -105,13 +105,12 @@ impl IngredientImpl { pub fn new_input( &self, - db: &dyn Database, + zalsa: &Zalsa, + zalsa_local: &zalsa_local::ZalsaLocal, fields: C::Fields, revisions: C::Revisions, durabilities: C::Durabilities, ) -> C::Struct { - let (zalsa, zalsa_local) = db.zalsas(); - let id = self.singleton.with_scope(|| { zalsa_local.allocate(zalsa, self.ingredient_index, |_| Value:: { fields, @@ -176,11 +175,11 @@ impl IngredientImpl { /// The caller is responsible for selecting the appropriate element. pub fn field<'db>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, + zalsa_local: &'db zalsa_local::ZalsaLocal, id: C::Struct, field_index: usize, ) -> &'db C::Fields { - let (zalsa, zalsa_local) = db.zalsas(); let field_ingredient_index = self.ingredient_index.successor(field_index); let id = id.as_id(); let value = Self::data(zalsa, id); @@ -196,17 +195,13 @@ impl IngredientImpl { #[cfg(feature = "salsa_unstable")] /// Returns all data corresponding to the input struct. - pub fn entries<'db>( - &'db self, - db: &'db dyn crate::Database, - ) -> impl Iterator> { - db.zalsa().table().slots_of::>() + pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { + zalsa.table().slots_of::>() } /// Peek at the field values without recording any read dependency. /// Used for debug printouts. - pub fn leak_fields<'db>(&'db self, db: &'db dyn Database, id: C::Struct) -> &'db C::Fields { - let zalsa = db.zalsa(); + pub fn leak_fields<'db>(&'db self, zalsa: &'db Zalsa, id: C::Struct) -> &'db C::Fields { let id = id.as_id(); let value = Self::data(zalsa, id); &value.fields @@ -224,7 +219,8 @@ impl Ingredient for IngredientImpl { unsafe fn maybe_changed_after( &self, - _db: &dyn Database, + _zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabasePointer<'_>, _input: Id, _revision: Revision, _cycle_heads: &mut CycleHeads, diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 5e1df4874..363167d93 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -8,7 +8,7 @@ use crate::input::{Configuration, IngredientImpl, Value}; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::zalsa::IngredientIndex; -use crate::{Database, Id, Revision}; +use crate::{Id, Revision}; /// Ingredient used to represent the fields of a `#[salsa::input]`. /// @@ -52,12 +52,12 @@ where unsafe fn maybe_changed_after( &self, - db: &dyn Database, + zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabasePointer<'_>, input: Id, revision: Revision, _cycle_heads: &mut CycleHeads, ) -> VerifyResult { - let zalsa = db.zalsa(); let value = >::data(zalsa, input); VerifyResult::changed_if(value.revisions[self.field_index] > revision) } diff --git a/src/interned.rs b/src/interned.rs index 269da3b71..0eb2a4a18 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -21,7 +21,7 @@ use crate::sync::{Arc, Mutex, OnceLock}; use crate::table::memo::{MemoTable, MemoTableTypes, MemoTableWithTypesMut}; use crate::table::Slot; use crate::zalsa::{IngredientIndex, Zalsa}; -use crate::{Database, DatabaseKeyIndex, Event, EventKind, Id, Revision}; +use crate::{DatabaseKeyIndex, Event, EventKind, Id, Revision}; /// Trait that defines the key properties of an interned struct. /// @@ -275,7 +275,8 @@ where /// the database ends up trying to intern or allocate a new value. pub fn intern<'db, Key>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, key: Key, assemble: impl FnOnce(Id, Key) -> C::Fields<'db>, ) -> C::Struct<'db> @@ -283,7 +284,7 @@ where Key: Hash, C::Fields<'db>: HashEqLike, { - FromId::from_id(self.intern_id(db, key, assemble)) + FromId::from_id(self.intern_id(zalsa, zalsa_local, key, assemble)) } /// Intern data to a unique reference. @@ -298,7 +299,8 @@ where /// the database ends up trying to intern or allocate a new value. pub fn intern_id<'db, Key>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, key: Key, assemble: impl FnOnce(Id, Key) -> C::Fields<'db>, ) -> crate::Id @@ -310,8 +312,6 @@ where // so instead we go with this and transmute the lifetime in the `eq` closure C::Fields<'db>: HashEqLike, { - let (zalsa, zalsa_local) = db.zalsas(); - // Record the current revision as active. let current_revision = zalsa.current_revision(); self.revision_queue.record(current_revision); @@ -394,7 +394,6 @@ where // Fill up the table for the first few revisions without attempting garbage collection. if !self.revision_queue.is_primed() { return self.intern_id_cold( - db, key, zalsa, zalsa_local, @@ -531,16 +530,7 @@ where } // If we could not find any stale slots, we are forced to allocate a new one. - self.intern_id_cold( - db, - key, - zalsa, - zalsa_local, - assemble, - shard, - shard_index, - hash, - ) + self.intern_id_cold(key, zalsa, zalsa_local, assemble, shard, shard_index, hash) } /// The cold path for interning a value, allocating a new slot. @@ -549,7 +539,6 @@ where #[allow(clippy::too_many_arguments)] fn intern_id_cold<'db, Key>( &'db self, - _db: &'db dyn crate::Database, key: Key, zalsa: &Zalsa, zalsa_local: &ZalsaLocal, @@ -720,8 +709,7 @@ where } /// Lookup the data for an interned value based on its ID. - pub fn data<'db>(&'db self, db: &'db dyn Database, id: Id) -> &'db C::Fields<'db> { - let zalsa = db.zalsa(); + pub fn data<'db>(&'db self, zalsa: &'db Zalsa, id: Id) -> &'db C::Fields<'db> { let value = zalsa.table().get::>(id); debug_assert!( @@ -746,12 +734,12 @@ where /// Lookup the fields from an interned struct. /// /// Note that this is not "leaking" since no dependency edge is required. - pub fn fields<'db>(&'db self, db: &'db dyn Database, s: C::Struct<'db>) -> &'db C::Fields<'db> { - self.data(db, AsId::as_id(&s)) + pub fn fields<'db>(&'db self, zalsa: &'db Zalsa, s: C::Struct<'db>) -> &'db C::Fields<'db> { + self.data(zalsa, AsId::as_id(&s)) } - pub fn reset(&mut self, db: &mut dyn Database) { - _ = db.zalsa_mut(); + pub fn reset(&mut self, zalsa_mut: &mut Zalsa) { + _ = zalsa_mut; for shard in self.shards.iter() { // We can clear the key maps now that we have cancelled all other handles. @@ -761,11 +749,8 @@ where #[cfg(feature = "salsa_unstable")] /// Returns all data corresponding to the interned struct. - pub fn entries<'db>( - &'db self, - db: &'db dyn crate::Database, - ) -> impl Iterator> { - db.zalsa().table().slots_of::>() + pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { + zalsa.table().slots_of::>() } } @@ -783,13 +768,12 @@ where unsafe fn maybe_changed_after( &self, - db: &dyn Database, + zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabasePointer<'_>, input: Id, _revision: Revision, _cycle_heads: &mut CycleHeads, ) -> VerifyResult { - let zalsa = db.zalsa(); - // Record the current revision as active. let current_revision = zalsa.current_revision(); self.revision_queue.record(current_revision); diff --git a/src/key.rs b/src/key.rs index 5883ef9cb..64cfaf95a 100644 --- a/src/key.rs +++ b/src/key.rs @@ -3,7 +3,7 @@ use core::fmt; use crate::cycle::CycleHeads; use crate::function::VerifyResult; use crate::zalsa::{IngredientIndex, Zalsa}; -use crate::{Database, Id}; +use crate::Id; // ANCHOR: DatabaseKeyIndex /// An integer that uniquely identifies a particular query instance within the @@ -36,16 +36,27 @@ impl DatabaseKeyIndex { pub(crate) fn maybe_changed_after( &self, - db: &dyn Database, + db: crate::database::RawDatabasePointer<'_>, zalsa: &Zalsa, last_verified_at: crate::Revision, cycle_heads: &mut CycleHeads, ) -> VerifyResult { // SAFETY: The `db` belongs to the ingredient unsafe { + // heer, `db` has to be either the correct type already, or a subtype (as far as trait + // hierarchy is concerned) zalsa .lookup_ingredient(self.ingredient_index()) - .maybe_changed_after(db, self.key_index(), last_verified_at, cycle_heads) + .maybe_changed_after( + zalsa, + // lets say we do turn this into an opaque pair of data and vtable pointer + // then we also need an upcast function, from our dbview to the ingredients + // dbview + db, + self.key_index(), + last_verified_at, + cycle_heads, + ) } } diff --git a/src/lib.rs b/src/lib.rs index 2d1465ee1..9cf3d8013 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,7 +43,7 @@ pub use self::accumulator::Accumulator; pub use self::active_query::Backtrace; pub use self::cancelled::Cancelled; pub use self::cycle::CycleRecoveryAction; -pub use self::database::{AsDynDatabase, Database}; +pub use self::database::Database; pub use self::database_impl::DatabaseImpl; pub use self::durability::Durability; pub use self::event::{Event, EventKind}; diff --git a/src/parallel.rs b/src/parallel.rs index 1d2504b77..d69b055e9 100644 --- a/src/parallel.rs +++ b/src/parallel.rs @@ -1,44 +1,90 @@ use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator}; -use crate::Database; +use crate::{database::RawDatabasePointer, views::DatabaseUpCaster, Database}; pub fn par_map(db: &Db, inputs: impl IntoParallelIterator, op: F) -> C where - Db: Database + ?Sized, + Db: Database + ?Sized + Send, F: Fn(&Db, T) -> R + Sync + Send, T: Send, R: Send + Sync, C: FromParallelIterator, { + let caster = db.zalsa().views().upcaster_for::(); + let db_caster = db.zalsa().views().base_database_upcaster(); inputs .into_par_iter() - .map_with(DbForkOnClone(db.fork_db()), |db, element| { - op(db.0.as_view(), element) - }) + .map_with( + DbForkOnClone(db.fork_db(), caster, db_caster), + |db, element| op(db.as_view(), element), + ) .collect() } -struct DbForkOnClone(Box); +struct DbForkOnClone<'views, Db: Database + ?Sized>( + RawDatabasePointer<'static>, + &'views DatabaseUpCaster, + &'views DatabaseUpCaster, +); -impl Clone for DbForkOnClone { +// SAFETY: `T: Send` -> `&own T: Send`, `DbForkOnClone` is an owning pointer +unsafe impl Send for DbForkOnClone<'_, Db> {} + +impl DbForkOnClone<'_, Db> { + fn as_view(&self) -> &Db { + // SAFETY: The upcaster ensures that the pointer is valid for the lifetime of the view. + unsafe { self.1.upcast_unchecked(self.0) } + } +} + +impl Drop for DbForkOnClone<'_, Db> { + fn drop(&mut self) { + // SAFETY: `caster` is derived from a `db` fitting for our database clone + let db = unsafe { self.1.upcast_mut_unchecked(self.0) }; + // SAFETY: `db` has been box allocated and leaked by `fork_db` + _ = unsafe { Box::from_raw(db) }; + } +} + +impl Clone for DbForkOnClone<'_, Db> { fn clone(&self) -> Self { - DbForkOnClone(self.0.fork_db()) + DbForkOnClone( + // SAFETY: `caster` is derived from a `db` fitting for our database clone + unsafe { self.2.upcast_unchecked(self.0) }.fork_db(), + self.1, + self.2, + ) } } -pub fn join(db: &Db, a: A, b: B) -> (RA, RB) +pub fn join(db: &Db, a: A, b: B) -> (RA, RB) where A: FnOnce(&Db) -> RA + Send, B: FnOnce(&Db) -> RB + Send, RA: Send, RB: Send, { + #[derive(Copy, Clone)] + struct AssertSend(T); + // SAFETY: We send owning pointers over, which are Send, given the `Db` type parameter above is Send + unsafe impl Send for AssertSend {} + + let caster = db.zalsa().views().upcaster_for::(); // we need to fork eagerly, as `rayon::join_context` gives us no option to tell whether we get // moved to another thread before the closure is executed - let db_a = db.fork_db(); - let db_b = db.fork_db(); - rayon::join( - move || a(db_a.as_view::()), - move || b(db_b.as_view::()), - ) + let db_a = AssertSend(db.fork_db()); + let db_b = AssertSend(db.fork_db()); + let res = rayon::join( + // SAFETY: `caster` is derived from a `db` fitting for our database clone + move || a(unsafe { caster.upcast_unchecked({ db_a }.0) }), + // SAFETY: `caster` is derived from a `db` fitting for our database clone + move || b(unsafe { caster.upcast_unchecked({ db_b }.0) }), + ); + + // SAFETY: `db` has been box allocated and leaked by `fork_db` + // FIXME: Clean this mess up, RAII + _ = unsafe { Box::from_raw(caster.upcast_mut_unchecked(db_a.0)) }; + // SAFETY: `db` has been box allocated and leaked by `fork_db` + _ = unsafe { Box::from_raw(caster.upcast_mut_unchecked(db_b.0)) }; + res } diff --git a/src/storage.rs b/src/storage.rs index 19dd55a40..78aed3fb2 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -2,6 +2,7 @@ use std::marker::PhantomData; use std::panic::RefUnwindSafe; +use crate::database::RawDatabasePointer; use crate::sync::{Arc, Condvar, Mutex}; use crate::zalsa::{Zalsa, ZalsaDatabase}; use crate::zalsa_local::{self, ZalsaLocal}; @@ -185,8 +186,8 @@ unsafe impl ZalsaDatabase for T { } #[inline(always)] - fn fork_db(&self) -> Box { - Box::new(self.clone()) + fn fork_db(&self) -> RawDatabasePointer<'static> { + Box::leak(Box::new(self.clone())).into() } } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 3190c9e7f..1fe85d82a 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -23,7 +23,7 @@ use crate::sync::Arc; use crate::table::memo::{MemoTable, MemoTableTypes, MemoTableWithTypesMut}; use crate::table::{Slot, Table}; use crate::zalsa::{IngredientIndex, Zalsa}; -use crate::{Database, Durability, Event, EventKind, Id, Revision}; +use crate::{Durability, Event, EventKind, Id, Revision}; pub mod tracked_field; @@ -376,11 +376,10 @@ where pub fn new_struct<'db>( &'db self, - db: &'db dyn Database, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, mut fields: C::Fields<'db>, ) -> C::Struct<'db> { - let (zalsa, zalsa_local) = db.zalsas(); - let identity_hash = IdentityHash { ingredient_index: self.ingredient_index, hash: crate::hash::hash(&C::untracked_fields(&fields)), @@ -730,11 +729,11 @@ where /// Used for debugging. pub fn leak_fields<'db>( &'db self, - db: &'db dyn Database, + zalsa: &'db Zalsa, s: C::Struct<'db>, ) -> &'db C::Fields<'db> { let id = AsId::as_id(&s); - let data = Self::data(db.zalsa().table(), id); + let data = Self::data(zalsa.table(), id); data.fields() } @@ -744,11 +743,11 @@ where /// The caller is responsible for selecting the appropriate element. pub fn tracked_field<'db>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, s: C::Struct<'db>, relative_tracked_index: usize, ) -> &'db C::Fields<'db> { - let (zalsa, zalsa_local) = db.zalsas(); let id = AsId::as_id(&s); let field_ingredient_index = self.ingredient_index.successor(relative_tracked_index); let data = Self::data(zalsa.table(), id); @@ -772,10 +771,9 @@ where /// The caller is responsible for selecting the appropriate element. pub fn untracked_field<'db>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, s: C::Struct<'db>, ) -> &'db C::Fields<'db> { - let zalsa = db.zalsa(); let id = AsId::as_id(&s); let data = Self::data(zalsa.table(), id); @@ -790,11 +788,8 @@ where #[cfg(feature = "salsa_unstable")] /// Returns all data corresponding to the tracked struct. - pub fn entries<'db>( - &'db self, - db: &'db dyn crate::Database, - ) -> impl Iterator> { - db.zalsa().table().slots_of::>() + pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { + zalsa.table().slots_of::>() } } @@ -812,7 +807,8 @@ where unsafe fn maybe_changed_after( &self, - _db: &dyn Database, + _zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabasePointer<'_>, _input: Id, _revision: Revision, _cycle_heads: &mut CycleHeads, diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 5ec38c680..03a948620 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -7,7 +7,7 @@ use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::tracked_struct::{Configuration, Value}; use crate::zalsa::IngredientIndex; -use crate::{Database, Id}; +use crate::Id; /// Created for each tracked struct. /// @@ -55,14 +55,14 @@ where self.ingredient_index } - unsafe fn maybe_changed_after<'db>( - &'db self, - db: &'db dyn Database, + unsafe fn maybe_changed_after( + &self, + zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabasePointer<'_>, input: Id, revision: crate::Revision, _cycle_heads: &mut CycleHeads, ) -> VerifyResult { - let zalsa = db.zalsa(); let data = >::data(zalsa.table(), input); let field_changed_at = data.revisions[self.field_index]; VerifyResult::changed_if(field_changed_at > revision) diff --git a/src/views.rs b/src/views.rs index a14852898..82a913653 100644 --- a/src/views.rs +++ b/src/views.rs @@ -1,10 +1,15 @@ -use std::any::{Any, TypeId}; +use std::{ + any::{Any, TypeId}, + marker::PhantomData, + mem, + ptr::NonNull, +}; -use crate::Database; +use crate::{database::RawDatabasePointer, Database}; /// A `Views` struct is associated with some specific database type /// (a `DatabaseImpl` for some existential `U`). It contains functions -/// to downcast from `dyn Database` to `dyn DbView` for various traits `DbView` via this specific +/// to upcast to `dyn DbView` for various traits `DbView` via this specific /// database type. /// None of these types are known at compilation time, they are all checked /// dynamically through `TypeId` magic. @@ -13,6 +18,7 @@ pub struct Views { view_casters: boxcar::Vec, } +#[derive(Clone)] struct ViewCaster { /// The id of the target type `dyn DbView` that we can cast to. target_type_id: TypeId, @@ -20,50 +26,68 @@ struct ViewCaster { /// The name of the target type `dyn DbView` that we can cast to. type_name: &'static str, - /// Type-erased function pointer that downcasts from `dyn Database` to `dyn DbView`. - cast: ErasedDatabaseDownCasterSig, + /// Type-erased function pointer that upcasts to `dyn DbView`. + cast: ErasedDatabaseUpCasterSig, } impl ViewCaster { - fn new(func: unsafe fn(&dyn Database) -> &DbView) -> ViewCaster { + fn new(func: DatabaseUpCasterSigRaw) -> ViewCaster { ViewCaster { target_type_id: TypeId::of::(), type_name: std::any::type_name::(), // SAFETY: We are type erasing for storage, taking care of unerasing before we call // the function pointer. cast: unsafe { - std::mem::transmute::, ErasedDatabaseDownCasterSig>( - func, - ) + mem::transmute::, ErasedDatabaseUpCasterSig>(func) }, } } } -type ErasedDatabaseDownCasterSig = unsafe fn(&dyn Database) -> *const (); -type DatabaseDownCasterSig = unsafe fn(&dyn Database) -> &DbView; +type ErasedDatabaseUpCasterSig = unsafe fn(RawDatabasePointer<'_>) -> NonNull<()>; +type DatabaseUpCasterSigRaw = + for<'db> unsafe fn(RawDatabasePointer<'db>) -> NonNull; +type DatabaseUpCasterSig = for<'db> unsafe fn(RawDatabasePointer<'db>) -> &'db DbView; +type DatabaseUpCasterSigMut = + for<'db> unsafe fn(RawDatabasePointer<'db>) -> &'db mut DbView; -pub struct DatabaseDownCaster(TypeId, DatabaseDownCasterSig); +#[repr(transparent)] +pub struct DatabaseUpCaster(ViewCaster, PhantomData DbView>); -impl DatabaseDownCaster { - pub fn downcast<'db>(&self, db: &'db dyn Database) -> &'db DbView { - assert_eq!( - self.0, - db.type_id(), - "Database type does not match the expected type for this `Views` instance" - ); - // SAFETY: We've asserted that the database is correct. - unsafe { (self.1)(db) } +impl Clone for DatabaseUpCaster { + fn clone(&self) -> Self { + Self(self.0.clone(), self.1) } +} - /// Downcast `db` to `DbView`. +impl DatabaseUpCaster { + /// Upcast `db` to `DbView`. + /// + /// # Safety + /// + /// The caller must ensure that `db` is of the correct type. + #[inline] + pub unsafe fn upcast_unchecked<'db>(&self, db: RawDatabasePointer<'db>) -> &'db DbView { + // SAFETY: The caller must ensure that `db` is of the correct type. + unsafe { + (mem::transmute::>(self.0.cast))( + db, + ) + } + } + /// Upcast `db` to `DbView`. /// /// # Safety /// /// The caller must ensure that `db` is of the correct type. - pub unsafe fn downcast_unchecked<'db>(&self, db: &'db dyn Database) -> &'db DbView { + #[inline] + pub unsafe fn upcast_mut_unchecked<'db>(&self, db: RawDatabasePointer<'db>) -> &'db mut DbView { // SAFETY: The caller must ensure that `db` is of the correct type. - unsafe { (self.1)(db) } + unsafe { + (mem::transmute::>( + self.0.cast, + ))(db) + } } } @@ -71,16 +95,16 @@ impl Views { pub(crate) fn new() -> Self { let source_type_id = TypeId::of::(); let view_casters = boxcar::Vec::new(); - // special case the no-op transformation, that way we skip out on reconstructing the wide pointer - view_casters.push(ViewCaster::new::(|db| db)); + view_casters.push(ViewCaster::new::(|db| db.ptr.cast::())); Self { source_type_id, view_casters, } } - /// Add a new downcaster from `dyn Database` to `dyn DbView`. - pub fn add(&self, func: DatabaseDownCasterSig) { + /// Add a new upcaster to `dyn DbView`. + pub fn add(&self, func: fn(&Concrete) -> &DbView) { + assert_eq!(self.source_type_id, TypeId::of::()); let target_type_id = TypeId::of::(); if self .view_casters @@ -89,30 +113,39 @@ impl Views { { return; } - self.view_casters.push(ViewCaster::new::(func)); + // SAFETY: We are type erasing the function pointer for storage, and we will unerase it + // before we call it. + self.view_casters.push(ViewCaster::new::(unsafe { + mem::transmute:: &DbView, DatabaseUpCasterSigRaw>(func) + })); + } + + #[inline] + pub fn base_database_upcaster(&self) -> &DatabaseUpCaster { + // SAFETY: The type-erased function pointer is guaranteed to be valid for `dyn Database` + // since we created it with the same type. + unsafe { &*((&raw const self.view_casters[0]).cast::>()) } } - /// Retrieve an downcaster function from `dyn Database` to `dyn DbView`. + /// Retrieve an upcaster function to `dyn DbView`. /// /// # Panics /// /// If the underlying type of `db` is not the same as the database type this upcasts was created for. - pub fn downcaster_for(&self) -> DatabaseDownCaster { + pub fn upcaster_for(&self) -> &DatabaseUpCaster { let view_type_id = TypeId::of::(); for (_idx, view) in self.view_casters.iter() { if view.target_type_id == view_type_id { // SAFETY: We are unerasing the type erased function pointer having made sure the // TypeId matches. - return DatabaseDownCaster(self.source_type_id, unsafe { - std::mem::transmute::>( - view.cast, - ) - }); + return unsafe { + &*((view as *const ViewCaster).cast::>()) + }; } } panic!( - "No downcaster registered for type `{}` in `Views`", + "No upcaster registered for type `{}` in `Views`", std::any::type_name::(), ); } diff --git a/src/zalsa.rs b/src/zalsa.rs index b5b90a04d..6576e26af 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -7,6 +7,7 @@ use std::panic::RefUnwindSafe; use rustc_hash::FxHashMap; +use crate::database::RawDatabasePointer; use crate::ingredient::{Ingredient, Jar}; use crate::nonce::{Nonce, NonceGenerator}; use crate::runtime::Runtime; @@ -55,7 +56,7 @@ pub unsafe trait ZalsaDatabase: Any { /// Clone the database. #[doc(hidden)] - fn fork_db(&self) -> Box; + fn fork_db(&self) -> RawDatabasePointer<'static>; } pub fn views(db: &Db) -> &Views { diff --git a/tests/debug_db_contents.rs b/tests/debug_db_contents.rs index 30efb1736..5acab277b 100644 --- a/tests/debug_db_contents.rs +++ b/tests/debug_db_contents.rs @@ -20,14 +20,15 @@ fn tracked_fn(db: &dyn salsa::Database, input: InputStruct) -> TrackedStruct<'_> #[test] fn execute() { + use salsa::plumbing::ZalsaDatabase; let db = salsa::DatabaseImpl::new(); let _ = InternedStruct::new(&db, "Salsa".to_string()); let _ = InternedStruct::new(&db, "Salsa2".to_string()); // test interned structs - let interned = InternedStruct::ingredient(&db) - .entries(&db) + let interned = InternedStruct::ingredient(db.zalsa()) + .entries(db.zalsa()) .collect::>(); assert_eq!(interned.len(), 2); @@ -38,7 +39,7 @@ fn execute() { let input = InputStruct::new(&db, 22); let inputs = InputStruct::ingredient(&db) - .entries(&db) + .entries(db.zalsa()) .collect::>(); assert_eq!(inputs.len(), 1); @@ -48,7 +49,7 @@ fn execute() { let computed = tracked_fn(&db, input).field(&db); assert_eq!(computed, 44); let tracked = TrackedStruct::ingredient(&db) - .entries(&db) + .entries(db.zalsa()) .collect::>(); assert_eq!(tracked.len(), 1); diff --git a/tests/interned-structs.rs b/tests/interned-structs.rs index da9ec6ae5..9734992a8 100644 --- a/tests/interned-structs.rs +++ b/tests/interned-structs.rs @@ -130,13 +130,13 @@ fn interning_boxed() { #[test] fn interned_structs_have_public_ingredients() { - use salsa::plumbing::AsId; + use salsa::plumbing::{AsId, ZalsaDatabase}; let db = salsa::DatabaseImpl::new(); let s = InternedString::new(&db, String::from("Hello, world!")); let underlying_id = s.0; - let data = InternedString::ingredient(&db).data(&db, underlying_id.as_id()); + let data = InternedString::ingredient(db.zalsa()).data(db.zalsa(), underlying_id.as_id()); assert_eq!(data, &(String::from("Hello, world!"),)); } diff --git a/tests/interned-structs_self_ref.rs b/tests/interned-structs_self_ref.rs index 384dea7ec..2e82f6925 100644 --- a/tests/interned-structs_self_ref.rs +++ b/tests/interned-structs_self_ref.rs @@ -149,7 +149,8 @@ const _: () = { String: zalsa_::interned::HashEqLike, { Configuration_::ingredient(db).intern( - db.as_dyn_database(), + db.zalsa(), + db.zalsa_local(), StructKey::<'db>(data, std::marker::PhantomData::default()), |id, data| { StructData( @@ -163,20 +164,20 @@ const _: () = { where Db_: ?Sized + zalsa_::Database, { - let fields = Configuration_::ingredient(db).fields(db.as_dyn_database(), self); + let fields = Configuration_::ingredient(db).fields(db.zalsa(), self); std::clone::Clone::clone((&fields.0)) } fn other(self, db: &'db Db_) -> InternedString<'db> where Db_: ?Sized + zalsa_::Database, { - let fields = Configuration_::ingredient(db).fields(db.as_dyn_database(), self); + let fields = Configuration_::ingredient(db).fields(db.zalsa(), self); std::clone::Clone::clone((&fields.1)) } #[doc = r" Default debug formatting for this struct (may be useful if you define your own `Debug` impl)"] pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { zalsa_::with_attached_database(|db| { - let fields = Configuration_::ingredient(db).fields(db.as_dyn_database(), this); + let fields = Configuration_::ingredient(db).fields(db.zalsa(), this); let mut f = f.debug_struct("InternedString"); let f = f.field("data", &fields.0); let f = f.field("other", &fields.1);