Skip to content

Commit

Permalink
Extension copy (#262)
Browse files Browse the repository at this point in the history
Co-authored-by: Kacper Korban <[email protected]>
  • Loading branch information
OndrejSpanel and KacperFKorban authored Dec 17, 2024
1 parent 7cf831e commit 663e750
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 64 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ excludeLintKeys in Global ++= Set(ideSkipProject)
val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq(
organization := "com.softwaremill.quicklens",
updateDocs := UpdateVersionInDocs(sLog.value, organization.value, version.value, List(file("README.md"))),
scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked"), // useful for debugging macros: "-Ycheck:all"
scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked"), // useful for debugging macros: "-Ycheck:all", "-Xcheck-macros"
ideSkipProject := (scalaVersion.value != scalaIdeaVersion)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ object QuicklensMacros {
def noSuchMember(tpeStr: String, name: String) =
s"$tpeStr has no member named $name"

def noSuitableMember(tpeStr: String, name: String, argNames: Iterable[String]) =
s"$tpeStr has no member $name with parameters ${argNames.mkString("(", ", ", ")")}"

def multipleMatchingMethods(tpeStr: String, name: String, syms: Seq[Symbol]) =
val symsStr = syms.map(s => s" - $s: ${s.termRef.dealias.widen.show}").mkString("\n", "\n", "")
s"Multiple methods named $name found in $tpeStr: $symsStr"
Expand Down Expand Up @@ -109,11 +112,14 @@ object QuicklensMacros {
case (symbol :: tail) => PathTree.Node(Seq(symbol -> Seq(tail.toPathTree)))

enum PathSymbol:
case Field(name: String)
case FunctionDelegate(name: String, givn: Term, typeTree: TypeTree, args: List[Term])
case Field(override val name: String)
case Extension(term: Term, override val name: String)
case FunctionDelegate(override val name: String, givn: Term, typeTree: TypeTree, args: List[Term])
def name: String

def equiv(other: Any): Boolean = (this, other) match
case (Field(name1), Field(name2)) => name1 == name2
case (Extension(term1, name1), Extension(term2, name2)) => term1 == term2 && name1 == name2
case (FunctionDelegate(name1, _, typeTree1, args1), FunctionDelegate(name2, _, typeTree2, args2)) =>
name1 == name2 && typeTree1.tpe == typeTree2.tpe && args1 == args2
case _ => false
Expand All @@ -133,6 +139,9 @@ object QuicklensMacros {
/** Method call with one type parameter and using clause */
case a @ Apply(TypeApply(Apply(TypeApply(Ident(s), _), idents), typeTrees), List(givn)) if methodSupported(s) =>
idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
/** Extension method, which is called e.g. as x(_$1) */
case Apply(obj@Select(term, member), Seq(deep)) if obj.symbol.flags.is(Flags.ExtensionMethod) =>
toPath(deep, focus) :+ PathSymbol.Extension(term, member)
/** Field access */
case Apply(deep, idents) =>
toPath(deep, focus) ++ idents.flatMap(toPath(_, focus))
Expand All @@ -157,43 +166,104 @@ object QuicklensMacros {
def matchingTypeSymbol: Symbol = tpe.widenAll match {
case AndType(l, r) =>
val lSym = l.matchingTypeSymbol
if l.matchingTypeSymbol != Symbol.noSymbol then lSym else r.matchingTypeSymbol
case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) =>
tpe.typeSymbol
case tpe if isProductLike(tpe.typeSymbol) =>
if lSym != Symbol.noSymbol then lSym else r.matchingTypeSymbol
case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) || isProductLike(tpe.typeSymbol) =>
tpe.typeSymbol
case _ =>
Symbol.noSymbol
}

def symbolAccessorByNameOrError(sym: Symbol, name: String): Symbol = {
val mem = sym.fieldMember(name)
if mem != Symbol.noSymbol then mem
else methodSymbolByNameOrError(sym, name)
extension (term: Term)
def appliedToIfNeeded(args: List[Term]): Term =
if args.isEmpty then term else term.appliedToArgs(args)

def symbolAccessorByNameOrError(obj: Term, name: String): Term = {
val objTpe = obj.tpe.widenAll
val objSymbol = objTpe.matchingTypeSymbol
// opaque types can find members of underlying types - ignore them (see https://github.com/scala/scala3/issues/22143)
val fieldMemberSym = objSymbol.fieldMember(name)
if !objSymbol.flags.is(Flags.Deferred) && fieldMemberSym.exists then
Select(obj, fieldMemberSym)
else
objSymbol.methodMember(name) match
case List(m) =>
Select(obj, m)
case lst =>
report.errorAndAbort(reportMethodError(objSymbol, name, lst))
}

def reportMethodError(sym: Symbol, name: String, lst: List[Symbol], maybeArgNames: Option[Iterable[String]] = None): String = {
(lst, maybeArgNames) match
case (Nil, _) => noSuchMember(sym.name, name)
case (lst, None) => multipleMatchingMethods(sym.name, name, lst)
case (lst, Some(argNames)) => noSuitableMember(sym.name, name, argNames)
}

def methodSymbolByNameOrError(sym: Symbol, name: String): Symbol = {
sym.methodMember(name) match
case List(m) => m
case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
case lst => report.errorAndAbort(multipleMatchingMethods(sym.name, name, lst))
case lst => report.errorAndAbort(reportMethodError(sym, name, lst))
}

def methodSymbolByNameAndArgsOrError(sym: Symbol, name: String, argsMap: Map[String, Term]): Symbol = {
def filterMethodsByNameAndArgs(allMethods: List[Symbol], argsMap: Map[String, Term]): Option[Symbol] = {
val argNames = argsMap.keys
sym.methodMember(name).filter{ msym =>
allMethods.filter { msym =>
// for copy, we filter out the methods that don't have the desired parameter names
val paramNames = msym.paramSymss.flatten.filter(_.isTerm).map(_.name)
argNames.forall(paramNames.contains)
} match
case List(m) => m
case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
case lst @ (m :: _) =>
case List(m) => Some(m)
case Nil => None
case lst@(m :: _) =>
// if we have multiple matching copy methods, pick the synthetic one, if it exists, otherwise, pick any method
val syntheticCopies = lst.filter(_.flags.is(Flags.Synthetic))
syntheticCopies match
case List(mSynth) => mSynth
case _ => m
case List(mSynth) => Some(mSynth)
case _ => Some(m)
}

def methodSymbolByNameAndArgs(sym: Symbol, name: String, argsMap: Map[String, Term]): Either[String, Symbol] = {
if !sym.flags.is(Flags.Deferred) then
val memberMethods = sym.methodMember(name)
filterMethodsByNameAndArgs(memberMethods, argsMap)
.toRight(reportMethodError(sym, name, memberMethods, Some(argsMap.keys)))
else Left(s"Deferred type ${sym.name}")
}

/**
* @param argsMap normal methods receive one parameter list, extensions methods two, the first one contains the value
* on which the extension is called
* */
def callMethod(obj: Term, copy: Symbol, argsMap: List[Map[String, Term]]) = {
require(argsMap.size == 1 || argsMap.size == 2, s"argsMap.size should be either 1 or 2, got: ${argsMap.size} ($argsMap)")
val objTpe = obj.tpe.widenAll
val objSymbol = objTpe.matchingTypeSymbol

val typeParams = objTpe.typeArgs
val copyTree: DefDef = copy.tree.asInstanceOf[DefDef]
val copyParams: List[(String, Option[Term])] = copyTree.termParamss.zip(argsMap)
.map((params, args) => params.params.map(_.name).map(name => name -> args.get(name)))
.flatten.toList

val args = copyParams.zipWithIndex.map { case ((n, v), _i) =>
val i = _i + 1
def defaultMethod: Term =
val methodSymbol = methodSymbolByNameOrError(objSymbol, copy.name + "$default$" + i.toString)
// default values in extension methods take the extension receiver as the first parameter
val defaultMethodArgs = argsMap.dropRight(1).flatMap(_.values)
obj.select(methodSymbol).appliedToIfNeeded(defaultMethodArgs)
n -> v.getOrElse(defaultMethod)
}.toMap

val argLists: List[List[Term]] = copyTree.termParamss.take(argsMap.size).map(list => list.params.map(p => args(p.name)))

if copyTree.termParamss.drop(argLists.size).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
report.errorAndAbort(
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit. ${copyTree.termParamss.drop(1)}"
)

val withTypeParamsApplied = obj.select(copy).appliedToTypes(typeParams)
argLists.foldLeft(withTypeParamsApplied)(Apply(_, _))
}

def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
Expand All @@ -210,15 +280,32 @@ object QuicklensMacros {
(sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract)))
}

def findCompanionLikeObject(objSymbol: Symbol): Symbol = {
if objSymbol.companionModule.exists then
objSymbol.companionModule
else
val namedFromOwnerScope = objSymbol.owner.fieldMember(objSymbol.name)
if namedFromOwnerScope.flags.is(Flags.Module) then namedFromOwnerScope
else Symbol.noSymbol
}

def hasExtensionNamed(sym: Symbol, methodName: String): List[Symbol] = {
val companionSymbol = findCompanionLikeObject(sym)
if companionSymbol.exists then
companionSymbol.methodMember(methodName).filter(s => s.name == methodName && s.flags.is(Flags.ExtensionMethod))
else
Nil
}

def isProductLike(sym: Symbol): Boolean = {
sym.methodMember("copy").size >= 1
sym.methodMember("copy").nonEmpty || hasExtensionNamed(sym, "copy").nonEmpty
}

def caseClassCopy(
owner: Symbol,
mod: Expr[A => A],
obj: Term,
fields: Seq[(PathSymbol.Field, Seq[PathTree])]
fields: Seq[(PathSymbol.Field | PathSymbol.Extension, Seq[PathTree])]
): Term = {
val objTpe = obj.tpe.widenAll
val objSymbol = objTpe.matchingTypeSymbol
Expand Down Expand Up @@ -248,50 +335,39 @@ object QuicklensMacros {
}

val elseThrow = '{ throw new IllegalStateException() }.asTerm

ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) =>
If(ifCond, ifThen, ifElse)
}
} else if isProduct(objSymbol) || isProductLike(objSymbol) then {
val argsMap: Map[String, Term] = fields.map { (field, trees) =>
val fieldMethod = symbolAccessorByNameOrError(objSymbol, field.name)
val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) =>
val fieldMethod = field match {
case PathSymbol.Field(name) =>
symbolAccessorByNameOrError(obj, name)
case PathSymbol.Extension(term, name) =>
val extensionMethod = symbolAccessorByNameOrError(term, name)
Apply(extensionMethod, List(obj))
}
val resTerm: Term = trees.foldLeft[Term](fieldMethod) { (term, tree) =>
mapToCopy(owner, mod, term, tree)
}
val namedArg = NamedArg(field.name, resTerm)
field.name -> namedArg
}.toMap
val copy = methodSymbolByNameAndArgsOrError(objSymbol, "copy", argsMap)

val typeParams = objTpe match {
case AppliedType(_, typeParams) => Some(typeParams)
case _ => None
}
val copyTree: DefDef = copy.tree.asInstanceOf[DefDef]
val copyParamNames: List[String] = copyTree.termParamss.headOption.map(_.params).toList.flatten.map(_.name)

val args = copyParamNames.zipWithIndex.map { (n, _i) =>
val i = _i + 1
val defaultMethod = obj.select(methodSymbolByNameOrError(objSymbol, "copy$default$" + i.toString))
// for extension methods, might need sth more like this: (or probably some weird implicit conversion)
// val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n))
argsMap.getOrElse(
n,
defaultMethod
)
}.toList

if copyTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
report.errorAndAbort(
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit."
)

typeParams match {
// if the object's type is parametrised, we need to call .copy with the same type parameters
case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args)
case _ => Apply(Select(obj, copy), args)
}
methodSymbolByNameAndArgs(objSymbol, "copy", argsMap) match
case Right(copy) =>
callMethod(obj, copy, List(argsMap))
case Left(error) =>
val objCompanion = findCompanionLikeObject(objSymbol)
methodSymbolByNameAndArgs(objCompanion, "copy", argsMap).toOption match
case Some(copy) =>
// now try to call the extension as a method, assume the object is its first parameter
val extensionParameter = copy.paramSymss.headOption.map(_.headOption).flatten
val argsWithObj = List(extensionParameter.map(name => name.name -> obj).toMap, argsMap)
callMethod(Ref(objCompanion), copy, argsWithObj)
case None => report.errorAndAbort(error)
} else
report.errorAndAbort(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol of type ${objTpe.show} (${obj.show})")
report.errorAndAbort(s"Unsupported source object: must be a case class, sealed trait or class with copy method, but got: $objSymbol of type ${objTpe.show} (${obj.show})")
}

def applyFunctionDelegate(
Expand Down Expand Up @@ -331,9 +407,9 @@ object QuicklensMacros {
case Nil =>
objTerm

case (_: PathSymbol.Field, _) :: _ =>
val (fs, funs) = pathSymbols.span(_._1.isInstanceOf[PathSymbol.Field])
val fields = fs.collect { case (p: PathSymbol.Field, trees) => p -> trees }
case (_: (PathSymbol.Field | PathSymbol.Extension), _) :: _ =>
val (fs, funs) = pathSymbols.span((ps, _) => ps.isInstanceOf[PathSymbol.Field] || ps.isInstanceOf[PathSymbol.Extension])
val fields = fs.collect { case (p: (PathSymbol.Field | PathSymbol.Extension), trees) => p -> trees }
val withCopiedFields: Term = caseClassCopy(owner, mod, objTerm, fields)
accumulateToCopy(owner, mod, withCopiedFields, funs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ package object quicklens {
def map[A](fa: M[A], f: A => A): M[A] = {
val mapped = fa.view.mapValues(f)
(fa match {
case sfa: SortedMap[K, A] => sfa.sortedMapFactory.from(mapped)(using sfa.ordering)
case sfa: SortedMap[K, A]@unchecked => sfa.sortedMapFactory.from(mapped)(using sfa.ordering)
case _ => mapped.to(fa.mapFactory)
}).asInstanceOf[M[A]]
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package com.softwaremill.quicklens
package test

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
Expand Down Expand Up @@ -33,7 +34,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
def paths(paths: Paths): Docs = copy(paths = paths)
}
val docs = Docs()
docs.modify(_.paths.pathItems).using(m => m + ("a" -> PathItem()))
val r = docs.modify(_.paths.pathItems).using(m => m + ("a" -> PathItem()))
r.paths.pathItems should contain ("a" -> PathItem())
}

it should "modify a case class with an additional explicit copy" in {
Expand All @@ -42,7 +44,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
}

val f = Frozen("A", 0)
f.modify(_.state).setTo("B")
val r = f.modify(_.state).setTo("B")
r.state shouldEqual "B"
}

it should "modify a case class with an ambiguous additional explicit copy" in {
Expand All @@ -51,7 +54,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
}

val f = Frozen("A", 0)
f.modify(_.state).setTo("B")
val r = f.modify(_.state).setTo("B")
r.state shouldEqual "B"
}

it should "modify a class with two explicit copy methods" in {
Expand All @@ -61,7 +65,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
}

val f = new Frozen("A", 0)
f.modify(_.state).setTo("B")
val r = f.modify(_.state).setTo("B")
r.state shouldEqual "B"
}

it should "modify a case class with an ambiguous additional explicit copy and pick the synthetic one first" in {
Expand All @@ -77,6 +82,19 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
accessed shouldEqual 0
}

it should "not compile when modifying a field which is not present as a copy parameter" in {
"""
case class Content(x: String)
class A(val c: Content) {
def copy(x: String = c.x): A = new A(Content(x))
}
val a = new A(Content("A"))
val am = a.modify(_.c).setTo(Content("B"))
""" shouldNot compile
}

// TODO: Would be nice to be able to handle this case. Based on the types, it
// is obvious, that the explicit copy should be picked, but I'm not sure if we
// can get that information
Expand All @@ -90,5 +108,4 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
// val f = Frozen("A", 0)
// f.modify(_.state).setTo('B')
// }

}
Loading

0 comments on commit 663e750

Please sign in to comment.