Skip to content

Commit 26b06c4

Browse files
rochalaclaude
andauthored
Fix method overloads being lost in get/list output (#16)
Replace name-based deduplication in collectClassMembers with tastyquery's overridingSymbol to properly distinguish overrides from overloads. This preserves all distinct overloads (including those split across inheritance) while still deduplicating true overrides. Extract shared collectClassMembers into SymbolResolver to eliminate duplication between GetFormatter and SymbolLister. Add overload fixtures and tests for Scala 3, Scala 2, and Java. Fixes #14 Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3094937 commit 26b06c4

8 files changed

Lines changed: 200 additions & 44 deletions

File tree

fixtureJava/src/cellar/fixture/java/CellarJavaClass.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,16 @@ public T getDefault() {
3838
public static int staticHelper(int n) {
3939
return n * 2;
4040
}
41+
42+
public String format(int value) {
43+
return String.valueOf(value);
44+
}
45+
46+
public String format(String value) {
47+
return value;
48+
}
49+
50+
public String format(int value, boolean verbose) {
51+
return verbose ? "value=" + value : String.valueOf(value);
52+
}
4153
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package cellar.fixture.scala2
2+
3+
/** Fixture for testing that overloaded methods are preserved in Scala 2 artifacts. */
4+
trait CellarOverloaded {
5+
def process(value: Int): String
6+
def process(value: String): String
7+
def process(value: Int, flag: Boolean): String
8+
def unique: Int
9+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package cellar.fixture.scala3
2+
3+
/** Fixture for testing that overloaded methods are all preserved in output. */
4+
trait CellarOverloaded:
5+
def process(value: Int): String
6+
def process(value: String): String
7+
def process(value: Int, flag: Boolean): String
8+
def unique: Int
9+
10+
/** Fixture for testing overloads split across inheritance. */
11+
trait CellarOverloadedBase:
12+
def action(value: Int): String
13+
14+
trait CellarOverloadedChild extends CellarOverloadedBase:
15+
def action(value: String): String
16+
def childOnly: Int

lib/src/cellar/GetFormatter.scala

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,32 +73,12 @@ object GetFormatter:
7373
private def renderMembers(sym: Symbol)(using ctx: Context): Option[String] =
7474
sym match
7575
case cls: ClassSymbol =>
76-
val members = collectClassMembers(cls)
76+
val members = SymbolResolver.collectClassMembers(cls)
7777
.filter(PublicApiFilter.isPublic)
7878
.map(m => TypePrinter.printSymbolSignatureSafe(m).linesIterator.mkString(" ").trim)
7979
if members.isEmpty then None else Some(members.mkString("\n"))
8080
case _ => None
8181

82-
private val universalBaseClasses = Set("scala.Any", "scala.AnyRef", "java.lang.Object")
83-
84-
private def collectClassMembers(cls: ClassSymbol)(using ctx: Context): List[TermOrTypeSymbol] =
85-
val seen = scala.collection.mutable.Set.empty[String]
86-
val result = List.newBuilder[TermOrTypeSymbol]
87-
val linearization =
88-
try cls.linearization
89-
catch case _: Exception => List(cls)
90-
for
91-
klass <- linearization if !universalBaseClasses.contains(klass.displayFullName)
92-
decl <-
93-
try klass.declarations
94-
catch case _: Exception => Nil
95-
do
96-
val key = decl.name.toString
97-
if !seen.contains(key) then
98-
seen += key
99-
result += decl
100-
result.result()
101-
10282
private def renderCompanion(sym: Symbol)(using ctx: Context): Option[String] =
10383
sym match
10484
case cls: ClassSymbol =>

lib/src/cellar/SymbolLister.scala

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package cellar
33
import cats.effect.IO
44
import fs2.Stream
55
import tastyquery.Contexts.Context
6-
import tastyquery.Symbols.{ClassSymbol, PackageSymbol, Symbol, TermOrTypeSymbol}
6+
import tastyquery.Symbols.{ClassSymbol, PackageSymbol, Symbol}
77

88
sealed trait ListTarget
99
object ListTarget:
@@ -40,26 +40,6 @@ object SymbolLister:
4040

4141
case ListTarget.Cls(cls) =>
4242
Stream
43-
.eval(IO.blocking(collectClassMembers(cls)))
43+
.eval(IO.blocking(SymbolResolver.collectClassMembers(cls)))
4444
.flatMap(syms => Stream.emits(syms))
4545
.filter(PublicApiFilter.isPublic)
46-
47-
private def collectClassMembers(cls: ClassSymbol)(using ctx: Context): List[TermOrTypeSymbol] =
48-
// Walk linearization (MRO), collecting decls. Deduplicate by unsigned name, keeping
49-
// the most-derived occurrence (first seen in MRO order).
50-
val seen = scala.collection.mutable.Set.empty[String]
51-
val result = List.newBuilder[TermOrTypeSymbol]
52-
val linearization =
53-
try cls.linearization
54-
catch case _: Exception => List(cls)
55-
for
56-
klass <- linearization
57-
decl <-
58-
try klass.declarations
59-
catch case _: Exception => Nil
60-
do
61-
val key = decl.name.toString
62-
if !seen.contains(key) then
63-
seen += key
64-
result += decl
65-
result.result()

lib/src/cellar/SymbolResolver.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import cats.effect.IO
44
import tastyquery.Contexts.Context
55
import tastyquery.Exceptions.MemberNotFoundException
66
import tastyquery.Names.{termName, typeName}
7-
import tastyquery.Symbols.{ClassSymbol, Symbol}
7+
import tastyquery.Symbols.{ClassSymbol, Symbol, TermOrTypeSymbol}
88

99
sealed trait LookupResult
1010
object LookupResult:
@@ -145,6 +145,33 @@ object SymbolResolver:
145145
i += 1
146146
best
147147

148+
private val universalBaseClasses = Set("scala.Any", "scala.AnyRef", "java.lang.Object")
149+
150+
/**
151+
* Walk the linearization (MRO) of a class, collecting all declarations.
152+
* Uses tastyquery's `overridingSymbol` to skip declarations that are
153+
* overridden by a more-derived class, while preserving distinct overloads.
154+
*/
155+
def collectClassMembers(cls: ClassSymbol)(using ctx: Context): List[TermOrTypeSymbol] =
156+
val seen = scala.collection.mutable.Set.empty[TermOrTypeSymbol]
157+
val result = List.newBuilder[TermOrTypeSymbol]
158+
val linearization =
159+
try cls.linearization
160+
catch case _: Exception => List(cls)
161+
linearization.filterNot(k => universalBaseClasses.contains(k.displayFullName)).foreach { klass =>
162+
val decls =
163+
try klass.declarations
164+
catch case _: Exception => Nil
165+
decls.foreach { decl =>
166+
if !seen.contains(decl) then
167+
val dominated = decl.overridingSymbol(cls).exists(_ != decl)
168+
if !dominated then
169+
seen += decl
170+
result += decl
171+
}
172+
}
173+
result.result()
174+
148175
private def tryOrNone[A](thunk: => A): Option[A] =
149176
try Some(thunk)
150177
catch

lib/test/src/cellar/GetFormatterTest.scala

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,65 @@ class GetFormatterTest extends CatsEffectSuite:
107107
assert(!output.contains("---"), s"Unexpected separator: $output")
108108
}
109109
}
110+
111+
test("formatSymbol members includes all overloaded methods"):
112+
withCtx { ctx =>
113+
IO.blocking {
114+
given Context = ctx
115+
val cls = ctx.findStaticClass("cellar.fixture.scala3.CellarOverloaded")
116+
val output = GetFormatter.formatSymbol(cls)
117+
val processCount = output.linesIterator.count(_.contains("def process("))
118+
assertEquals(processCount, 3, s"Expected 3 process overloads in:\n$output")
119+
}
120+
}
121+
122+
test("formatSymbol members includes inherited overloaded methods"):
123+
withCtx { ctx =>
124+
IO.blocking {
125+
given Context = ctx
126+
val cls = ctx.findStaticClass("cellar.fixture.scala3.CellarOverloadedChild")
127+
val output = GetFormatter.formatSymbol(cls)
128+
val actionCount = output.linesIterator.count(_.contains("def action("))
129+
assertEquals(actionCount, 2, s"Expected 2 action overloads in:\n$output")
130+
}
131+
}
132+
133+
private def withJavaCtx[A](body: Context => IO[A]): IO[A] =
134+
TestFixtures.assumeFixturesAvailable()
135+
for
136+
jrePaths <- JreClasspath.jrtPath()
137+
jars <- CoursierFetchClient.fetchClasspath(
138+
TestFixtures.javaCoord, Seq(TestFixtures.localM2Repo))
139+
result <- ContextResource.make(jars, jrePaths).use { (ctx, _) => body(ctx) }
140+
yield result
141+
142+
private def withScala2Ctx[A](body: Context => IO[A]): IO[A] =
143+
TestFixtures.assumeFixturesAvailable()
144+
for
145+
jrePaths <- JreClasspath.jrtPath()
146+
jars <- CoursierFetchClient.fetchClasspath(
147+
TestFixtures.scala2Coord, Seq(TestFixtures.localM2Repo))
148+
result <- ContextResource.make(jars, jrePaths).use { (ctx, _) => body(ctx) }
149+
yield result
150+
151+
test("formatSymbol members includes all overloaded methods (Java)"):
152+
withJavaCtx { ctx =>
153+
IO.blocking {
154+
given Context = ctx
155+
val cls = ctx.findStaticClass("cellar.fixture.java.CellarJavaClass")
156+
val output = GetFormatter.formatSymbol(cls)
157+
val formatCount = output.linesIterator.count(_.contains("def format("))
158+
assertEquals(formatCount, 3, s"Expected 3 format overloads in:\n$output")
159+
}
160+
}
161+
162+
test("formatSymbol members includes all overloaded methods (Scala 2)"):
163+
withScala2Ctx { ctx =>
164+
IO.blocking {
165+
given Context = ctx
166+
val cls = ctx.findStaticClass("cellar.fixture.scala2.CellarOverloaded")
167+
val output = GetFormatter.formatSymbol(cls)
168+
val processCount = output.linesIterator.count(_.contains("def process("))
169+
assertEquals(processCount, 3, s"Expected 3 process overloads in:\n$output")
170+
}
171+
}

lib/test/src/cellar/SymbolListerTest.scala

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,73 @@ class SymbolListerTest extends CatsEffectSuite:
8484
case other => fail(s"CellarA not found, got $other")
8585
}
8686
}
87+
88+
test("listMembers includes all overloaded methods"):
89+
withCtx { ctx =>
90+
given Context = ctx
91+
SymbolLister.resolve("cellar.fixture.scala3.CellarOverloaded").flatMap {
92+
case ListResolveResult.Found(target) =>
93+
SymbolLister.listMembers(target).compile.toList.map { syms =>
94+
val processCount = syms.count(_.name.toString == "process")
95+
assertEquals(processCount, 3, s"Expected 3 process overloads, got names: ${syms.map(_.name)}")
96+
}
97+
case other => fail(s"CellarOverloaded not found, got $other")
98+
}
99+
}
100+
101+
test("listMembers includes inherited overloaded methods"):
102+
withCtx { ctx =>
103+
given Context = ctx
104+
SymbolLister.resolve("cellar.fixture.scala3.CellarOverloadedChild").flatMap {
105+
case ListResolveResult.Found(target) =>
106+
SymbolLister.listMembers(target).compile.toList.map { syms =>
107+
val actionCount = syms.count(_.name.toString == "action")
108+
assertEquals(actionCount, 2, s"Expected 2 action overloads, got names: ${syms.map(_.name)}")
109+
}
110+
case other => fail(s"CellarOverloadedChild not found, got $other")
111+
}
112+
}
113+
114+
private def withJavaCtx[A](body: Context => IO[A]): IO[A] =
115+
TestFixtures.assumeFixturesAvailable()
116+
for
117+
jrePaths <- JreClasspath.jrtPath()
118+
jars <- CoursierFetchClient.fetchClasspath(
119+
TestFixtures.javaCoord, Seq(TestFixtures.localM2Repo))
120+
result <- ContextResource.make(jars, jrePaths).use { (ctx, _) => body(ctx) }
121+
yield result
122+
123+
private def withScala2Ctx[A](body: Context => IO[A]): IO[A] =
124+
TestFixtures.assumeFixturesAvailable()
125+
for
126+
jrePaths <- JreClasspath.jrtPath()
127+
jars <- CoursierFetchClient.fetchClasspath(
128+
TestFixtures.scala2Coord, Seq(TestFixtures.localM2Repo))
129+
result <- ContextResource.make(jars, jrePaths).use { (ctx, _) => body(ctx) }
130+
yield result
131+
132+
test("listMembers includes all overloaded methods (Java)"):
133+
withJavaCtx { ctx =>
134+
given Context = ctx
135+
SymbolLister.resolve("cellar.fixture.java.CellarJavaClass").flatMap {
136+
case ListResolveResult.Found(target) =>
137+
SymbolLister.listMembers(target).compile.toList.map { syms =>
138+
val formatCount = syms.count(_.name.toString == "format")
139+
assertEquals(formatCount, 3, s"Expected 3 format overloads, got names: ${syms.map(_.name)}")
140+
}
141+
case other => fail(s"CellarJavaClass not found, got $other")
142+
}
143+
}
144+
145+
test("listMembers includes all overloaded methods (Scala 2)"):
146+
withScala2Ctx { ctx =>
147+
given Context = ctx
148+
SymbolLister.resolve("cellar.fixture.scala2.CellarOverloaded").flatMap {
149+
case ListResolveResult.Found(target) =>
150+
SymbolLister.listMembers(target).compile.toList.map { syms =>
151+
val processCount = syms.count(_.name.toString == "process")
152+
assertEquals(processCount, 3, s"Expected 3 process overloads, got names: ${syms.map(_.name)}")
153+
}
154+
case other => fail(s"CellarOverloaded not found, got $other")
155+
}
156+
}

0 commit comments

Comments
 (0)