From c41af09bc113c106b060cc6c5bfbdddf2bb16499 Mon Sep 17 00:00:00 2001 From: Grant Wuerker Date: Sat, 25 Nov 2023 22:25:37 +0300 Subject: [PATCH] ADT recursion --- Cargo.lock | 2 + crates/common2/Cargo.toml | 2 + crates/common2/src/lib.rs | 1 + crates/common2/src/recursive_def.rs | 132 ++++++++++++++++++ crates/hir-analysis/src/lib.rs | 1 + crates/hir-analysis/src/ty/def_analysis.rs | 29 ++-- crates/hir-analysis/src/ty/diagnostics.rs | 57 ++++---- crates/hir-analysis/src/ty/mod.rs | 40 +++++- .../fixtures/ty/def/recursive_type.snap | 22 ++- 9 files changed, 228 insertions(+), 58 deletions(-) create mode 100644 crates/common2/src/recursive_def.rs diff --git a/Cargo.lock b/Cargo.lock index 4c1f38447b..0167b39678 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -964,7 +964,9 @@ name = "fe-common2" version = "0.23.0" dependencies = [ "camino", + "ena", "fe-parser2", + "rustc-hash", "salsa-2022", "semver 1.0.17", "smol_str", diff --git a/crates/common2/Cargo.toml b/crates/common2/Cargo.toml index 7dd6492f04..fd12a55c74 100644 --- a/crates/common2/Cargo.toml +++ b/crates/common2/Cargo.toml @@ -15,3 +15,5 @@ camino = "1.1.4" smol_str = "0.1.24" salsa = { git = "https://github.com/salsa-rs/salsa", package = "salsa-2022" } parser = { path = "../parser2", package = "fe-parser2" } +rustc-hash = "1.1.0" +ena = "0.14" diff --git a/crates/common2/src/lib.rs b/crates/common2/src/lib.rs index c2694c6967..8a18870d07 100644 --- a/crates/common2/src/lib.rs +++ b/crates/common2/src/lib.rs @@ -1,5 +1,6 @@ pub mod diagnostics; pub mod input; +pub mod recursive_def; pub use input::{InputFile, InputIngot}; diff --git a/crates/common2/src/recursive_def.rs b/crates/common2/src/recursive_def.rs new file mode 100644 index 0000000000..e559142706 --- /dev/null +++ b/crates/common2/src/recursive_def.rs @@ -0,0 +1,132 @@ +use std::{fmt::Debug, hash::Hash}; + +use ena::unify::{InPlaceUnificationTable, UnifyKey}; +use rustc_hash::FxHashMap; + +/// Represents a definition that contains a direct reference to itself. +/// +/// Recursive definitions are not valid and must be reported to the user. +/// It is preferable to group definitions together such that recursions +/// are reported in-whole rather than separately. `RecursiveDef` can be +/// used with `RecursiveDefHelper` to perform this grouping operation. +/// +/// The fields `from` and `to` are the relevant identifiers and `site` can +/// be used to carry diagnostic information. +#[derive(Eq, PartialEq, Clone, Debug, Hash)] +pub struct RecursiveDef +where + T: PartialEq + Copy, +{ + pub from: T, + pub to: T, + pub site: U, +} + +impl RecursiveDef +where + T: PartialEq + Copy, +{ + pub fn new(from: T, to: T, site: U) -> Self { + Self { from, to, site } + } +} + +#[derive(PartialEq, Debug, Clone, Copy)] +struct RecursiveDefKey(u32); + +impl UnifyKey for RecursiveDefKey { + type Value = (); + + fn index(&self) -> u32 { + self.0 + } + + fn from_index(idx: u32) -> Self { + Self(idx) + } + + fn tag() -> &'static str { + "RecursiveDefKey" + } +} + +pub struct RecursiveDefHelper +where + T: Eq + Clone + Debug + Copy, +{ + defs: Vec>, + table: InPlaceUnificationTable, + keys: FxHashMap, +} + +impl RecursiveDefHelper +where + T: Eq + Clone + Debug + Copy + Hash, +{ + pub fn new(defs: Vec>) -> Self { + let mut table = InPlaceUnificationTable::new(); + let keys: FxHashMap<_, _> = defs + .iter() + .map(|def| (def.from, table.new_key(()))) + .collect(); + + for def in defs.iter() { + table.union(keys[&def.from], keys[&def.to]) + } + + Self { defs, table, keys } + } + + /// Removes a disjoint set of recursive definitions from the helper + /// and returns it, if one exists. + pub fn remove_disjoint_set(&mut self) -> Option>> { + let mut disjoint_set = vec![]; + let mut remaining_set = vec![]; + let mut union_key: Option<&RecursiveDefKey> = None; + + while let Some(def) = self.defs.pop() { + let cur_key = &self.keys[&def.from]; + + if union_key.is_none() || self.table.unioned(*union_key.unwrap(), *cur_key) { + disjoint_set.push(def) + } else { + remaining_set.push(def) + } + + if union_key.is_none() { + union_key = Some(cur_key) + } + } + + self.defs = remaining_set; + + if union_key.is_some() { + Some(disjoint_set) + } else { + None + } + } +} + +#[test] +fn one_recursion() { + let defs = vec![RecursiveDef::new(0, 1, ()), RecursiveDef::new(1, 0, ())]; + let mut helper = RecursiveDefHelper::new(defs); + assert!(helper.remove_disjoint_set().is_some()); + assert!(helper.remove_disjoint_set().is_none()); +} + +#[test] +fn two_recursions() { + let defs = vec![ + RecursiveDef::new(0, 1, ()), + RecursiveDef::new(1, 0, ()), + RecursiveDef::new(2, 3, ()), + RecursiveDef::new(3, 4, ()), + RecursiveDef::new(4, 2, ()), + ]; + let mut helper = RecursiveDefHelper::new(defs); + assert!(helper.remove_disjoint_set().is_some()); + assert!(helper.remove_disjoint_set().is_some()); + assert!(helper.remove_disjoint_set().is_none()); +} diff --git a/crates/hir-analysis/src/lib.rs b/crates/hir-analysis/src/lib.rs index 051071d3f5..0595aba417 100644 --- a/crates/hir-analysis/src/lib.rs +++ b/crates/hir-analysis/src/lib.rs @@ -72,6 +72,7 @@ pub struct Jar( ty::diagnostics::ImplTraitDefDiagAccumulator, ty::diagnostics::ImplDefDiagAccumulator, ty::diagnostics::FuncDefDiagAccumulator, + ty::diagnostics::RecursiveAdtDefAccumulator, ); pub trait HirAnalysisDb: salsa::DbWithJar + HirDb { diff --git a/crates/hir-analysis/src/ty/def_analysis.rs b/crates/hir-analysis/src/ty/def_analysis.rs index 4a074b42a9..dd539b1182 100644 --- a/crates/hir-analysis/src/ty/def_analysis.rs +++ b/crates/hir-analysis/src/ty/def_analysis.rs @@ -21,7 +21,10 @@ use super::{ collect_impl_block_constraints, collect_super_traits, AssumptionListId, SuperTraitCycle, }, constraint_solver::{is_goal_satisfiable, GoalSatisfiability}, - diagnostics::{ImplDiag, TraitConstraintDiag, TraitLowerDiag, TyDiagCollection, TyLowerDiag}, + diagnostics::{ + ImplDiag, RecursiveAdtDef, TraitConstraintDiag, TraitLowerDiag, TyDiagCollection, + TyLowerDiag, + }, trait_def::{ingot_trait_env, Implementor, TraitDef, TraitMethod}, trait_lower::{lower_trait, lower_trait_ref, TraitRefLowerError}, ty_def::{AdtDef, AdtRef, AdtRefId, FuncDef, InvalidCause, TyData, TyId}, @@ -33,7 +36,8 @@ use crate::{ ty::{ diagnostics::{ AdtDefDiagAccumulator, FuncDefDiagAccumulator, ImplDefDiagAccumulator, - ImplTraitDefDiagAccumulator, TraitDefDiagAccumulator, TypeAliasDefDiagAccumulator, + ImplTraitDefDiagAccumulator, RecursiveAdtDefAccumulator, TraitDefDiagAccumulator, + TypeAliasDefDiagAccumulator, }, method_table::collect_methods, trait_lower::lower_impl_trait, @@ -62,8 +66,8 @@ pub fn analyze_adt(db: &dyn HirAnalysisDb, adt_ref: AdtRefId) { AdtDefDiagAccumulator::push(db, diag); } - if let Some(diag) = check_recursive_adt(db, adt_ref) { - AdtDefDiagAccumulator::push(db, diag); + if let Some(def) = check_recursive_adt(db, adt_ref) { + RecursiveAdtDefAccumulator::push(db, def); } } @@ -764,7 +768,7 @@ impl<'db> Visitor for DefAnalyzer<'db> { pub(crate) fn check_recursive_adt( db: &dyn HirAnalysisDb, adt: AdtRefId, -) -> Option { +) -> Option { let adt_def = lower_adt(db, adt); for field in adt_def.fields(db) { for ty in field.iter_types(db) { @@ -781,7 +785,7 @@ fn check_recursive_adt_impl( db: &dyn HirAnalysisDb, cycle: &salsa::Cycle, adt: AdtRefId, -) -> Option { +) -> Option { let participants: FxHashSet<_> = cycle .participant_keys() .map(|key| check_recursive_adt::key_from_id(key.key_index())) @@ -792,11 +796,14 @@ fn check_recursive_adt_impl( for (ty_idx, ty) in field.iter_types(db).enumerate() { for field_adt_ref in ty.collect_direct_adts(db) { if participants.contains(&field_adt_ref) && participants.contains(&adt) { - let diag = TyLowerDiag::recursive_type( - adt.name_span(db), - adt_def.variant_ty_span(db, field_idx, ty_idx), - ); - return Some(diag.into()); + return Some(RecursiveAdtDef::new( + adt, + field_adt_ref, + ( + adt.name_span(db), + adt_def.variant_ty_span(db, field_idx, ty_idx), + ), + )); } } } diff --git a/crates/hir-analysis/src/ty/diagnostics.rs b/crates/hir-analysis/src/ty/diagnostics.rs index 100a6328ef..1aefa66dfa 100644 --- a/crates/hir-analysis/src/ty/diagnostics.rs +++ b/crates/hir-analysis/src/ty/diagnostics.rs @@ -1,7 +1,10 @@ use std::collections::BTreeSet; -use common::diagnostics::{ - CompleteDiagnostic, DiagnosticPass, GlobalErrorCode, LabelStyle, Severity, SubDiagnostic, +use common::{ + diagnostics::{ + CompleteDiagnostic, DiagnosticPass, GlobalErrorCode, LabelStyle, Severity, SubDiagnostic, + }, + recursive_def::RecursiveDef, }; use hir::{ diagnostics::DiagnosticVoucher, @@ -11,11 +14,12 @@ use hir::{ }; use itertools::Itertools; +use crate::HirAnalysisDb; + use super::{ constraint::PredicateId, - ty_def::{Kind, TyId}, + ty_def::{AdtRefId, Kind, TyId}, }; -use crate::HirAnalysisDb; #[salsa::accumulator] pub struct AdtDefDiagAccumulator(pub(super) TyDiagCollection); @@ -29,6 +33,10 @@ pub struct ImplDefDiagAccumulator(pub(super) TyDiagCollection); pub struct FuncDefDiagAccumulator(pub(super) TyDiagCollection); #[salsa::accumulator] pub struct TypeAliasDefDiagAccumulator(pub(super) TyDiagCollection); +#[salsa::accumulator] +pub struct RecursiveAdtDefAccumulator(pub(super) RecursiveAdtDef); + +pub type RecursiveAdtDef = RecursiveDef; #[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::From)] pub enum TyDiagCollection { @@ -53,10 +61,7 @@ impl TyDiagCollection { pub enum TyLowerDiag { ExpectedStarKind(DynLazySpan), InvalidTypeArgKind(DynLazySpan, String), - RecursiveType { - primary_span: DynLazySpan, - field_span: DynLazySpan, - }, + AdtRecursion(Vec), UnboundTypeAliasParam { span: DynLazySpan, @@ -140,11 +145,8 @@ impl TyLowerDiag { Self::InvalidTypeArgKind(span, msg) } - pub(super) fn recursive_type(primary_span: DynLazySpan, field_span: DynLazySpan) -> Self { - Self::RecursiveType { - primary_span, - field_span, - } + pub(super) fn adt_recursion(defs: Vec) -> Self { + Self::AdtRecursion(defs) } pub(super) fn unbound_type_alias_param( @@ -249,7 +251,7 @@ impl TyLowerDiag { match self { Self::ExpectedStarKind(_) => 0, Self::InvalidTypeArgKind(_, _) => 1, - Self::RecursiveType { .. } => 2, + Self::AdtRecursion { .. } => 2, Self::UnboundTypeAliasParam { .. } => 3, Self::TypeAliasCycle { .. } => 4, Self::InconsistentKindBound(_, _) => 5, @@ -270,7 +272,7 @@ impl TyLowerDiag { match self { Self::ExpectedStarKind(_) => "expected `*` kind in this context".to_string(), Self::InvalidTypeArgKind(_, _) => "invalid type argument kind".to_string(), - Self::RecursiveType { .. } => "recursive type is not allowed".to_string(), + Self::AdtRecursion { .. } => "recursive type is not allowed".to_string(), Self::UnboundTypeAliasParam { .. } => { "all type parameters of type alias must be given".to_string() @@ -326,22 +328,23 @@ impl TyLowerDiag { span.resolve(db), )], - Self::RecursiveType { - primary_span, - field_span, - } => { - vec![ - SubDiagnostic::new( + Self::AdtRecursion(defs) => { + let mut diags = vec![]; + + for RecursiveAdtDef { site, .. } in defs { + diags.push(SubDiagnostic::new( LabelStyle::Primary, "recursive type definition".to_string(), - primary_span.resolve(db), - ), - SubDiagnostic::new( + site.0.resolve(db), + )); + diags.push(SubDiagnostic::new( LabelStyle::Secondary, "recursion occurs here".to_string(), - field_span.resolve(db), - ), - ] + site.1.resolve(db), + )); + } + + diags } Self::UnboundTypeAliasParam { diff --git a/crates/hir-analysis/src/ty/mod.rs b/crates/hir-analysis/src/ty/mod.rs index 539779c224..d63cfc9ad2 100644 --- a/crates/hir-analysis/src/ty/mod.rs +++ b/crates/hir-analysis/src/ty/mod.rs @@ -1,4 +1,6 @@ +use common::recursive_def::RecursiveDefHelper; use hir::{analysis_pass::ModuleAnalysisPass, hir_def::TopLevelMod}; +use itertools::Itertools; use self::{ def_analysis::{ @@ -7,7 +9,8 @@ use self::{ }, diagnostics::{ AdtDefDiagAccumulator, FuncDefDiagAccumulator, ImplDefDiagAccumulator, - ImplTraitDefDiagAccumulator, TraitDefDiagAccumulator, TypeAliasDefDiagAccumulator, + ImplTraitDefDiagAccumulator, RecursiveAdtDef, RecursiveAdtDefAccumulator, + TraitDefDiagAccumulator, TyDiagCollection, TyLowerDiag, TypeAliasDefDiagAccumulator, }, ty_def::AdtRefId, }; @@ -61,15 +64,40 @@ impl<'db> ModuleAnalysisPass for TypeDefAnalysisPass<'db> { .iter() .map(|c| AdtRefId::from_contract(self.db, *c)), ); + let (diags, recursive_adt_defs): (Vec<_>, Vec<_>) = adts + .map(|adt| { + ( + analyze_adt::accumulated::(self.db, adt), + analyze_adt::accumulated::(self.db, adt), + ) + }) + .unzip(); + let recursive_adt_defs = recursive_adt_defs.into_iter().flatten().collect_vec(); - adts.flat_map(|adt| { - analyze_adt::accumulated::(self.db, adt).into_iter() - }) - .map(|diag| diag.to_voucher()) - .collect() + diags + .into_iter() + .flatten() + .map(|diag| diag.to_voucher()) + .chain( + adt_recursion_diags(recursive_adt_defs) + .iter() + .map(|diag| diag.to_voucher()), + ) + .collect() } } +fn adt_recursion_diags(defs: Vec) -> Vec { + let mut helper = RecursiveDefHelper::new(defs); + let mut diags = vec![]; + + while let Some(defs) = helper.remove_disjoint_set() { + diags.push(TyLowerDiag::adt_recursion(defs).into()); + } + + diags +} + /// An analysis pass for trait definitions. pub struct TraitAnalysisPass<'db> { db: &'db dyn HirAnalysisDb, diff --git a/crates/uitest/fixtures/ty/def/recursive_type.snap b/crates/uitest/fixtures/ty/def/recursive_type.snap index daf6f6bf12..33c93eb64f 100644 --- a/crates/uitest/fixtures/ty/def/recursive_type.snap +++ b/crates/uitest/fixtures/ty/def/recursive_type.snap @@ -1,7 +1,7 @@ --- source: crates/uitest/tests/ty.rs expression: diags -input_file: crates/uitest/fixtures/ty/recursive_type.fe +input_file: crates/uitest/fixtures/ty/def/recursive_type.fe --- error[3-0002]: recursive type is not allowed ┌─ recursive_type.fe:1:12 @@ -12,24 +12,18 @@ error[3-0002]: recursive type is not allowed │ -- recursion occurs here error[3-0002]: recursive type is not allowed - ┌─ recursive_type.fe:5:12 - │ -5 │ pub struct S2 { - │ ^^ recursive type definition -6 │ s: S3 - │ -- recursion occurs here - -error[3-0002]: recursive type is not allowed - ┌─ recursive_type.fe:9:12 + ┌─ recursive_type.fe:5:12 │ + 5 │ pub struct S2 { + │ ^^ recursive type definition + 6 │ s: S3 + │ -- recursion occurs here + · 9 │ pub struct S3 { │ ^^ recursive type definition 10 │ s: S4 │ -- recursion occurs here - -error[3-0002]: recursive type is not allowed - ┌─ recursive_type.fe:13:12 - │ + · 13 │ pub struct S4 { │ ^^ recursive type definition 14 │ s: S2