diff --git a/src/tag_analysis.rs b/src/tag_analysis.rs index 866cc54..1b60d5d 100644 --- a/src/tag_analysis.rs +++ b/src/tag_analysis.rs @@ -8,20 +8,21 @@ use compile_util::{make_suggestion, span_to_snippet}; use etrace::{ok_or, some_or}; use must_analysis::Obj; use rustc_abi::{FieldIdx, VariantIdx}; -use rustc_ast::{BindingAnnotation, Mutability}; +use rustc_ast::{BindingAnnotation, LitKind, Mutability}; use rustc_hir::{ def::Res, definitions::DefPathDataName, intravisit::{self, Visitor as HVisitor}, - ByRef, Expr, ExprKind, ItemKind, Node, PatKind, QPath, StmtKind, UnOp, VariantData, + ByRef, Expr, ExprKind, ItemKind, MatchSource, Node, Pat, PatKind, QPath, StmtKind, UnOp, + VariantData, }; use rustc_index::{bit_set::BitSet, IndexVec}; use rustc_middle::{ hir::nested_filter, mir::{ visit::{MutatingUseContext, PlaceContext, Visitor as MVisitor}, - AggregateKind, ConstantKind, HasLocalDecls, Local, LocalDecl, Location, Operand, Place, - PlaceElem, ProjectionElem, Rvalue, Terminator, TerminatorKind, + AggregateKind, BasicBlock, Body, ConstantKind, HasLocalDecls, Local, LocalDecl, Location, + Operand, Place, PlaceElem, ProjectionElem, Rvalue, Terminator, TerminatorKind, }, ty::{List, Ty, TyCtxt, TyKind, TypeAndMut, TypeckResults}, }; @@ -216,7 +217,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { let hbody = hir.body(body_id); let mut visitor = MBodyVisitor::new(tcx, &body.local_decls, &structs, &unions); visitor.visit_body(body); - let mut hvisitor = BitFieldInitVisitor::new(tcx); + let mut hvisitor = HBodyVisitor::new(tcx); hvisitor.visit_body(hbody); if !hvisitor.inits.is_empty() { for (local, location) in &visitor.inits { @@ -282,7 +283,21 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { gc: true, }; let states = must_analysis::analyze_body(body, ctx, tcx); + let ctx = AccessCtx::new(&states.states, body, tcx); for access in visitor.accesses { + if let Some(access_match) = + access_in_match(access, &hvisitor.matches, &visitor.switches, ctx) + { + println!( + "{:?} {:?}", + access_match.match_span, access_match.field_tags + ); + } else { + let accesses_if = access_in_if(access, &hvisitor.ifs, &visitor.ifs, ctx); + for access_if in accesses_if { + println!("{:?} {:?}", access_if.if_span, access_if.field_tags); + } + } if matches!( access.ctx, PlaceContext::MutatingUse(MutatingUseContext::Store) @@ -290,7 +305,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { continue; } let span = body.source_info(access.location).span; - let tags = compute_tags(access, &states.states, &body.local_decls, tcx); + let tags = ctx.compute_tags(access); for (f, ns) in tags { let vts = union_uses .entry(access.ty) @@ -760,7 +775,7 @@ impl {} {{ let hir_body = hir.body(body_id); let local_def_id = item_id.owner_id.def_id; let typeck = tcx.typeck(local_def_id); - let mut visitor = HBodyVisitor { + let mut visitor = SuggestingVisitor { tcx, typeck, func: local_def_id, @@ -1006,28 +1021,13 @@ fn find_paths( visited.remove(&curr); } -fn compute_tags<'tcx, D: HasLocalDecls<'tcx> + ?Sized>( - access: FieldAccess<'tcx>, - states: &HashMap, - local_decls: &D, - tcx: TyCtxt<'tcx>, -) -> Vec<(FieldIdx, AbsInt)> { - let state = some_or!(states.get(&access.location), return vec![]); - let (path, is_deref) = access.get_path(state, local_decls, tcx); - let g = state.g(); - let obj = some_or!(g.get_obj(&path, is_deref), return vec![]); - let Obj::Struct(fields, _) = obj else { return vec![] }; - let mut v: Vec<_> = fields - .iter() - .filter_map(|(f, obj)| { - let Obj::Ptr(loc) = obj else { return None }; - let obj = g.obj_at_location(loc)?; - let Obj::AtAddr(n) = obj else { return None }; - Some((*f, n.clone())) - }) - .collect(); - v.sort_by_key(|(f, _)| *f); - v +type ArmTags = Option>; + +struct MIf { + c: Span, + loc: Location, + t: BasicBlock, + f: BasicBlock, } struct MBodyVisitor<'tcx, 'a> { @@ -1039,6 +1039,8 @@ struct MBodyVisitor<'tcx, 'a> { struct_accesses: HashSet, aggregates: HashMap>, inits: Vec<(Local, Location)>, + switches: HashMap)>, + ifs: Vec, } impl<'tcx, 'a> MBodyVisitor<'tcx, 'a> { @@ -1057,6 +1059,8 @@ impl<'tcx, 'a> MBodyVisitor<'tcx, 'a> { struct_accesses: HashSet::new(), aggregates: HashMap::new(), inits: vec![], + switches: HashMap::new(), + ifs: vec![], } } } @@ -1124,20 +1128,50 @@ impl<'tcx> MVisitor<'tcx> for MBodyVisitor<'tcx, '_> { } fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { - if let TerminatorKind::Call { func, args, .. } = &terminator.kind { - if let Some(constant) = func.constant() { - let ConstantKind::Val(_, ty) = constant.literal else { unreachable!() }; - let TyKind::FnDef(def_id, _) = ty.kind() else { unreachable!() }; - if def_id.is_local() && self.tcx.impl_of_method(*def_id).is_some() { - let ty = args[0].ty(self.local_decls, self.tcx); - let TyKind::Ref(_, ty, _) = ty.kind() else { unreachable!() }; - let TyKind::Adt(adt_def, _) = ty.kind() else { unreachable!() }; - let local_def_id = adt_def.did().expect_local(); - if self.structs.contains_key(&local_def_id) { - self.struct_accesses.insert(args[0].place().unwrap().local); + match &terminator.kind { + TerminatorKind::Call { func, args, .. } => { + if let Some(constant) = func.constant() { + let ConstantKind::Val(_, ty) = constant.literal else { unreachable!() }; + let TyKind::FnDef(def_id, _) = ty.kind() else { unreachable!() }; + if def_id.is_local() && self.tcx.impl_of_method(*def_id).is_some() { + let ty = args[0].ty(self.local_decls, self.tcx); + let TyKind::Ref(_, ty, _) = ty.kind() else { unreachable!() }; + let TyKind::Adt(adt_def, _) = ty.kind() else { unreachable!() }; + let local_def_id = adt_def.did().expect_local(); + if self.structs.contains_key(&local_def_id) { + self.struct_accesses.insert(args[0].place().unwrap().local); + } + } + } + } + TerminatorKind::SwitchInt { discr, targets } => { + let span = terminator.source_info.span; + let ty = discr.ty(self.local_decls, self.tcx); + if ty.is_bool() { + let mut targets_iter = targets.iter(); + let (tag, bb) = targets_iter.next().unwrap(); + assert!(targets_iter.next().is_none()); + assert_eq!(tag, 0); + let mif = MIf { + c: span, + loc: location, + t: targets.otherwise(), + f: bb, + }; + self.ifs.push(mif); + } else { + let mut tags: HashMap<_, BTreeSet<_>> = HashMap::new(); + for (tag, bb) in targets.iter() { + tags.entry(bb).or_default().insert(tag); } + tags.remove(&targets.otherwise()); + let mut ts: HashMap<_, _> = + tags.into_iter().map(|(k, v)| (Some(v), k)).collect(); + ts.insert(None, targets.otherwise()); + self.switches.insert(span.hi(), (location, ts)); } } + _ => {} } self.super_terminator(terminator, location); } @@ -1206,7 +1240,7 @@ struct TaggedUnion { tag_index: FieldIdx, } -struct HBodyVisitor<'a, 'tcx> { +struct SuggestingVisitor<'a, 'tcx> { tcx: TyCtxt<'tcx>, typeck: &'a TypeckResults<'tcx>, func: LocalDefId, @@ -1217,7 +1251,7 @@ struct HBodyVisitor<'a, 'tcx> { suggestions: &'a mut Suggestions<'tcx>, } -impl<'tcx> HBodyVisitor<'_, 'tcx> { +impl<'tcx> SuggestingVisitor<'_, 'tcx> { fn handle_expr(&mut self, expr: &'tcx Expr<'tcx>) { let source_map = self.tcx.sess.source_map(); match expr.kind { @@ -1381,7 +1415,7 @@ impl<'tcx> HBodyVisitor<'_, 'tcx> { } } -impl<'tcx> HVisitor<'tcx> for HBodyVisitor<'_, 'tcx> { +impl<'tcx> HVisitor<'tcx> for SuggestingVisitor<'_, 'tcx> { type NestedFilter = nested_filter::OnlyBodies; fn nested_visit_map(&mut self) -> Self::Map { @@ -1394,48 +1428,77 @@ impl<'tcx> HVisitor<'tcx> for HBodyVisitor<'_, 'tcx> { } } -struct BitFieldInitVisitor<'tcx> { +struct HIf { + c: Span, + t: Span, + f: Option, +} + +struct HBodyVisitor<'tcx> { tcx: TyCtxt<'tcx>, inits: HashMap, set_exprs: HashSet, + matches: Vec<(Span, Vec<(Span, ArmTags)>)>, + ifs: Vec, } -impl<'tcx> BitFieldInitVisitor<'tcx> { +impl<'tcx> HBodyVisitor<'tcx> { fn new(tcx: TyCtxt<'tcx>) -> Self { Self { tcx, inits: HashMap::new(), set_exprs: HashSet::new(), + matches: vec![], + ifs: vec![], } } fn handle_expr(&mut self, expr: &'tcx Expr<'tcx>) { - let ExprKind::Block(block, _) = expr.kind else { return }; - if block.stmts.len() <= 1 { - return; - } - let StmtKind::Local(local) = block.stmts[0].kind else { return }; - let PatKind::Binding(_, hir_id, ident, _) = local.pat.kind else { return }; - if ident.name.to_ident_string() != "init" { - return; - } - let init = some_or!(local.init, return); - let ExprKind::Struct(_, _, _) = init.kind else { return }; - let e = some_or!(block.expr, return); - let ExprKind::Path(QPath::Resolved(_, path)) = e.kind else { return }; - let Res::Local(id) = path.res else { return }; - if hir_id != id { - return; - } - self.inits.insert(e.span, init.span); - for stmt in block.stmts.iter().skip(1) { - let StmtKind::Semi(e) = stmt.kind else { continue }; - self.set_exprs.insert(e.span); + match expr.kind { + ExprKind::Block(block, _) => { + if block.stmts.len() <= 1 { + return; + } + let StmtKind::Local(local) = block.stmts[0].kind else { return }; + let PatKind::Binding(_, hir_id, ident, _) = local.pat.kind else { return }; + if ident.name.to_ident_string() != "init" { + return; + } + let init = some_or!(local.init, return); + let ExprKind::Struct(_, _, _) = init.kind else { return }; + let e = some_or!(block.expr, return); + let ExprKind::Path(QPath::Resolved(_, path)) = e.kind else { return }; + let Res::Local(id) = path.res else { return }; + if hir_id != id { + return; + } + self.inits.insert(e.span, init.span); + for stmt in block.stmts.iter().skip(1) { + let StmtKind::Semi(e) = stmt.kind else { continue }; + self.set_exprs.insert(e.span); + } + } + ExprKind::Match(e, arms, MatchSource::Normal) => { + let arms = arms + .iter() + .map(|arm| (arm.span, pat_to_tags(arm.pat))) + .collect(); + self.matches.push((e.span, arms)); + } + ExprKind::If(c, t, f) => { + let hif = HIf { + c: c.span, + t: t.span, + f: f.map(|f| f.span), + }; + self.ifs.push(hif); + } + _ => {} } } } -impl<'tcx> HVisitor<'tcx> for BitFieldInitVisitor<'tcx> { +impl<'tcx> HVisitor<'tcx> for HBodyVisitor<'tcx> { type NestedFilter = nested_filter::OnlyBodies; fn nested_visit_map(&mut self) -> Self::Map { @@ -1555,3 +1618,246 @@ fn tag_to_string(tag: Tag, ty: &str) -> String { tag.to_string() } } + +fn expr_to_tag(expr: &Expr<'_>) -> u128 { + match expr.kind { + ExprKind::Lit(lit) => { + let LitKind::Int(n, _) = lit.node else { unreachable!() }; + n + } + ExprKind::Unary(UnOp::Neg, e) => 0u128.wrapping_sub(expr_to_tag(e)), + _ => unreachable!(), + } +} + +fn pat_to_tags(pat: &Pat<'_>) -> ArmTags { + match pat.kind { + PatKind::Lit(expr) => Some([expr_to_tag(expr)].into_iter().collect()), + PatKind::Or(pats) => { + let mut tags = BTreeSet::new(); + for pat in pats { + tags.extend(pat_to_tags(pat)?); + } + Some(tags) + } + PatKind::Wild => None, + _ => unreachable!("{:?}", pat), + } +} + +#[derive(Clone, Copy)] +struct AccessCtx<'a, 'tcx> { + states: &'a HashMap, + body: &'a Body<'tcx>, + tcx: TyCtxt<'tcx>, +} + +impl<'a, 'tcx> AccessCtx<'a, 'tcx> { + fn new(states: &'a HashMap, body: &'a Body<'tcx>, tcx: TyCtxt<'tcx>) -> Self { + Self { states, body, tcx } + } + + fn compute_tags(&self, access: FieldAccess<'tcx>) -> Vec<(FieldIdx, AbsInt)> { + let state = some_or!(self.states.get(&access.location), return vec![]); + let (path, is_deref) = access.get_path(state, &self.body.local_decls, self.tcx); + let g = state.g(); + let obj = some_or!(g.get_obj(&path, is_deref), return vec![]); + let Obj::Struct(fields, _) = obj else { return vec![] }; + let mut v: Vec<_> = fields + .iter() + .filter_map(|(f, obj)| { + let Obj::Ptr(loc) = obj else { return None }; + let obj = g.obj_at_location(loc)?; + let Obj::AtAddr(n) = obj else { return None }; + Some((*f, n.clone())) + }) + .collect(); + v.sort_by_key(|(f, _)| *f); + v + } +} + +#[derive(Debug)] +#[allow(dead_code)] +struct AccessInMatch<'tcx> { + access: FieldAccess<'tcx>, + field_tags: Vec<(FieldIdx, HashSet)>, + match_loc: Location, + arm_loc: Location, + match_span: Span, + arm_span: Span, +} + +fn access_in_match<'tcx>( + access: FieldAccess<'tcx>, + matches: &[(Span, Vec<(Span, ArmTags)>)], + switches: &HashMap)>, + ctx: AccessCtx<'_, 'tcx>, +) -> Option> { + let span = ctx.body.source_info(access.location).span; + let (match_span, arm_span, tags) = matches.iter().find_map(|(match_span, arms)| { + let (arm_span, tags) = arms.iter().find(|(s, _)| s.overlaps(span))?; + Some((*match_span, *arm_span, tags)) + })?; + let state_tags = ctx.compute_tags(access); + let mut field_tags = if let Some(tags) = tags { + vec![state_tags.iter().find_map(|(f, n)| { + let ts = n.into_set(); + if set_eq(&ts, tags) { + Some((*f, ts)) + } else { + None + } + })?] + } else if state_tags.is_empty() { + return None; + } else { + state_tags + .into_iter() + .filter_map(|(f, n)| { + let ts = n.into_set(); + if ts.is_empty() { + None + } else { + Some((f, ts)) + } + }) + .collect() + }; + + let (match_loc, ref targets) = switches[&match_span.hi()]; + let block = targets[tags]; + let arm_loc = Location { + block, + statement_index: 0, + }; + field_tags.retain(|(f, tags)| is_tag_from_branch(*f, tags, match_loc, arm_loc, access, ctx)); + if field_tags.is_empty() { + return None; + } + + let access = AccessInMatch { + access, + field_tags, + match_loc, + arm_loc, + match_span, + arm_span, + }; + Some(access) +} + +#[derive(Debug)] +#[allow(dead_code)] +struct AccessInIf<'tcx> { + access: FieldAccess<'tcx>, + field_tags: Vec<(FieldIdx, HashSet)>, + if_loc: Location, + branch_loc: Location, + if_span: Span, + branch_span: Span, +} + +fn access_in_if<'tcx>( + access: FieldAccess<'tcx>, + hifs: &[HIf], + mifs: &[MIf], + ctx: AccessCtx<'_, 'tcx>, +) -> Vec> { + let span = ctx.body.source_info(access.location).span; + let state_tags = ctx.compute_tags(access); + let field_tags: Vec<_> = state_tags + .into_iter() + .filter_map(|(f, n)| { + let ts = n.into_set(); + if ts.is_empty() { + None + } else { + Some((f, ts)) + } + }) + .collect(); + + let mut accesses = vec![]; + for hif in hifs { + let (c_span, branch_span, is_true) = if hif.t.overlaps(span) { + (hif.c, hif.t, true) + } else if let Some(f) = hif.f { + if f.overlaps(span) { + (hif.c, f, false) + } else { + continue; + } + } else { + continue; + }; + + for mif in mifs { + if !mif.c.overlaps(c_span) { + continue; + } + + let mut field_tags = field_tags.clone(); + let if_loc = mif.loc; + let block = if is_true { mif.t } else { mif.f }; + let branch_loc = Location { + block, + statement_index: 0, + }; + field_tags + .retain(|(f, tags)| is_tag_from_branch(*f, tags, if_loc, branch_loc, access, ctx)); + if !field_tags.is_empty() { + let access = AccessInIf { + access, + field_tags, + if_loc: mif.loc, + branch_loc, + if_span: mif.c, + branch_span, + }; + accesses.push(access); + } + } + } + + accesses +} + +fn is_tag_from_branch<'tcx>( + f: FieldIdx, + tags: &HashSet, + loc_before: Location, + loc_after: Location, + access: FieldAccess<'tcx>, + ctx: AccessCtx<'_, 'tcx>, +) -> bool { + let access = FieldAccess { + location: loc_after, + ..access + }; + let arm_tags = ctx.compute_tags(access); + let (_, arm_tags) = some_or!(arm_tags.into_iter().find(|(f1, _)| f == *f1), return false); + if &arm_tags.into_set() != tags { + return false; + } + + let access = FieldAccess { + location: loc_before, + ..access + }; + let match_tags = ctx.compute_tags(access); + if let Some((_, match_tags)) = match_tags.into_iter().find(|(f1, _)| f == *f1) { + if &match_tags.into_set() == tags { + return false; + } + } + + true +} + +fn set_eq(s1: &HashSet, s2: &BTreeSet) -> bool { + if s1.len() != s2.len() { + return false; + } + s2.iter().all(|x| s1.contains(x)) +}