From ef0f9322dec467a6a0eb9b5d78a4624bacab61a9 Mon Sep 17 00:00:00 2001 From: Andrew Ferraiuolo Date: Sat, 24 Mar 2018 16:59:56 -0400 Subject: [PATCH] Satanic performance bug finally vanquished --- src/main/scala/firrtl/LoweringCompilers.scala | 4 +- .../passes/labelchecking/ConstraintGen.scala | 17 +++++-- .../passes/labelchecking/DepExprPasses.scala | 46 ++++++++++++++++--- .../passes/labelchecking/ForwardProp.scala | 5 +- .../passes/labelchecking/SimplifyLabels.scala | 10 ++-- 5 files changed, 62 insertions(+), 20 deletions(-) diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 7750bf3..f9475c7 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -113,11 +113,11 @@ class LabelChecking extends CoreTransform { def outputForm = HighForm def passSeq = Seq( passes.PropNodes, - passes.LabelExprs, - // passes.LabelMPorts, + passes.LabelMPorts, passes.DepsToWorkingIR, passes.DepsResolveKinds, passes.DepsInferTypes, + passes.LabelExprs, passes.DeterminePC, passes.NextCycleTransform, passes.SeqPortGenNext, diff --git a/src/main/scala/firrtl/passes/labelchecking/ConstraintGen.scala b/src/main/scala/firrtl/passes/labelchecking/ConstraintGen.scala index b1aa47a..cad4d4a 100644 --- a/src/main/scala/firrtl/passes/labelchecking/ConstraintGen.scala +++ b/src/main/scala/firrtl/passes/labelchecking/ConstraintGen.scala @@ -40,6 +40,7 @@ abstract class ConstraintGenerator { // Type declaration string def emitTypeDecl(typeDecs: TypeDeclSet)(t: AggregateType): String + val bot = ProdLabel(PolicyHolder.bottom, PolicyHolder.top) val top = ProdLabel(PolicyHolder.top, PolicyHolder.bottom) @@ -263,7 +264,11 @@ object BVConstraintGen extends ConstraintGenerator { case tx : VectorType => throw new Exception } - def refToIdent(e: Expression) = e match { + // h4x to get around bad compile times for transforms on dependent labels + // appearing in bundle types + def toWIR(e: Expression) = ToWorkingIR.toExp(e) + + def refToIdent(e: Expression) = toWIR(e) match { case ex: WSubIndex => val idx = CBVLit(ex.value, toBInt(vec_size(ex.exp.tpe))) CASelect(refToIdent(ex.exp), idx).serialize @@ -279,7 +284,7 @@ object BVConstraintGen extends ConstraintGenerator { case Endorse(exx, _) => refToIdent(exx) } - def exprToCons(e: Expression): Constraint = e match { + def exprToCons(e: Expression): Constraint = toWIR(e) match { case ex : Literal => CBVLit(ex.value, toBInt(ex.width)) case ex : DoPrim => primOpToBVOp(ex) case ex : WSubIndex => @@ -297,7 +302,7 @@ object BVConstraintGen extends ConstraintGenerator { case Endorse(exx,_) => exprToCons(exx) } - def exprToCons(e: Expression, w: BigInt) = e match { + def exprToCons(e: Expression, w: BigInt) = toWIR(e) match { case ex : Literal => CBVLit(ex.value, w) case _ => var diff = BigInt(0) @@ -314,8 +319,10 @@ object BVConstraintGen extends ConstraintGenerator { else c } - def exprToConsBool(e: Expression) = - CBVWrappedBV(exprToCons(e), bitWidth(e.tpe)) + def exprToConsBool(e: Expression) = { + val ePrime = toWIR(e) + CBVWrappedBV(exprToCons(ePrime), bitWidth(ePrime.tpe)) + } override def serialize(l: LabelComp) = l match { case FunLabel(fname, args) => //s"($fname ${refToIdent(arg)})" diff --git a/src/main/scala/firrtl/passes/labelchecking/DepExprPasses.scala b/src/main/scala/firrtl/passes/labelchecking/DepExprPasses.scala index f31adad..c520d7f 100644 --- a/src/main/scala/firrtl/passes/labelchecking/DepExprPasses.scala +++ b/src/main/scala/firrtl/passes/labelchecking/DepExprPasses.scala @@ -7,11 +7,39 @@ import firrtl.Utils._ // This should convert all expressions within labels to the WIR, but leave // actual expressions (that reside directly in statements) as they are. -object DepsToWorkingIR extends ToWorkingIRT { +// +// AF: These passes no longer do any checking within bundle labels because +// doing so is computationally expensive, and as it turns out, unnecessary; +// these passes are only relevant for: 1) getting kinds for dependands for +// the sequence of passes that perform next-cycle transforms, 2) getting types +// for VecTypes which is necessary to properly apply indexing expressions +// during the apply vec labels pass +object DepsToWorkingIR extends Pass with PassDebug { override def name = "Dep Expressions to Working IR" - override def toLbl(l: Label): Label = l map toLblComp map toLbl - def toLblComp(lc: LabelComp): LabelComp = - lc map toExp map toLblComp + def toLbl(l: Label): Label = l match { + case lx: BundleLabel => lx + case lx => lx map toLbl map toLblComp + } + def toLblComp(lc: LabelComp): LabelComp = lc map toExp map toLblComp + + def toExp(e: Expression): Expression = e map toExp match { + case ex: Reference => WRef(ex.name, ex.tpe, ex.lbl, NodeKind, UNKNOWNGENDER) + case ex: SubField => WSubField(ex.expr, ex.name, ex.tpe, ex.lbl, UNKNOWNGENDER) + case ex: SubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, ex.lbl, UNKNOWNGENDER) + case ex: SubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, ex.lbl, UNKNOWNGENDER) + case ex => ex // This might look like a case to use case _ => e, DO NOT! + } + + def toExpL(e: Expression): Expression = + toExp(e) map toExpL map toLbl + + def toStmt(s: Statement): Statement = s map toExpL map toLbl match { + case sx: DefInstance => WDefInstance(sx.info, sx.name, sx.module, UnknownType, UnknownLabel) + case sx => sx map toStmt + } + + def run (c:Circuit): Circuit = + c copy (modules = c.modules map (_ map toStmt)) } object DepsResolveKinds extends ResolveKindsT { @@ -29,7 +57,10 @@ object DepsResolveKinds extends ResolveKindsT { } override def resolve_lbl(kinds: KindMap)(l: Label): Label = - l map resolve_lbl(kinds) map resolve_lbl_cmp(kinds) + l match { + case lx: BundleLabel => lx + case lx => lx map resolve_lbl(kinds) map resolve_lbl_cmp(kinds) + } def resolve_lbl_cmp(kinds: KindMap)(l: LabelComp): LabelComp = { l map resolve_expr(kinds) map resolve_lbl_cmp(kinds) } @@ -38,7 +69,10 @@ object DepsResolveKinds extends ResolveKindsT { object DepsInferTypes extends InferTypesT { override def name = "Infer Types for Dependands" override def infer_types_l(types: TypeMap)(l: Label): Label = - l map infer_types_lc(types) map infer_types_l(types) + l match { + case lx: BundleLabel => lx + case lx => lx map infer_types_l(types) map infer_types_lc(types) + } def infer_types_lc(types: TypeMap)(l: LabelComp): LabelComp = l map infer_types_e(types) map infer_types_lc(types) } diff --git a/src/main/scala/firrtl/passes/labelchecking/ForwardProp.scala b/src/main/scala/firrtl/passes/labelchecking/ForwardProp.scala index da07fc1..5d78896 100644 --- a/src/main/scala/firrtl/passes/labelchecking/ForwardProp.scala +++ b/src/main/scala/firrtl/passes/labelchecking/ForwardProp.scala @@ -76,9 +76,10 @@ object ForwardProp extends Pass with PassDebug { e map prop_env_e(env) map prop_env_l(env) def prop_env_l(env: LabelVarMap)(l: Label): Label = - l map prop_env_l(env) match { + l match { case lx: VarLabel if env.contains(lx) => env(lx) - case lx => lx + case lx: BundleLabel => lx + case lx => lx map prop_env_l(env) } def forward_prop(m: DefModule): DefModule = { diff --git a/src/main/scala/firrtl/passes/labelchecking/SimplifyLabels.scala b/src/main/scala/firrtl/passes/labelchecking/SimplifyLabels.scala index 9529f86..c52b6bb 100644 --- a/src/main/scala/firrtl/passes/labelchecking/SimplifyLabels.scala +++ b/src/main/scala/firrtl/passes/labelchecking/SimplifyLabels.scala @@ -48,7 +48,7 @@ object SimplifyLabels extends Pass with PassDebug { // cnf_e gets called here even though integrity components are in // dnf because components have expressions which have labels and // labels are in cnf. - l map dnf_lb_cmp map cnf_e match { + l map dnf_lb_cmp match { case lbx: JoinLabelComp => sortClauses( (clauses(lbx.l) cross clauses(lbx.r)) map { case (lhs:LabelComp, rhs:LabelComp) => cnf_lb_cmp(lhs join rhs) @@ -100,7 +100,7 @@ object SimplifyLabels extends Pass with PassDebug { terms_(l); termSet } - l map cnf_lb_cmp map cnf_e match { + l map cnf_lb_cmp match { case lbx: MeetLabelComp => sortClauses( (clauses(lbx.l) cross clauses(lbx.r)) map { case (lhs:LabelComp, rhs:LabelComp) => dnf_lb_cmp(lhs meet rhs) @@ -181,11 +181,11 @@ object SimplifyLabels extends Pass with PassDebug { terms_(l); termSet } - l map cnf_lb match { + l match { case lbx: MeetLabel=> sortClauses( (clauses(lbx.l) cross clauses(lbx.r)) map { case (lhs:Label, rhs:Label) => lhs meet rhs - }).foldLeft(bot) { _ join _ } + }).foldLeft(bot) { _ join _ } map cnf_lb case lbx: JoinLabel => val simplified_clauses = new LinkedHashSet[Label] val lx_clauses = clauses(lbx) @@ -196,7 +196,7 @@ object SimplifyLabels extends Pass with PassDebug { } if(addme) simplified_clauses += cls_i } - sortClauses(simplified_clauses).foldLeft(bot) { _ join _ } + sortClauses(simplified_clauses).foldLeft(bot) { _ join _ } map cnf_lb case ProdLabel(conf, integ) => // Confidentiality components are in the same normal form as labels, // but integrity components are in the dual of that normal form