From 7a4c15bcba9e511dfb8b63a3de902927a77d5b7c Mon Sep 17 00:00:00 2001 From: lcnr Date: Thu, 26 Jun 2025 14:51:16 +0200 Subject: [PATCH 1/5] significantly improve provisional cache rebasing --- .../src/solve/search_graph.rs | 4 +- .../rustc_type_ir/src/search_graph/mod.rs | 166 +++++++++--------- .../rustc_type_ir/src/search_graph/stack.rs | 5 +- 3 files changed, 90 insertions(+), 85 deletions(-) diff --git a/compiler/rustc_next_trait_solver/src/solve/search_graph.rs b/compiler/rustc_next_trait_solver/src/solve/search_graph.rs index 12cbc7e8f91e8..84f8eda4f8da7 100644 --- a/compiler/rustc_next_trait_solver/src/solve/search_graph.rs +++ b/compiler/rustc_next_trait_solver/src/solve/search_graph.rs @@ -48,7 +48,9 @@ where ) -> QueryResult { match kind { PathKind::Coinductive => response_no_constraints(cx, input, Certainty::Yes), - PathKind::Unknown => response_no_constraints(cx, input, Certainty::overflow(false)), + PathKind::Unknown | PathKind::ForcedAmbiguity => { + response_no_constraints(cx, input, Certainty::overflow(false)) + } // Even though we know these cycles to be unproductive, we still return // overflow during coherence. This is both as we are not 100% confident in // the implementation yet and any incorrect errors would be unsound there. diff --git a/compiler/rustc_type_ir/src/search_graph/mod.rs b/compiler/rustc_type_ir/src/search_graph/mod.rs index 8941360d2d0c5..e2bc7b372b331 100644 --- a/compiler/rustc_type_ir/src/search_graph/mod.rs +++ b/compiler/rustc_type_ir/src/search_graph/mod.rs @@ -21,10 +21,9 @@ use std::marker::PhantomData; use derive_where::derive_where; #[cfg(feature = "nightly")] use rustc_macros::{Decodable_NoContext, Encodable_NoContext, HashStable_NoContext}; +use rustc_type_ir::data_structures::HashMap; use tracing::{debug, instrument}; -use crate::data_structures::HashMap; - mod stack; use stack::{Stack, StackDepth, StackEntry}; mod global_cache; @@ -137,6 +136,12 @@ pub enum PathKind { Unknown, /// A path with at least one coinductive step. Such cycles hold. Coinductive, + /// A path which is treated as ambiguous. Once a path has this path kind + /// any other segment does not change its kind. + /// + /// This is currently only used when fuzzing to support negative reasoning. + /// For more details, see #143054. + ForcedAmbiguity, } impl PathKind { @@ -149,6 +154,9 @@ impl PathKind { /// to `max(self, rest)`. fn extend(self, rest: PathKind) -> PathKind { match (self, rest) { + (PathKind::ForcedAmbiguity, _) | (_, PathKind::ForcedAmbiguity) => { + PathKind::ForcedAmbiguity + } (PathKind::Coinductive, _) | (_, PathKind::Coinductive) => PathKind::Coinductive, (PathKind::Unknown, _) | (_, PathKind::Unknown) => PathKind::Unknown, (PathKind::Inductive, PathKind::Inductive) => PathKind::Inductive, @@ -187,41 +195,6 @@ impl UsageKind { } } -/// For each goal we track whether the paths from this goal -/// to its cycle heads are coinductive. -/// -/// This is a necessary condition to rebase provisional cache -/// entries. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum AllPathsToHeadCoinductive { - Yes, - No, -} -impl From for AllPathsToHeadCoinductive { - fn from(path: PathKind) -> AllPathsToHeadCoinductive { - match path { - PathKind::Coinductive => AllPathsToHeadCoinductive::Yes, - _ => AllPathsToHeadCoinductive::No, - } - } -} -impl AllPathsToHeadCoinductive { - #[must_use] - fn merge(self, other: impl Into) -> Self { - match (self, other.into()) { - (AllPathsToHeadCoinductive::Yes, AllPathsToHeadCoinductive::Yes) => { - AllPathsToHeadCoinductive::Yes - } - (AllPathsToHeadCoinductive::No, _) | (_, AllPathsToHeadCoinductive::No) => { - AllPathsToHeadCoinductive::No - } - } - } - fn and_merge(&mut self, other: impl Into) { - *self = self.merge(other); - } -} - #[derive(Debug, Clone, Copy)] struct AvailableDepth(usize); impl AvailableDepth { @@ -261,9 +234,9 @@ impl AvailableDepth { /// /// We also track all paths from this goal to that head. This is necessary /// when rebasing provisional cache results. -#[derive(Clone, Debug, PartialEq, Eq, Default)] +#[derive(Clone, Debug, Default)] struct CycleHeads { - heads: BTreeMap, + heads: BTreeMap, } impl CycleHeads { @@ -283,27 +256,16 @@ impl CycleHeads { self.heads.first_key_value().map(|(k, _)| *k) } - fn remove_highest_cycle_head(&mut self) { + fn remove_highest_cycle_head(&mut self) -> PathsToNested { let last = self.heads.pop_last(); - debug_assert_ne!(last, None); + last.unwrap().1 } - fn insert( - &mut self, - head: StackDepth, - path_from_entry: impl Into + Copy, - ) { - self.heads.entry(head).or_insert(path_from_entry.into()).and_merge(path_from_entry); + fn insert(&mut self, head: StackDepth, path_from_entry: impl Into + Copy) { + *self.heads.entry(head).or_insert(path_from_entry.into()) |= path_from_entry.into(); } - fn merge(&mut self, heads: &CycleHeads) { - for (&head, &path_from_entry) in heads.heads.iter() { - self.insert(head, path_from_entry); - debug_assert!(matches!(self.heads[&head], AllPathsToHeadCoinductive::Yes)); - } - } - - fn iter(&self) -> impl Iterator + '_ { + fn iter(&self) -> impl Iterator + '_ { self.heads.iter().map(|(k, v)| (*k, *v)) } @@ -317,13 +279,7 @@ impl CycleHeads { Ordering::Equal => continue, Ordering::Greater => unreachable!(), } - - let path_from_entry = match step_kind { - PathKind::Coinductive => AllPathsToHeadCoinductive::Yes, - PathKind::Unknown | PathKind::Inductive => path_from_entry, - }; - - self.insert(head, path_from_entry); + self.insert(head, path_from_entry.extend_with(step_kind)); } } } @@ -332,13 +288,14 @@ bitflags::bitflags! { /// Tracks how nested goals have been accessed. This is necessary to disable /// global cache entries if computing them would otherwise result in a cycle or /// access a provisional cache entry. - #[derive(Debug, Clone, Copy)] + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct PathsToNested: u8 { /// The initial value when adding a goal to its own nested goals. const EMPTY = 1 << 0; const INDUCTIVE = 1 << 1; const UNKNOWN = 1 << 2; const COINDUCTIVE = 1 << 3; + const FORCED_AMBIGUITY = 1 << 4; } } impl From for PathsToNested { @@ -347,6 +304,7 @@ impl From for PathsToNested { PathKind::Inductive => PathsToNested::INDUCTIVE, PathKind::Unknown => PathsToNested::UNKNOWN, PathKind::Coinductive => PathsToNested::COINDUCTIVE, + PathKind::ForcedAmbiguity => PathsToNested::FORCED_AMBIGUITY, } } } @@ -379,10 +337,45 @@ impl PathsToNested { self.insert(PathsToNested::COINDUCTIVE); } } + PathKind::ForcedAmbiguity => { + if self.intersects( + PathsToNested::EMPTY + | PathsToNested::INDUCTIVE + | PathsToNested::UNKNOWN + | PathsToNested::COINDUCTIVE, + ) { + self.remove( + PathsToNested::EMPTY + | PathsToNested::INDUCTIVE + | PathsToNested::UNKNOWN + | PathsToNested::COINDUCTIVE, + ); + self.insert(PathsToNested::FORCED_AMBIGUITY); + } + } } self } + + #[must_use] + fn extend_with_paths(self, path: PathsToNested) -> Self { + let mut new = PathsToNested::empty(); + for p in path.iter_paths() { + new |= self.extend_with(p); + } + new + } + + fn iter_paths(self) -> impl Iterator { + let (PathKind::Inductive + | PathKind::Unknown + | PathKind::Coinductive + | PathKind::ForcedAmbiguity); + [PathKind::Inductive, PathKind::Unknown, PathKind::Coinductive, PathKind::ForcedAmbiguity] + .into_iter() + .filter(move |&p| self.contains(p.into())) + } } /// The nested goals of each stack entry and the path from the @@ -693,7 +686,7 @@ impl, X: Cx> SearchGraph { if let Some((_scope, expected)) = validate_cache { // Do not try to move a goal into the cache again if we're testing // the global cache. - assert_eq!(evaluation_result.result, expected, "input={input:?}"); + assert_eq!(expected, evaluation_result.result, "input={input:?}"); } else if D::inspect_is_noop(inspect) { self.insert_global_cache(cx, input, evaluation_result, dep_node) } @@ -782,7 +775,7 @@ impl, X: Cx> SearchGraph { stack_entry: &StackEntry, mut mutate_result: impl FnMut(X::Input, X::Result) -> X::Result, ) { - let head = self.stack.next_index(); + let popped_head = self.stack.next_index(); #[allow(rustc::potential_query_instability)] self.provisional_cache.retain(|&input, entries| { entries.retain_mut(|entry| { @@ -792,30 +785,37 @@ impl, X: Cx> SearchGraph { path_from_head, result, } = entry; - if heads.highest_cycle_head() == head { + let ep = if heads.highest_cycle_head() == popped_head { heads.remove_highest_cycle_head() } else { return true; - } - - // We only try to rebase if all paths from the cache entry - // to its heads are coinductive. In this case these cycle - // kinds won't change, no matter the goals between these - // heads and the provisional cache entry. - if heads.iter().any(|(_, p)| matches!(p, AllPathsToHeadCoinductive::No)) { - return false; - } + }; - // The same for nested goals of the cycle head. - if stack_entry.heads.iter().any(|(_, p)| matches!(p, AllPathsToHeadCoinductive::No)) - { - return false; + // We're rebasing an entry `e` over a head `p`. This head + // has a number of own heads `h` it depends on. We need to + // make sure that the path kind of all paths `hph` remain the + // same after rebasing. + // + // After rebasing the cycles `hph` will go through `e`. We need + // to make sure that forall possible paths `hep` and `heph` + // is equal to `hph.` + for (h, ph) in stack_entry.heads.iter() { + let hp = + Self::cycle_path_kind(&self.stack, stack_entry.step_kind_from_parent, h); + let hph = ph.extend_with(hp); + let he = hp.extend(*path_from_head); + let hep = ep.extend_with(he); + for hep in hep.iter_paths() { + let heph = ph.extend_with(hep); + if hph != heph { + return false; + } + } + + let eph = ep.extend_with_paths(ph); + heads.insert(h, eph); } - // Merge the cycle heads of the provisional cache entry and the - // popped head. If the popped cycle head was a root, discard all - // provisional cache entries which depend on it. - heads.merge(&stack_entry.heads); let Some(head) = heads.opt_highest_cycle_head() else { return false; }; diff --git a/compiler/rustc_type_ir/src/search_graph/stack.rs b/compiler/rustc_type_ir/src/search_graph/stack.rs index e0fd934df698f..a58cd82b02303 100644 --- a/compiler/rustc_type_ir/src/search_graph/stack.rs +++ b/compiler/rustc_type_ir/src/search_graph/stack.rs @@ -3,7 +3,7 @@ use std::ops::{Index, IndexMut}; use derive_where::derive_where; use rustc_index::IndexVec; -use super::{AvailableDepth, Cx, CycleHeads, NestedGoals, PathKind, UsageKind}; +use crate::search_graph::{AvailableDepth, Cx, CycleHeads, NestedGoals, PathKind, UsageKind}; rustc_index::newtype_index! { #[orderable] @@ -79,6 +79,9 @@ impl Stack { } pub(super) fn push(&mut self, entry: StackEntry) -> StackDepth { + if cfg!(debug_assertions) && self.entries.iter().any(|e| e.input == entry.input) { + panic!("pushing duplicate entry on stack: {entry:?} {:?}", self.entries); + } self.entries.push(entry) } From fe2752f327a37fab66e22d53d513eef1291a481c Mon Sep 17 00:00:00 2001 From: lcnr Date: Fri, 13 Jun 2025 17:53:41 +0200 Subject: [PATCH 2/5] gaming :3 --- .../src/solve/eval_ctxt/mod.rs | 15 +- .../src/search_graph/global_cache.rs | 10 +- .../rustc_type_ir/src/search_graph/mod.rs | 519 ++++++++++++++---- .../rustc_type_ir/src/search_graph/stack.rs | 14 +- .../rustc_type_ir/src/search_graph/tree.rs | 213 +++++++ .../cycles/inductive-cycle-but-err.stderr | 10 + 6 files changed, 675 insertions(+), 106 deletions(-) create mode 100644 compiler/rustc_type_ir/src/search_graph/tree.rs diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs index b42587618b57b..dc2d953d168a3 100644 --- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs +++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs @@ -449,12 +449,15 @@ where let (orig_values, canonical_goal) = self.canonicalize_goal(goal); let mut goal_evaluation = self.inspect.new_goal_evaluation(goal, &orig_values, goal_evaluation_kind); - let canonical_result = self.search_graph.evaluate_goal( - self.cx(), - canonical_goal, - self.step_kind_for_source(source), - &mut goal_evaluation, - ); + let canonical_result = self + .search_graph + .evaluate_goal( + self.cx(), + canonical_goal, + self.step_kind_for_source(source), + &mut goal_evaluation, + ) + .1; goal_evaluation.query_result(canonical_result); self.inspect.goal_evaluation(goal_evaluation); let response = match canonical_result { diff --git a/compiler/rustc_type_ir/src/search_graph/global_cache.rs b/compiler/rustc_type_ir/src/search_graph/global_cache.rs index eb56c1af408bb..080bde7db5e2d 100644 --- a/compiler/rustc_type_ir/src/search_graph/global_cache.rs +++ b/compiler/rustc_type_ir/src/search_graph/global_cache.rs @@ -47,8 +47,14 @@ impl GlobalCache { evaluation_result: EvaluationResult, dep_node: X::DepNodeIndex, ) { - let EvaluationResult { encountered_overflow, required_depth, heads, nested_goals, result } = - evaluation_result; + let EvaluationResult { + node_id: _, + encountered_overflow, + required_depth, + heads, + nested_goals, + result, + } = evaluation_result; debug_assert!(heads.is_empty()); let result = cx.mk_tracked(result, dep_node); let entry = self.map.entry(input).or_default(); diff --git a/compiler/rustc_type_ir/src/search_graph/mod.rs b/compiler/rustc_type_ir/src/search_graph/mod.rs index e2bc7b372b331..87bc52f076e25 100644 --- a/compiler/rustc_type_ir/src/search_graph/mod.rs +++ b/compiler/rustc_type_ir/src/search_graph/mod.rs @@ -21,14 +21,16 @@ use std::marker::PhantomData; use derive_where::derive_where; #[cfg(feature = "nightly")] use rustc_macros::{Decodable_NoContext, Encodable_NoContext, HashStable_NoContext}; -use rustc_type_ir::data_structures::HashMap; +use rustc_type_ir::data_structures::{HashMap, HashSet}; use tracing::{debug, instrument}; mod stack; use stack::{Stack, StackDepth, StackEntry}; mod global_cache; +mod tree; use global_cache::CacheData; pub use global_cache::GlobalCache; +use tree::SearchTree; /// The search graph does not simply use `Interner` directly /// to enable its fuzzing without having to stub the rest of @@ -282,6 +284,14 @@ impl CycleHeads { self.insert(head, path_from_entry.extend_with(step_kind)); } } + + fn contains_stack_entry(&self, depth: StackDepth) -> bool { + self.heads.contains_key(&depth) + } + + fn contains(&self, other: &CycleHeads) -> bool { + other.heads.iter().all(|(h, &path)| self.heads.get(h).is_some_and(|p| p.contains(path))) + } } bitflags::bitflags! { @@ -436,6 +446,7 @@ impl NestedGoals { /// goals still on the stack. #[derive_where(Debug; X: Cx)] struct ProvisionalCacheEntry { + entry_node_id: tree::NodeId, /// Whether evaluating the goal encountered overflow. This is used to /// disable the cache entry except if the last goal on the stack is /// already involved in this cycle. @@ -459,6 +470,7 @@ struct ProvisionalCacheEntry { /// evaluation. #[derive_where(Debug; X: Cx)] struct EvaluationResult { + node_id: tree::NodeId, encountered_overflow: bool, required_depth: usize, heads: CycleHeads, @@ -473,13 +485,14 @@ impl EvaluationResult { result: X::Result, ) -> EvaluationResult { EvaluationResult { - encountered_overflow, + encountered_overflow: final_entry.encountered_overflow | encountered_overflow, // Unlike `encountered_overflow`, we share `heads`, `required_depth`, // and `nested_goals` between evaluations. required_depth: final_entry.required_depth, heads: final_entry.heads, nested_goals: final_entry.nested_goals, - // We only care about the final result. + // We only care about the result and the `node_id` of the final iteration. + node_id: final_entry.node_id, result, } } @@ -497,6 +510,8 @@ pub struct SearchGraph, X: Cx = ::Cx> { /// is only valid until the result of one of its cycle heads changes. provisional_cache: HashMap>>, + tree: SearchTree, + _marker: PhantomData, } @@ -520,6 +535,7 @@ impl, X: Cx> SearchGraph { root_depth: AvailableDepth(root_depth), stack: Default::default(), provisional_cache: Default::default(), + tree: Default::default(), _marker: PhantomData, } } @@ -574,6 +590,11 @@ impl, X: Cx> SearchGraph { self.stack.len() } + /// This should only be used for debugging purposes. + pub fn debug_stack(&self) -> &impl std::fmt::Debug { + &self.stack + } + /// Whether the path from `head` to the current stack entry is inductive or coinductive. /// /// The `step_kind_to_head` is used to add a single additional path segment to the path on @@ -598,13 +619,16 @@ impl, X: Cx> SearchGraph { input: X::Input, step_kind_from_parent: PathKind, inspect: &mut D::ProofTreeBuilder, - ) -> X::Result { + ) -> (Option, X::Result) { let Some(available_depth) = AvailableDepth::allowed_depth_for_nested::(self.root_depth, &self.stack) else { - return self.handle_overflow(cx, input, inspect); + return (None, self.handle_overflow(cx, input, inspect)); }; + let node_id = + self.tree.create_node(&self.stack, input, step_kind_from_parent, available_depth); + // We check the provisional cache before checking the global cache. This simplifies // the implementation as we can avoid worrying about cases where both the global and // provisional cache may apply, e.g. consider the following example @@ -613,8 +637,8 @@ impl, X: Cx> SearchGraph { // - A // - BA cycle // - CB :x: - if let Some(result) = self.lookup_provisional_cache(input, step_kind_from_parent) { - return result; + if let Some(result) = self.lookup_provisional_cache(node_id, input, step_kind_from_parent) { + return (Some(node_id), result); } // Lookup the global cache unless we're building proof trees or are currently @@ -630,9 +654,9 @@ impl, X: Cx> SearchGraph { .inspect(|expected| debug!(?expected, "validate cache entry")) .map(|r| (scope, r)) } else if let Some(result) = - self.lookup_global_cache(cx, input, step_kind_from_parent, available_depth) + self.lookup_global_cache(cx, node_id, input, step_kind_from_parent, available_depth) { - return result; + return (None, result); } else { None }; @@ -641,13 +665,14 @@ impl, X: Cx> SearchGraph { // avoid iterating over the stack in case a goal has already been computed. // This may not have an actual performance impact and we could reorder them // as it may reduce the number of `nested_goals` we need to track. - if let Some(result) = self.check_cycle_on_stack(cx, input, step_kind_from_parent) { + if let Some(result) = self.check_cycle_on_stack(cx, node_id, input, step_kind_from_parent) { debug_assert!(validate_cache.is_none(), "global cache and cycle on stack: {input:?}"); - return result; + return (Some(node_id), result); } // Unfortunate, it looks like we actually have to compute this goal. self.stack.push(StackEntry { + node_id, input, step_kind_from_parent, available_depth, @@ -678,6 +703,14 @@ impl, X: Cx> SearchGraph { evaluation_result.encountered_overflow, UpdateParentGoalCtxt::Ordinary(&evaluation_result.nested_goals), ); + // FIXME: Cloning the cycle heads here is quite ass. We should make cycle heads + // CoW and use reference counting. + self.tree.finish_evaluation( + evaluation_result.node_id, + evaluation_result.encountered_overflow, + evaluation_result.heads.clone(), + evaluation_result.result, + ); let result = evaluation_result.result; // We're now done with this goal. We only add the root of cycles to the global cache. @@ -690,10 +723,13 @@ impl, X: Cx> SearchGraph { } else if D::inspect_is_noop(inspect) { self.insert_global_cache(cx, input, evaluation_result, dep_node) } + + (None, result) } else if D::ENABLE_PROVISIONAL_CACHE { debug_assert!(validate_cache.is_none(), "unexpected non-root: {input:?}"); let entry = self.provisional_cache.entry(input).or_default(); let EvaluationResult { + node_id, encountered_overflow, required_depth: _, heads, @@ -705,15 +741,20 @@ impl, X: Cx> SearchGraph { step_kind_from_parent, heads.highest_cycle_head(), ); - let provisional_cache_entry = - ProvisionalCacheEntry { encountered_overflow, heads, path_from_head, result }; + let provisional_cache_entry = ProvisionalCacheEntry { + entry_node_id: node_id, + encountered_overflow, + heads, + path_from_head, + result, + }; debug!(?provisional_cache_entry); entry.push(provisional_cache_entry); + (Some(node_id), result) } else { debug_assert!(validate_cache.is_none(), "unexpected non-root: {input:?}"); + (Some(node_id), result) } - - result } fn handle_overflow( @@ -741,11 +782,17 @@ impl, X: Cx> SearchGraph { /// When reevaluating a goal with a changed provisional result, all provisional cache entry /// which depend on this goal get invalidated. - fn clear_dependent_provisional_results(&mut self) { - let head = self.stack.next_index(); + fn clear_dependent_provisional_results( + stack: &Stack, + provisional_cache: &mut HashMap>>, + mut handle_removed_entry: impl FnMut(X::Input, ProvisionalCacheEntry), + ) { + let head = stack.next_index(); #[allow(rustc::potential_query_instability)] - self.provisional_cache.retain(|_, entries| { - entries.retain(|entry| entry.heads.highest_cycle_head() != head); + provisional_cache.retain(|&input, entries| { + for e in entries.extract_if(.., |entry| entry.heads.highest_cycle_head() == head) { + handle_removed_entry(input, e) + } !entries.is_empty() }); } @@ -771,15 +818,17 @@ impl, X: Cx> SearchGraph { /// goals whose result doesn't actually depend on this cycle head, but that's acceptable /// to me. fn rebase_provisional_cache_entries( - &mut self, + stack: &Stack, + provisional_cache: &mut HashMap>>, stack_entry: &StackEntry, mut mutate_result: impl FnMut(X::Input, X::Result) -> X::Result, ) { - let popped_head = self.stack.next_index(); + let popped_head = stack.next_index(); #[allow(rustc::potential_query_instability)] - self.provisional_cache.retain(|&input, entries| { + provisional_cache.retain(|&input, entries| { entries.retain_mut(|entry| { let ProvisionalCacheEntry { + entry_node_id: _, encountered_overflow: _, heads, path_from_head, @@ -800,8 +849,7 @@ impl, X: Cx> SearchGraph { // to make sure that forall possible paths `hep` and `heph` // is equal to `hph.` for (h, ph) in stack_entry.heads.iter() { - let hp = - Self::cycle_path_kind(&self.stack, stack_entry.step_kind_from_parent, h); + let hp = Self::cycle_path_kind(&stack, stack_entry.step_kind_from_parent, h); let hph = ph.extend_with(hp); let he = hp.extend(*path_from_head); let hep = ep.extend_with(he); @@ -823,13 +871,14 @@ impl, X: Cx> SearchGraph { // We now care about the path from the next highest cycle head to the // provisional cache entry. *path_from_head = path_from_head.extend(Self::cycle_path_kind( - &self.stack, + &stack, stack_entry.step_kind_from_parent, head, )); // Mutate the result of the provisional cache entry in case we did // not reach a fixpoint. *result = mutate_result(input, *result); + debug!(?input, ?entry, "rebased entry"); true }); !entries.is_empty() @@ -838,6 +887,7 @@ impl, X: Cx> SearchGraph { fn lookup_provisional_cache( &mut self, + node_id: tree::NodeId, input: X::Input, step_kind_from_parent: PathKind, ) -> Option { @@ -846,8 +896,13 @@ impl, X: Cx> SearchGraph { } let entries = self.provisional_cache.get(&input)?; - for &ProvisionalCacheEntry { encountered_overflow, ref heads, path_from_head, result } in - entries + for &ProvisionalCacheEntry { + entry_node_id, + encountered_overflow, + ref heads, + path_from_head, + result, + } in entries { let head = heads.highest_cycle_head(); if encountered_overflow { @@ -879,6 +934,12 @@ impl, X: Cx> SearchGraph { ); debug_assert!(self.stack[head].has_been_used.is_some()); debug!(?head, ?path_from_head, "provisional cache hit"); + let provisional_results = self + .stack + .iter_enumerated() + .filter_map(|(depth, entry)| entry.provisional_result.map(|r| (depth, r))) + .collect(); + self.tree.provisional_cache_hit(node_id, entry_node_id, provisional_results); return Some(result); } } @@ -919,6 +980,7 @@ impl, X: Cx> SearchGraph { // A provisional cache entry is applicable if the path to // its highest cycle head is equal to the expected path. for &ProvisionalCacheEntry { + entry_node_id: _, encountered_overflow, ref heads, path_from_head: head_to_provisional, @@ -977,6 +1039,7 @@ impl, X: Cx> SearchGraph { fn lookup_global_cache( &mut self, cx: X, + node_id: tree::NodeId, input: X::Input, step_kind_from_parent: PathKind, available_depth: AvailableDepth, @@ -1000,6 +1063,7 @@ impl, X: Cx> SearchGraph { ); debug!(?required_depth, "global cache hit"); + self.tree.global_cache_hit(node_id); Some(result) }) } @@ -1007,6 +1071,7 @@ impl, X: Cx> SearchGraph { fn check_cycle_on_stack( &mut self, cx: X, + node_id: tree::NodeId, input: X::Input, step_kind_from_parent: PathKind, ) -> Option { @@ -1037,16 +1102,21 @@ impl, X: Cx> SearchGraph { // Return the provisional result or, if we're in the first iteration, // start with no constraints. - if let Some(result) = self.stack[head].provisional_result { - Some(result) - } else { - Some(D::initial_provisional_result(cx, path_kind, input)) - } + let result = self.stack[head] + .provisional_result + .unwrap_or_else(|| D::initial_provisional_result(cx, path_kind, input)); + + let provisional_results = self + .stack + .iter_enumerated() + .filter_map(|(depth, entry)| entry.provisional_result.map(|r| (depth, r))) + .collect(); + self.tree.cycle_on_stack(node_id, self.stack[head].node_id, result, provisional_results); + Some(result) } /// Whether we've reached a fixpoint when evaluating a cycle head. fn reached_fixpoint( - &mut self, cx: X, stack_entry: &StackEntry, usage_kind: UsageKind, @@ -1072,41 +1142,40 @@ impl, X: Cx> SearchGraph { input: X::Input, inspect: &mut D::ProofTreeBuilder, ) -> EvaluationResult { - // We reset `encountered_overflow` each time we rerun this goal - // but need to make sure we currently propagate it to the global - // cache even if only some of the evaluations actually reach the - // recursion limit. - let mut encountered_overflow = false; - let mut i = 0; - loop { - let result = D::compute_goal(self, cx, input, inspect); - let stack_entry = self.stack.pop(); - encountered_overflow |= stack_entry.encountered_overflow; - debug_assert_eq!(stack_entry.input, input); - - // If the current goal is not the root of a cycle, we are done. - // - // There are no provisional cache entries which depend on this goal. - let Some(usage_kind) = stack_entry.has_been_used else { - return EvaluationResult::finalize(stack_entry, encountered_overflow, result); - }; + let mut result = D::compute_goal(self, cx, input, inspect); + let mut stack_entry = self.stack.pop(); + let mut encountered_overflow = stack_entry.encountered_overflow; + debug_assert_eq!(stack_entry.input, input); + // If the current goal is not the root of a cycle, we are done. + // + // There are no provisional cache entries which depend on this goal. + let Some(usage_kind) = stack_entry.has_been_used else { + return EvaluationResult::finalize(stack_entry, encountered_overflow, result); + }; - // If it is a cycle head, we have to keep trying to prove it until - // we reach a fixpoint. We need to do so for all cycle heads, - // not only for the root. - // - // See tests/ui/traits/next-solver/cycles/fixpoint-rerun-all-cycle-heads.rs - // for an example. - // - // Check whether we reached a fixpoint, either because the final result - // is equal to the provisional result of the previous iteration, or because - // this was only the root of either coinductive or inductive cycles, and the - // final result is equal to the initial response for that case. - if self.reached_fixpoint(cx, &stack_entry, usage_kind, result) { - self.rebase_provisional_cache_entries(&stack_entry, |_, result| result); - return EvaluationResult::finalize(stack_entry, encountered_overflow, result); - } + // If it is a cycle head, we have to keep trying to prove it until + // we reach a fixpoint. We need to do so for all cycle heads, + // not only for the root. + // + // See tests/ui/traits/next-solver/cycles/fixpoint-rerun-all-cycle-heads.rs + // for an example. + // + // Check whether we reached a fixpoint, either because the final result + // is equal to the provisional result of the previous iteration, or because + // this was only the root of either coinductive or inductive cycles, and the + // final result is equal to the initial response for that case. + if Self::reached_fixpoint(cx, &stack_entry, usage_kind, result) { + Self::rebase_provisional_cache_entries( + &self.stack, + &mut self.provisional_cache, + &stack_entry, + |_, result| result, + ); + return EvaluationResult::finalize(stack_entry, encountered_overflow, result); + } + let mut i = 0; + loop { // If computing this goal results in ambiguity with no constraints, // we do not rerun it. It's incredibly difficult to get a different // response in the next iteration in this case. These changes would @@ -1120,9 +1189,12 @@ impl, X: Cx> SearchGraph { // we also taint all provisional cache entries which depend on the // current goal. if D::is_ambiguous_result(result) { - self.rebase_provisional_cache_entries(&stack_entry, |input, _| { - D::propagate_ambiguity(cx, input, result) - }); + Self::rebase_provisional_cache_entries( + &self.stack, + &mut self.provisional_cache, + &stack_entry, + |input, _| D::propagate_ambiguity(cx, input, result), + ); return EvaluationResult::finalize(stack_entry, encountered_overflow, result); }; @@ -1130,40 +1202,295 @@ impl, X: Cx> SearchGraph { // provisional cache entries which depend on the current goal. i += 1; if i >= D::FIXPOINT_STEP_LIMIT { - debug!("canonical cycle overflow"); + debug!(?result, "canonical cycle overflow"); let result = D::on_fixpoint_overflow(cx, input); - self.rebase_provisional_cache_entries(&stack_entry, |input, _| { - D::on_fixpoint_overflow(cx, input) - }); + Self::rebase_provisional_cache_entries( + &self.stack, + &mut self.provisional_cache, + &stack_entry, + |input, _| D::on_fixpoint_overflow(cx, input), + ); return EvaluationResult::finalize(stack_entry, encountered_overflow, result); } - // Clear all provisional cache entries which depend on a previous provisional - // result of this goal and rerun. - self.clear_dependent_provisional_results(); - - debug!(?result, "fixpoint changed provisional results"); - self.stack.push(StackEntry { - input, - step_kind_from_parent: stack_entry.step_kind_from_parent, - available_depth: stack_entry.available_depth, - provisional_result: Some(result), - // We can keep these goals from previous iterations as they are only - // ever read after finalizing this evaluation. - required_depth: stack_entry.required_depth, - heads: stack_entry.heads, - nested_goals: stack_entry.nested_goals, - // We reset these two fields when rerunning this goal. We could - // keep `encountered_overflow` as it's only used as a performance - // optimization. However, given that the proof tree will likely look - // similar to the previous iterations when reevaluating, it's better - // for caching if the reevaluation also starts out with `false`. - encountered_overflow: false, - has_been_used: None, - }); + debug!(?i, ?result, "changed provisional results"); + match self.reevaluate_goal_on_stack(cx, stack_entry, result, inspect) { + (new_stack_entry, new_result) => { + if new_result == result { + Self::rebase_provisional_cache_entries( + &self.stack, + &mut self.provisional_cache, + &new_stack_entry, + |_, result| result, + ); + return EvaluationResult::finalize( + new_stack_entry, + encountered_overflow, + result, + ); + } else { + result = new_result; + encountered_overflow |= new_stack_entry.encountered_overflow; + stack_entry = new_stack_entry; + } + } + } } } + fn reevaluate_goal_on_stack( + &mut self, + cx: X, + prev_stack_entry: StackEntry, + provisional_result: X::Result, + inspect: &mut D::ProofTreeBuilder, + ) -> (StackEntry, X::Result) { + let node_id = prev_stack_entry.node_id; + let current_depth = self.stack.next_index(); + + let mut removed_entries = BTreeMap::new(); + // Clear all provisional cache entries which depend on a previous provisional + // result of this goal and rerun. + Self::clear_dependent_provisional_results( + &self.stack, + &mut self.provisional_cache, + |input, entry| { + let prev = removed_entries.insert(entry.entry_node_id, (input, entry)); + if let Some(prev) = prev { + unreachable!("duplicate entries for the same `NodeId`: {prev:?}"); + } + }, + ); + self.stack.push(StackEntry { + node_id, + input: prev_stack_entry.input, + step_kind_from_parent: prev_stack_entry.step_kind_from_parent, + available_depth: prev_stack_entry.available_depth, + required_depth: prev_stack_entry.required_depth, + heads: prev_stack_entry.heads, + nested_goals: prev_stack_entry.nested_goals, + provisional_result: Some(provisional_result), + encountered_overflow: false, + has_been_used: None, + }); + + if !D::ENABLE_PROVISIONAL_CACHE { + let result = D::compute_goal(self, cx, prev_stack_entry.input, inspect); + let reeval_entry = self.stack.pop(); + return (reeval_entry, result); + } + + let truncate_stack = |stack: &mut Stack, provisional_cache: &mut _, depth| { + while stack.next_index() > depth { + let reeval_entry = stack.pop(); + // TODO: How can we tell whether this entry was the final revision. + // + // We should be able to rebase provisional entries in most cases. + Self::clear_dependent_provisional_results(stack, provisional_cache, |_, _| ()); + Self::update_parent_goal( + stack, + reeval_entry.step_kind_from_parent, + reeval_entry.required_depth, + &reeval_entry.heads, + reeval_entry.encountered_overflow, + UpdateParentGoalCtxt::Ordinary(&reeval_entry.nested_goals), + ); + } + }; + + let cycles = self.tree.rerun_get_and_reset_cycles(prev_stack_entry.node_id); + let current_stack_len = self.stack.len(); + let mut has_changed = HashSet::default(); + 'outer: for cycle in cycles { + let &tree::Cycle { node_id: cycle_node_id, ref provisional_results } = + self.tree.get_cycle(cycle); + + match self.tree.node_kind_raw(cycle_node_id) { + &tree::NodeKind::InProgress { .. } | &tree::NodeKind::Finished { .. } => { + unreachable!() + } + &tree::NodeKind::CycleOnStack { entry_node_id, result: _ } => { + if entry_node_id != node_id { + continue; + } + } + &tree::NodeKind::ProvisionalCacheHit { entry_node_id } => { + // We evaluated the provisional cache entry before evaluating this goal. It + // cannot depend on the current goal. + if entry_node_id < node_id { + continue; + } + // This provisional cache entry was computed with the current goal on the + // stack. Check whether it depends on it. + if !self.tree.get_heads(entry_node_id).contains_stack_entry(current_depth) { + continue; + } + + // We've evaluated the `entry_node_id` before evaluating this goal. In case + // that node and its parents has not changed, we can reinsert the cache entry + // before starting to reevaluate it. + if !self.tree.goal_or_parent_has_changed(node_id, &has_changed, entry_node_id) { + continue; + } + } + }; + + // We then build the stack at the point of reaching this cycle. + let mut rev_stack = + self.tree.compute_rev_stack(cycle_node_id, prev_stack_entry.node_id); + let span = tracing::debug_span!("reevaluate cycle", ?rev_stack, ?provisional_result); + let _span = span.enter(); + let mut current_goal = rev_stack.remove(0); + let mut added_goals = rev_stack.into_iter().rev().peekable(); + // We only pop from the stack when checking whether a result has changed. + // If a later cycle does not have to truncate the stack, we've already reevaluated + // a parent so there's no need to consider that goal. + for idx in current_stack_len.. { + let stack_depth = StackDepth::from_usize(idx); + match (added_goals.peek(), self.stack.get(stack_depth)) { + (Some(&(node_id, info)), Some(existing_entry)) => { + let provisional_result = provisional_results.get(&stack_depth).copied(); + if existing_entry.node_id == node_id + && provisional_result == existing_entry.provisional_result + { + debug_assert_eq!(existing_entry.input, info.input); + debug_assert_eq!( + existing_entry.step_kind_from_parent, + info.step_kind_from_parent + ); + let _ = added_goals.next().unwrap(); + } else { + truncate_stack( + &mut self.stack, + &mut self.provisional_cache, + stack_depth, + ); + break; + } + } + (Some(&(node_id, info)), None) => { + if current_goal.0 == node_id { + debug!(parent = ?info.input, cycle = ?added_goals.last().unwrap(), "reevaluated parent, skip cycle"); + continue 'outer; + } else { + break; + } + } + (None, Some(_)) => { + truncate_stack(&mut self.stack, &mut self.provisional_cache, stack_depth); + break; + } + (None, None) => break, + } + } + + for (node_id, info) in added_goals { + let tree::GoalInfo { input, step_kind_from_parent, available_depth } = info; + let stack_depth = self.stack.next_index(); + let provisional_result = provisional_results.get(&stack_depth).copied(); + self.stack.push(StackEntry { + node_id, + input, + step_kind_from_parent, + available_depth, + provisional_result, + required_depth: 0, + heads: Default::default(), + encountered_overflow: false, + has_been_used: None, + nested_goals: Default::default(), + }); + } + + /* + while let Some((&entry_node_id, _)) = removed_entries.first_key_value() { + if entry_node_id < current_goal.0 + && self.stack.iter().all(|e| e.node_id != entry_node_id) + { + let (entry_node_id, (input, entry)) = removed_entries.pop_first().unwrap(); + if !self.tree.goal_or_parent_has_changed(node_id, &has_changed, entry_node_id) { + self.provisional_cache.entry(input).or_default().push(entry); + } + } + }*/ + + loop { + let span = tracing::debug_span!( + "reevaluate_canonical_goal", + input = ?current_goal.1.input, + step_kind_from_parent = ?current_goal.1.step_kind_from_parent + ); + let _span = span.enter(); + let (node_id, result) = self.evaluate_goal( + cx, + current_goal.1.input, + current_goal.1.step_kind_from_parent, + inspect, + ); + if node_id.is_some_and(|node_id| self.tree.result_matches(current_goal.0, node_id)) + { + // TODO: This seems wrong. If a later loop reevaluates this goal again, we'd use + // its updated `NodeId`. + removed_entries.remove(¤t_goal.0); + debug!(input = ?current_goal.1.input, ?result, "goal did not change"); + continue 'outer; + } else { + has_changed.insert(current_goal.0); + debug!(input = ?current_goal.1.input, ?result, "goal did change"); + if self.stack.len() > current_stack_len { + let parent = self.stack.pop(); + Self::clear_dependent_provisional_results( + &self.stack, + &mut self.provisional_cache, + |_, _| (), + ); + Self::update_parent_goal( + &mut self.stack, + parent.step_kind_from_parent, + parent.required_depth, + &parent.heads, + parent.encountered_overflow, + UpdateParentGoalCtxt::Ordinary(&parent.nested_goals), + ); + current_goal = ( + parent.node_id, + tree::GoalInfo { + input: parent.input, + step_kind_from_parent: parent.step_kind_from_parent, + available_depth: parent.available_depth, + }, + ); + } else { + break; + } + } + } + + debug!("reevaluating goal itself"); + debug_assert_eq!(self.stack.len(), current_stack_len); + debug_assert_eq!(self.stack.last().unwrap().input, prev_stack_entry.input); + let result = D::compute_goal(self, cx, prev_stack_entry.input, inspect); + let reeval_entry = self.stack.pop(); + return (reeval_entry, result); + } + + truncate_stack( + &mut self.stack, + &mut self.provisional_cache, + StackDepth::from_usize(current_stack_len), + ); + + for (entry_node_id, (input, entry)) in removed_entries { + if !self.tree.goal_or_parent_has_changed(node_id, &has_changed, entry_node_id) { + self.provisional_cache.entry(input).or_default().push(entry); + } + } + + debug_assert_eq!(self.stack.len(), current_stack_len); + let reeval_entry = self.stack.pop(); + (reeval_entry, provisional_result) + } + /// When encountering a cycle, both inductive and coinductive, we only /// move the root into the global cache. We also store all other cycle /// participants involved. diff --git a/compiler/rustc_type_ir/src/search_graph/stack.rs b/compiler/rustc_type_ir/src/search_graph/stack.rs index a58cd82b02303..489b8dba8442d 100644 --- a/compiler/rustc_type_ir/src/search_graph/stack.rs +++ b/compiler/rustc_type_ir/src/search_graph/stack.rs @@ -3,7 +3,7 @@ use std::ops::{Index, IndexMut}; use derive_where::derive_where; use rustc_index::IndexVec; -use crate::search_graph::{AvailableDepth, Cx, CycleHeads, NestedGoals, PathKind, UsageKind}; +use crate::search_graph::{AvailableDepth, Cx, CycleHeads, NestedGoals, PathKind, UsageKind, tree}; rustc_index::newtype_index! { #[orderable] @@ -15,6 +15,8 @@ rustc_index::newtype_index! { /// when popping a child goal or completely immutable. #[derive_where(Debug; X: Cx)] pub(super) struct StackEntry { + pub node_id: tree::NodeId, + pub input: X::Input, /// Whether proving this goal is a coinductive step. @@ -48,7 +50,7 @@ pub(super) struct StackEntry { pub nested_goals: NestedGoals, } -#[derive_where(Default; X: Cx)] +#[derive_where(Debug, Default; X: Cx)] pub(super) struct Stack { entries: IndexVec>, } @@ -85,6 +87,10 @@ impl Stack { self.entries.push(entry) } + pub(super) fn get(&self, depth: StackDepth) -> Option<&StackEntry> { + self.entries.get(depth) + } + pub(super) fn pop(&mut self) -> StackEntry { self.entries.pop().unwrap() } @@ -97,6 +103,10 @@ impl Stack { self.entries.iter() } + pub(super) fn iter_enumerated(&self) -> impl Iterator)> { + self.entries.iter_enumerated() + } + pub(super) fn find(&self, input: X::Input) -> Option { self.entries.iter_enumerated().find(|(_, e)| e.input == input).map(|(idx, _)| idx) } diff --git a/compiler/rustc_type_ir/src/search_graph/tree.rs b/compiler/rustc_type_ir/src/search_graph/tree.rs new file mode 100644 index 0000000000000..7bebd3c2a3647 --- /dev/null +++ b/compiler/rustc_type_ir/src/search_graph/tree.rs @@ -0,0 +1,213 @@ +use std::hash::Hash; +use std::ops::Range; + +use derive_where::derive_where; +use rustc_index::IndexVec; +use rustc_type_ir::data_structures::{HashMap, HashSet}; + +use crate::search_graph::{AvailableDepth, Cx, CycleHeads, PathKind, Stack, StackDepth}; + +#[derive_where(Debug, Clone, Copy; X: Cx)] +pub(super) struct GoalInfo { + pub input: X::Input, + pub step_kind_from_parent: PathKind, + pub available_depth: AvailableDepth, +} + +rustc_index::newtype_index! { + #[orderable] + #[gate_rustc_only] + pub struct NodeId {} // TODO: private +} + +rustc_index::newtype_index! { + #[orderable] + #[gate_rustc_only] + pub(super) struct CycleId {} +} + +#[derive_where(Debug; X: Cx)] +pub(super) enum NodeKind { + InProgress { cycles_start: CycleId }, + Finished { encountered_overflow: bool, heads: CycleHeads, result: X::Result }, + CycleOnStack { entry_node_id: NodeId, result: X::Result }, + ProvisionalCacheHit { entry_node_id: NodeId }, +} + +#[derive_where(Debug; X: Cx)] +struct Node { + info: GoalInfo, + parent: Option, + kind: NodeKind, +} + +#[derive_where(Debug; X: Cx)] +pub(super) struct Cycle { + pub node_id: NodeId, + pub provisional_results: HashMap, +} + +#[derive_where(Debug, Default; X: Cx)] +pub(super) struct SearchTree { + nodes: IndexVec>, + cycles: IndexVec>, +} + +impl SearchTree { + pub(super) fn create_node( + &mut self, + stack: &Stack, + input: X::Input, + step_kind_from_parent: PathKind, + available_depth: AvailableDepth, + ) -> NodeId { + let info = GoalInfo { input, step_kind_from_parent, available_depth }; + let parent = stack.last().map(|e| e.node_id); + self.nodes.push(Node { + info, + parent, + kind: NodeKind::InProgress { cycles_start: self.cycles.next_index() }, + }) + } + + pub(super) fn global_cache_hit(&mut self, node_id: NodeId) { + debug_assert_eq!(node_id, self.nodes.last_index().unwrap()); + debug_assert!(matches!(self.nodes[node_id].kind, NodeKind::InProgress { .. })); + self.nodes.pop(); + } + + pub(super) fn provisional_cache_hit( + &mut self, + node_id: NodeId, + entry_node_id: NodeId, + provisional_results: HashMap, + ) { + debug_assert_eq!(node_id, self.nodes.last_index().unwrap()); + debug_assert!(matches!(self.nodes[node_id].kind, NodeKind::InProgress { .. })); + self.cycles.push(Cycle { node_id, provisional_results }); + self.nodes[node_id].kind = NodeKind::ProvisionalCacheHit { entry_node_id }; + } + + pub(super) fn cycle_on_stack( + &mut self, + node_id: NodeId, + entry_node_id: NodeId, + result: X::Result, + provisional_results: HashMap, + ) { + debug_assert_eq!(node_id, self.nodes.last_index().unwrap()); + debug_assert!(matches!(self.nodes[node_id].kind, NodeKind::InProgress { .. })); + self.cycles.push(Cycle { node_id, provisional_results }); + self.nodes[node_id].kind = NodeKind::CycleOnStack { entry_node_id, result } + } + + pub(super) fn finish_evaluation( + &mut self, + node_id: NodeId, + encountered_overflow: bool, + heads: CycleHeads, + result: X::Result, + ) { + let NodeKind::InProgress { cycles_start: _ } = self.nodes[node_id].kind else { + panic!("unexpected node kind: {:?}", self.nodes[node_id]); + }; + self.nodes[node_id].kind = NodeKind::Finished { encountered_overflow, heads, result } + } + + pub(super) fn get_cycle(&self, cycle_id: CycleId) -> &Cycle { + &self.cycles[cycle_id] + } + + pub(super) fn node_kind_raw(&self, node_id: NodeId) -> &NodeKind { + &self.nodes[node_id].kind + } + + pub(super) fn result_matches(&self, prev: NodeId, new: NodeId) -> bool { + match (&self.nodes[prev].kind, &self.nodes[new].kind) { + ( + NodeKind::Finished { + encountered_overflow: prev_overflow, + heads: prev_heads, + result: prev_result, + }, + NodeKind::Finished { + encountered_overflow: new_overflow, + heads: new_heads, + result: new_result, + }, + ) => { + prev_result == new_result + && (*prev_overflow || !*new_overflow) + && prev_heads.contains(new_heads) + } + ( + NodeKind::CycleOnStack { entry_node_id: _, result: prev }, + NodeKind::CycleOnStack { entry_node_id: _, result: new }, + ) => prev == new, + (&NodeKind::ProvisionalCacheHit { entry_node_id }, _) => { + self.result_matches(entry_node_id, new) + } + (_, &NodeKind::ProvisionalCacheHit { entry_node_id }) => { + self.result_matches(prev, entry_node_id) + } + result_matches => { + tracing::debug!(?result_matches); + false + } + } + } + + pub(super) fn rerun_get_and_reset_cycles(&mut self, node_id: NodeId) -> Range { + if let NodeKind::InProgress { cycles_start, .. } = &mut self.nodes[node_id].kind { + let prev = *cycles_start; + *cycles_start = self.cycles.next_index(); + prev..self.cycles.next_index() + } else { + panic!("unexpected node kind: {:?}", self.nodes[node_id]); + } + } + + pub(super) fn get_heads(&self, node_id: NodeId) -> &CycleHeads { + if let NodeKind::Finished { heads, .. } = &self.nodes[node_id].kind { + heads + } else { + panic!("unexpected node kind: {:?}", self.nodes[node_id]); + } + } + + pub(super) fn goal_or_parent_has_changed( + &self, + cycle_head: NodeId, + has_changed: &HashSet, + mut node_id: NodeId, + ) -> bool { + loop { + if node_id == cycle_head { + return false; + } else if has_changed.contains(&node_id) { + return true; + } else { + node_id = self.nodes[node_id].parent.unwrap(); + } + } + } + + /// Compute the list of parents of `node_id` until encountering the node + /// `until`. We're excluding `until` and are including `node_id`. + pub(super) fn compute_rev_stack( + &self, + mut node_id: NodeId, + until: NodeId, + ) -> Vec<(NodeId, GoalInfo)> { + let mut rev_stack = Vec::new(); + loop { + if node_id == until { + return rev_stack; + } + + let node = &self.nodes[node_id]; + rev_stack.push((node_id, node.info)); + node_id = node.parent.unwrap(); + } + } +} diff --git a/tests/ui/traits/next-solver/cycles/inductive-cycle-but-err.stderr b/tests/ui/traits/next-solver/cycles/inductive-cycle-but-err.stderr index 7895a2636345a..cb22ff7d04a0c 100644 --- a/tests/ui/traits/next-solver/cycles/inductive-cycle-but-err.stderr +++ b/tests/ui/traits/next-solver/cycles/inductive-cycle-but-err.stderr @@ -17,6 +17,16 @@ LL | impls_trait::(); | ^^^^^^^^^^^^^^^^^^ the trait `Trait` is not implemented for `MultipleCandidates` | = help: the trait `Trait` is implemented for `MultipleCandidates` +note: required for `MultipleCandidates` to implement `Trait` + --> $DIR/inductive-cycle-but-err.rs:43:6 + | +LL | impl Trait for MultipleNested + | ^^^^^ ^^^^^^^^^^^^^^ +... +LL | MultipleCandidates: Trait, + | ----- unsatisfied trait bound introduced here + = note: 8 redundant requirements hidden + = note: required for `MultipleCandidates` to implement `Trait` note: required by a bound in `impls_trait` --> $DIR/inductive-cycle-but-err.rs:51:19 | From bcc6c3bcea2687bb8ada04732e8025a6c1181162 Mon Sep 17 00:00:00 2001 From: lcnr Date: Tue, 1 Jul 2025 11:10:29 +0200 Subject: [PATCH 3/5] impoopment --- compiler/rustc_type_ir/src/search_graph/mod.rs | 13 +++++-------- compiler/rustc_type_ir/src/search_graph/tree.rs | 6 +++--- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/compiler/rustc_type_ir/src/search_graph/mod.rs b/compiler/rustc_type_ir/src/search_graph/mod.rs index 87bc52f076e25..ed435de886d38 100644 --- a/compiler/rustc_type_ir/src/search_graph/mod.rs +++ b/compiler/rustc_type_ir/src/search_graph/mod.rs @@ -1300,7 +1300,7 @@ impl, X: Cx> SearchGraph { let cycles = self.tree.rerun_get_and_reset_cycles(prev_stack_entry.node_id); let current_stack_len = self.stack.len(); - let mut has_changed = HashSet::default(); + let mut was_reevaluated = HashSet::default(); 'outer: for cycle in cycles { let &tree::Cycle { node_id: cycle_node_id, ref provisional_results } = self.tree.get_cycle(cycle); @@ -1329,7 +1329,7 @@ impl, X: Cx> SearchGraph { // We've evaluated the `entry_node_id` before evaluating this goal. In case // that node and its parents has not changed, we can reinsert the cache entry // before starting to reevaluate it. - if !self.tree.goal_or_parent_has_changed(node_id, &has_changed, entry_node_id) { + if !self.tree.goal_or_parent_was_reevaluated(node_id, &was_reevaluated, entry_node_id) { continue; } } @@ -1369,7 +1369,7 @@ impl, X: Cx> SearchGraph { } } (Some(&(node_id, info)), None) => { - if current_goal.0 == node_id { + if was_reevaluated.contains(&node_id) { debug!(parent = ?info.input, cycle = ?added_goals.last().unwrap(), "reevaluated parent, skip cycle"); continue 'outer; } else { @@ -1427,15 +1427,12 @@ impl, X: Cx> SearchGraph { current_goal.1.step_kind_from_parent, inspect, ); + was_reevaluated.insert(current_goal.0); if node_id.is_some_and(|node_id| self.tree.result_matches(current_goal.0, node_id)) { - // TODO: This seems wrong. If a later loop reevaluates this goal again, we'd use - // its updated `NodeId`. - removed_entries.remove(¤t_goal.0); debug!(input = ?current_goal.1.input, ?result, "goal did not change"); continue 'outer; } else { - has_changed.insert(current_goal.0); debug!(input = ?current_goal.1.input, ?result, "goal did change"); if self.stack.len() > current_stack_len { let parent = self.stack.pop(); @@ -1481,7 +1478,7 @@ impl, X: Cx> SearchGraph { ); for (entry_node_id, (input, entry)) in removed_entries { - if !self.tree.goal_or_parent_has_changed(node_id, &has_changed, entry_node_id) { + if !self.tree.goal_or_parent_was_reevaluated(node_id, &was_reevaluated, entry_node_id) { self.provisional_cache.entry(input).or_default().push(entry); } } diff --git a/compiler/rustc_type_ir/src/search_graph/tree.rs b/compiler/rustc_type_ir/src/search_graph/tree.rs index 7bebd3c2a3647..3ac4f15dff247 100644 --- a/compiler/rustc_type_ir/src/search_graph/tree.rs +++ b/compiler/rustc_type_ir/src/search_graph/tree.rs @@ -175,16 +175,16 @@ impl SearchTree { } } - pub(super) fn goal_or_parent_has_changed( + pub(super) fn goal_or_parent_was_reevaluated( &self, cycle_head: NodeId, - has_changed: &HashSet, + was_reevaluated: &HashSet, mut node_id: NodeId, ) -> bool { loop { if node_id == cycle_head { return false; - } else if has_changed.contains(&node_id) { + } else if was_reevaluated.contains(&node_id) { return true; } else { node_id = self.nodes[node_id].parent.unwrap(); From d372c6cd86cd0db9d58aaf646f3b0a6b311a3c5e Mon Sep 17 00:00:00 2001 From: lcnr Date: Tue, 1 Jul 2025 12:44:11 +0200 Subject: [PATCH 4/5] wf --- compiler/rustc_type_ir/src/search_graph/mod.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_type_ir/src/search_graph/mod.rs b/compiler/rustc_type_ir/src/search_graph/mod.rs index ed435de886d38..890fb223d9617 100644 --- a/compiler/rustc_type_ir/src/search_graph/mod.rs +++ b/compiler/rustc_type_ir/src/search_graph/mod.rs @@ -1329,7 +1329,11 @@ impl, X: Cx> SearchGraph { // We've evaluated the `entry_node_id` before evaluating this goal. In case // that node and its parents has not changed, we can reinsert the cache entry // before starting to reevaluate it. - if !self.tree.goal_or_parent_was_reevaluated(node_id, &was_reevaluated, entry_node_id) { + if !self.tree.goal_or_parent_was_reevaluated( + node_id, + &was_reevaluated, + entry_node_id, + ) { continue; } } From cde64b31657b6a5da4287c867bc42780bebbbe9a Mon Sep 17 00:00:00 2001 From: lcnr Date: Wed, 2 Jul 2025 14:44:14 +0200 Subject: [PATCH 5/5] overwhelm --- .../rustc_type_ir/src/search_graph/mod.rs | 145 ++++++++++++------ .../rustc_type_ir/src/search_graph/tree.rs | 87 +++++++++-- 2 files changed, 177 insertions(+), 55 deletions(-) diff --git a/compiler/rustc_type_ir/src/search_graph/mod.rs b/compiler/rustc_type_ir/src/search_graph/mod.rs index 890fb223d9617..fee20c7892e2b 100644 --- a/compiler/rustc_type_ir/src/search_graph/mod.rs +++ b/compiler/rustc_type_ir/src/search_graph/mod.rs @@ -785,14 +785,11 @@ impl, X: Cx> SearchGraph { fn clear_dependent_provisional_results( stack: &Stack, provisional_cache: &mut HashMap>>, - mut handle_removed_entry: impl FnMut(X::Input, ProvisionalCacheEntry), ) { let head = stack.next_index(); #[allow(rustc::potential_query_instability)] - provisional_cache.retain(|&input, entries| { - for e in entries.extract_if(.., |entry| entry.heads.highest_cycle_head() == head) { - handle_removed_entry(input, e) - } + provisional_cache.retain(|_, entries| { + entries.retain(|entry| entry.heads.highest_cycle_head() != head); !entries.is_empty() }); } @@ -1171,6 +1168,7 @@ impl, X: Cx> SearchGraph { &stack_entry, |_, result| result, ); + self.tree.set_rebase_kind(stack_entry.node_id, tree::RebaseEntriesKind::Normal); return EvaluationResult::finalize(stack_entry, encountered_overflow, result); } @@ -1195,6 +1193,8 @@ impl, X: Cx> SearchGraph { &stack_entry, |input, _| D::propagate_ambiguity(cx, input, result), ); + + self.tree.set_rebase_kind(stack_entry.node_id, tree::RebaseEntriesKind::Ambiguity); return EvaluationResult::finalize(stack_entry, encountered_overflow, result); }; @@ -1210,6 +1210,7 @@ impl, X: Cx> SearchGraph { &stack_entry, |input, _| D::on_fixpoint_overflow(cx, input), ); + self.tree.set_rebase_kind(stack_entry.node_id, tree::RebaseEntriesKind::Overflow); return EvaluationResult::finalize(stack_entry, encountered_overflow, result); } @@ -1223,6 +1224,7 @@ impl, X: Cx> SearchGraph { &new_stack_entry, |_, result| result, ); + self.tree.set_rebase_kind(new_stack_entry.node_id, tree::RebaseEntriesKind::Normal); return EvaluationResult::finalize( new_stack_entry, encountered_overflow, @@ -1247,20 +1249,9 @@ impl, X: Cx> SearchGraph { ) -> (StackEntry, X::Result) { let node_id = prev_stack_entry.node_id; let current_depth = self.stack.next_index(); - - let mut removed_entries = BTreeMap::new(); // Clear all provisional cache entries which depend on a previous provisional // result of this goal and rerun. - Self::clear_dependent_provisional_results( - &self.stack, - &mut self.provisional_cache, - |input, entry| { - let prev = removed_entries.insert(entry.entry_node_id, (input, entry)); - if let Some(prev) = prev { - unreachable!("duplicate entries for the same `NodeId`: {prev:?}"); - } - }, - ); + Self::clear_dependent_provisional_results(&self.stack, &mut self.provisional_cache); self.stack.push(StackEntry { node_id, input: prev_stack_entry.input, @@ -1280,13 +1271,58 @@ impl, X: Cx> SearchGraph { return (reeval_entry, result); } - let truncate_stack = |stack: &mut Stack, provisional_cache: &mut _, depth| { + let truncate_unchanged_stack = |stack: &mut Stack, + provisional_cache: &mut _, + tree: &SearchTree, + depth| { while stack.next_index() > depth { - let reeval_entry = stack.pop(); - // TODO: How can we tell whether this entry was the final revision. - // - // We should be able to rebase provisional entries in most cases. - Self::clear_dependent_provisional_results(stack, provisional_cache, |_, _| ()); + let mut reeval_entry = stack.pop(); + let &tree::NodeKind::Finished { + encountered_overflow, + ref heads, + last_iteration_provisional_result, + rebase_entries_kind, + result, + } = tree.node_kind_raw(reeval_entry.node_id) + else { + unreachable!(); + }; + if last_iteration_provisional_result == reeval_entry.provisional_result { + reeval_entry.heads = heads.clone(); + match rebase_entries_kind { + Some(tree::RebaseEntriesKind::Normal) => { + debug!(?reeval_entry.input, "rebase entries while truncating stack"); + Self::rebase_provisional_cache_entries( + stack, + provisional_cache, + &reeval_entry, + |_, result| result, + ) + } + Some(tree::RebaseEntriesKind::Ambiguity) => { + Self::rebase_provisional_cache_entries( + stack, + provisional_cache, + &reeval_entry, + |input, result| D::propagate_ambiguity(cx, input, result), + ) + } + Some(tree::RebaseEntriesKind::Overflow) => { + Self::rebase_provisional_cache_entries( + stack, + provisional_cache, + &reeval_entry, + |input, _| D::on_fixpoint_overflow(cx, input), + ) + } + None | _ => { + Self::clear_dependent_provisional_results(stack, provisional_cache) + } + } + } else { + Self::clear_dependent_provisional_results(stack, provisional_cache); + } + Self::update_parent_goal( stack, reeval_entry.step_kind_from_parent, @@ -1295,10 +1331,39 @@ impl, X: Cx> SearchGraph { reeval_entry.encountered_overflow, UpdateParentGoalCtxt::Ordinary(&reeval_entry.nested_goals), ); + let entry = provisional_cache.entry(reeval_entry.input).or_default(); + + for (head, path_to_nested) in heads.iter() { + let path_from_head = + Self::cycle_path_kind(&stack, reeval_entry.step_kind_from_parent, head); + for path_kind in path_to_nested.extend_with(path_from_head).iter_paths() { + let usage_kind = UsageKind::Single(path_kind); + stack[head].has_been_used = Some( + stack[head] + .has_been_used + .map_or(usage_kind, |prev| prev.merge(usage_kind)), + ); + } + } + let path_from_head = Self::cycle_path_kind( + &stack, + reeval_entry.step_kind_from_parent, + heads.highest_cycle_head(), + ); + let provisional_cache_entry = ProvisionalCacheEntry { + entry_node_id: reeval_entry.node_id, + encountered_overflow, + heads: heads.clone(), + path_from_head, + result, + }; + debug!(?provisional_cache_entry); + entry.push(provisional_cache_entry); } }; - let cycles = self.tree.rerun_get_and_reset_cycles(prev_stack_entry.node_id); + let cycles = + self.tree.rerun_get_and_reset_cycles(prev_stack_entry.node_id, provisional_result); let current_stack_len = self.stack.len(); let mut was_reevaluated = HashSet::default(); 'outer: for cycle in cycles { @@ -1364,9 +1429,10 @@ impl, X: Cx> SearchGraph { ); let _ = added_goals.next().unwrap(); } else { - truncate_stack( + truncate_unchanged_stack( &mut self.stack, &mut self.provisional_cache, + &self.tree, stack_depth, ); break; @@ -1381,7 +1447,12 @@ impl, X: Cx> SearchGraph { } } (None, Some(_)) => { - truncate_stack(&mut self.stack, &mut self.provisional_cache, stack_depth); + truncate_unchanged_stack( + &mut self.stack, + &mut self.provisional_cache, + &self.tree, + stack_depth, + ); break; } (None, None) => break, @@ -1406,18 +1477,6 @@ impl, X: Cx> SearchGraph { }); } - /* - while let Some((&entry_node_id, _)) = removed_entries.first_key_value() { - if entry_node_id < current_goal.0 - && self.stack.iter().all(|e| e.node_id != entry_node_id) - { - let (entry_node_id, (input, entry)) = removed_entries.pop_first().unwrap(); - if !self.tree.goal_or_parent_has_changed(node_id, &has_changed, entry_node_id) { - self.provisional_cache.entry(input).or_default().push(entry); - } - } - }*/ - loop { let span = tracing::debug_span!( "reevaluate_canonical_goal", @@ -1443,7 +1502,6 @@ impl, X: Cx> SearchGraph { Self::clear_dependent_provisional_results( &self.stack, &mut self.provisional_cache, - |_, _| (), ); Self::update_parent_goal( &mut self.stack, @@ -1475,18 +1533,13 @@ impl, X: Cx> SearchGraph { return (reeval_entry, result); } - truncate_stack( + truncate_unchanged_stack( &mut self.stack, &mut self.provisional_cache, + &self.tree, StackDepth::from_usize(current_stack_len), ); - for (entry_node_id, (input, entry)) in removed_entries { - if !self.tree.goal_or_parent_was_reevaluated(node_id, &was_reevaluated, entry_node_id) { - self.provisional_cache.entry(input).or_default().push(entry); - } - } - debug_assert_eq!(self.stack.len(), current_stack_len); let reeval_entry = self.stack.pop(); (reeval_entry, provisional_result) diff --git a/compiler/rustc_type_ir/src/search_graph/tree.rs b/compiler/rustc_type_ir/src/search_graph/tree.rs index 3ac4f15dff247..c46c8c3e3011a 100644 --- a/compiler/rustc_type_ir/src/search_graph/tree.rs +++ b/compiler/rustc_type_ir/src/search_graph/tree.rs @@ -26,12 +26,34 @@ rustc_index::newtype_index! { pub(super) struct CycleId {} } +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub(super) enum RebaseEntriesKind { + Normal, + Ambiguity, + Overflow, +} + #[derive_where(Debug; X: Cx)] pub(super) enum NodeKind { - InProgress { cycles_start: CycleId }, - Finished { encountered_overflow: bool, heads: CycleHeads, result: X::Result }, - CycleOnStack { entry_node_id: NodeId, result: X::Result }, - ProvisionalCacheHit { entry_node_id: NodeId }, + InProgress { + cycles_start: CycleId, + last_iteration_provisional_result: Option, + rebase_entries_kind: Option, + }, + Finished { + encountered_overflow: bool, + heads: CycleHeads, + last_iteration_provisional_result: Option, + rebase_entries_kind: Option, + result: X::Result, + }, + CycleOnStack { + entry_node_id: NodeId, + result: X::Result, + }, + ProvisionalCacheHit { + entry_node_id: NodeId, + }, } #[derive_where(Debug; X: Cx)] @@ -66,7 +88,11 @@ impl SearchTree { self.nodes.push(Node { info, parent, - kind: NodeKind::InProgress { cycles_start: self.cycles.next_index() }, + kind: NodeKind::InProgress { + cycles_start: self.cycles.next_index(), + last_iteration_provisional_result: None, + rebase_entries_kind: None, + }, }) } @@ -108,10 +134,21 @@ impl SearchTree { heads: CycleHeads, result: X::Result, ) { - let NodeKind::InProgress { cycles_start: _ } = self.nodes[node_id].kind else { + let NodeKind::InProgress { + cycles_start: _, + last_iteration_provisional_result, + rebase_entries_kind, + } = self.nodes[node_id].kind + else { panic!("unexpected node kind: {:?}", self.nodes[node_id]); }; - self.nodes[node_id].kind = NodeKind::Finished { encountered_overflow, heads, result } + self.nodes[node_id].kind = NodeKind::Finished { + encountered_overflow, + heads, + result, + last_iteration_provisional_result, + rebase_entries_kind, + } } pub(super) fn get_cycle(&self, cycle_id: CycleId) -> &Cycle { @@ -129,15 +166,22 @@ impl SearchTree { encountered_overflow: prev_overflow, heads: prev_heads, result: prev_result, + rebase_entries_kind: prev_rebase_entries_kind, + last_iteration_provisional_result: prev_last_iteration_provisional_result, }, NodeKind::Finished { encountered_overflow: new_overflow, heads: new_heads, result: new_result, + rebase_entries_kind: new_rebase_entries_kind, + last_iteration_provisional_result: new_last_iteration_provisional_result, }, ) => { prev_result == new_result && (*prev_overflow || !*new_overflow) + && prev_rebase_entries_kind == new_rebase_entries_kind + && prev_last_iteration_provisional_result + == new_last_iteration_provisional_result && prev_heads.contains(new_heads) } ( @@ -157,10 +201,35 @@ impl SearchTree { } } - pub(super) fn rerun_get_and_reset_cycles(&mut self, node_id: NodeId) -> Range { - if let NodeKind::InProgress { cycles_start, .. } = &mut self.nodes[node_id].kind { + pub(super) fn set_rebase_kind(&mut self, node_id: NodeId, rebase_kind: RebaseEntriesKind) { + if let NodeKind::InProgress { + cycles_start: _, + last_iteration_provisional_result: _, + rebase_entries_kind, + } = &mut self.nodes[node_id].kind + { + let prev = rebase_entries_kind.replace(rebase_kind); + debug_assert!(prev.is_none()); + } else { + panic!("unexpected node kind: {:?}", self.nodes[node_id]); + } + } + + pub(super) fn rerun_get_and_reset_cycles( + &mut self, + node_id: NodeId, + provisional_result: X::Result, + ) -> Range { + if let NodeKind::InProgress { + cycles_start, + last_iteration_provisional_result, + rebase_entries_kind, + } = &mut self.nodes[node_id].kind + { + debug_assert!(rebase_entries_kind.is_none()); let prev = *cycles_start; *cycles_start = self.cycles.next_index(); + *last_iteration_provisional_result = Some(provisional_result); prev..self.cycles.next_index() } else { panic!("unexpected node kind: {:?}", self.nodes[node_id]);