Skip to content

Commit 41e0211

Browse files
committed
Annotate temporary nodes in when statements
Fix rameloni/tywaves-chisel#28
1 parent 2d57816 commit 41e0211

File tree

4 files changed

+232
-28
lines changed

4 files changed

+232
-28
lines changed

core/src/main/scala/chisel3/tywavesinternal/TywavesAnnotation.scala

Lines changed: 87 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package chisel3.tywavesinternal
22

3-
import chisel3.{Data, Record, Vec, VecLike}
3+
import chisel3.{Data, MemBase, Record, Vec, VecLike}
44
import chisel3.experimental.{BaseModule, ChiselAnnotation}
5-
import chisel3.internal.HasId
5+
import chisel3.internal.{HasId, NamedComponent}
66
import chisel3.internal.firrtl.ir._
7+
import chisel3.properties.{DynamicObject, StaticObject}
78
import firrtl.annotations.{Annotation, IsMember, SingleTargetAnnotation}
89

10+
import scala.collection.mutable
11+
912
// TODO: if the code touches a lot of Chisel internals, it might be better to put it into
1013
// - core
1114
// otherwise:
@@ -45,11 +48,31 @@ private[chisel3] case class TywavesAnnotation[T <: IsMember](
4548
}
4649

4750
object TywavesChiselAnnotation {
51+
52+
private val annoCreated = new mutable.HashSet[IsMember]()
53+
private def createTywavesChiselAnno[T <: IsMember](
54+
target: T,
55+
name: String,
56+
paramsOpt: Option[Seq[ClassParam]]
57+
): Option[ChiselAnnotation] = {
58+
59+
if (annoCreated.contains(target)) {
60+
None
61+
} else {
62+
annoCreated.add(target)
63+
Some(new ChiselAnnotation {
64+
override def toFirrtl: Annotation = TywavesAnnotation(target, name, paramsOpt)
65+
})
66+
}
67+
}
68+
4869
def generate(circuit: Circuit): Seq[ChiselAnnotation] = {
4970
// TODO: iterate over a circuit and generate TywavesAnnotation
5071
val typeAliases: Seq[String] = circuit.typeAliases.map(_.name)
5172

52-
circuit.components.flatMap(c => generate(c, typeAliases))
73+
val result = circuit.components.flatMap(c => generate(c, typeAliases))
74+
annoCreated.clear()
75+
result
5376
// circuit.layers
5477
// circuit.options
5578

@@ -59,18 +82,18 @@ object TywavesChiselAnnotation {
5982
def generate(component: Component, typeAliases: Seq[String]): Seq[ChiselAnnotation] = component match {
6083
case ctx @ DefModule(id, name, public, layers, ports, cmds) =>
6184
// TODO: Add tywaves annotation: components, ports, commands, layers
62-
Seq(createAnno(id)) ++ (ports ++ ctx.secretPorts).flatMap(p =>
85+
createAnno(id) ++ (ports ++ ctx.secretPorts).flatMap(p =>
6386
generate(p, typeAliases)
6487
) ++ (cmds ++ ctx.secretCommands).flatMap(c => generate(c, typeAliases))
6588
case ctx @ DefBlackBox(id, name, ports, topDir, params) =>
6689
// TODO: Add tywaves annotation, ports, ?params?
67-
Seq(createAnno(id)) ++ (ports ++ ctx.secretPorts).flatMap(p => generate(p, typeAliases))
90+
createAnno(id) ++ (ports ++ ctx.secretPorts).flatMap(p => generate(p, typeAliases))
6891
case ctx @ DefIntrinsicModule(id, name, ports, topDir, params) =>
6992
// TODO: Add tywaves annotation: ports, ?params?
70-
Seq(createAnno(id)) ++ (ports ++ ctx.secretPorts).flatMap(p => generate(p, typeAliases))
93+
createAnno(id) ++ (ports ++ ctx.secretPorts).flatMap(p => generate(p, typeAliases))
7194
case ctx @ DefClass(id, name, ports, cmds) =>
7295
// TODO: Add tywaves annotation: ports, commands
73-
Seq(createAnno(id)) ++ (ports ++ ctx.secretPorts).flatMap(p => generate(p, typeAliases)) ++ cmds.flatMap(c =>
96+
createAnno(id) ++ (ports ++ ctx.secretPorts).flatMap(p => generate(p, typeAliases)) ++ cmds.flatMap(c =>
7497
generate(c, typeAliases)
7598
)
7699
case ctx => throw new Exception(s"Failed to generate TywavesAnnotation. Unknown component type: $ctx")
@@ -84,9 +107,10 @@ object TywavesChiselAnnotation {
84107
val name = s"$binding[${dataToTypeName(innerType)}[$size]]"
85108
// TODO: what if innerType is a Vec or a Bundle?
86109

87-
Seq(new ChiselAnnotation {
88-
override def toFirrtl: Annotation = TywavesAnnotation(target.toTarget, name, None)
89-
}) //++ createAnno(chisel3.Wire(innerType))
110+
createTywavesChiselAnno(target.toTarget, name, None).toSeq
111+
// Seq(new ChiselAnnotation {
112+
// override def toFirrtl: Annotation = TywavesAnnotation(target.toTarget, name, None)
113+
// }) //++ createAnno(chisel3.Wire(innerType))
90114
}
91115
command match {
92116
case e: DefPrim[_] => Seq.empty // TODO: check prim
@@ -98,7 +122,7 @@ object TywavesChiselAnnotation {
98122
case e @ FirrtlMemory(info, id, t, size, readPortNames, writePortNames, readwritePortNames) =>
99123
createAnnoMem(id, id.getClass.getSimpleName, size, t)
100124
case e @ DefMemPort(info, id, source, dir, idx, clock) => createAnno(id)
101-
case Connect(info, loc, exp) => Seq.empty // TODO: check connect
125+
case Connect(info, loc, exp) => createAnno(exp)
102126
case PropAssign(info, loc, exp) => ???
103127
case Attach(info, locs) => ???
104128
case DefInvalid(info, arg) => Seq.empty // TODO: check invalid
@@ -113,6 +137,11 @@ object TywavesChiselAnnotation {
113137
case e @ ProbeForce(sourceInfo, clock, cond, probe, value) => ???
114138
case e @ ProbeRelease(sourceInfo, clock, cond, probe) => ???
115139
case e @ Verification(_, op, info, clk, pred, pable) => ???
140+
case e @ When(info, arg, ifRegion, elseRegion) =>
141+
println(s"$ifRegion")
142+
println(s"$elseRegion")
143+
ifRegion.flatMap(generate(_, typeAliases)) ++ elseRegion
144+
.flatMap(generate(_, typeAliases))
116145
case e =>
117146
println(s"Unknown command: $e") // TODO: replace with logger
118147
Seq.empty
@@ -316,18 +345,59 @@ object TywavesChiselAnnotation {
316345
case _ => getConstructorParamsOpt(target)
317346
}
318347

319-
annotations :+ new ChiselAnnotation {
320-
override def toFirrtl: Annotation = TywavesAnnotation(target.toTarget, name, paramsOpt)
321-
}
348+
annotations ++
349+
createTywavesChiselAnno(target.toTarget, name, paramsOpt).toSeq
350+
// new ChiselAnnotation {
351+
// override def toFirrtl: Annotation = TywavesAnnotation(target.toTarget, name, paramsOpt)
352+
// }
322353
}
323354

324-
private def createAnno(target: BaseModule): ChiselAnnotation = {
355+
private def createAnno(target: BaseModule): Seq[ChiselAnnotation] = {
325356
val name = target.desiredName
326357
val paramsOpt = getConstructorParamsOpt(target)
327358
// val name = target.getClass.getTypeName
328-
new ChiselAnnotation {
329-
override def toFirrtl: Annotation = TywavesAnnotation(target.toTarget, name, paramsOpt)
359+
createTywavesChiselAnno(target.toTarget, name, paramsOpt).toSeq
360+
361+
// new ChiselAnnotation {
362+
// override def toFirrtl: Annotation = TywavesAnnotation(target.toTarget, name, paramsOpt)
363+
// }
364+
}
365+
366+
// TODO: replace ??? with a nice logger to avoid unexpected crashes
367+
private def createAnno(target: HasId): Seq[ChiselAnnotation] = {
368+
target match {
369+
case t: Data => createAnno(t)
370+
case t: BaseModule => createAnno(t)
371+
case t: MemBase[_] => ???
372+
case t: NamedComponent => ???
373+
case t: VecLike[_] => ???
374+
case t if t.isInstanceOf[DynamicObject] => ???
375+
case t if t.isInstanceOf[StaticObject] => ???
376+
}
377+
}
378+
// TODO: check all the cases for Arg
379+
private def createAnno(target: Arg): Seq[ChiselAnnotation] = {
380+
target match {
381+
case t @ Node(id) => createAnno(id)
382+
case t @ ModuleIO(mod, name) => ???
383+
case t @ ILit(n) => ???
384+
case t @ Ref(name) => ???
385+
case t @ PropertyLit(propertyType, lit) => ???
386+
case t @ PropExpr(sourceInfo, tpe, op, args) => ???
387+
case t @ Slot(imm, name) => ???
388+
case t @ ProbeExpr(probe) => ???
389+
case t @ ProbeRead(probe) => ???
390+
case t @ RWProbeExpr(probe) => ???
391+
case t @ Index(imm, value) => ???
392+
case t @ ModuleCloneIO(mod, name) => ???
393+
case t @ OpaqueSlot(imm) => ???
394+
case t: Component => ???
395+
case t: LitArg => Seq.empty // Ignore
396+
case t =>
397+
println(s"Unknown Arg type: $t")
398+
Seq.empty
330399
}
400+
331401
}
332402

333403
}

src/test/scala/circtTests/tywavesTests/TywavesAnnotationCircuits.scala

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,60 @@ import org.scalatest.matchers.should.Matchers
1313
/** Utility functions for testing [[chisel3.tywavesinternal.TywavesAnnotation]] */
1414
object TestUtils extends Matchers {
1515

16+
def getSubOccurency(mainString: String, subString: String): List[String] = { // Split the text T into lines
17+
val lines = mainString.split("\n").toList
18+
19+
// Compile a regex pattern to find the string X
20+
val pattern = (".*" + subString + ".*").r
21+
22+
// Find the index of the line that matches the pattern
23+
val matchIndex = lines.indexWhere(line => pattern.matches(line))
24+
25+
// If a match is found, return the line and the immediate next line
26+
if (matchIndex >= 0) {
27+
lines.slice(matchIndex, matchIndex + 2)
28+
} else {
29+
List() // Return an empty list if no match is found
30+
}
31+
}
32+
33+
def getMissingSubOccurency(mainString: String, subString: String, expectedMatches: Seq[String]): String = {
34+
// Split the text into lines
35+
val lines = mainString.split("\n").toSeq
36+
37+
// Compile a regex pattern to find the string X
38+
val pattern = (".*" + subString + ".*").r
39+
40+
// Filter lines that match the pattern and get also the next line
41+
val matchingIndexes = lines.indices.filter(i => pattern.findFirstIn(lines(i)).isDefined)
42+
43+
// Get the lines at these indexes and the next line
44+
val linesWithNext = matchingIndexes.flatMap { idx =>
45+
if (idx < lines.length - 1) List(lines(idx), lines(idx + 1))
46+
else List(lines(idx)) // Handle edge case where last line matches
47+
}.distinct // Remove duplicates in case same line is matched more than once
48+
49+
val expectedRegex = expectedMatches.map(_.r)
50+
// Filter out lines that are in expectedMatches
51+
val missingLines = linesWithNext.filterNot { line =>
52+
expectedRegex.exists(_.findFirstIn(line).isDefined)
53+
}
54+
missingLines.mkString("\n")
55+
56+
}
57+
1658
def countSubstringOccurrences(mainString: String, subString: String): Int = {
1759
val pattern = subString.r
1860
pattern.findAllMatchIn(mainString).length
1961
}
2062

63+
// Return target and expected regex string
2164
def createExpected(
2265
target: String,
2366
typeName: String,
2467
binding: String = "",
2568
params: Option[Seq[ClassParam]] = None
26-
): String = {
69+
): (String, String) = {
2770
val realTypeName = binding match {
2871
case "" => typeName
2972
case _ => s"$binding\\[$typeName\\]"
@@ -43,16 +86,25 @@ object TestUtils extends Matchers {
4386
}.mkString(",\\s+")}\\s+\\]"""
4487
case None => ""
4588
}
46-
s"""\"target\":\"$target\",\\s+\"typeName\":\"$realTypeName\"$realParams\\s*}""".stripMargin
89+
(target, s"""\"target\":\"$target\",\\s+\"typeName\":\"$realTypeName\"$realParams\\s*""".stripMargin)
4790
}
4891

49-
def checkAnno(expectedMatches: Seq[(String, Int)], refString: String, includeConstructor: Boolean = false): Unit = {
92+
def checkAnno(
93+
expectedMatches: Seq[((String, String), Int)],
94+
refString: String,
95+
includeConstructor: Boolean = false
96+
): Unit = {
5097
def totalAnnoCheck(n: Int): (String, Int) =
5198
(""""class":"chisel3.tywavesinternal.TywavesAnnotation"""", if (includeConstructor) n else n + 1)
52-
53-
(expectedMatches :+ totalAnnoCheck(expectedMatches.map(_._2).sum)).foreach {
99+
val targetStrings = expectedMatches.map(p => {
100+
val s = p._1._1
101+
"\"target\":\"" + s + "\""
102+
})
103+
(expectedMatches.map(p => (p._1._2, p._2)) :+ totalAnnoCheck(expectedMatches.map(_._2).sum)).foreach {
54104
case (pattern, count) =>
55-
(countSubstringOccurrences(refString, pattern) should be(count)).withClue(s"Pattern: $pattern")
105+
(countSubstringOccurrences(refString, pattern) should be(count)).withClue(
106+
s"Pattern: $pattern: ${getMissingSubOccurency(refString, pattern, targetStrings)}"
107+
)
56108
}
57109
}
58110
}
@@ -237,6 +289,29 @@ object TywavesAnnotationCircuits {
237289
class TopCircuitTypeInSubmodule(bindingChoice: BindingChoice) extends RawModule {
238290
val mod = Module(new TopCircuitGroundTypes(bindingChoice))
239291
}
292+
293+
// Test temporary values declared inside when and otherwise blocks
294+
class TopCircuitWhenElse extends RawModule {
295+
// Internally implement a MUX
296+
val inSeq = IO(Input(Vec(8, UInt(8.W))))
297+
val out = IO(Output(UInt(8.W)))
298+
val sel = IO(Input(UInt(math.sqrt(8).ceil.toInt.W)))
299+
300+
when(sel % 2.U === 0.U) {
301+
val outTmp = inSeq(sel)
302+
val evenSel = outTmp + 1.U
303+
out := evenSel
304+
}.elsewhen(sel === 1.U) {
305+
val outTmp = inSeq(sel)
306+
val selIsOne = outTmp + 1.U
307+
out := selIsOne
308+
}.otherwise {
309+
val outTmp = inSeq(sel)
310+
val oddSel = outTmp + 1.U
311+
out := oddSel
312+
}
313+
314+
}
240315
}
241316

242317
object MemCircuits {

src/test/scala/circtTests/tywavesTests/dataTypesTests/TypeAnnotationDataTypesSpec.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,13 @@ class TypeAnnotationDataTypesSpec extends AnyFunSpec with Matchers with chiselTe
126126
(createExpected("~TopCircuitTypeInSubmodule\\|TopCircuitGroundTypes>sint", "SInt<8>", b.toString), 1),
127127
(createExpected("~TopCircuitTypeInSubmodule\\|TopCircuitGroundTypes>bool", "Bool", b.toString), 1),
128128
(createExpected("~TopCircuitTypeInSubmodule\\|TopCircuitGroundTypes>bits", "UInt<8>", b.toString), 1),
129-
(""""target":"~TopCircuitTypeInSubmodule\|TopCircuitGroundTypes",\s+"typeName":"TopCircuitGroundTypes"""", 1)
129+
(
130+
(
131+
"""~TopCircuitTypeInSubmodule\|TopCircuitGroundTypes""",
132+
""""target":"~TopCircuitTypeInSubmodule\|TopCircuitGroundTypes",\s+"typeName":"TopCircuitGroundTypes""""
133+
),
134+
1
135+
)
130136
) ++ addClockReset("TopCircuitTypeInSubmodule", Some("TopCircuitGroundTypes")) ++ analog
131137
checkAnno(expectedMatches, string)
132138
}
@@ -185,4 +191,26 @@ class TypeAnnotationDataTypesSpec extends AnyFunSpec with Matchers with chiselTe
185191
typeTests(args, targetDir, RegBinding)
186192
}
187193

194+
describe("Tmp Values Annotations") {
195+
val targetDir = os.pwd / "test_run_dir" / "TywavesAnnotationSpec" / "Tmp Values Annotations"
196+
val args: Array[String] = Array("--target", "chirrtl", "--target-dir", targetDir.toString)
197+
// format: off
198+
it("should annotate tmp value in when") {
199+
(new ChiselStage(true)).execute(args, Seq(ChiselGeneratorAnnotation(() => new TopCircuitWhenElse)))
200+
val string = os.read(targetDir / "TopCircuitWhenElse.fir")
201+
val expectedMatches = Seq(
202+
(createExpected("~TopCircuitWhenElse\\|TopCircuitWhenElse>inSeq", "UInt<8>\\[8\\]", "IO",
203+
params = Some(Seq(ClassParam("gen", "=> T", None), ClassParam("length", "Int", Some("8"))))), 1),
204+
(createExpected("~TopCircuitWhenElse\\|TopCircuitWhenElse>inSeq\\[0\\]", "UInt<8>", "IO"), 1),
205+
(createExpected("~TopCircuitWhenElse\\|TopCircuitWhenElse>out", "UInt<8>", "IO"), 1),
206+
(createExpected("~TopCircuitWhenElse\\|TopCircuitWhenElse>sel", "UInt<3>", "IO"), 1),
207+
// Tmp
208+
(createExpected("~TopCircuitWhenElse\\|TopCircuitWhenElse>evenSel", "UInt<8>", "OpResult"), 1),
209+
(createExpected("~TopCircuitWhenElse\\|TopCircuitWhenElse>oddSel", "UInt<8>", "OpResult"), 1),
210+
(createExpected("~TopCircuitWhenElse\\|TopCircuitWhenElse>selIsOne", "UInt<8>", "OpResult"), 1)
211+
)
212+
checkAnno(expectedMatches, string)
213+
// format: on
214+
}
215+
}
188216
}

0 commit comments

Comments
 (0)