Skip to content

Commit

Permalink
find possible tag fields
Browse files Browse the repository at this point in the history
  • Loading branch information
Medowhill committed May 23, 2024
1 parent 6a80b74 commit db6dc69
Showing 1 changed file with 122 additions and 15 deletions.
137 changes: 122 additions & 15 deletions src/tag_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use rustc_middle::{
ty::{List, Ty, TyCtxt, TyKind, TypeAndMut, TypeckResults},
};
use rustc_session::config::Input;
use rustc_span::{def_id::LocalDefId, source_map::SourceMap, BytePos, Span};
use rustc_span::{def_id::LocalDefId, source_map::SourceMap, BytePos, Span, Symbol};
use rustfix::Suggestion;
use typed_arena::Arena;

Expand Down Expand Up @@ -68,6 +68,41 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
.unwrap_or_else(|| may_analysis::analyze(&pre, &tss, tcx));
let may_points_to = may_analysis::post_analyze(pre, solutions, &tss, tcx);

let mut non_tag_fields = HashMap::new();
for item_id in hir.items() {
let item = hir.item(item_id);
let body_id = match item.kind {
ItemKind::Fn(_, _, body_id) | ItemKind::Static(_, _, body_id) => body_id,
_ => continue,
};
let body = hir.body(body_id);
let typeck = tcx.typeck(item_id.owner_id.def_id);
let mut visitor = FieldVisitor {
tcx,
typeck,
fields: &mut non_tag_fields,
};
visitor.visit_body(body);
}
let non_tag_fields: HashMap<_, HashSet<_>> = non_tag_fields
.into_iter()
.map(|(s, symbols)| {
let item = hir.expect_item(s);
let ItemKind::Struct(VariantData::Struct(fs, _), _) = item.kind else { unreachable!() };
let fields = symbols
.into_iter()
.map(|sym| {
fs.iter()
.enumerate()
.find(|(_, f)| f.ident.name == sym)
.map(|(i, _)| FieldIdx::from_usize(i))
.unwrap()
})
.collect();
(s, fields)
})
.collect();

let mut structs = HashMap::new();
let mut unions = vec![];
let mut union_to_struct = HashMap::new();
Expand Down Expand Up @@ -99,11 +134,20 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
}
let adt_def = tcx.adt_def(local_def_id);
let variant = adt_def.variant(VariantIdx::from_u32(0));
let has_int_field = variant
let non_tag_fields = non_tag_fields.get(&local_def_id);
let mut int_fields: HashSet<_> = variant
.fields
.iter()
.any(|f| f.ty(tcx, List::empty()).is_integral());
if !has_int_field && !tss.bitfields.contains_key(&local_def_id) {
.iter_enumerated()
.filter(|(i, f)| {
f.ty(tcx, List::empty()).is_integral()
&& non_tag_fields.map_or(true, |fields| !fields.contains(i))
})
.map(|(i, _)| i)
.collect();
if let Some(bitfield) = tss.bitfields.get(&local_def_id) {
int_fields.extend(bitfield.fields.keys());
}
if int_fields.is_empty() {
continue;
}
let mut struct_unions = vec![];
Expand All @@ -130,7 +174,11 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
}
}
if !struct_unions.is_empty() {
structs.insert(local_def_id, struct_unions);
let info = StructInfo {
int_fields,
unions: struct_unions,
};
structs.insert(local_def_id, info);
}
}

