Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OperationDef as macros. #223

Merged
merged 9 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 70 additions & 109 deletions clairV2/src/main/scala-3/CV2Mirror.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import scala.deriving._

import scala.Tuple.Zip
import scala.collection.View.Empty
import scala.quoted._

// ░█████╗░ ██╗░░░░░ ░█████╗░ ██╗ ██████╗░ ██╗░░░██╗ ██████╗░
// ██╔══██╗ ██║░░░░░ ██╔══██╗ ██║ ██╔══██╗ ██║░░░██║ ╚════██╗
Expand All @@ -28,77 +29,62 @@ import scala.collection.View.Empty
|| MIRROR LOGIC ||
\*≡==----=≡≡≡≡=----==≡*/

// currently not supported, but will soon be :)
inline def inputVariadicity[Elem] = inline erasedValue[Elem] match
case _: Variadic[t] => Variadicity.Variadic
case _ => Variadicity.Single

// for some reason match types do not work here, as an inline erasedValue[unwrappedInput[Elem]]
// tries to match on that type exactly (ie. unwrappedType[Value[IntegerType]] for example) rather than the matched type...
// very weird things going on

/** Produces an OpInput to OperationDef given a definition of a Type.
*
* @return
* Input to OperationDef, either: OperandDef, ResultDef, RegionDef,
* SuccessorDef, OpPropertyDef, OpAttributeDef
*/
inline def getDefInput[Label, Elem]: OpInput = {

val name = inline erasedValue[Label] match
case _: String => constValue[Label].asInstanceOf[String]
case _ =>
throw new Exception("Internal error!")
def getDefInput[Label: Type, Elem: Type](using Quotes): OpInputDef = {
val name = Type.of[Label] match
case '[String] =>
Type.valueOfConstant[Label].get.asInstanceOf[String]

inline erasedValue[Elem] match
case _: Variadic[Operand[t]] =>
OperandDef(
id = name,
Type.of[Elem] match
case '[Variadic[Result[t]]] =>
ResultDef(
name = name,
typeString = typeToString[t],
Variadicity.Variadic
)
case _: Variadic[Result[t]] =>
case '[Variadic[Operand[t]]] =>
OperandDef(
id = name,
name = name,
typeString = typeToString[t],
Variadicity.Variadic
)
case _: Result[t] =>
case '[Result[t]] =>
ResultDef(
id = name,
name = name,
typeString = typeToString[t],
Variadicity.Single
)
case _: Operand[t] =>
case '[Operand[t]] =>
OperandDef(
id = name,
name = name,
typeString = typeToString[t],
Variadicity.Single
)
case _: Region =>
case '[Region] =>
RegionDef(
id = name,
name = name,
Variadicity.Single
)
case _: Successor =>
case '[Successor] =>
SuccessorDef(
id = name,
name = name,
Variadicity.Single
)
case _: Property[t] =>
case '[Property[t]] =>
OpPropertyDef(
id = name,
name = name,
typeString = typeToString[t]
)
case _: Attr[t] =>
case '[Attr[t]] =>
OpAttributeDef(
id = name,
name = name,
typeString = typeToString[t]
)
case _ =>
throw new Exception(
s"Unsupported shennaigans here with field $name of type ${typeToString[Elem]}"
)
}

/** Loops through a Tuple of Input definitions and produces a List of inputs to
Expand All @@ -107,25 +93,28 @@ inline def getDefInput[Label, Elem]: OpInput = {
* @return
* Lambda that produces an input to OperationDef, given a string
*/
inline def summonInput[Labels <: Tuple, Elems <: Tuple]: List[OpInput] = {
def summonInput[Labels: Type, Elems: Type](using Quotes): List[OpInputDef] = {

inline erasedValue[(Labels, Elems)] match
case _: ((label *: labels, elem *: elems)) =>
Type.of[(Labels, Elems)] match
case '[(label *: labels, elem *: elems)] =>
getDefInput[label, elem] :: summonInput[labels, elems]
case _: (EmptyTuple, EmptyTuple) => Nil
case '[(EmptyTuple, EmptyTuple)] => Nil
}

/** Translates a Tuple of string types into a list of strings.
*
* @return
* Tuple of String types
*/
inline def stringifyLabels[Elems <: Tuple]: List[String] = {

inline erasedValue[Elems] match
case _: (elem *: elems) =>
constValue[elem].asInstanceOf[String] :: stringifyLabels[elems]
case _: EmptyTuple => Nil
def stringifyLabels[Elems: Type](using Quotes): List[String] = {

Type.of[Elems] match
case '[elem *: elems] =>
Type
.valueOfConstant[elem]
.get
.asInstanceOf[String] :: stringifyLabels[elems]
case '[EmptyTuple] => Nil
}

inline def getMLIRName[T] = inline erasedValue[T] match
Expand All @@ -135,67 +124,39 @@ inline def getMLIRName[T] = inline erasedValue[T] match
"Expected this type to extend MLIRName with a constant type-parameter."
)

/** Generates a OperationDef given param m.
*
* @param m
* \- Mirror Product of an dialect enum case.
* @return
* Lambda that produces an Operadtion Def given a dialect name.
*/
inline def getDef[T](using
m: Mirror.ProductOf[T]
): OperationDef = {

val defname = constValue[m.MirroredLabel]
val paramLabels = stringifyLabels[m.MirroredElemLabels]

val inputs = summonInput[m.MirroredElemLabels, m.MirroredElemTypes]

val operands: ListType[OperandDef] = ListType()
val results: ListType[ResultDef] = ListType()
val regions: ListType[RegionDef] = ListType()
val successors: ListType[SuccessorDef] = ListType()
val opProperty: ListType[OpPropertyDef] = ListType()
val opAttribute: ListType[OpAttributeDef] = ListType()
var assembly_format: Option[String] = None

for (input <- inputs) yield input match {
case a: OperandDef => operands += a
case b: ResultDef => results += b
case c: RegionDef => regions += c
case d: SuccessorDef => successors += d
case e: OpPropertyDef => opProperty += e
case f: OpAttributeDef => opAttribute += f
case _ => throw new Exception("Internal error!")
}

val name = getMLIRName[T]

OperationDef(
name,
defname,
operands.toSeq,
results.toSeq,
regions.toSeq,
successors.toSeq,
opProperty.toSeq,
opAttribute.toSeq,
assembly_format
)

}

/** Generates a list of OperationDef given enum cases.
*
* @param dialect_name
*/
inline def summonOperationDefs[Prods <: Tuple]: Seq[OperationDef] = {
def getDefImpl[T: Type](using quotes: Quotes): OperationDef =

val m = Expr.summon[Mirror.ProductOf[T]].get
m match
case '{
$m: Mirror.ProductOf[T] {
type MirroredLabel = label; type MirroredElemLabels = elemLabels;
type MirroredElemTypes = elemTypes
}
} =>
val defname = Type.valueOfConstant[label].get.asInstanceOf[String]

val paramLabels = stringifyLabels[elemLabels]
val name = Type.of[T] match
case '[MLIRName[name]] =>
Type.valueOfConstant[name].get.asInstanceOf[String]

val inputs = Type.of[(elemLabels, elemTypes)] match
case _: Type[(Tuple, Tuple)] => summonInput[elemLabels, elemTypes]
val e = OperationDef(
name = name,
className = defname,
operands = inputs.collect { case a: OperandDef => a },
results = inputs.collect { case a: ResultDef => a },
regions = inputs.collect { case a: RegionDef => a },
successors = inputs.collect { case a: SuccessorDef => a },
properties = inputs.collect { case a: OpPropertyDef => a },
attributes = inputs.collect { case a: OpAttributeDef => a },
assembly_format = None
)
e

inline erasedValue[Prods] match
case _: (prod *: prods) =>
getDef[prod](using
summonInline[Mirror.ProductOf[prod]]
) +: summonOperationDefs[prods]
inline def getNameDefBlaBla[T] = ${ getNameDefBlaBlaImpl[T] }

case _: EmptyTuple => Seq.empty
}
def getNameDefBlaBlaImpl[T: Type](using quotes: Quotes): Expr[String] =
Expr(getDefImpl[T].name)
51 changes: 34 additions & 17 deletions clairV2/src/main/scala-3/CodeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package scair.clairV2.codegen
import java.io.File
import java.io.PrintStream
import scala.reflect.*
import scair.ir._

// ░█████╗░ ██╗░░░░░ ░█████╗░ ██╗ ██████╗░ ██╗░░░██╗ ██████╗░
// ██╔══██╗ ██║░░░░░ ██╔══██╗ ██║ ██╔══██╗ ██║░░░██║ ╚════██╗
Expand All @@ -22,7 +23,10 @@ import scala.reflect.*
|| CONTAINERS ||
\*≡=--==≡≡==--=≡*/

abstract class OpInput {}
sealed trait OpInputDef(val name: String) {}

sealed trait MayVariadicOpInputDef(val variadicity: Variadicity)
extends OpInputDef

// TODO: Add support for optionals AFTER variadic support is laid out
// It really just adds cognitive noise otherwise IMO. The broader structure and logic is exactly the same.
Expand All @@ -31,37 +35,50 @@ enum Variadicity {
case Single, Variadic
}

type DefinedInput[T <: OpInputDef] = T match {
case OperandDef => Operand[Attribute]
case ResultDef => Result[Attribute]
case RegionDef => Region
case SuccessorDef => Successor
case OpPropertyDef => Property[Attribute]
case OpAttributeDef => Attribute
}

case class OperandDef(
val id: String,
override val name: String,
val typeString: String,
val variadicity: Variadicity = Variadicity.Single
) extends OpInput {}
override val variadicity: Variadicity = Variadicity.Single
) extends OpInputDef(name)
with MayVariadicOpInputDef(variadicity) {}

case class ResultDef(
val id: String,
override val name: String,
val typeString: String,
val variadicity: Variadicity = Variadicity.Single
) extends OpInput {}
override val variadicity: Variadicity = Variadicity.Single
) extends OpInputDef(name)
with MayVariadicOpInputDef(variadicity) {}

case class RegionDef(
val id: String,
val variadicity: Variadicity = Variadicity.Single
) extends OpInput {}
override val name: String,
override val variadicity: Variadicity = Variadicity.Single
) extends OpInputDef(name)
with MayVariadicOpInputDef(variadicity) {}

case class SuccessorDef(
val id: String,
val variadicity: Variadicity = Variadicity.Single
) extends OpInput {}
override val name: String,
override val variadicity: Variadicity = Variadicity.Single
) extends OpInputDef(name)
with MayVariadicOpInputDef(variadicity) {}

case class OpPropertyDef(
val id: String,
override val name: String,
val typeString: String
) extends OpInput {}
) extends OpInputDef(name) {}

case class OpAttributeDef(
val id: String,
override val name: String,
val typeString: String
) extends OpInput {}
) extends OpInputDef(name) {}

/*≡≡=---=≡≡≡≡≡=---=≡≡*\
|| OPERATION DEF ||
Expand Down
Loading