From a24c58c07f0c66270224a3d057fa10ab115ccc89 Mon Sep 17 00:00:00 2001 From: Jaemin Hong Date: Tue, 28 May 2024 16:37:10 +0000 Subject: [PATCH] match arm --- src/tag_analysis.rs | 151 ++++++++++++++++++++++++++++---------------- 1 file changed, 95 insertions(+), 56 deletions(-) diff --git a/src/tag_analysis.rs b/src/tag_analysis.rs index cbca2f4..aacd43d 100644 --- a/src/tag_analysis.rs +++ b/src/tag_analysis.rs @@ -410,6 +410,10 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { let ItemKind::Union(VariantData::Struct(fs, _), _) = item.kind else { unreachable!() }; let field_names: IndexVec = fs.iter().map(|f| f.ident.name.to_ident_string()).collect(); + let field_tys: IndexVec = fs + .iter() + .map(|f| source_map.span_to_snippet(f.ty.span).unwrap()) + .collect(); let field_name_to_index = fs .iter() .enumerate() @@ -421,6 +425,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { variant_tags, empty_tags, field_names, + field_tys, field_name_to_index, struct_local_def_id, index_in_struct, @@ -784,6 +789,7 @@ impl {} {{ suggestions: &mut suggestions, locals: HashMap::new(), + match_targets: HashMap::new(), }; visitor.visit_body(hir_body); } @@ -1233,6 +1239,7 @@ struct TaggedUnion { variant_tags: BTreeMap>, empty_tags: Vec, field_names: IndexVec, + field_tys: IndexVec, field_name_to_index: HashMap, struct_local_def_id: LocalDefId, index_in_struct: FieldIdx, @@ -1251,6 +1258,7 @@ struct SuggestingVisitor<'a, 'tcx> { suggestions: &'a mut Suggestions<'tcx>, locals: HashMap>, + match_targets: HashMap, } impl<'tcx> SuggestingVisitor<'_, 'tcx> { @@ -1267,6 +1275,9 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> { ExprKind::Field(expr_struct, _) => { let span = expr.span.with_lo(expr_struct.span.hi()); self.suggestions.add(span, format!(".{}", union_field_name)); + + let s = source_map.span_to_snippet(expr_struct.span).unwrap(); + self.match_targets.insert(expr.span, normalize_expr_str(&s)); } ExprKind::Path(QPath::Resolved(_, path)) => { let Res::Local(hir_id) = path.res else { unreachable!() }; @@ -1275,6 +1286,8 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> { let struct_str = source_map.span_to_snippet(expr_struct.span).unwrap(); self.suggestions .add(expr.span, format!("{}.{}", struct_str, union_field_name)); + self.match_targets + .insert(expr.span, normalize_expr_str(&struct_str)); } _ => unreachable!("{:?}", expr), } @@ -1291,12 +1304,6 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> { let pat = if tu.empty_tags.contains(&tag) { format!("{}::Empty{}", tu.name, tag) } else { - assert!(matches!(arm.body.kind, ExprKind::Block(_, _))); - let pos = arm.body.span.lo() + BytePos(1); - let span = arm.body.span.with_lo(pos).with_hi(pos); - self.suggestions - .add(span, "let mut __v = __v as *mut _;".to_string()); - let f = tu .variant_tags .iter() @@ -1310,6 +1317,14 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> { }, ) .unwrap(); + + assert!(matches!(arm.body.kind, ExprKind::Block(_, _))); + let pos = arm.body.span.lo() + BytePos(1); + let span = arm.body.span.with_lo(pos).with_hi(pos); + let f_ty = &tu.field_tys[f]; + let code = format!("let mut __v = __v as *mut {};", f_ty); + self.suggestions.add(span, code); + let f_name = &tu.field_names[f]; format!("{}::{}{}(ref mut __v)", tu.name, f_name, tag) }; @@ -1423,57 +1438,75 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> { } } } else if let Some(tu) = self.unions.get(&did) { - let (ctx, _) = get_expr_context(expr, self.tcx); - match ctx { - ExprContext::Value => { - let call = format!("get_{}()", field.name); - self.suggestions.add(field.span, call); - } - ExprContext::Store(_) | ExprContext::Address => { - let span = expr.span.shrink_to_lo(); - self.suggestions.add(span, "(*".to_string()); - - let ItemKind::Union(VariantData::Struct(fs, _), _) = - self.tcx.hir().expect_item(did).kind - else { - unreachable!() - }; - let (i, _) = fs - .iter() - .enumerate() - .find(|(_, f)| f.ident.name == field.name) - .unwrap(); - let tags = &tu.variant_tags[&FieldIdx::from(i)]; - - let call = if tags.len() == 1 { - format!("deref_{}_mut())", field.name) - } else { - let ts = &self.structs[&tu.struct_local_def_id]; - let field_name = &ts.field_names[ts.tag_index]; - let ExprKind::Field(e2, _) = e.kind else { unreachable!() }; - let tag = format!( - "{}.{}", - source_map.span_to_snippet(e2.span).unwrap(), - field_name, - ); - format!("deref_{}_mut({}()))", field.name, tag) - }; - self.suggestions.add(field.span, call); - - let root = get_root(expr); - if let ExprKind::Unary(UnOp::Deref, e) = root.kind { - let ty = self.typeck.expr_ty(e); - if let TyKind::RawPtr(TypeAndMut { - mutbl: Mutability::Not, - ty, - }) = ty.kind() - { - let span = e.span.shrink_to_lo(); - self.suggestions.add(span, "(".to_string()); + let matched = if let Some((match_span, _)) = self + .access_in_matches + .iter() + .find(|(_, ams)| ams.iter().any(|am| am.arm_span.contains(expr.span))) + { + let s1 = &self.match_targets[match_span]; + let ExprKind::Field(expr_struct, _) = e.kind else { unreachable!() }; + let struct_str = source_map.span_to_snippet(expr_struct.span).unwrap(); + let s2 = normalize_expr_str(&struct_str); + s1 == &s2 + } else { + false + }; + + if matched { + self.suggestions.add(expr.span, "(*__v)".to_string()); + } else { + let (ctx, _) = get_expr_context(expr, self.tcx); + match ctx { + ExprContext::Value => { + let call = format!("get_{}()", field.name); + self.suggestions.add(field.span, call); + } + ExprContext::Store(_) | ExprContext::Address => { + let span = expr.span.shrink_to_lo(); + self.suggestions.add(span, "(*".to_string()); + + let ItemKind::Union(VariantData::Struct(fs, _), _) = + self.tcx.hir().expect_item(did).kind + else { + unreachable!() + }; + let (i, _) = fs + .iter() + .enumerate() + .find(|(_, f)| f.ident.name == field.name) + .unwrap(); + let tags = &tu.variant_tags[&FieldIdx::from(i)]; + + let call = if tags.len() == 1 { + format!("deref_{}_mut())", field.name) + } else { + let ts = &self.structs[&tu.struct_local_def_id]; + let field_name = &ts.field_names[ts.tag_index]; + let ExprKind::Field(e2, _) = e.kind else { unreachable!() }; + let tag = format!( + "{}.{}", + source_map.span_to_snippet(e2.span).unwrap(), + field_name, + ); + format!("deref_{}_mut({}()))", field.name, tag) + }; + self.suggestions.add(field.span, call); + + let root = get_root(expr); + if let ExprKind::Unary(UnOp::Deref, e) = root.kind { + let ty = self.typeck.expr_ty(e); + if let TyKind::RawPtr(TypeAndMut { + mutbl: Mutability::Not, + ty, + }) = ty.kind() + { + let span = e.span.shrink_to_lo(); + self.suggestions.add(span, "(".to_string()); - let span = e.span.shrink_to_hi(); - let cast = format!(" as *mut crate::{:?})", ty); - self.suggestions.add(span, cast); + let span = e.span.shrink_to_hi(); + let cast = format!(" as *mut crate::{:?})", ty); + self.suggestions.add(span, cast); + } } } } @@ -1995,3 +2028,9 @@ fn unwrap_cast<'a, 'tcx>(e: &'a Expr<'tcx>) -> &'a Expr<'tcx> { let ExprKind::Cast(e, _) = e.kind else { return e }; unwrap_cast(e) } + +fn normalize_expr_str(s: &str) -> String { + s.chars() + .filter(|&c| !c.is_whitespace() && c != '(' && c != ')') + .collect() +}