Skip to content

Commit

Permalink
fix is_tag_from_branch
Browse files Browse the repository at this point in the history
  • Loading branch information
Medowhill committed May 28, 2024
1 parent 59d9898 commit 849d6d1
Showing 1 changed file with 70 additions and 41 deletions.
111 changes: 70 additions & 41 deletions src/tag_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{

use compile_util::{make_suggestion, span_to_snippet};
use etrace::{ok_or, some_or};
use must_analysis::Obj;
use must_analysis::{Graph, Obj};
use rustc_abi::{FieldIdx, VariantIdx};
use rustc_ast::{BindingAnnotation, LitKind, Mutability};
use rustc_hir::{
Expand All @@ -20,7 +20,7 @@ use rustc_index::{bit_set::BitSet, IndexVec};
use rustc_middle::{
hir::nested_filter,
mir::{
visit::{MutatingUseContext, PlaceContext, Visitor as MVisitor},
visit::{PlaceContext, Visitor as MVisitor},
AggregateKind, BasicBlock, Body, ConstantKind, HasLocalDecls, Local, LocalDecl, Location,
Operand, Place, PlaceElem, ProjectionElem, Rvalue, Terminator, TerminatorKind,
},
Expand Down Expand Up @@ -283,35 +283,23 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
gc: true,
};
let states = must_analysis::analyze_body(body, ctx, tcx);
let ctx = AccessCtx::new(&states.states, body, tcx);
let ctx = AccessCtx::new(&states.states, &local_to_unions, body, tcx);
for access in visitor.accesses {
if let Some(access_match) =
let tags = if let Some(access_match) =
access_in_match(access, &hvisitor.matches, &visitor.switches, ctx)
{
println!(
"{:?} {:?}",
access_match.match_span, access_match.field_tags
);
access_match.field_tags
} else {
let accesses_if = access_in_if(access, &hvisitor.ifs, &visitor.ifs, ctx);
for access_if in accesses_if {
println!("{:?} {:?}", access_if.if_span, access_if.field_tags);
}
}
if matches!(
access.ctx,
PlaceContext::MutatingUse(MutatingUseContext::Store)
) {
continue;
}
accesses_if.into_iter().flat_map(|a| a.field_tags).collect()
};
let span = body.source_info(access.location).span;
let tags = ctx.compute_tags(access);
for (f, ns) in tags {
let vts = union_uses
.entry(access.ty)
.or_default()
.get_access_tags_mut(f);
for n in ns.into_set() {
for n in ns {
vts.insert(access.field, n, span);
}
}
Expand Down Expand Up @@ -1089,7 +1077,6 @@ impl<'tcx> MVisitor<'tcx> for MBodyVisitor<'tcx, '_> {
local: place.local,
projection: &place.projection[..=i],
field: f,
ctx: context,
location,
};
self.accesses.push(access);
Expand Down Expand Up @@ -1183,7 +1170,6 @@ struct FieldAccess<'tcx> {
local: Local,
projection: &'tcx [PlaceElem<'tcx>],
field: FieldIdx,
ctx: PlaceContext,
location: Location,
}

Expand Down Expand Up @@ -1648,13 +1634,24 @@ fn pat_to_tags(pat: &Pat<'_>) -> ArmTags {
#[derive(Clone, Copy)]
struct AccessCtx<'a, 'tcx> {
states: &'a HashMap<Location, AbsMem>,
local_to_unions: &'a HashMap<LocalDefId, Vec<(Local, Vec<AccElem>)>>,
body: &'a Body<'tcx>,
tcx: TyCtxt<'tcx>,
}

impl<'a, 'tcx> AccessCtx<'a, 'tcx> {
fn new(states: &'a HashMap<Location, AbsMem>, body: &'a Body<'tcx>, tcx: TyCtxt<'tcx>) -> Self {
Self { states, body, tcx }
fn new(
states: &'a HashMap<Location, AbsMem>,
local_to_unions: &'a HashMap<LocalDefId, Vec<(Local, Vec<AccElem>)>>,
body: &'a Body<'tcx>,
tcx: TyCtxt<'tcx>,
) -> Self {
Self {
states,
local_to_unions,
body,
tcx,
}
}

fn compute_tags(&self, access: FieldAccess<'tcx>) -> Vec<(FieldIdx, AbsInt)> {
Expand Down Expand Up @@ -1831,28 +1828,43 @@ fn is_tag_from_branch<'tcx>(
access: FieldAccess<'tcx>,
ctx: AccessCtx<'_, 'tcx>,
) -> bool {
let access = FieldAccess {
location: loc_after,
..access
};
let arm_tags = ctx.compute_tags(access);
let (_, arm_tags) = some_or!(arm_tags.into_iter().find(|(f1, _)| f == *f1), return false);
if &arm_tags.into_set() != tags {
let mut paths = ctx.local_to_unions[&access.ty].clone();

filter_paths_by_tag(f, tags, access.location, &mut paths, true, ctx);
if paths.is_empty() {
return false;
}

let access = FieldAccess {
location: loc_before,
..access
};
let match_tags = ctx.compute_tags(access);
if let Some((_, match_tags)) = match_tags.into_iter().find(|(f1, _)| f == *f1) {
if &match_tags.into_set() == tags {
return false;
}
filter_paths_by_tag(f, tags, loc_after, &mut paths, true, ctx);
if paths.is_empty() {
return false;
}

true
filter_paths_by_tag(f, tags, loc_before, &mut paths, false, ctx);
!paths.is_empty()
}

fn filter_paths_by_tag(
f: FieldIdx,
tags: &HashSet<u128>,
loc: Location,
paths: &mut Vec<(Local, Vec<AccElem>)>,
same: bool,
ctx: AccessCtx<'_, '_>,
) {
let state = &ctx.states[&loc];
let g = state.g();
paths.retain(|(l, path)| {
let len = path.len();
let AccElem::Field(_) = path[len - 1] else { unreachable!() };
let objs = g.objs_at(*l, &path[..len - 1]);
objs.iter().any(|obj| {
let loc_tags = extract_tags_from_obj(obj, g);
let (_, loc_tags) =
some_or!(loc_tags.into_iter().find(|(f1, _)| f == *f1), return false);
&loc_tags == tags
}) == same
});
}

fn set_eq<T: Eq + Ord + std::hash::Hash>(s1: &HashSet<T>, s2: &BTreeSet<T>) -> bool {
Expand All @@ -1861,3 +1873,20 @@ fn set_eq<T: Eq + Ord + std::hash::Hash>(s1: &HashSet<T>, s2: &BTreeSet<T>) -> b
}
s2.iter().all(|x| s1.contains(x))
}

fn extract_tags_from_obj(obj: &Obj, g: &Graph) -> Vec<(FieldIdx, HashSet<u128>)> {
let Obj::Struct(fs, _) = obj else { return vec![] };
fs.iter()
.filter_map(|(f, obj)| {
let Obj::Ptr(loc) = obj else { return None };
let obj = g.obj_at_location(loc)?;
let Obj::AtAddr(n) = obj else { return None };
let ns = n.into_set();
if ns.is_empty() {
None
} else {
Some((*f, ns))
}
})
.collect()
}

0 comments on commit 849d6d1

Please sign in to comment.