Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions fixtureJava/src/cellar/fixture/java/CellarJavaClass.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,16 @@ public T getDefault() {
public static int staticHelper(int n) {
return n * 2;
}

public String format(int value) {
return String.valueOf(value);
}

public String format(String value) {
return value;
}

public String format(int value, boolean verbose) {
return verbose ? "value=" + value : String.valueOf(value);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package cellar.fixture.scala2

/** Fixture for testing that overloaded methods are preserved in Scala 2 artifacts. */
trait CellarOverloaded {
def process(value: Int): String
def process(value: String): String
def process(value: Int, flag: Boolean): String
def unique: Int
}
16 changes: 16 additions & 0 deletions fixtureScala3/src/cellar/fixture/scala3/CellarOverloaded.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package cellar.fixture.scala3

/** Fixture for testing that overloaded methods are all preserved in output. */
trait CellarOverloaded:
def process(value: Int): String
def process(value: String): String
def process(value: Int, flag: Boolean): String
def unique: Int

/** Fixture for testing overloads split across inheritance. */
trait CellarOverloadedBase:
def action(value: Int): String

trait CellarOverloadedChild extends CellarOverloadedBase:
def action(value: String): String
def childOnly: Int
22 changes: 1 addition & 21 deletions lib/src/cellar/GetFormatter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,32 +73,12 @@ object GetFormatter:
private def renderMembers(sym: Symbol)(using ctx: Context): Option[String] =
sym match
case cls: ClassSymbol =>
val members = collectClassMembers(cls)
val members = SymbolResolver.collectClassMembers(cls)
.filter(PublicApiFilter.isPublic)
.map(m => TypePrinter.printSymbolSignatureSafe(m).linesIterator.mkString(" ").trim)
if members.isEmpty then None else Some(members.mkString("\n"))
case _ => None

private val universalBaseClasses = Set("scala.Any", "scala.AnyRef", "java.lang.Object")

private def collectClassMembers(cls: ClassSymbol)(using ctx: Context): List[TermOrTypeSymbol] =
val seen = scala.collection.mutable.Set.empty[String]
val result = List.newBuilder[TermOrTypeSymbol]
val linearization =
try cls.linearization
catch case _: Exception => List(cls)
for
klass <- linearization if !universalBaseClasses.contains(klass.displayFullName)
decl <-
try klass.declarations
catch case _: Exception => Nil
do
val key = decl.name.toString
if !seen.contains(key) then
seen += key
result += decl
result.result()

private def renderCompanion(sym: Symbol)(using ctx: Context): Option[String] =
sym match
case cls: ClassSymbol =>
Expand Down
24 changes: 2 additions & 22 deletions lib/src/cellar/SymbolLister.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package cellar
import cats.effect.IO
import fs2.Stream
import tastyquery.Contexts.Context
import tastyquery.Symbols.{ClassSymbol, PackageSymbol, Symbol, TermOrTypeSymbol}
import tastyquery.Symbols.{ClassSymbol, PackageSymbol, Symbol}

sealed trait ListTarget
object ListTarget:
Expand Down Expand Up @@ -40,26 +40,6 @@ object SymbolLister:

case ListTarget.Cls(cls) =>
Stream
.eval(IO.blocking(collectClassMembers(cls)))
.eval(IO.blocking(SymbolResolver.collectClassMembers(cls)))
.flatMap(syms => Stream.emits(syms))
.filter(PublicApiFilter.isPublic)

private def collectClassMembers(cls: ClassSymbol)(using ctx: Context): List[TermOrTypeSymbol] =
// Walk linearization (MRO), collecting decls. Deduplicate by unsigned name, keeping
// the most-derived occurrence (first seen in MRO order).
val seen = scala.collection.mutable.Set.empty[String]
val result = List.newBuilder[TermOrTypeSymbol]
val linearization =
try cls.linearization
catch case _: Exception => List(cls)
for
klass <- linearization
decl <-
try klass.declarations
catch case _: Exception => Nil
do
val key = decl.name.toString
if !seen.contains(key) then
seen += key
result += decl
result.result()
29 changes: 28 additions & 1 deletion lib/src/cellar/SymbolResolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import cats.effect.IO
import tastyquery.Contexts.Context
import tastyquery.Exceptions.MemberNotFoundException
import tastyquery.Names.{termName, typeName}
import tastyquery.Symbols.{ClassSymbol, Symbol}
import tastyquery.Symbols.{ClassSymbol, Symbol, TermOrTypeSymbol}

sealed trait LookupResult
object LookupResult:
Expand Down Expand Up @@ -145,6 +145,33 @@ object SymbolResolver:
i += 1
best

private val universalBaseClasses = Set("scala.Any", "scala.AnyRef", "java.lang.Object")

/**
* Walk the linearization (MRO) of a class, collecting all declarations.
* Uses tastyquery's `overridingSymbol` to skip declarations that are
* overridden by a more-derived class, while preserving distinct overloads.
*/
def collectClassMembers(cls: ClassSymbol)(using ctx: Context): List[TermOrTypeSymbol] =
val seen = scala.collection.mutable.Set.empty[TermOrTypeSymbol]
val result = List.newBuilder[TermOrTypeSymbol]
val linearization =
try cls.linearization
catch case _: Exception => List(cls)
linearization.filterNot(k => universalBaseClasses.contains(k.displayFullName)).foreach { klass =>
val decls =
try klass.declarations
catch case _: Exception => Nil
decls.foreach { decl =>
if !seen.contains(decl) then
val dominated = decl.overridingSymbol(cls).exists(_ != decl)
if !dominated then
seen += decl
result += decl
}
}
result.result()

private def tryOrNone[A](thunk: => A): Option[A] =
try Some(thunk)
catch
Expand Down
62 changes: 62 additions & 0 deletions lib/test/src/cellar/GetFormatterTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,65 @@ class GetFormatterTest extends CatsEffectSuite:
assert(!output.contains("---"), s"Unexpected separator: $output")
}
}

