Skip to content

Fix #23224: Optimize simple tuple extraction #23373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 73 additions & 20 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,27 @@ object desugar {
sel
end match

case class TuplePatternInfo(arity: Int, varNum: Int, wildcardNum: Int)
object TuplePatternInfo:
def apply(pat: Tree)(using Context): TuplePatternInfo = pat match
case Tuple(pats) =>
var arity = 0
var varNum = 0
var wildcardNum = 0
pats.foreach: p =>
arity += 1
p match
case id: Ident if !isBackquoted(id) =>
if id.name.isVarPattern then
varNum += 1
if id.name == nme.WILDCARD then
wildcardNum += 1
case _ =>
TuplePatternInfo(arity, varNum, wildcardNum)
case _ =>
TuplePatternInfo(-1, -1, -1)
end TuplePatternInfo

/** If `pat` is a variable pattern,
*
* val/var/lazy val p = e
Expand Down Expand Up @@ -1483,30 +1504,47 @@ object desugar {
|please bind to an identifier and use an alias given.""", bind)
false

def isTuplePattern(arity: Int): Boolean = pat match {
case Tuple(pats) if pats.size == arity =>
pats.forall(isVarPattern)
case _ => false
}

val isMatchingTuple: Tree => Boolean = {
case Tuple(es) => isTuplePattern(es.length) && !hasNamedArg(es)
case _ => false
}
val tuplePatternInfo = TuplePatternInfo(pat)

// When desugaring a PatDef in general, we use pattern matching on the rhs
// and collect the variable values in a tuple, then outside the match,
// we destructure the tuple to get the individual variables.
// We can achieve two kinds of tuple optimizations if the pattern is a tuple
// of simple variables or wildcards:
// 1. Full optimization:
// If the rhs is known to produce a literal tuple of the same arity,
// we can directly fetch the values from the tuple.
// For example: `val (x, y) = if ... then (1, "a") else (2, "b")` becomes
// `val $1$ = if ...; val x = $1$._1; val y = $1$._2`.
// 2. Partial optimization:
// If the rhs can be typed as a tuple and matched with correct arity, we can
// return the tuple itself in the case if there are no more than one variable
// in the pattern, or return the the value if there is only one variable.

val fullTupleOptimizable =
val isMatchingTuple: Tree => Boolean = {
case Tuple(es) => tuplePatternInfo.varNum == es.length && !hasNamedArg(es)
case _ => false
}
tuplePatternInfo.arity > 0
&& tuplePatternInfo.arity == tuplePatternInfo.varNum
&& forallResults(rhs, isMatchingTuple)

// We can only optimize `val pat = if (...) e1 else e2` if:
// - `e1` and `e2` are both tuples of arity N
// - `pat` is a tuple of N variables or wildcard patterns like `(x1, x2, ..., xN)`
val tupleOptimizable = forallResults(rhs, isMatchingTuple)
val partialTupleOptimizable =
tuplePatternInfo.arity > 0
&& tuplePatternInfo.arity == tuplePatternInfo.varNum
// We exclude the case where there is only one variable,
// because it should be handled by `makeTuple` directly.
&& tuplePatternInfo.wildcardNum < tuplePatternInfo.arity - 1

val inAliasGenerator = original match
case _: GenAlias => true
case _ => false

val vars =
if (tupleOptimizable) // include `_`
val vars: List[VarInfo] =
if fullTupleOptimizable || partialTupleOptimizable then // include `_`
pat match
case Tuple(pats) => pats.map { case id: Ident => id -> TypeTree() }
case Tuple(pats) => pats.map { case id: Ident => (id, TypeTree()) }
else
getVariables(
tree = pat,
Expand All @@ -1517,12 +1555,27 @@ object desugar {
errorOnGivenBinding
) // no `_`

val ids = for ((named, _) <- vars) yield Ident(named.name)
val ids = for ((named, tpt) <- vars) yield Ident(named.name)

val matchExpr =
if (tupleOptimizable) rhs
if fullTupleOptimizable then rhs
else
val caseDef = CaseDef(pat, EmptyTree, makeTuple(ids).withAttachment(ForArtifact, ()))
val caseDef =
if partialTupleOptimizable then
val tmpTuple = UniqueName.fresh()
// Replace all variables with wildcards in the pattern
val pat1 = pat match
case Tuple(pats) =>
val wildcardPats = pats.map(p => Ident(nme.WILDCARD).withSpan(p.span))
Tuple(wildcardPats).withSpan(pat.span)
CaseDef(
Bind(tmpTuple, pat1),
EmptyTree,
Ident(tmpTuple).withAttachment(ForArtifact, ())
)
else CaseDef(pat, EmptyTree, makeTuple(ids).withAttachment(ForArtifact, ()))
Match(makeSelector(rhs, MatchCheck.IrrefutablePatDef), caseDef :: Nil)

vars match {
case Nil if !mods.is(Lazy) =>
matchExpr
Expand Down
10 changes: 6 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,16 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
}

/** Checks whether predicate `p` is true for all result parts of this expression,
* where we zoom into Ifs, Matches, and Blocks.
* where we zoom into Ifs, Matches, Tries, and Blocks.
*/
def forallResults(tree: Tree, p: Tree => Boolean): Boolean = tree match {
def forallResults(tree: Tree, p: Tree => Boolean): Boolean = tree match
case If(_, thenp, elsep) => forallResults(thenp, p) && forallResults(elsep, p)
case Match(_, cases) => cases forall (c => forallResults(c.body, p))
case Match(_, cases) => cases.forall(c => forallResults(c.body, p))
case Try(_, cases, finalizer) =>
cases.forall(c => forallResults(c.body, p))
&& (finalizer.isEmpty || forallResults(finalizer, p))
case Block(_, expr) => forallResults(expr, p)
case _ => p(tree)
}

