Skip to content

Commit

Permalink
match target
Browse files Browse the repository at this point in the history
  • Loading branch information
Medowhill committed May 28, 2024
1 parent 849d6d1 commit 1e1e9e2
Showing 1 changed file with 122 additions and 17 deletions.
139 changes: 122 additions & 17 deletions src/tag_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use rustc_hir::{
def::Res,
definitions::DefPathDataName,
intravisit::{self, Visitor as HVisitor},
ByRef, Expr, ExprKind, ItemKind, MatchSource, Node, Pat, PatKind, QPath, StmtKind, UnOp,
ByRef, Expr, ExprKind, HirId, ItemKind, MatchSource, Node, Pat, PatKind, QPath, StmtKind, UnOp,
VariantData,
};
use rustc_index::{bit_set::BitSet, IndexVec};
Expand Down Expand Up @@ -205,6 +205,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
let mut union_uses: HashMap<_, UnionUse> = HashMap::new();
let mut aggregates: HashMap<_, _> = HashMap::new();
let mut field_values: HashMap<FieldAt, BTreeSet<Tag>> = HashMap::new();
let mut access_in_matches: HashMap<_, Vec<_>> = HashMap::new();
println!("Start analysis");
for item_id in hir.items() {
let item = hir.item(item_id);
Expand Down Expand Up @@ -288,7 +289,12 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
let tags = if let Some(access_match) =
access_in_match(access, &hvisitor.matches, &visitor.switches, ctx)
{
access_match.field_tags
let tags = access_match.field_tags.clone();
access_in_matches
.entry(access_match.match_span)
.or_default()
.push(access_match);
tags
} else {
let accesses_if = access_in_if(access, &hvisitor.ifs, &visitor.ifs, ctx);
accesses_if.into_iter().flat_map(|a| a.field_tags).collect()
Expand Down Expand Up @@ -402,6 +408,8 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
let item = hir.expect_item(*u);
let name = item.ident.name.to_ident_string();
let ItemKind::Union(VariantData::Struct(fs, _), _) = item.kind else { unreachable!() };
let field_names: IndexVec<FieldIdx, _> =
fs.iter().map(|f| f.ident.name.to_ident_string()).collect();
let field_name_to_index = fs
.iter()
.enumerate()
Expand All @@ -412,6 +420,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
name,
variant_tags,
empty_tags,
field_names,
field_name_to_index,
struct_local_def_id,
index_in_struct,
Expand Down Expand Up @@ -771,7 +780,10 @@ impl {} {{
field_values: &field_values,
structs: &tagged_structs,
unions: &tagged_unions,
access_in_matches: &access_in_matches,
suggestions: &mut suggestions,

locals: HashMap::new(),
};
visitor.visit_body(hir_body);
}
Expand Down Expand Up @@ -1220,6 +1232,7 @@ struct TaggedUnion {
name: String,
variant_tags: BTreeMap<FieldIdx, Vec<Tag>>,
empty_tags: Vec<Tag>,
field_names: IndexVec<FieldIdx, String>,
field_name_to_index: HashMap<String, FieldIdx>,
struct_local_def_id: LocalDefId,
index_in_struct: FieldIdx,
Expand All @@ -1234,12 +1247,77 @@ struct SuggestingVisitor<'a, 'tcx> {
field_values: &'a HashMap<FieldAt, BTreeSet<Tag>>,
structs: &'a HashMap<LocalDefId, TaggedStruct>,
unions: &'a HashMap<LocalDefId, TaggedUnion>,
access_in_matches: &'a HashMap<Span, Vec<AccessInMatch<'tcx>>>,
suggestions: &'a mut Suggestions<'tcx>,

locals: HashMap<HirId, &'tcx Expr<'tcx>>,
}

impl<'tcx> SuggestingVisitor<'_, 'tcx> {
fn handle_expr(&mut self, expr: &'tcx Expr<'tcx>) {
let source_map = self.tcx.sess.source_map();

if let Some(access) = self.access_in_matches.get(&expr.span) {
let tu = &self.unions[&access[0].ty];
let ts = &self.structs[&tu.struct_local_def_id];
let union_field_name = &ts.field_names[tu.index_in_struct];

let expr_wo_cast = unwrap_cast(expr);
match expr_wo_cast.kind {
ExprKind::Field(expr_struct, _) => {
let span = expr.span.with_lo(expr_struct.span.hi());
self.suggestions.add(span, format!(".{}", union_field_name));
}
ExprKind::Path(QPath::Resolved(_, path)) => {
let Res::Local(hir_id) = path.res else { unreachable!() };
let init = self.locals[&hir_id];
let ExprKind::Field(expr_struct, _) = init.kind else { unreachable!() };
let struct_str = source_map.span_to_snippet(expr_struct.span).unwrap();
self.suggestions
.add(expr.span, format!("{}.{}", struct_str, union_field_name));
}
_ => unreachable!("{:?}", expr),
}

let Node::Expr(match_expr) = self.tcx.hir().get_parent(expr.hir_id) else {
unreachable!()
};
let ExprKind::Match(_, arms, _) = match_expr.kind else { unreachable!() };
for arm in arms {
let span = arm.pat.span;
if let Some(tags) = pat_to_tags(arm.pat) {
assert_eq!(tags.len(), 1);
let tag = Tag(*tags.iter().next().unwrap() as _);
let pat = if tu.empty_tags.contains(&tag) {
format!("{}::Empty{}", tu.name, tag)
} else {
assert!(matches!(arm.body.kind, ExprKind::Block(_, _)));
let pos = arm.body.span.lo() + BytePos(1);
let span = arm.body.span.with_lo(pos).with_hi(pos);
self.suggestions
.add(span, "let mut __v = __v as *mut _;".to_string());

let f = tu
.variant_tags
.iter()
.find_map(
|(f, tags)| {
if tags.contains(&tag) {
Some(*f)
} else {
None
}
},
)
.unwrap();
let f_name = &tu.field_names[f];
format!("{}::{}{}(ref mut __v)", tu.name, f_name, tag)
};
self.suggestions.add(span, pat);
}
}
}

match expr.kind {
ExprKind::Struct(_, fs, _) => {
let TyKind::Adt(adt_def, _) = self.typeck.expr_ty(expr).kind() else {
Expand Down Expand Up @@ -1318,8 +1396,14 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> {
let (ctx, e2) = get_expr_context(e, self.tcx);
match ctx {
ExprContext::Value => {
let span = field.span.shrink_to_hi();
self.suggestions.add(span, "()".to_string());
if self
.access_in_matches
.iter()
.all(|(span, _)| !span.contains(expr.span))
{
let span = field.span.shrink_to_hi();
self.suggestions.add(span, "()".to_string());
}
}
ExprContext::Store(op) => {
assert!(!op);
Expand Down Expand Up @@ -1399,6 +1483,12 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> {
_ => {}
}
}

fn handle_local(&mut self, local: &'tcx rustc_hir::Local<'tcx>) {
let PatKind::Binding(_, hir_id, _, _) = local.pat.kind else { return };
let init = some_or!(local.init, return);
self.locals.insert(hir_id, init);
}
}

impl<'tcx> HVisitor<'tcx> for SuggestingVisitor<'_, 'tcx> {
Expand All @@ -1412,6 +1502,11 @@ impl<'tcx> HVisitor<'tcx> for SuggestingVisitor<'_, 'tcx> {
self.handle_expr(expr);
intravisit::walk_expr(self, expr);
}

fn visit_local(&mut self, local: &'tcx rustc_hir::Local<'tcx>) {
self.handle_local(local);
intravisit::walk_local(self, local);
}
}

struct HIf {
Expand Down Expand Up @@ -1677,6 +1772,7 @@ impl<'a, 'tcx> AccessCtx<'a, 'tcx> {
#[derive(Debug)]
#[allow(dead_code)]
struct AccessInMatch<'tcx> {
ty: LocalDefId,
access: FieldAccess<'tcx>,
field_tags: Vec<(FieldIdx, HashSet<u128>)>,
match_loc: Location,
Expand Down Expand Up @@ -1734,6 +1830,7 @@ fn access_in_match<'tcx>(
}

let access = AccessInMatch {
ty: access.ty,
access,
field_tags,
match_loc,
Expand Down Expand Up @@ -1852,19 +1949,22 @@ fn filter_paths_by_tag(
same: bool,
ctx: AccessCtx<'_, '_>,
) {
let state = &ctx.states[&loc];
let g = state.g();
paths.retain(|(l, path)| {
let len = path.len();
let AccElem::Field(_) = path[len - 1] else { unreachable!() };
let objs = g.objs_at(*l, &path[..len - 1]);
objs.iter().any(|obj| {
let loc_tags = extract_tags_from_obj(obj, g);
let (_, loc_tags) =
some_or!(loc_tags.into_iter().find(|(f1, _)| f == *f1), return false);
&loc_tags == tags
}) == same
});
if let Some(state) = ctx.states.get(&loc) {
let g = state.g();
paths.retain(|(l, path)| {
let len = path.len();
let AccElem::Field(_) = path[len - 1] else { unreachable!() };
let objs = g.objs_at(*l, &path[..len - 1]);
objs.iter().any(|obj| {
let loc_tags = extract_tags_from_obj(obj, g);
let (_, loc_tags) =
some_or!(loc_tags.into_iter().find(|(f1, _)| f == *f1), return false);
&loc_tags == tags
}) == same
});
} else {
paths.clear();
}
}

fn set_eq<T: Eq + Ord + std::hash::Hash>(s1: &HashSet<T>, s2: &BTreeSet<T>) -> bool {
Expand All @@ -1890,3 +1990,8 @@ fn extract_tags_from_obj(obj: &Obj, g: &Graph) -> Vec<(FieldIdx, HashSet<u128>)>
})
.collect()
}

fn unwrap_cast<'a, 'tcx>(e: &'a Expr<'tcx>) -> &'a Expr<'tcx> {
let ExprKind::Cast(e, _) = e.kind else { return e };
unwrap_cast(e)
}

0 comments on commit 1e1e9e2

Please sign in to comment.