diff --git a/src/tag_analysis.rs b/src/tag_analysis.rs index b8dc601..b6a98dc 100644 --- a/src/tag_analysis.rs +++ b/src/tag_analysis.rs @@ -26,7 +26,7 @@ use rustc_middle::{ ty::{List, Ty, TyCtxt, TyKind, TypeAndMut, TypeckResults}, }; use rustc_session::config::Input; -use rustc_span::{def_id::LocalDefId, source_map::SourceMap, BytePos, Span}; +use rustc_span::{def_id::LocalDefId, source_map::SourceMap, BytePos, Span, Symbol}; use rustfix::Suggestion; use typed_arena::Arena; @@ -68,6 +68,41 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { .unwrap_or_else(|| may_analysis::analyze(&pre, &tss, tcx)); let may_points_to = may_analysis::post_analyze(pre, solutions, &tss, tcx); + let mut non_tag_fields = HashMap::new(); + for item_id in hir.items() { + let item = hir.item(item_id); + let body_id = match item.kind { + ItemKind::Fn(_, _, body_id) | ItemKind::Static(_, _, body_id) => body_id, + _ => continue, + }; + let body = hir.body(body_id); + let typeck = tcx.typeck(item_id.owner_id.def_id); + let mut visitor = FieldVisitor { + tcx, + typeck, + fields: &mut non_tag_fields, + }; + visitor.visit_body(body); + } + let non_tag_fields: HashMap<_, HashSet<_>> = non_tag_fields + .into_iter() + .map(|(s, symbols)| { + let item = hir.expect_item(s); + let ItemKind::Struct(VariantData::Struct(fs, _), _) = item.kind else { unreachable!() }; + let fields = symbols + .into_iter() + .map(|sym| { + fs.iter() + .enumerate() + .find(|(_, f)| f.ident.name == sym) + .map(|(i, _)| FieldIdx::from_usize(i)) + .unwrap() + }) + .collect(); + (s, fields) + }) + .collect(); + let mut structs = HashMap::new(); let mut unions = vec![]; let mut union_to_struct = HashMap::new(); @@ -99,11 +134,20 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { } let adt_def = tcx.adt_def(local_def_id); let variant = adt_def.variant(VariantIdx::from_u32(0)); - let has_int_field = variant + let non_tag_fields = non_tag_fields.get(&local_def_id); + let mut int_fields: HashSet<_> = variant .fields - .iter() - .any(|f| f.ty(tcx, List::empty()).is_integral()); - if !has_int_field && !tss.bitfields.contains_key(&local_def_id) { + .iter_enumerated() + .filter(|(i, f)| { + f.ty(tcx, List::empty()).is_integral() + && non_tag_fields.map_or(true, |fields| !fields.contains(i)) + }) + .map(|(i, _)| i) + .collect(); + if let Some(bitfield) = tss.bitfields.get(&local_def_id) { + int_fields.extend(bitfield.fields.keys()); + } + if int_fields.is_empty() { continue; } let mut struct_unions = vec![]; @@ -130,7 +174,11 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { } } if !struct_unions.is_empty() { - structs.insert(local_def_id, struct_unions); + let info = StructInfo { + int_fields, + unions: struct_unions, + }; + structs.insert(local_def_id, info); } } @@ -320,7 +368,9 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { let mut tagged_unions = HashMap::new(); for (u, uu) in &union_uses { - if let Some((tag_index, mut variant_tags)) = uu.compute_tags() { + let (index_in_struct, struct_local_def_id) = union_to_struct[u]; + let int_fields = &structs[&struct_local_def_id].int_fields; + if let Some((tag_index, mut variant_tags)) = uu.compute_tags(int_fields) { println!("Union {:?}", u); println!("Used fields: {:?}", uu.fields); println!("Tag field: {:?}", tag_index); @@ -353,7 +403,6 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { .enumerate() .map(|(i, f)| (f.ident.name.to_ident_string(), FieldIdx::from_usize(i))) .collect(); - let (index_in_struct, struct_local_def_id) = union_to_struct[u]; let tu = TaggedUnion { local_def_id: *u, name, @@ -763,10 +812,16 @@ impl UnionUse { self.obj_tags.entry(field).or_default() } - fn compute_tags(&self) -> Option<(FieldIdx, BTreeMap>)> { + fn compute_tags( + &self, + int_fields: &HashSet, + ) -> Option<(FieldIdx, BTreeMap>)> { self.access_tags .iter() .filter_map(|(f, ts)| { + if !int_fields.contains(f) { + return None; + } let mut tags = ts.compute_tags()?; if let Some(tag_spans) = self.obj_tags.get(f) { tag_spans.compute_tags_with(&mut tags); @@ -919,7 +974,7 @@ fn compute_tags<'tcx, D: HasLocalDecls<'tcx> + ?Sized>( struct MBodyVisitor<'tcx, 'a> { tcx: TyCtxt<'tcx>, local_decls: &'a IndexVec>, - structs: &'a HashMap>, + structs: &'a HashMap, unions: &'a Vec, accesses: Vec>, struct_accesses: HashSet, @@ -931,7 +986,7 @@ impl<'tcx, 'a> MBodyVisitor<'tcx, 'a> { fn new( tcx: TyCtxt<'tcx>, local_decls: &'a IndexVec>, - structs: &'a HashMap>, + structs: &'a HashMap, unions: &'a Vec, ) -> Self { Self { @@ -1065,6 +1120,12 @@ struct FieldAt { field: FieldIdx, } +struct StructInfo { + int_fields: HashSet, + #[allow(unused)] + unions: Vec<(FieldIdx, LocalDefId)>, +} + struct TaggedStruct { #[allow(dead_code)] local_def_id: LocalDefId, @@ -1168,7 +1229,9 @@ impl<'tcx> HBodyVisitor<'_, 'tcx> { let span = field.span.shrink_to_hi(); self.suggestions.add(span, "()".to_string()); } - ExprContext::Store => { + ExprContext::Store(op) => { + assert!(!op); + let span = field.span.shrink_to_lo(); self.suggestions.add(span, "set_".to_string()); @@ -1190,7 +1253,7 @@ impl<'tcx> HBodyVisitor<'_, 'tcx> { let call = format!("get_{}()", field.name); self.suggestions.add(field.span, call); } - ExprContext::Store | ExprContext::Address => { + ExprContext::Store(_) | ExprContext::Address => { let span = expr.span.shrink_to_lo(); self.suggestions.add(span, "(*".to_string()); @@ -1313,11 +1376,52 @@ impl<'tcx> HVisitor<'tcx> for BitFieldInitVisitor<'tcx> { } } +struct FieldVisitor<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + typeck: &'tcx TypeckResults<'tcx>, + fields: &'a mut HashMap>, +} + +impl<'tcx> FieldVisitor<'_, 'tcx> { + fn handle_expr(&mut self, expr: &'tcx Expr<'tcx>) { + let ExprKind::Field(e, f) = expr.kind else { return }; + let ty = self.typeck.expr_ty(expr); + if !ty.is_integral() { + return; + } + let ty = self.typeck.expr_ty(e); + let TyKind::Adt(adt_def, _) = ty.kind() else { return }; + if !adt_def.is_struct() { + return; + } + let def_id = some_or!(adt_def.did().as_local(), return); + if matches!( + get_expr_context(expr, self.tcx).0, + ExprContext::Store(true) | ExprContext::Address + ) { + self.fields.entry(def_id).or_default().insert(f.name); + } + } +} + +impl<'tcx> HVisitor<'tcx> for FieldVisitor<'_, 'tcx> { + type NestedFilter = nested_filter::OnlyBodies; + + fn nested_visit_map(&mut self) -> Self::Map { + self.tcx.hir() + } + + fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { + self.handle_expr(expr); + intravisit::walk_expr(self, expr); + } +} + #[derive(Debug, Clone, Copy)] enum ExprContext { Value, Address, - Store, + Store(bool), } fn get_expr_context<'tcx>( @@ -1329,7 +1433,10 @@ fn get_expr_context<'tcx>( Node::Expr(e) => match e.kind { ExprKind::Assign(l, _, _) | ExprKind::AssignOp(_, l, _) => { if expr.hir_id == l.hir_id { - (ExprContext::Store, e) + ( + ExprContext::Store(matches!(e.kind, ExprKind::AssignOp(_, _, _))), + e, + ) } else { (ExprContext::Value, expr) }