diff --git a/src/compile_util.rs b/src/compile_util.rs index 4621795..dcc5a2c 100644 --- a/src/compile_util.rs +++ b/src/compile_util.rs @@ -101,12 +101,11 @@ pub fn span_to_path(span: Span, source_map: &SourceMap) -> Option { pub type Suggestions = HashMap>; -pub fn apply_suggestions(suggestions: &mut Suggestions) { +pub fn apply_suggestions(suggestions: &Suggestions) { for (path, suggestions) in suggestions { if suggestions.is_empty() { continue; } - suggestions.sort_by_key(|s| s.snippets[0].range.start); let code = String::from_utf8(fs::read(path).unwrap()).unwrap(); let fixed = rustfix::apply_suggestions(&code, suggestions).unwrap(); fs::write(path, fixed.as_bytes()).unwrap(); diff --git a/src/must_analysis/analysis.rs b/src/must_analysis/analysis.rs index c851821..3fce529 100644 --- a/src/must_analysis/analysis.rs +++ b/src/must_analysis/analysis.rs @@ -143,15 +143,18 @@ impl Analyzer<'_, '_, '_> { while let Some(location) = work_list.pop() { let state = states.get(&location).unwrap_or(&bot); - let nexts = self.body.stmt_at(location).either( + let (nexts, is_call) = self.body.stmt_at(location).either( |stmt| { let mut next_state = state.clone(); self.transfer_stmt(stmt, location, &mut next_state); - vec![(location.successor_within_block(), next_state)] + (vec![(location.successor_within_block(), next_state)], false) }, |terminator| { let v = self.discriminant_values.get(&location.block); - self.transfer_term(terminator, v, location, state) + ( + self.transfer_term(terminator, v, location, state), + matches!(terminator.kind, TerminatorKind::Call { .. }), + ) }, ); // println!("{:?}", state); @@ -160,7 +163,7 @@ impl Analyzer<'_, '_, '_> { // println!("{:?}", nexts); // println!("-----------------"); for (next_location, new_next_state) in nexts { - if self.join_terminators.contains(&location) { + if self.join_terminators.contains(&location) || is_call { let out_state = out_states.get(&location).unwrap_or(&bot); let joined = out_state.join(&new_next_state); out_states.insert(location, joined); diff --git a/src/tag_analysis.rs b/src/tag_analysis.rs index 3ae9d09..3abebf8 100644 --- a/src/tag_analysis.rs +++ b/src/tag_analysis.rs @@ -228,18 +228,18 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { visitor.visit_body(body); let mut hvisitor = HBodyVisitor::new(tcx); hvisitor.visit_body(hbody); - if visitor.accesses.is_empty() - && visitor.struct_accesses.is_empty() - && visitor.aggregates.is_empty() - { - continue; - } basic_blocks.insert(local_def_id, visitor.basic_blocks); let locals = locals.entry(local_def_id).or_default(); for (local, local_def) in body.local_decls.iter_enumerated() { let hir_id = some_or!(hvisitor.bindings.get(&local_def.source_info.span), continue); locals.entry(*hir_id).or_insert(local); } + if visitor.accesses.is_empty() + && visitor.struct_accesses.is_empty() + && visitor.aggregates.is_empty() + { + continue; + } if !hvisitor.inits.is_empty() { for (local, location) in &visitor.inits { let span = body @@ -810,6 +810,7 @@ impl {} {{ let mut match_targets = 0; let mut if_targets = 0; + let mut aggregates_num = 0; for item_id in hir.items() { let item = hir.item(item_id); let (ItemKind::Fn(_, _, body_id) | ItemKind::Static(_, _, body_id)) = item.kind else { @@ -817,7 +818,6 @@ impl {} {{ }; let hir_body = hir.body(body_id); let local_def_id = item_id.owner_id.def_id; - let basic_blocks = some_or!(basic_blocks.get(&local_def_id), continue); let typeck = tcx.typeck(local_def_id); let mut visitor = SuggestingVisitor { tcx, @@ -829,34 +829,38 @@ impl {} {{ unions: &tagged_unions, access_in_matches: &access_in_matches, access_in_ifs: &access_in_ifs, - basic_blocks, + basic_blocks: &basic_blocks[&local_def_id], hir_id_to_locals: &locals[&local_def_id], suggestions: &mut suggestions, locals: HashMap::new(), match_targets: HashMap::new(), if_targets: HashMap::new(), + aggregates_num: 0, aggregate_spans: vec![], }; visitor.visit_body(hir_body); match_targets += visitor.match_targets.len(); if_targets += visitor.if_targets.len(); + aggregates_num += visitor.aggregates_num; } println!("match_targets: {}", match_targets); println!("if_targets: {}", if_targets); + println!("aggregates_num: {}", aggregates_num); let mut suggestions = suggestions.suggestions; - for (path, suggestions) in &suggestions { + for (path, suggestions) in &mut suggestions { tracing::info!("{:?}", path); + suggestions.sort_by_key(|s| s.snippets[0].range.start); for suggestion in suggestions { tracing::info!("{:?}", suggestion); } } if conf.transform { - compile_util::apply_suggestions(&mut suggestions); + compile_util::apply_suggestions(&suggestions); } } @@ -1237,16 +1241,20 @@ impl<'tcx> MVisitor<'tcx> for MBodyVisitor<'tcx, '_> { } fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) { - if let Some(stmt) = data.statements.get(0) { - let span = stmt.source_info.span; - let mut lo = span.lo(); - let mut hi = span.hi(); - for stmt in &data.statements[1..] { - let span = stmt.source_info.span; - lo = lo.min(span.lo()); - hi = hi.max(span.hi()); - } - let span = span.with_lo(lo).with_hi(hi); + let spans = data.statements.iter().map(|stmt| stmt.source_info.span); + let term = data.terminator(); + let spans: Box> = + if matches!(term.kind, TerminatorKind::Call { .. }) { + Box::new(spans.chain(std::iter::once(term.source_info.span))) + } else { + Box::new(spans) + }; + if let Some(span) = spans.reduce(|span1, span2| { + let lo = span1.lo().min(span2.lo()); + let hi = span1.hi().max(span2.hi()); + span1.with_lo(lo).with_hi(hi) + }) { + let span = span.with_hi(span.hi() + BytePos(1)); let location = Location { block, statement_index: data.statements.len(), @@ -1339,6 +1347,7 @@ struct SuggestingVisitor<'a, 'tcx> { locals: HashMap>, match_targets: HashMap, if_targets: HashMap, + aggregates_num: usize, aggregate_spans: Vec, } @@ -1581,17 +1590,17 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> { self.suggestions.add(expr.span, "(*__v)".to_string()); } else { let (ctx, _) = get_expr_context(expr, self.tcx); - match ctx { - ExprContext::Value => { - let call = format!("get_{}()", field.name); - self.suggestions.add(field.span, call); - } - ExprContext::Store(_) | ExprContext::Address => { - if !self - .aggregate_spans - .iter() - .any(|span| span.contains(expr.span)) - { + if !self + .aggregate_spans + .iter() + .any(|span| span.contains(expr.span)) + { + match ctx { + ExprContext::Value => { + let call = format!("get_{}()", field.name); + self.suggestions.add(field.span, call); + } + ExprContext::Store(_) | ExprContext::Address => { let span = expr.span.shrink_to_lo(); self.suggestions.add(span, "(*".to_string()); @@ -1796,14 +1805,15 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> { let (variant, value) = vs.into_iter().next().unwrap(); if let ExprKind::Path(QPath::Resolved(_, path)) = root.kind { let Res::Local(hir_id) = path.res else { continue }; - let local = self.hir_id_to_locals[&hir_id]; + let local = *some_or!(self.hir_id_to_locals.get(&hir_id), continue); let field_at = FieldAt { func: self.func, location, local, field: ts.tag_index, }; - let tags = some_or!(self.field_values.get(&field_at), continue); + let tags = self.field_values.get(&field_at); + let tags = some_or!(tags, continue); if tags.len() != 1 { continue; } @@ -1814,6 +1824,7 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> { .filter(|bs| &bs.fields[0] == union_field_name) .collect(); let bs = bss.pop().unwrap(); + let value = self.assigned_value_to_string(&value); let code = format!( "{}.{} = {}::{}{}({});", self.tcx @@ -1848,6 +1859,7 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> { if removed { self.suggestions.add(tag_assign.span, "".to_string()); self.aggregate_spans.push(tag_assign.span); + self.aggregates_num += 1; } } } @@ -1877,6 +1889,42 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> { None } } + + fn assigned_value_to_string(&self, assigned_value: &AssignedValue<'tcx>) -> String { + let source_map = self.tcx.sess.source_map(); + match assigned_value { + AssignedValue::Compound(name, fields) => { + let mut s = name.clone(); + s.push_str(" { "); + for (i, (field, value)) in fields.iter().enumerate() { + if i > 0 { + s.push_str(", "); + } + s.push_str(field); + let value = self.assigned_value_to_string(value); + if field != &value { + write!(&mut s, ": {}", value).unwrap(); + } + } + s.push_str(" }"); + s + } + AssignedValue::Primitive(value) => { + if let ExprKind::Field(e, f) = value.kind { + let ty = self.typeck.expr_ty(e); + if let TyKind::Adt(adt_def, _) = ty.kind() { + if let Some(did) = adt_def.did().as_local() { + if self.unions.contains_key(&did) { + let e = source_map.span_to_snippet(e.span).unwrap(); + return format!("{}.get_{}()", e, f.name); + } + } + } + } + source_map.span_to_snippet(value.span).unwrap() + } + } + } } fn make_aggregate<'tcx>( @@ -1884,19 +1932,17 @@ fn make_aggregate<'tcx>( fields: &mut Vec, assigns: &HashMap, &AssignBlockStmt<'tcx>>, tcx: TyCtxt<'tcx>, -) -> Option { +) -> Option> { if let Some(bs) = assigns.get(fields) { - tcx.sess - .source_map() - .span_to_snippet(bs.rhs.span) - .ok() - .map(AssignedValue::Primitive) + Some(AssignedValue::Primitive(bs.rhs)) } else { let ty = ty?; let def_path = tcx.def_path(ty.to_def_id()); let mut name = "crate".to_string(); for data in def_path.data { - write!(name, "::{}", data).unwrap(); + let data = format!("{}", data); + let escape = if data == "async" { "r#" } else { "" }; + write!(name, "::{}{}", escape, data).unwrap(); } let adt_def = tcx.adt_def(ty); let (ItemKind::Struct(vd, _) | ItemKind::Union(vd, _)) = tcx.hir().expect_item(ty).kind @@ -1935,37 +1981,9 @@ fn make_aggregate<'tcx>( } } -enum AssignedValue { - Compound(String, HashMap), - Primitive(String), -} - -impl std::fmt::Debug for AssignedValue { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - AssignedValue::Compound(name, fields) => { - write!(f, "{} {{ ", name)?; - for (i, (field, value)) in fields.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{}", field)?; - let value = value.to_string(); - if field != &value { - write!(f, ": {}", value)?; - } - } - write!(f, " }}") - } - AssignedValue::Primitive(value) => write!(f, "{}", value), - } - } -} - -impl std::fmt::Display for AssignedValue { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } +enum AssignedValue<'tcx> { + Compound(String, HashMap>), + Primitive(&'tcx Expr<'tcx>), } #[derive(Debug, Default)]