Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class BlockTransformer(subst: SymbolSubst):
applyDefn(defn): defn2 =>
val rst2 = applySubBlock(rst)
if (defn2 is defn) && (rst2 is rst) then b else Define(defn2, rst2)
case HandleBlock(l, res, par, args, cls, hdr, bod, rst) =>
case h @ HandleBlock(l, res, par, args, cls, hdr, bod, rst) =>
val l2 = applyLocal(l)
val res2 = applyLocal(res)
applyPath(par): par2 =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class BlockTraverser:
cls.traverse
applyPath(path)
case Case.Tup(len, inf) => ()
case Case.Field(_, _) => ()

def applyHandler(hdr: Handler): Unit =
hdr.sym.traverse
Expand Down
1,191 changes: 406 additions & 785 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala

Large diffs are not rendered by default.

21 changes: 11 additions & 10 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ object Lifter:
* Lifts classes and functions to the top-level. Also automatically rewrites lambdas.
* Assumes the input block does not have any `HandleBlock`s.
*/
class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
class Lifter()(using State, Raise):
import Lifter.*

/**
Expand Down Expand Up @@ -247,9 +247,10 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
case PubField(isym, sym) => Select(isym.asPath, Tree.Ident(sym.nme))(d)


def isHandlerClsPath(p: Path) = handlerPaths match
case None => false
case Some(paths) => paths.isHandlerClsPath(p)

val ignoredSet = Set(State.globalThisSymbol.asPath.selSN("Object"), State.runtimeSymbol.asPath.selSN("NonLocalReturn"))

def isIgnoredPath(p: Path) = ignoredSet.contains(p)

/**
* Creates a capture class for a function consisting of its mutable (and possibly immutable) local variables.
Expand Down Expand Up @@ -476,7 +477,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
// If B extends A, then A -> B is an edge
parentPath match
case None => ()
case Some(path) if isHandlerClsPath(path) => ()
case Some(path) if isIgnoredPath(path) => ()
case Some(Select(RefOfBms(s, _), Tree.Ident("class"))) =>
if clsSyms.contains(s) then extendsGraph += (s -> defn.sym)
case Some(RefOfBms(s, _)) =>
Expand Down Expand Up @@ -639,14 +640,14 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
case Some(c: ClsLikeDefn) => Value.Lit(Tree.BoolLit(false)).asArg :: getCallArgs(l, ctx)
case _ => getCallArgs(l, ctx)
applyListOf(args, applyArg(_)(_)): newArgs =>
k(Call(info.singleCallBms.asPath, extraArgs ++ newArgs)(c.isMlsFun, false))
k(Call(info.singleCallBms.asPath, extraArgs ++ newArgs)(c.isMlsFun, c.mayRaiseEffects))
case _ => super.applyResult(r)(k)
case c @ Instantiate(mut, InstSel(l), args) =>
ctx.bmsReqdInfo.get(l) match
case Some(info) if !ctx.isModOrObj(l) =>
val extraArgs = Value.Lit(Tree.BoolLit(mut)).asArg :: getCallArgs(l, ctx)
applyListOf(args, applyArg(_)(_)): newArgs =>
k(Call(info.singleCallBms.asPath, extraArgs ++ newArgs)(true, false))
k(Call(info.singleCallBms.asPath, extraArgs ++ newArgs)(true, true))
case _ => super.applyResult(r)(k)
// LEGACY CODE: We previously directly created the closure and assigned it to the
// variable here. But, since this closure may be re-used later, this doesn't work
Expand Down Expand Up @@ -944,7 +945,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
val args2 = headPlistCopy.params.map(p => p.sym.asPath.asArg)

val bdy = blockBuilder
.ret(Call(singleCallBms.asPath, args1 ++ args2)(true, false)) // TODO: restParams not considered
.ret(Call(singleCallBms.asPath, args1 ++ args2)(true, true)) // TODO: restParams not considered

val mainDefn = FunDefn(f.owner, f.sym, PlainParamList(extraParamsCpy) :: headPlistCopy :: Nil, bdy)
val auxDefn = FunDefn(N, singleCallBms, flatPlist, lifted.body)
Expand Down Expand Up @@ -1035,7 +1036,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
)

for ps <- newAuxSyms do
val call = Call(curSym.asPath, ps.map(_.asPath.asArg))(true, false)
val call = Call(curSym.asPath, ps.map(_.asPath.asArg))(true, true)
curSym = TempSymbol(None, "tmp")
val thisSym = curSym
acc = acc.assign(thisSym, call)
Expand Down Expand Up @@ -1266,7 +1267,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
// so we need to desugar them again
val blk = LambdaRewriter.desugar(_blk)

val analyzer = UsedVarAnalyzer(blk, handlerPaths)
val analyzer = UsedVarAnalyzer(blk)
val ctx = LifterCtx
.withLocals(analyzer.findUsedLocals)
.withDefns(analyzer.defnsMap)
Expand Down
31 changes: 13 additions & 18 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -512,14 +512,6 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
t.toLoc :: Nil,
source = Diagnostic.Source.Compilation)
conclude(Value.Ref(State.runtimeSymbol).selSN("raisePrintStackEffect").withLocOf(f))
case t if instantiatedResolvedBms.exists(_ is ctx.builtins.debug.getLocals) =>
if !config.effectHandlers.exists(_.debug) then
return fail:
ErrorReport(
msg"Debugging functions are not enabled" ->
t.toLoc :: Nil,
source = Diagnostic.Source.Compilation)
conclude(Value.Ref(ctx.builtins.debug.getLocals, N).withLocOf(f))
// * Due to whacky JS semantics, we need to make sure that selections leading to a call
// * are preserved in the call and not moved to a temporary variable.
case sel @ Sel(prefix, nme) =>
Expand Down Expand Up @@ -950,21 +942,24 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
val desug = LambdaRewriter.desugar(blk)

val handlerPaths = new HandlerPaths

val withHandlers1 = config.effectHandlers.fold(desug): opt =>
HandlerLowering(handlerPaths, opt).translateHandleBlocks(desug)

val (withHandlers, doUnwindPaths) = config.effectHandlers.fold((desug, Map.empty)): opt =>
HandlerLowering(handlerPaths, opt).translateTopLevel(desug)
val lifted =
if lift then Lifter().transform(withHandlers1)
else withHandlers1

val (withHandlers2, doUnwindPaths) = config.effectHandlers.fold((lifted, Map.empty)): opt =>
HandlerLowering(handlerPaths, opt).translateTopLevel(lifted)

val stackSafe = config.stackSafety match
case N => withHandlers
case S(sts) => StackSafeTransform(sts.stackLimit, handlerPaths, doUnwindPaths).transformTopLevel(withHandlers)
case N => withHandlers2
case S(sts) => StackSafeTransform(sts.stackLimit, handlerPaths, doUnwindPaths).transformTopLevel(withHandlers2)

val flattened = stackSafe.flattened

val lifted =
if lift then Lifter(S(handlerPaths)).transform(flattened)
else flattened

val bufferable = BufferableTransform().transform(lifted)
val bufferable = BufferableTransform().transform(flattened)

val merged = MergeMatchArmTransformer.applyBlock(bufferable)

Expand Down Expand Up @@ -1004,7 +999,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
trait LoweringSelSanityChecks(using Config, TL, Raise, State)
extends Lowering:

private val instrument: Bool = config.sanityChecks.isDefined
private val instrument: Bool = config.sanityChecks.isDefined && config.effectHandlers.isEmpty

override def setupSelection(prefix: st, nme: Tree.Ident, disamb: Opt[DefinitionSymbol[?]])(k: Result => Block)(using LoweringCtx): Block =
if !instrument then return super.setupSelection(prefix, nme, disamb)(k)
Expand Down
33 changes: 8 additions & 25 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,8 @@ import hkmc2.semantics.*
import hkmc2.syntax.Tree
import hkmc2.codegen.HandlerLowering.FnOrCls

class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, doUnwindMap: Map[FnOrCls, Path])(using State):
class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, doUnwindMap: collection.Map[FnOrCls, Path => Return])(using State):
private val STACK_DEPTH_IDENT: Tree.Ident = Tree.Ident("stackDepth")

val doUnwindFns = doUnwindMap.values.collect:
case s: Select if s.symbol.isDefined => s.symbol.get
case Value.Ref(sym, _) => sym
.toSet

private val runtimePath: Path = State.runtimeSymbol.asPath
private val checkDepthPath: Path = runtimePath.selN(Tree.Ident("checkDepth"))
Expand Down Expand Up @@ -135,44 +130,32 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, doUnwindMap: Map[
usedDepth = true
TempSymbol(None, "curDepth")

val doUnwindPath = doUnwindMap.get(fnOrCls)
val doUnwind = doUnwindMap.get(fnOrCls)
val newBody = transform(blk, curDepth)

if isTrivial(blk) then
newBody
else if doUnwindPath.isEmpty then
val resSym = TempSymbol(None, "stackDelayRes")
else if doUnwind.isEmpty then
// The current function is not instrumented and we cannot provide stack safety.
// TODO: shouldn't we just return the old blk?
blockBuilder
.staticif(usedDepth, _.assign(curDepth, stackDepthPath))
.rest(newBody)
else
val resSym = TempSymbol(None, "stackDelayRes")
val rewritten = blockBuilder
blockBuilder
.staticif(usedDepth, _.assign(curDepth, stackDepthPath))
.assignFieldN(runtimePath, STACK_DEPTH_IDENT, op("+", stackDepthPath, intLit(increment)))
.assign(resSym, Call(checkDepthPath, Nil)(true, true))
.ifthen(
resSym.asPath,
Case.Cls(paths.effectSigSym, paths.effectSigPath),
Return(
Call(doUnwindPath.get, resSym.asPath.asArg :: intLit(0).asArg :: Nil)(true, false),
false
)
doUnwind.get(resSym.asPath)
)
.rest(newBody)
// Float out defns, including the doUnwind function, so that they appear at the top of the block
// This is because the doUnwind function must appear before the checks inserted by the stack
// safety pass.
// However, due to how tightly coupled the stack safety and handler lowering are, it might be
// better to simply merge the two passes in the future.
val (blk, defns) = doUnwindPath.get match
case Value.Ref(sym, _) => rewritten.floatOutDefns()
case _ => (rewritten, Nil)
defns.foldLeft(blk)((acc, defn) => Define(defn, acc))


def rewriteFn(defn: FunDefn) =
if doUnwindFns.contains(defn.sym) then defn
else FunDefn(defn.owner, defn.sym, defn.params, rewriteBlk(defn.body, L(defn.sym), 1))
FunDefn(defn.owner, defn.sym, defn.params, rewriteBlk(defn.body, L(defn.sym), 1))

def transformTopLevel(b: Block) = transform(b, TempSymbol(N), true)
13 changes: 2 additions & 11 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import scala.collection.mutable.Map as MutMap
*
* Assumes the input trees have no lambdas.
*/
class UsedVarAnalyzer(b: Block, handlerPaths: Opt[HandlerPaths])(using State):
class UsedVarAnalyzer(b: Block)(using State):
import Lifter.*

