diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala index f81f5af8..4545ca6d 100644 --- a/src/main/scala/scala/async/internal/AnfTransform.scala +++ b/src/main/scala/scala/async/internal/AnfTransform.scala @@ -16,16 +16,18 @@ private[async] trait AnfTransform { import c.internal._ import decorators._ - def anfTransform(tree: Tree): Block = { + def anfTransform(tree: Tree, owner: Symbol): Block = { // Must prepend the () for issue #31. - val block = c.typecheck(atPos(tree.pos)(Block(List(Literal(Constant(()))), tree))).setType(tree.tpe) + val block = c.typecheck(atPos(tree.pos)(newBlock(List(Literal(Constant(()))), tree))).setType(tree.tpe) sealed abstract class AnfMode case object Anf extends AnfMode case object Linearizing extends AnfMode + val tree1 = adjustTypeOfTranslatedPatternMatches(block, owner) + var mode: AnfMode = Anf - typingTransform(block)((tree, api) => { + typingTransform(tree1, owner)((tree, api) => { def blockToList(tree: Tree): List[Tree] = tree match { case Block(stats, expr) => stats :+ expr case t => t :: Nil @@ -34,7 +36,7 @@ private[async] trait AnfTransform { def listToBlock(trees: List[Tree]): Block = trees match { case trees @ (init :+ last) => val pos = trees.map(_.pos).reduceLeft(_ union _) - Block(init, last).setType(last.tpe).setPos(pos) + newBlock(init, last).setType(last.tpe).setPos(pos) } object linearize { @@ -66,6 +68,17 @@ private[async] trait AnfTransform { stats :+ valDef :+ atPos(tree.pos)(ref1) case If(cond, thenp, elsep) => + // If we run the ANF transform post patmat, deal with trees like `(if (cond) jump1(){String} else jump2(){String}){String}` + // as though it was typed with `Unit`. + def isPatMatGeneratedJump(t: Tree): Boolean = t match { + case Block(_, expr) => isPatMatGeneratedJump(expr) + case If(_, thenp, elsep) => isPatMatGeneratedJump(thenp) && isPatMatGeneratedJump(elsep) + case _: Apply if isLabel(t.symbol) => true + case _ => false + } + if (isPatMatGeneratedJump(expr)) { + internal.setType(expr, definitions.UnitTpe) + } // if type of if-else is Unit don't introduce assignment, // but add Unit value to bring it into form expected by async transform if (expr.tpe =:= definitions.UnitTpe) { @@ -77,7 +90,7 @@ private[async] trait AnfTransform { def branchWithAssign(orig: Tree) = api.typecheck(atPos(orig.pos) { def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, tpe(varDef.symbol)) orig match { - case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr))) + case Block(thenStats, thenExpr) => newBlock(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr))) case _ => Assign(Ident(varDef.symbol), cast(orig)) } }) @@ -115,7 +128,7 @@ private[async] trait AnfTransform { } } - private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { + def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp)) valDef(sym, mkZero(uncheckedBounds(tp))).setType(NoType).setPos(pos) } @@ -152,8 +165,7 @@ private[async] trait AnfTransform { } def _transformToList(tree: Tree): List[Tree] = trace(tree) { - val containsAwait = tree exists isAwait - if (!containsAwait) { + if (!containsAwait(tree)) { tree match { case Block(stats, expr) => // avoids nested block in `while(await(false)) ...`. @@ -207,10 +219,11 @@ private[async] trait AnfTransform { funStats ++ argStatss.flatten.flatten :+ typedNewApply case Block(stats, expr) => - (stats :+ expr).flatMap(linearize.transformToList) + val trees = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) ::: linearize.transformToList(expr) + eliminateMatchEndLabelParameter(trees) case ValDef(mods, name, tpt, rhs) => - if (rhs exists isAwait) { + if (containsAwait(rhs)) { val stats :+ expr = api.atOwner(api.currentOwner.owner)(linearize.transformToList(rhs)) stats.foreach(_.changeOwner(api.currentOwner, api.currentOwner.owner)) stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr) @@ -247,7 +260,7 @@ private[async] trait AnfTransform { scrutStats :+ treeCopy.Match(tree, scrutExpr, caseDefs) case LabelDef(name, params, rhs) => - List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) + List(LabelDef(name, params, newBlock(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) case TypeApply(fun, targs) => val funStats :+ simpleFun = linearize.transformToList(fun) @@ -259,6 +272,52 @@ private[async] trait AnfTransform { } } + // Replace the label parameters on `matchEnd` with use of a `matchRes` temporary variable + // + // CaseDefs are translated to labels without parmeters. A terminal label, `matchEnd`, accepts + // a parameter which is the result of the match (this is regular, so even Unit-typed matches have this). + // + // For our purposes, it is easier to: + // - extract a `matchRes` variable + // - rewrite the terminal label def to take no parameters, and instead read this temp variable + // - change jumps to the terminal label to an assignment and a no-arg label application + def eliminateMatchEndLabelParameter(statsExpr: List[Tree]): List[Tree] = { + import internal.{methodType, setInfo} + val caseDefToMatchResult = collection.mutable.Map[Symbol, Symbol]() + + val matchResults = collection.mutable.Buffer[Tree]() + val statsExpr0 = statsExpr.reverseMap { + case ld @ LabelDef(_, param :: Nil, body) => + val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos) + matchResults += matchResult + caseDefToMatchResult(ld.symbol) = matchResult.symbol + val ld2 = treeCopy.LabelDef(ld, ld.name, Nil, body.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil)) + setInfo(ld.symbol, methodType(Nil, ld.symbol.info.resultType)) + ld2 + case t => + if (caseDefToMatchResult.isEmpty) t + else typingTransform(t)((tree, api) => + tree match { + case Apply(fun, arg :: Nil) if isLabel(fun.symbol) && caseDefToMatchResult.contains(fun.symbol) => + api.typecheck(atPos(tree.pos)(newBlock(Assign(Ident(caseDefToMatchResult(fun.symbol)), api.recur(arg)) :: Nil, treeCopy.Apply(tree, fun, Nil)))) + case Block(stats, expr) => + api.default(tree) match { + case Block(stats, Block(stats1, expr)) => + treeCopy.Block(tree, stats ::: stats1, expr) + case t => t + } + case _ => + api.default(tree) + } + ) + } + matchResults.toList match { + case Nil => statsExpr + case r1 :: Nil => (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol)) + case _ => c.error(macroPos, "Internal error: unexpected tree encountered during ANF transform " + statsExpr); statsExpr + } + } + def anfLinearize(tree: Tree): Block = { val trees: List[Tree] = mode match { case Anf => anf._transformToList(tree) diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala index 7464c42d..7a1e274d 100644 --- a/src/main/scala/scala/async/internal/AsyncBase.scala +++ b/src/main/scala/scala/async/internal/AsyncBase.scala @@ -43,9 +43,9 @@ abstract class AsyncBase { (body: c.Expr[T]) (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { import c.universe._, c.internal._, decorators._ - val asyncMacro = AsyncMacro(c, self) + val asyncMacro = AsyncMacro(c, self)(body.tree) - val code = asyncMacro.asyncTransform[T](body.tree, execContext.tree)(c.weakTypeTag[T]) + val code = asyncMacro.asyncTransform[T](execContext.tree)(c.weakTypeTag[T]) AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}") // Mark range positions for synthetic code as transparent to allow some wiggle room for overlapping ranges diff --git a/src/main/scala/scala/async/internal/AsyncId.scala b/src/main/scala/scala/async/internal/AsyncId.scala index 3afa55b3..86544746 100644 --- a/src/main/scala/scala/async/internal/AsyncId.scala +++ b/src/main/scala/scala/async/internal/AsyncId.scala @@ -41,11 +41,11 @@ object AsyncTestLV extends AsyncBase { * A trivial implementation of [[FutureSystem]] that performs computations * on the current thread. Useful for testing. */ +class Box[A] { + var a: A = _ +} object IdentityFutureSystem extends FutureSystem { - - class Prom[A] { - var a: A = _ - } + type Prom[A] = Box[A] type Fut[A] = A type ExecContext = Unit @@ -57,7 +57,7 @@ object IdentityFutureSystem extends FutureSystem { def execContext: Expr[ExecContext] = c.Expr[Unit](Literal(Constant(()))) - def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]] + def promType[A: WeakTypeTag]: Type = weakTypeOf[Box[A]] def tryType[A: WeakTypeTag]: Type = weakTypeOf[scala.util.Try[A]] def execContextType: Type = weakTypeOf[Unit] diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala index e969f9bc..e22407da 100644 --- a/src/main/scala/scala/async/internal/AsyncMacro.scala +++ b/src/main/scala/scala/async/internal/AsyncMacro.scala @@ -1,15 +1,17 @@ package scala.async.internal object AsyncMacro { - def apply(c0: reflect.macros.Context, base: AsyncBase): AsyncMacro { val c: c0.type } = { + def apply(c0: reflect.macros.Context, base: AsyncBase)(body0: c0.Tree): AsyncMacro { val c: c0.type } = { import language.reflectiveCalls new AsyncMacro { self => val c: c0.type = c0 + val body: c.Tree = body0 // This member is required by `AsyncTransform`: val asyncBase: AsyncBase = base // These members are required by `ExprBuilder`: val futureSystem: FutureSystem = base.futureSystem val futureSystemOps: futureSystem.Ops {val c: self.c.type} = futureSystem.mkOps(c) + val containsAwait: c.Tree => Boolean = containsAwaitCached(body0) } } } @@ -19,7 +21,10 @@ private[async] trait AsyncMacro with ExprBuilder with AsyncTransform with AsyncAnalysis with LiveVariables { val c: scala.reflect.macros.Context + val body: c.Tree + val containsAwait: c.Tree => Boolean lazy val macroPos = c.macroApplication.pos.makeTransparent def atMacroPos(t: c.Tree) = c.universe.atPos(macroPos)(t) + } diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index baa3fc20..af290e4f 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -9,7 +9,7 @@ trait AsyncTransform { val asyncBase: AsyncBase - def asyncTransform[T](body: Tree, execContext: Tree) + def asyncTransform[T](execContext: Tree) (resultType: WeakTypeTag[T]): Tree = { // We annotate the type of the whole expression as `T @uncheckedBounds` so as not to introduce @@ -22,7 +22,7 @@ trait AsyncTransform { // Transform to A-normal form: // - no await calls in qualifiers or arguments, // - if/match only used in statement position. - val anfTree0: Block = anfTransform(body) + val anfTree0: Block = anfTransform(body, c.internal.enclosingOwner) val anfTree = futureSystemOps.postAnfTransform(anfTree0) @@ -35,7 +35,7 @@ trait AsyncTransform { val stateMachine: ClassDef = { val body: List[Tree] = { val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(StateAssigner.Initial))) - val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T](uncheckedBoundsResultTag)), futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree) + val resultAndAccessors = mkMutableField(futureSystemOps.promType[T](uncheckedBoundsResultTag), name.result, futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree) val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext) val apply0DefDef: DefDef = { @@ -43,7 +43,7 @@ trait AsyncTransform { // See SI-1247 for the the optimization that avoids creation. DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.apply), literalNull :: Nil)) } - List(emptyConstructor, stateVar, result, execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef) + List(emptyConstructor, stateVar) ++ resultAndAccessors ++ List(execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef) } val tryToUnit = appliedType(definitions.FunctionClass(1), futureSystemOps.tryType[Any], typeOf[Unit]) @@ -98,10 +98,11 @@ trait AsyncTransform { } val isSimple = asyncBlock.asyncStates.size == 1 - if (isSimple) + val result = if (isSimple) futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }` else startStateMachine + cleanupContainsAwaitAttachments(result) } def logDiagnostics(anfTree: Tree, states: Seq[String]) { diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index 164e85b3..16b9207b 100644 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -146,6 +146,8 @@ trait ExprBuilder { private val stats = ListBuffer[Tree]() /** The state of the target of a LabelDef application (while loop jump) */ private var nextJumpState: Option[Int] = None + private var nextJumpSymbol: Symbol = NoSymbol + def effectiveNextState(nextState: Int) = nextJumpState.orElse(if (nextJumpSymbol == NoSymbol) None else Some(stateIdForLabel(nextJumpSymbol))).getOrElse(nextState) def +=(stat: Tree): this.type = { stat match { @@ -155,11 +157,16 @@ trait ExprBuilder { } def addStat() = stats += stat stat match { - case Apply(fun, Nil) => + case Apply(fun, args) if isLabel(fun.symbol) => // labelDefStates belongs to the current ExprBuilder labelDefStates get fun.symbol match { - case opt @ Some(nextState) => nextJumpState = opt // re-use object - case None => addStat() + case opt@Some(nextState) => + // A backward jump + nextJumpState = opt // re-use object + nextJumpSymbol = fun.symbol + case None => + // We haven't the corresponding LabelDef, this is a forward jump + nextJumpSymbol = fun.symbol } case _ => addStat() } @@ -169,13 +176,11 @@ trait ExprBuilder { def resultWithAwait(awaitable: Awaitable, onCompleteState: Int, nextState: Int): AsyncState = { - val effectiveNextState = nextJumpState.getOrElse(nextState) - new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState, awaitable, symLookup) + new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState(nextState), awaitable, symLookup) } def resultSimple(nextState: Int): AsyncState = { - val effectiveNextState = nextJumpState.getOrElse(nextState) - new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup) + new SimpleAsyncState(stats.toList, state, effectiveNextState(nextState), symLookup) } def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { @@ -243,9 +248,17 @@ trait ExprBuilder { } import stateAssigner.nextState + def directlyAdjacentLabelDefs(t: Tree): List[Tree] = { + def isPatternCaseLabelDef(t: Tree) = t match { + case LabelDef(name, _, _) => name.toString.startsWith("case") + case _ => false + } + val (before, _ :: after) = (stats :+ expr).span(_ ne t) + before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef) + } // populate asyncStates - for (stat <- stats) stat match { + for (stat <- (stats :+ expr)) stat match { // the val name = await(..) pattern case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => val onCompleteState = nextState() @@ -255,7 +268,7 @@ trait ExprBuilder { currState = afterAwaitState stateBuilder = new AsyncStateBuilder(currState, symLookup) - case If(cond, thenp, elsep) if (stat exists isAwait) || containsForiegnLabelJump(stat) => + case If(cond, thenp, elsep) if containsAwait(stat) || containsForiegnLabelJump(stat) => checkForUnsupportedAwait(cond) val thenStartState = nextState() @@ -275,7 +288,7 @@ trait ExprBuilder { currState = afterIfState stateBuilder = new AsyncStateBuilder(currState, symLookup) - case Match(scrutinee, cases) if stat exists isAwait => + case Match(scrutinee, cases) if containsAwait(stat) => checkForUnsupportedAwait(scrutinee) val caseStates = cases.map(_ => nextState()) @@ -293,24 +306,21 @@ trait ExprBuilder { currState = afterMatchState stateBuilder = new AsyncStateBuilder(currState, symLookup) + case ld @ LabelDef(name, params, rhs) + if containsAwait(rhs) || directlyAdjacentLabelDefs(ld).exists(containsAwait) => - case ld @ LabelDef(name, params, rhs) if rhs exists isAwait => - val startLabelState = nextState() + val startLabelState = stateIdForLabel(ld.symbol) val afterLabelState = nextState() asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup) labelDefStates(ld.symbol) = startLabelState val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) asyncStates ++= builder.asyncStates - currState = afterLabelState stateBuilder = new AsyncStateBuilder(currState, symLookup) - case _ => checkForUnsupportedAwait(stat) stateBuilder += stat } - // complete last state builder (representing the expressions after the last await) - stateBuilder += expr val lastState = stateBuilder.resultSimple(endState) asyncStates += lastState } @@ -383,18 +393,26 @@ trait ExprBuilder { * } * } */ - private def resumeFunTree[T: WeakTypeTag]: Tree = - Try( - Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ), - List( - CaseDef( - Bind(name.t, Ident(nme.WILDCARD)), - Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), { - val t = c.Expr[Throwable](Ident(name.t)) - val complete = futureSystemOps.completeProm[T]( + private def resumeFunTree[T: WeakTypeTag]: Tree = { + val body = Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T])) + Try( + body, + List( + CaseDef( + Bind(name.t, Typed(Ident(nme.WILDCARD), Ident(defn.ThrowableClass))), + EmptyTree, { + val then = { + val t = c.Expr[Throwable](Ident(name.t)) + val complete = futureSystemOps.completeProm[T]( c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryyFailure[T](t)).tree - Block(toList(complete), Return(literalUnit)) - })), EmptyTree) + Block(toList(complete), Return(literalUnit)) + } + If(Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), then, Throw(Ident(name.t))) + then + })), EmptyTree) + + //body + } def forever(t: Tree): Tree = { val labelName = name.fresh("while$") @@ -435,6 +453,14 @@ trait ExprBuilder { private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef = mkHandlerCase(num, adaptToUnit(rhs)) + // We use the convention that the state machine's ID for a state corresponding to + // a labeldef will a negative number be based on the symbol ID. This allows us + // to translate a forward jump to the label as a state transition to a known state + // ID, even though the state machine transform hasn't yet processed the target label + // def. Negative numbers are used so as as not to clash with regular state IDs, which + // are allocated in ascending order from 0. + private def stateIdForLabel(sym: Symbol): Int = -symId(sym) + private def tpeOf(t: Tree): Type = t match { case _ if t.tpe != null => t.tpe case Try(body, Nil, _) => tpeOf(body) diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala index 4242a8e3..2998bafe 100644 --- a/src/main/scala/scala/async/internal/Lifter.scala +++ b/src/main/scala/scala/async/internal/Lifter.scala @@ -40,6 +40,7 @@ trait Lifter { val defs: Map[Tree, Int] = { /** Collect the DefTrees directly enclosed within `t` that have the same owner */ def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match { + case ld: LabelDef => Nil case dt: DefTree => dt :: Nil case _: Function => Nil case t => diff --git a/src/main/scala/scala/async/internal/StateAssigner.scala b/src/main/scala/scala/async/internal/StateAssigner.scala index 55e7a51c..2b74e8dc 100644 --- a/src/main/scala/scala/async/internal/StateAssigner.scala +++ b/src/main/scala/scala/async/internal/StateAssigner.scala @@ -7,8 +7,7 @@ package scala.async.internal private[async] final class StateAssigner { private var current = StateAssigner.Initial - def nextState(): Int = - try current finally current += 1 + def nextState(): Int = try current finally current += 1 } object StateAssigner { diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala index 547f9807..90419d3d 100644 --- a/src/main/scala/scala/async/internal/TransformUtils.scala +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -41,6 +41,19 @@ private[async] trait TransformUtils { def isAwait(fun: Tree) = fun.symbol == defn.Async_await + def newBlock(stats: List[Tree], expr: Tree): Block = { + Block(stats, expr) + } + + def isLiteralUnit(t: Tree) = t match { + case Literal(Constant(())) => + true + case _ => false + } + + def isPastTyper = + c.universe.asInstanceOf[scala.reflect.internal.SymbolTable].isPastTyper + // Copy pasted from TreeInfo in the compiler. // Using a quasiquote pattern like `case q"$fun[..$targs](...$args)" => is not // sufficient since https://github.com/scala/scala/pull/3656 as it doesn't match @@ -150,6 +163,7 @@ private[async] trait TransformUtils { } val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal") + val ThrowableClass = rootMirror.staticClass("java.lang.Throwable") val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol) val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException") } @@ -161,16 +175,26 @@ private[async] trait TransformUtils { val labelDefs = t.collect { case ld: LabelDef => ld.symbol }.toSet - t.exists { + val result = t.exists { case rt: RefTree => rt.symbol != null && isLabel(rt.symbol) && !(labelDefs contains rt.symbol) case _ => false } + result } - private def isLabel(sym: Symbol): Boolean = { + def isLabel(sym: Symbol): Boolean = { val LABEL = 1L << 17 // not in the public reflection API. (internal.flags(sym).asInstanceOf[Long] & LABEL) != 0L } + def symId(sym: Symbol): Int = { + val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable] + sym.asInstanceOf[symtab.Symbol].id + } + def substituteTrees(t: Tree, from: List[Symbol], to: List[Tree]): Tree = { + val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable] + val subst = new symtab.TreeSubstituter(from.asInstanceOf[List[symtab.Symbol]], to.asInstanceOf[List[symtab.Tree]]) + subst.transform(t.asInstanceOf[symtab.Tree]).asInstanceOf[Tree] + } /** Map a list of arguments to: @@ -362,4 +386,121 @@ private[async] trait TransformUtils { else withAnnotation(tp, Annotation(UncheckedBoundsClass.asType.toType, Nil, ListMap())) } // ===================================== + + /** + * Efficiently decorate each subtree within `t` with the result of `t exists isAwait`, + * and return a function that can be used on derived trees to efficiently test the + * same condition. + * + * If the derived tree contains synthetic wrapper trees, these will be recursed into + * in search of a sub tree that was decorated with the cached answer. + */ + final def containsAwaitCached(t: Tree): Tree => Boolean = { + def treeCannotContainAwait(t: Tree) = t match { + case _: Ident | _: TypeTree | _: Literal => true + case _ => false + } + def shouldAttach(t: Tree) = !treeCannotContainAwait(t) + val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] + def attachContainsAwait(t: Tree): Unit = if (shouldAttach(t)) { + val t1 = t.asInstanceOf[symtab.Tree] + t1.updateAttachment(ContainsAwait) + t1.removeAttachment[NoAwait.type] + } + def attachNoAwait(t: Tree): Unit = if (shouldAttach(t)) { + val t1 = t.asInstanceOf[symtab.Tree] + t1.updateAttachment(NoAwait) + } + object markContainsAwaitTraverser extends Traverser { + var stack: List[Tree] = Nil + + override def traverse(tree: Tree): Unit = { + stack ::= tree + try { + if (isAwait(tree)) + stack.foreach(attachContainsAwait) + else + attachNoAwait(tree) + super.traverse(tree) + } finally stack = stack.tail + } + } + markContainsAwaitTraverser.traverse(t) + + (t: Tree) => { + object traverser extends Traverser { + var containsAwait = false + override def traverse(tree: Tree): Unit = { + def castTree = tree.asInstanceOf[symtab.Tree] + if (!castTree.hasAttachment[NoAwait.type]) { + if (castTree.hasAttachment[ContainsAwait.type]) + containsAwait = true + else if (!treeCannotContainAwait(t)) + super.traverse(tree) + } + } + } + traverser.traverse(t) + traverser.containsAwait + } + } + + final def cleanupContainsAwaitAttachments(t: Tree): t.type = { + val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] + t.foreach {t => + t.asInstanceOf[symtab.Tree].removeAttachment[ContainsAwait.type] + t.asInstanceOf[symtab.Tree].removeAttachment[NoAwait.type] + } + t + } + + // First modification to translated patterns: + // - Set the type of label jumps to `Unit` + // - Propagate this change to trees known to directly enclose them: + // ``If` / `Block`) adjust types of enclosing + final def adjustTypeOfTranslatedPatternMatches(t: Tree, owner: Symbol): Tree = { + import definitions.UnitTpe + typingTransform(t, owner) { + (tree, api) => + tree match { + case Block(stats, expr) => + val stats1 = stats map api.recur + val expr1 = api.recur(expr) + if (expr1.tpe =:= UnitTpe) + internal.setType(treeCopy.Block(tree, stats1, expr1), UnitTpe) + else + treeCopy.Block(tree, stats1, expr1) + case If(cond, thenp, elsep) => + val cond1 = api.recur(cond) + val thenp1 = api.recur(thenp) + val elsep1 = api.recur(elsep) + if (thenp1.tpe =:= definitions.UnitTpe && elsep.tpe =:= UnitTpe) + internal.setType(treeCopy.If(tree, cond1, thenp1, elsep1), UnitTpe) + else + treeCopy.If(tree, cond1, thenp1, elsep1) + case Apply(fun, args) if isLabel(fun.symbol) => + internal.setType(treeCopy.Apply(tree, api.recur(fun), args map api.recur), UnitTpe) + case t => api.default(t) + } + } + } + + final def mkMutableField(tpt: Type, name: TermName, init: Tree): List[Tree] = { + if (isPastTyper) { + // If we are running after the typer phase (ie being called from a compiler plugin) + // we have to create the trio of members manually. + val ACCESSOR = (1L << 27).asInstanceOf[FlagSet] + val STABLE = (1L << 22).asInstanceOf[FlagSet] + val field = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name + " ", TypeTree(tpt), init) + val getter = DefDef(Modifiers(ACCESSOR | STABLE), name, Nil, Nil, TypeTree(tpt), Select(This(tpnme.EMPTY), field.name)) + val setter = DefDef(Modifiers(ACCESSOR), name + "_=", Nil, List(List(ValDef(NoMods, TermName("x"), TypeTree(tpt), EmptyTree))), TypeTree(definitions.UnitTpe), Assign(Select(This(tpnme.EMPTY), field.name), Ident(TermName("x")))) + field :: getter :: setter :: Nil + } else { + val result = ValDef(NoMods, name, TypeTree(tpt), init) + result :: Nil + } + } } + +case object ContainsAwait +case object NoAwait \ No newline at end of file diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index d6c619f8..09fa69e4 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -82,6 +82,8 @@ object TreeInterrogation extends App { println(tree) val tree1 = tb.typeCheck(tree.duplicate) println(cm.universe.show(tree1)) + println(tb.eval(tree)) } + } diff --git a/src/test/scala/scala/async/run/late/LateExpansion.scala b/src/test/scala/scala/async/run/late/LateExpansion.scala new file mode 100644 index 00000000..b8665271 --- /dev/null +++ b/src/test/scala/scala/async/run/late/LateExpansion.scala @@ -0,0 +1,170 @@ +package scala.async.run.late + +import java.io.File + +import junit.framework.Assert.assertEquals +import org.junit.Test + +import scala.annotation.StaticAnnotation +import scala.async.internal.{AsyncId, AsyncMacro} +import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader +import scala.tools.nsc._ +import scala.tools.nsc.plugins.{Plugin, PluginComponent} +import scala.tools.nsc.reporters.StoreReporter +import scala.tools.nsc.transform.TypingTransformers + +// Tests for customized use of the async transform from a compiler plugin, which +// calls it from a new phase that runs after patmat. +class LateExpansion { + @Test def test0(): Unit = { + val result = wrapAndRun( + """ + | @autoawait def id(a: String) = a + | id("foo") + id("bar") + | """.stripMargin) + assertEquals("foobar", result) + } + @Test def testGuard(): Unit = { + val result = wrapAndRun( + """ + | @autoawait def id[A](a: A) = a + | "" match { case _ if id(false) => ???; case _ => "okay" } + | """.stripMargin) + assertEquals("okay", result) + } + + @Test def testExtractor(): Unit = { + val result = wrapAndRun( + """ + | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) } + | "" match { case Extractor(a, b) if "".isEmpty => a == b } + | """.stripMargin) + assertEquals(true, result) + } + + @Test def testNestedMatchExtractor(): Unit = { + val result = wrapAndRun( + """ + | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) } + | "" match { + | case _ if "".isEmpty => + | "" match { case Extractor(a, b) => a == b } + | } + | """.stripMargin) + assertEquals(true, result) + } + + @Test def testCombo(): Unit = { + val result = wrapAndRun( + """ + | object Extractor1 { @autoawait def unapply(a: String) = Some((a + 1, a + 2)) } + | object Extractor2 { @autoawait def unapply(a: String) = Some(a + 3) } + | @autoawait def id(a: String) = a + | println("Test.test") + | val r1 = Predef.identity("blerg") match { + | case x if " ".isEmpty => "case 2: " + x + | case Extractor1(Extractor2(x), y: String) if x == "xxx" => "case 1: " + x + ":" + y + | x match { + | case Extractor1(Extractor2(x), y: String) => + | case _ => + | } + | case Extractor2(x) => "case 3: " + x + | } + | r1 + | """.stripMargin) + assertEquals("case 3: blerg3", result) + } + + def wrapAndRun(code: String): Any = { + run( + s""" + |import scala.async.run.late.{autoawait,lateasync} + |object Test { + | @lateasync + | def test: Any = { + | $code + | } + |} + | """.stripMargin) + } + + def run(code: String): Any = { + val reporter = new StoreReporter + val settings = new Settings(println(_)) + settings.outdir.value = sys.props("java.io.tmpdir") + settings.embeddedDefaults(getClass.getClassLoader) + val isInSBT = !settings.classpath.isSetByUser + if (isInSBT) settings.usejavacp.value = true + val global = new Global(settings, reporter) { + self => + + object late extends { + val global: self.type = self + } with LatePlugin + + override protected def loadPlugins(): List[Plugin] = late :: Nil + } + import global._ + + val run = new Run + val source = newSourceFile(code) + run.compileSources(source :: Nil) + assert(!reporter.hasErrors, reporter.infos.mkString("\n")) + val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader) + val cls = loader.loadClass("Test") + cls.getMethod("test").invoke(null) + } +} + +abstract class LatePlugin extends Plugin { + import global._ + + override val components: List[PluginComponent] = List(new PluginComponent with TypingTransformers { + val global: LatePlugin.this.global.type = LatePlugin.this.global + + lazy val asyncIdSym = symbolOf[AsyncId.type] + lazy val asyncSym = asyncIdSym.info.member(TermName("async")) + lazy val awaitSym = asyncIdSym.info.member(TermName("await")) + lazy val autoAwaitSym = symbolOf[autoawait] + lazy val lateAsyncSym = symbolOf[lateasync] + + def newTransformer(unit: CompilationUnit) = new TypingTransformer(unit) { + override def transform(tree: Tree): Tree = { + super.transform(tree) match { + case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) => + localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil)) + case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) { + val expandee = localTyper.context.withMacrosDisabled( + localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(dd.rhs.tpe) :: Nil), List(dd.rhs))) + ) + val c = analyzer.macroContext(localTyper, gen.mkAttributedRef(asyncIdSym), expandee) + val asyncMacro = AsyncMacro(c, AsyncId)(dd.rhs) + val code = asyncMacro.asyncTransform[Any](localTyper.typed(Literal(Constant(()))))(c.weakTypeTag[Any]) + deriveDefDef(dd)(_ => localTyper.typed(code)) + } + case x => x + } + } + } + + override def newPhase(prev: Phase): Phase = new StdPhase(prev) { + override def apply(unit: CompilationUnit): Unit = { + val translated = newTransformer(unit).transformUnit(unit) + //println(show(unit.body)) + translated + } + } + + override val runsAfter: List[String] = "patmat" :: Nil + override val phaseName: String = "postpatmat" + + }) + override val description: String = "postpatmat" + override val name: String = "postpatmat" +} + +// Methods with this annotation are translated to having the RHS wrapped in `AsyncId.async { }` +final class lateasync extends StaticAnnotation + +// Calls to methods with this annotation are translated to `AsyncId.await()` +final class autoawait extends StaticAnnotation