test("formatSymbol members includes all overloaded methods"):
withCtx { ctx =>
IO.blocking {
given Context = ctx
val cls = ctx.findStaticClass("cellar.fixture.scala3.CellarOverloaded")
val output = GetFormatter.formatSymbol(cls)
val processCount = output.linesIterator.count(_.contains("def process("))
assertEquals(processCount, 3, s"Expected 3 process overloads in:\n$output")
}
}

test("formatSymbol members includes inherited overloaded methods"):
withCtx { ctx =>
IO.blocking {
given Context = ctx
val cls = ctx.findStaticClass("cellar.fixture.scala3.CellarOverloadedChild")
val output = GetFormatter.formatSymbol(cls)
val actionCount = output.linesIterator.count(_.contains("def action("))
assertEquals(actionCount, 2, s"Expected 2 action overloads in:\n$output")
}
}

private def withJavaCtx[A](body: Context => IO[A]): IO[A] =
TestFixtures.assumeFixturesAvailable()
for
jrePaths <- JreClasspath.jrtPath()
jars <- CoursierFetchClient.fetchClasspath(
TestFixtures.javaCoord, Seq(TestFixtures.localM2Repo))
result <- ContextResource.make(jars, jrePaths).use { (ctx, _) => body(ctx) }
yield result

private def withScala2Ctx[A](body: Context => IO[A]): IO[A] =
TestFixtures.assumeFixturesAvailable()
for
jrePaths <- JreClasspath.jrtPath()
jars <- CoursierFetchClient.fetchClasspath(
TestFixtures.scala2Coord, Seq(TestFixtures.localM2Repo))
result <- ContextResource.make(jars, jrePaths).use { (ctx, _) => body(ctx) }
yield result

test("formatSymbol members includes all overloaded methods (Java)"):
withJavaCtx { ctx =>
IO.blocking {
given Context = ctx
val cls = ctx.findStaticClass("cellar.fixture.java.CellarJavaClass")
val output = GetFormatter.formatSymbol(cls)
val formatCount = output.linesIterator.count(_.contains("def format("))
assertEquals(formatCount, 3, s"Expected 3 format overloads in:\n$output")
}
}

test("formatSymbol members includes all overloaded methods (Scala 2)"):
withScala2Ctx { ctx =>
IO.blocking {
given Context = ctx
val cls = ctx.findStaticClass("cellar.fixture.scala2.CellarOverloaded")
val output = GetFormatter.formatSymbol(cls)
val processCount = output.linesIterator.count(_.contains("def process("))
assertEquals(processCount, 3, s"Expected 3 process overloads in:\n$output")
}
}
70 changes: 70 additions & 0 deletions lib/test/src/cellar/SymbolListerTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,73 @@ class SymbolListerTest extends CatsEffectSuite:
case other => fail(s"CellarA not found, got $other")
}
}