private case class DefnMetadata(
Expand Down Expand Up @@ -120,10 +120,6 @@ class UsedVarAnalyzer(b: Block, handlerPaths: Opt[HandlerPaths])(using State):
val DefnMetadata(definedLocals, defnsMap, existingVars,
inScopeDefns, nestedDefns, nestedDeep, nestedIn, companionMap) = createMetadata

def isHandlerClsPath(p: Path) = handlerPaths match
case None => false
case Some(paths) => paths.isHandlerClsPath(p)

private val blkMutCache: MutMap[Local, AccessInfo] = MutMap.empty
private def blkAccessesShallow(b: Block, cacheId: Opt[Local] = N): AccessInfo =
cacheId.flatMap(blkMutCache.get) match
Expand Down Expand Up @@ -173,7 +169,7 @@ class UsedVarAnalyzer(b: Block, handlerPaths: Opt[HandlerPaths])(using State):
blkAccessesShallow(f.body).withoutLocals(fVars)
case c: ClsLikeDefn =>
val methodSyms = c.methods.map(_.sym).toSet
val ret = c.methods.foldLeft(blkAccessesShallow(c.preCtor) ++ blkAccessesShallow(c.ctor)):
c.methods.foldLeft(blkAccessesShallow(c.preCtor) ++ blkAccessesShallow(c.ctor)):
case (acc, fDefn) =>
// class methods do not need to be lifted, so we don't count calls to their methods.
// a previous reference to this class's block member symbol is enough to assume any
Expand All @@ -182,11 +178,6 @@ class UsedVarAnalyzer(b: Block, handlerPaths: Opt[HandlerPaths])(using State):
// however, we must keep references to the class itself!
val defnAccess = findAccessesShallow(fDefn)
acc ++ defnAccess.withoutBms(methodSyms)
if c.parentPath.isDefined && isHandlerClsPath(c.parentPath.get) then
// for continuation classes, treat them like they only read variables
AccessInfo(ret.accessed ++ ret.mutated, Set.empty, ret.refdDefns)
else
ret
case _: ValDefn => AccessInfo.empty

accessedCache.getOrElseUpdate(defn.sym, create)
Expand Down
13 changes: 11 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,7 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder:

val privs = mkPrivs(pubFlds, privFlds, mtdPrefix, isym)

val preCtorCode = body(preCtor, true)
val ctorCode = doc"$preCtorCode${body(ctor, endSemi = true)}${
val ctorCode = doc"${body(Begin(preCtor, ctor), endSemi = true)}${
kind match
case syntax.Obj =>
doc" # ${defineProperty(doc"this", "class", doc"${scope.lookup_!(isym, isym.toLoc)}")}"
Expand Down Expand Up @@ -447,6 +446,16 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder:
case S(el) => returningTerm(el, endSemi = true)
case N => doc""
e :: returningTerm(rest, endSemi)
case Match(scrut, arms, els, rest) if arms.forall(_._1.isInstanceOf[Case.Lit]) =>
val l = arms.foldLeft(doc""): (acc, arm) =>
acc :: doc" # case ${arm._1.asInstanceOf[Case.Lit].lit.idStr}: #{ ${
returningTerm(arm._2, endSemi = true)
} # break; #} "
val e = els match
case S(el) =>
doc" # default: #{ ${ returningTerm(el, endSemi = true) } # break; #} "
case N => doc""
doc" # switch (${result(scrut)}) { #{ ${l :: e} #} # }" :: returningTerm(rest, endSemi)
case Match(scrut, hd :: tl, els, rest) =>
val sd = result(scrut)
def cond(cse: Case) = cse match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ object Elaborator:
val plus_impl = assumeObject("plus_impl")
object debug extends VirtualModule(assumeBuiltinMod("debug")):
val printStack = assumeObject("printStack")
val getLocals = assumeObject("getLocals")
object annotations extends VirtualModule(assumeBuiltinMod("annotations")):
val compile = assumeObject("compile")
val buffered = assumeObject("buffered")
Expand Down
44 changes: 25 additions & 19 deletions hkmc2/shared/src/test/mlscript-compile/Predef.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -127,27 +127,33 @@ globalThis.Object.freeze(class Predef {
let len, scrut, i, init, scrut1, tmp, tmp1, tmp2, tmp3;
len = rest.length;
scrut = len == 0;
if (scrut === true) {
return first
} else {
i = len - 1;
init = runtime.safeCall(rest.at(i));
tmp4: while (true) {
scrut1 = i > 0;
if (scrut1 === true) {
tmp = i - 1;
i = tmp;
tmp1 = runtime.safeCall(rest.at(i));
tmp2 = runtime.safeCall(f(tmp1, init));
init = tmp2;
tmp3 = runtime.Unit;
continue tmp4
} else {
tmp3 = runtime.Unit;
switch (scrut) {
case true:
return first;
break;
default:
i = len - 1;
init = runtime.safeCall(rest.at(i));
tmp4: while (true) {
scrut1 = i > 0;
switch (scrut1) {
case true:
tmp = i - 1;
i = tmp;
tmp1 = runtime.safeCall(rest.at(i));
tmp2 = runtime.safeCall(f(tmp1, init));
init = tmp2;
tmp3 = runtime.Unit;
continue tmp4;
break;
default:
tmp3 = runtime.Unit;
break;
}
break;
}
return runtime.safeCall(f(first, init));
break;
}
return runtime.safeCall(f(first, init))
}
}
}
Expand Down
Loading