diff --git a/fixtureJava/src/cellar/fixture/java/CellarJavaClass.java b/fixtureJava/src/cellar/fixture/java/CellarJavaClass.java index 9fee579..427dc62 100644 --- a/fixtureJava/src/cellar/fixture/java/CellarJavaClass.java +++ b/fixtureJava/src/cellar/fixture/java/CellarJavaClass.java @@ -38,4 +38,16 @@ public T getDefault() { public static int staticHelper(int n) { return n * 2; } + + public String format(int value) { + return String.valueOf(value); + } + + public String format(String value) { + return value; + } + + public String format(int value, boolean verbose) { + return verbose ? "value=" + value : String.valueOf(value); + } } diff --git a/fixtureScala2/src/cellar/fixture/scala2/CellarOverloaded.scala b/fixtureScala2/src/cellar/fixture/scala2/CellarOverloaded.scala new file mode 100644 index 0000000..2c904d4 --- /dev/null +++ b/fixtureScala2/src/cellar/fixture/scala2/CellarOverloaded.scala @@ -0,0 +1,9 @@ +package cellar.fixture.scala2 + +/** Fixture for testing that overloaded methods are preserved in Scala 2 artifacts. */ +trait CellarOverloaded { + def process(value: Int): String + def process(value: String): String + def process(value: Int, flag: Boolean): String + def unique: Int +} diff --git a/fixtureScala3/src/cellar/fixture/scala3/CellarOverloaded.scala b/fixtureScala3/src/cellar/fixture/scala3/CellarOverloaded.scala new file mode 100644 index 0000000..5c90c0a --- /dev/null +++ b/fixtureScala3/src/cellar/fixture/scala3/CellarOverloaded.scala @@ -0,0 +1,16 @@ +package cellar.fixture.scala3 + +/** Fixture for testing that overloaded methods are all preserved in output. */ +trait CellarOverloaded: + def process(value: Int): String + def process(value: String): String + def process(value: Int, flag: Boolean): String + def unique: Int + +/** Fixture for testing overloads split across inheritance. */ +trait CellarOverloadedBase: + def action(value: Int): String + +trait CellarOverloadedChild extends CellarOverloadedBase: + def action(value: String): String + def childOnly: Int diff --git a/lib/src/cellar/GetFormatter.scala b/lib/src/cellar/GetFormatter.scala index ef1b9ad..2a84903 100644 --- a/lib/src/cellar/GetFormatter.scala +++ b/lib/src/cellar/GetFormatter.scala @@ -73,32 +73,12 @@ object GetFormatter: private def renderMembers(sym: Symbol)(using ctx: Context): Option[String] = sym match case cls: ClassSymbol => - val members = collectClassMembers(cls) + val members = SymbolResolver.collectClassMembers(cls) .filter(PublicApiFilter.isPublic) .map(m => TypePrinter.printSymbolSignatureSafe(m).linesIterator.mkString(" ").trim) if members.isEmpty then None else Some(members.mkString("\n")) case _ => None - private val universalBaseClasses = Set("scala.Any", "scala.AnyRef", "java.lang.Object") - - private def collectClassMembers(cls: ClassSymbol)(using ctx: Context): List[TermOrTypeSymbol] = - val seen = scala.collection.mutable.Set.empty[String] - val result = List.newBuilder[TermOrTypeSymbol] - val linearization = - try cls.linearization - catch case _: Exception => List(cls) - for - klass <- linearization if !universalBaseClasses.contains(klass.displayFullName) - decl <- - try klass.declarations - catch case _: Exception => Nil - do - val key = decl.name.toString - if !seen.contains(key) then - seen += key - result += decl - result.result() - private def renderCompanion(sym: Symbol)(using ctx: Context): Option[String] = sym match case cls: ClassSymbol => diff --git a/lib/src/cellar/SymbolLister.scala b/lib/src/cellar/SymbolLister.scala index 290de87..5c150be 100644 --- a/lib/src/cellar/SymbolLister.scala +++ b/lib/src/cellar/SymbolLister.scala @@ -3,7 +3,7 @@ package cellar import cats.effect.IO import fs2.Stream import tastyquery.Contexts.Context -import tastyquery.Symbols.{ClassSymbol, PackageSymbol, Symbol, TermOrTypeSymbol} +import tastyquery.Symbols.{ClassSymbol, PackageSymbol, Symbol} sealed trait ListTarget object ListTarget: @@ -40,26 +40,6 @@ object SymbolLister: case ListTarget.Cls(cls) => Stream - .eval(IO.blocking(collectClassMembers(cls))) + .eval(IO.blocking(SymbolResolver.collectClassMembers(cls))) .flatMap(syms => Stream.emits(syms)) .filter(PublicApiFilter.isPublic) - - private def collectClassMembers(cls: ClassSymbol)(using ctx: Context): List[TermOrTypeSymbol] = - // Walk linearization (MRO), collecting decls. Deduplicate by unsigned name, keeping - // the most-derived occurrence (first seen in MRO order). - val seen = scala.collection.mutable.Set.empty[String] - val result = List.newBuilder[TermOrTypeSymbol] - val linearization = - try cls.linearization - catch case _: Exception => List(cls) - for - klass <- linearization - decl <- - try klass.declarations - catch case _: Exception => Nil - do - val key = decl.name.toString - if !seen.contains(key) then - seen += key - result += decl - result.result() diff --git a/lib/src/cellar/SymbolResolver.scala b/lib/src/cellar/SymbolResolver.scala index 397c785..d44b8f2 100644 --- a/lib/src/cellar/SymbolResolver.scala +++ b/lib/src/cellar/SymbolResolver.scala @@ -4,7 +4,7 @@ import cats.effect.IO import tastyquery.Contexts.Context import tastyquery.Exceptions.MemberNotFoundException import tastyquery.Names.{termName, typeName} -import tastyquery.Symbols.{ClassSymbol, Symbol} +import tastyquery.Symbols.{ClassSymbol, Symbol, TermOrTypeSymbol} sealed trait LookupResult object LookupResult: @@ -145,6 +145,33 @@ object SymbolResolver: i += 1 best + private val universalBaseClasses = Set("scala.Any", "scala.AnyRef", "java.lang.Object") + + /** + * Walk the linearization (MRO) of a class, collecting all declarations. + * Uses tastyquery's `overridingSymbol` to skip declarations that are + * overridden by a more-derived class, while preserving distinct overloads. + */ + def collectClassMembers(cls: ClassSymbol)(using ctx: Context): List[TermOrTypeSymbol] = + val seen = scala.collection.mutable.Set.empty[TermOrTypeSymbol] + val result = List.newBuilder[TermOrTypeSymbol] + val linearization = + try cls.linearization + catch case _: Exception => List(cls) + linearization.filterNot(k => universalBaseClasses.contains(k.displayFullName)).foreach { klass => + val decls = + try klass.declarations + catch case _: Exception => Nil + decls.foreach { decl => + if !seen.contains(decl) then + val dominated = decl.overridingSymbol(cls).exists(_ != decl) + if !dominated then + seen += decl + result += decl + } + } + result.result() + private def tryOrNone[A](thunk: => A): Option[A] = try Some(thunk) catch diff --git a/lib/test/src/cellar/GetFormatterTest.scala b/lib/test/src/cellar/GetFormatterTest.scala index 294aa3e..5a1dcf9 100644 --- a/lib/test/src/cellar/GetFormatterTest.scala +++ b/lib/test/src/cellar/GetFormatterTest.scala @@ -107,3 +107,65 @@ class GetFormatterTest extends CatsEffectSuite: assert(!output.contains("---"), s"Unexpected separator: $output") } } + + test("formatSymbol members includes all overloaded methods"): + withCtx { ctx => + IO.blocking { + given Context = ctx + val cls = ctx.findStaticClass("cellar.fixture.scala3.CellarOverloaded") + val output = GetFormatter.formatSymbol(cls) + val processCount = output.linesIterator.count(_.contains("def process(")) + assertEquals(processCount, 3, s"Expected 3 process overloads in:\n$output") + } + } + + test("formatSymbol members includes inherited overloaded methods"): + withCtx { ctx => + IO.blocking { + given Context = ctx + val cls = ctx.findStaticClass("cellar.fixture.scala3.CellarOverloadedChild") + val output = GetFormatter.formatSymbol(cls) + val actionCount = output.linesIterator.count(_.contains("def action(")) + assertEquals(actionCount, 2, s"Expected 2 action overloads in:\n$output") + } + } + + private def withJavaCtx[A](body: Context => IO[A]): IO[A] = + TestFixtures.assumeFixturesAvailable() + for + jrePaths <- JreClasspath.jrtPath() + jars <- CoursierFetchClient.fetchClasspath( + TestFixtures.javaCoord, Seq(TestFixtures.localM2Repo)) + result <- ContextResource.make(jars, jrePaths).use { (ctx, _) => body(ctx) } + yield result + + private def withScala2Ctx[A](body: Context => IO[A]): IO[A] = + TestFixtures.assumeFixturesAvailable() + for + jrePaths <- JreClasspath.jrtPath() + jars <- CoursierFetchClient.fetchClasspath( + TestFixtures.scala2Coord, Seq(TestFixtures.localM2Repo)) + result <- ContextResource.make(jars, jrePaths).use { (ctx, _) => body(ctx) } + yield result + + test("formatSymbol members includes all overloaded methods (Java)"): + withJavaCtx { ctx => + IO.blocking { + given Context = ctx + val cls = ctx.findStaticClass("cellar.fixture.java.CellarJavaClass") + val output = GetFormatter.formatSymbol(cls) + val formatCount = output.linesIterator.count(_.contains("def format(")) + assertEquals(formatCount, 3, s"Expected 3 format overloads in:\n$output") + } + } + + test("formatSymbol members includes all overloaded methods (Scala 2)"): + withScala2Ctx { ctx => + IO.blocking { + given Context = ctx + val cls = ctx.findStaticClass("cellar.fixture.scala2.CellarOverloaded") + val output = GetFormatter.formatSymbol(cls) + val processCount = output.linesIterator.count(_.contains("def process(")) + assertEquals(processCount, 3, s"Expected 3 process overloads in:\n$output") + } + } diff --git a/lib/test/src/cellar/SymbolListerTest.scala b/lib/test/src/cellar/SymbolListerTest.scala index 1b44b96..c3ef7d5 100644 --- a/lib/test/src/cellar/SymbolListerTest.scala +++ b/lib/test/src/cellar/SymbolListerTest.scala @@ -84,3 +84,73 @@ class SymbolListerTest extends CatsEffectSuite: case other => fail(s"CellarA not found, got $other") } } + + test("listMembers includes all overloaded methods"): + withCtx { ctx => + given Context = ctx + SymbolLister.resolve("cellar.fixture.scala3.CellarOverloaded").flatMap { + case ListResolveResult.Found(target) => + SymbolLister.listMembers(target).compile.toList.map { syms => + val processCount = syms.count(_.name.toString == "process") + assertEquals(processCount, 3, s"Expected 3 process overloads, got names: ${syms.map(_.name)}") + } + case other => fail(s"CellarOverloaded not found, got $other") + } + } + + test("listMembers includes inherited overloaded methods"): + withCtx { ctx => + given Context = ctx + SymbolLister.resolve("cellar.fixture.scala3.CellarOverloadedChild").flatMap { + case ListResolveResult.Found(target) => + SymbolLister.listMembers(target).compile.toList.map { syms => + val actionCount = syms.count(_.name.toString == "action") + assertEquals(actionCount, 2, s"Expected 2 action overloads, got names: ${syms.map(_.name)}") + } + case other => fail(s"CellarOverloadedChild not found, got $other") + } + } + + private def withJavaCtx[A](body: Context => IO[A]): IO[A] = + TestFixtures.assumeFixturesAvailable() + for + jrePaths <- JreClasspath.jrtPath() + jars <- CoursierFetchClient.fetchClasspath( + TestFixtures.javaCoord, Seq(TestFixtures.localM2Repo)) + result <- ContextResource.make(jars, jrePaths).use { (ctx, _) => body(ctx) } + yield result + + private def withScala2Ctx[A](body: Context => IO[A]): IO[A] = + TestFixtures.assumeFixturesAvailable() + for + jrePaths <- JreClasspath.jrtPath() + jars <- CoursierFetchClient.fetchClasspath( + TestFixtures.scala2Coord, Seq(TestFixtures.localM2Repo)) + result <- ContextResource.make(jars, jrePaths).use { (ctx, _) => body(ctx) } + yield result + + test("listMembers includes all overloaded methods (Java)"): + withJavaCtx { ctx => + given Context = ctx + SymbolLister.resolve("cellar.fixture.java.CellarJavaClass").flatMap { + case ListResolveResult.Found(target) => + SymbolLister.listMembers(target).compile.toList.map { syms => + val formatCount = syms.count(_.name.toString == "format") + assertEquals(formatCount, 3, s"Expected 3 format overloads, got names: ${syms.map(_.name)}") + } + case other => fail(s"CellarJavaClass not found, got $other") + } + } + + test("listMembers includes all overloaded methods (Scala 2)"): + withScala2Ctx { ctx => + given Context = ctx + SymbolLister.resolve("cellar.fixture.scala2.CellarOverloaded").flatMap { + case ListResolveResult.Found(target) => + SymbolLister.listMembers(target).compile.toList.map { syms => + val processCount = syms.count(_.name.toString == "process") + assertEquals(processCount, 3, s"Expected 3 process overloads, got names: ${syms.map(_.name)}") + } + case other => fail(s"CellarOverloaded not found, got $other") + } + }