test("listMembers includes all overloaded methods"):
withCtx { ctx =>
given Context = ctx
SymbolLister.resolve("cellar.fixture.scala3.CellarOverloaded").flatMap {
case ListResolveResult.Found(target) =>
SymbolLister.listMembers(target).compile.toList.map { syms =>
val processCount = syms.count(_.name.toString == "process")
assertEquals(processCount, 3, s"Expected 3 process overloads, got names: ${syms.map(_.name)}")
}
case other => fail(s"CellarOverloaded not found, got $other")
}
}
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new test only covers overloads declared within a single trait. Given the name-based MRO deduplication in collectClassMembers, overloads split across inheritance (base trait defines one overload, derived trait adds another) are a likely regression case; adding a fixture/test for that scenario would better validate the intended behavior end-to-end.

Suggested change
}
}
test("listMembers includes inherited overloaded methods"):
withCtx { ctx =>
given Context = ctx
SymbolLister.resolve("cellar.fixture.scala3.CellarOverloadedInherited").flatMap {
case ListResolveResult.Found(target) =>
SymbolLister.listMembers(target).compile.toList.map { syms =>
val processCount = syms.count(_.name.toString == "process")
assertEquals(
processCount,
2,
s"Expected 2 inherited process overloads (base + derived), got names: ${syms.map(_.name)}"
)
}
case other => fail(s"CellarOverloadedInherited not found, got $other")
}
}

Copilot uses AI. Check for mistakes.

test("listMembers includes inherited overloaded methods"):
withCtx { ctx =>
given Context = ctx
SymbolLister.resolve("cellar.fixture.scala3.CellarOverloadedChild").flatMap {
case ListResolveResult.Found(target) =>
SymbolLister.listMembers(target).compile.toList.map { syms =>
val actionCount = syms.count(_.name.toString == "action")
assertEquals(actionCount, 2, s"Expected 2 action overloads, got names: ${syms.map(_.name)}")
}
case other => fail(s"CellarOverloadedChild not found, got $other")
}
}

private def withJavaCtx[A](body: Context => IO[A]): IO[A] =
TestFixtures.assumeFixturesAvailable()
for
jrePaths <- JreClasspath.jrtPath()
jars <- CoursierFetchClient.fetchClasspath(
TestFixtures.javaCoord, Seq(TestFixtures.localM2Repo))
result <- ContextResource.make(jars, jrePaths).use { (ctx, _) => body(ctx) }
yield result

private def withScala2Ctx[A](body: Context => IO[A]): IO[A] =
TestFixtures.assumeFixturesAvailable()
for
jrePaths <- JreClasspath.jrtPath()
jars <- CoursierFetchClient.fetchClasspath(
TestFixtures.scala2Coord, Seq(TestFixtures.localM2Repo))
result <- ContextResource.make(jars, jrePaths).use { (ctx, _) => body(ctx) }
yield result

test("listMembers includes all overloaded methods (Java)"):
withJavaCtx { ctx =>
given Context = ctx
SymbolLister.resolve("cellar.fixture.java.CellarJavaClass").flatMap {
case ListResolveResult.Found(target) =>
SymbolLister.listMembers(target).compile.toList.map { syms =>
val formatCount = syms.count(_.name.toString == "format")
assertEquals(formatCount, 3, s"Expected 3 format overloads, got names: ${syms.map(_.name)}")
}
case other => fail(s"CellarJavaClass not found, got $other")
}
}

test("listMembers includes all overloaded methods (Scala 2)"):
withScala2Ctx { ctx =>
given Context = ctx
SymbolLister.resolve("cellar.fixture.scala2.CellarOverloaded").flatMap {
case ListResolveResult.Found(target) =>
SymbolLister.listMembers(target).compile.toList.map { syms =>
val processCount = syms.count(_.name.toString == "process")
assertEquals(processCount, 3, s"Expected 3 process overloads, got names: ${syms.map(_.name)}")
}
case other => fail(s"CellarOverloaded not found, got $other")
}
}
Loading