/** The tree stripped of the possibly nested applications (term and type).
* The original tree if it's not an application.
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1689,7 +1689,8 @@ trait Applications extends Compatibility {
if selType <:< unapplyArgType then
unapp.println(i"case 1 $unapplyArgType ${ctx.typerState.constraint}")
fullyDefinedType(unapplyArgType, "pattern selector", tree.srcPos)
selType.dropAnnot(defn.UncheckedAnnot) // need to drop @unchecked. Just because the selector is @unchecked, the pattern isn't.
if selType.isBottomType then unapplyArgType
else selType.dropAnnot(defn.UncheckedAnnot) // need to drop @unchecked. Just because the selector is @unchecked, the pattern isn't.
else
if !ctx.mode.is(Mode.InTypeTest) then
checkMatchable(selType, tree.srcPos, pattern = true)
Expand All @@ -1711,7 +1712,7 @@ trait Applications extends Compatibility {
val unapplyPatterns = UnapplyArgs(unapplyApp.tpe, unapplyFn, unadaptedArgs, tree.srcPos)
.typedPatterns(qual, this)
val result = assignType(cpy.UnApply(tree)(newUnapplyFn, unapplyImplicits(dummyArg, unapplyApp), unapplyPatterns), ownType)
if (ownType.stripped eq selType.stripped) || ownType.isError then result
if (ownType.stripped eq selType.stripped) || selType.isBottomType || ownType.isError then result
else tryWithTypeTest(Typed(result, TypeTree(ownType)), selType)
case tp =>
val unapplyErr = if (tp.isError) unapplyFn else notAnExtractor(unapplyFn)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2816,6 +2816,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
if isStableIdentifierOrLiteral || isNamedTuplePattern then pt
else if isWildcardStarArg(body1)
|| pt == defn.ImplicitScrutineeTypeRef
|| pt.isBottomType
|| body1.tpe <:< pt // There is some strange interaction with gadt matching.
// and implicit scopes.
// run/t2755.scala fails to compile if this subtype test is omitted
Expand Down
9 changes: 9 additions & 0 deletions compiler/test/dotty/tools/backend/jvm/DottyBytecodeTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ trait DottyBytecodeTest {
}, l.stringLines)
}

def assertNoInvoke(m: MethodNode, receiver: String, method: String): Unit =
assertNoInvoke(instructionsFromMethod(m), receiver, method)
def assertNoInvoke(l: List[Instruction], receiver: String, method: String): Unit = {
assert(!l.exists {
case Invoke(_, `receiver`, `method`, _, _) => true
case _ => false
}, s"Found unexpected invoke of $receiver.$method in:\n${l.stringLines}")
}

def diffInstructions(isa: List[Instruction], isb: List[Instruction]): String = {
val len = Math.max(isa.length, isb.length)
val sb = new StringBuilder
Expand Down
15 changes: 15 additions & 0 deletions compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,21 @@ class DottyBytecodeTests extends DottyBytecodeTest {
}
}

@Test
def simpleTupleExtraction(): Unit = {
val code =
"""class C {
| def f1(t: (Int, String)) =
| val (i, s) = t
| i + s.length
|}
""".stripMargin
checkBCode(code) { dir =>
val c = loadClassNode(dir.lookupName("C.class", directory = false).input)
assertNoInvoke(getMethod(c, "f1"), "scala/Tuple2$", "apply") // no Tuple2.apply call
}
}

@Test
def deprecation(): Unit = {
val code =
Expand Down
12 changes: 6 additions & 6 deletions tests/neg/i7294.check
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
-- [E007] Type Mismatch Error: tests/neg/i7294.scala:7:15 --------------------------------------------------------------
-- [E007] Type Mismatch Error: tests/neg/i7294.scala:7:18 --------------------------------------------------------------
7 | case x: T => x.g(10) // error
| ^
| Found: (x : Nothing)
| Required: ?{ g: ? }
| Note that implicit conversions were not tried because the result of an implicit conversion
| must be more specific than ?{ g: [applied to (10) returning T] }
| ^^^^^^^
| Found: Any
| Required: T
|
| where: T is a type in given instance f with bounds <: foo.Foo
|
| longer explanation available when compiling with `-explain`
43 changes: 43 additions & 0 deletions tests/pos/simple-tuple-extract.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

class Test:
def f1: (Int, String, AnyRef) = (1, "2", "3")
def f2: (x: Int, y: String) = (0, "y")

def test1 =
val (a, b, c) = f1
// Desugared to:
// val $2$: (Int, String, AnyRef) =
// this.f1:(Int, String, AnyRef) @unchecked match
// {
// case $1$ @ Tuple3.unapply[Int, String, Object](_, _, _) =>
// $1$:(Int, String, AnyRef)
// }
// val a: Int = $2$._1
// val b: String = $2$._2
// val c: AnyRef = $2$._3
a + b.length() + c.toString.length()

// This pattern will not be optimized:
// val (a1, b1, c1: String) = f1

def test2 =
val (_, b, c) = f1
b.length() + c.toString.length()

val (a2, _, c2) = f1
a2 + c2.toString.length()

val (a3, _, _) = f1
a3 + 1

def test3 =
val (_, b, _) = f1
b.length() + 1

def test4 =
val (x, y) = f2
x + y.length()

def test5 =
val (_, b) = f2
b.length() + 1