Expand Down Expand Up @@ -320,7 +368,9 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {

let mut tagged_unions = HashMap::new();
for (u, uu) in &union_uses {
if let Some((tag_index, mut variant_tags)) = uu.compute_tags() {
let (index_in_struct, struct_local_def_id) = union_to_struct[u];
let int_fields = &structs[&struct_local_def_id].int_fields;
if let Some((tag_index, mut variant_tags)) = uu.compute_tags(int_fields) {
println!("Union {:?}", u);
println!("Used fields: {:?}", uu.fields);
println!("Tag field: {:?}", tag_index);
Expand Down Expand Up @@ -353,7 +403,6 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
.enumerate()
.map(|(i, f)| (f.ident.name.to_ident_string(), FieldIdx::from_usize(i)))
.collect();
let (index_in_struct, struct_local_def_id) = union_to_struct[u];
let tu = TaggedUnion {
local_def_id: *u,
name,
Expand Down Expand Up @@ -763,10 +812,16 @@ impl UnionUse {
self.obj_tags.entry(field).or_default()
}

fn compute_tags(&self) -> Option<(FieldIdx, BTreeMap<FieldIdx, Vec<Tag>>)> {
fn compute_tags(
&self,
int_fields: &HashSet<FieldIdx>,
) -> Option<(FieldIdx, BTreeMap<FieldIdx, Vec<Tag>>)> {
self.access_tags
.iter()
.filter_map(|(f, ts)| {
if !int_fields.contains(f) {
return None;
}
let mut tags = ts.compute_tags()?;
if let Some(tag_spans) = self.obj_tags.get(f) {
tag_spans.compute_tags_with(&mut tags);
Expand Down Expand Up @@ -919,7 +974,7 @@ fn compute_tags<'tcx, D: HasLocalDecls<'tcx> + ?Sized>(
struct MBodyVisitor<'tcx, 'a> {
tcx: TyCtxt<'tcx>,
local_decls: &'a IndexVec<Local, LocalDecl<'tcx>>,
structs: &'a HashMap<LocalDefId, Vec<(FieldIdx, LocalDefId)>>,
structs: &'a HashMap<LocalDefId, StructInfo>,
unions: &'a Vec<LocalDefId>,
accesses: Vec<FieldAccess<'tcx>>,
struct_accesses: HashSet<Local>,
Expand All @@ -931,7 +986,7 @@ impl<'tcx, 'a> MBodyVisitor<'tcx, 'a> {
fn new(
tcx: TyCtxt<'tcx>,
local_decls: &'a IndexVec<Local, LocalDecl<'tcx>>,
structs: &'a HashMap<LocalDefId, Vec<(FieldIdx, LocalDefId)>>,
structs: &'a HashMap<LocalDefId, StructInfo>,
unions: &'a Vec<LocalDefId>,
) -> Self {
Self {
Expand Down Expand Up @@ -1065,6 +1120,12 @@ struct FieldAt {
field: FieldIdx,
}

struct StructInfo {
int_fields: HashSet<FieldIdx>,
#[allow(unused)]
unions: Vec<(FieldIdx, LocalDefId)>,
}

struct TaggedStruct {
#[allow(dead_code)]
local_def_id: LocalDefId,
Expand Down Expand Up @@ -1168,7 +1229,9 @@ impl<'tcx> HBodyVisitor<'_, 'tcx> {
let span = field.span.shrink_to_hi();
self.suggestions.add(span, "()".to_string());
}
ExprContext::Store => {
ExprContext::Store(op) => {
assert!(!op);

let span = field.span.shrink_to_lo();
self.suggestions.add(span, "set_".to_string());

Expand All @@ -1190,7 +1253,7 @@ impl<'tcx> HBodyVisitor<'_, 'tcx> {
let call = format!("get_{}()", field.name);
self.suggestions.add(field.span, call);
}
ExprContext::Store | ExprContext::Address => {
ExprContext::Store(_) | ExprContext::Address => {
let span = expr.span.shrink_to_lo();
self.suggestions.add(span, "(*".to_string());

Expand Down Expand Up @@ -1313,11 +1376,52 @@ impl<'tcx> HVisitor<'tcx> for BitFieldInitVisitor<'tcx> {
}
}

struct FieldVisitor<'a, 'tcx> {
tcx: TyCtxt<'tcx>,
typeck: &'tcx TypeckResults<'tcx>,
fields: &'a mut HashMap<LocalDefId, HashSet<Symbol>>,
}

impl<'tcx> FieldVisitor<'_, 'tcx> {
fn handle_expr(&mut self, expr: &'tcx Expr<'tcx>) {
let ExprKind::Field(e, f) = expr.kind else { return };
let ty = self.typeck.expr_ty(expr);
if !ty.is_integral() {
return;
}
let ty = self.typeck.expr_ty(e);
let TyKind::Adt(adt_def, _) = ty.kind() else { return };
if !adt_def.is_struct() {
return;
}
let def_id = some_or!(adt_def.did().as_local(), return);
if matches!(
get_expr_context(expr, self.tcx).0,
ExprContext::Store(true) | ExprContext::Address
) {
self.fields.entry(def_id).or_default().insert(f.name);
}
}
}

impl<'tcx> HVisitor<'tcx> for FieldVisitor<'_, 'tcx> {
type NestedFilter = nested_filter::OnlyBodies;

fn nested_visit_map(&mut self) -> Self::Map {
self.tcx.hir()
}

fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
self.handle_expr(expr);
intravisit::walk_expr(self, expr);
}
}

#[derive(Debug, Clone, Copy)]
enum ExprContext {
Value,
Address,
Store,
Store(bool),
}

fn get_expr_context<'tcx>(
Expand All @@ -1329,7 +1433,10 @@ fn get_expr_context<'tcx>(
Node::Expr(e) => match e.kind {
ExprKind::Assign(l, _, _) | ExprKind::AssignOp(_, l, _) => {
if expr.hir_id == l.hir_id {
(ExprContext::Store, e)
(
ExprContext::Store(matches!(e.kind, ExprKind::AssignOp(_, _, _))),
e,
)
} else {
(ExprContext::Value, expr)
}
Expand Down

0 comments on commit db6dc69

Please sign in to comment.