Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions ai-core/src/main/scala/wvlet/ai/core/weaver/CaseClassWeaver.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package wvlet.ai.core.weaver

import scala.deriving.Mirror
import scala.compiletime.{erasedValue, summonInline}
import wvlet.ai.core.msgpack.spi.{Packer, Unpacker}

// Removed duplicate ObjectWeaver trait.
// The canonical one is in ObjectWeaver.scala

class CaseClassWeaver[A](using m: Mirror.ProductOf[A]) extends ObjectWeaver[A] {

// Note: elementWeavers are now of type ObjectWeaver from the canonical definition
private inline def summonElementWeavers[Elems <: Tuple]: List[ObjectWeaver[?]] =
inline erasedValue[Elems] match {
case _: (elem *: elemsTail) =>
summonInline[ObjectWeaver[elem]] :: summonElementWeavers[elemsTail]
case _: EmptyTuple =>
Nil
}

private val elementWeavers: List[ObjectWeaver[?]] = summonElementWeavers[m.MirroredElemTypes]

override def pack(packer: Packer, v: A, config: WeaverConfig): Unit = {
val product = v.asInstanceOf[Product]
if (product.productArity != elementWeavers.size) {
// TODO: More specific error handling using WeaverContext
throw new IllegalArgumentException(s"Element count mismatch. Expected: ${elementWeavers.size}, Got: ${product.productArity}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The TODO comment on the preceding line (// TODO: More specific error handling using WeaverContext) suggests using WeaverContext for error handling here. However, the pack method signature override def pack(packer: Packer, v: A, config: WeaverConfig): Unit doesn't include a WeaverContext parameter, unlike the unpack method where WeaverContext is used for error reporting.

Could you clarify the intention behind this TODO?

  • Is the goal to introduce a more specific custom exception type (e.g., WeaverPackingException) that could be thrown here instead of IllegalArgumentException?
  • If WeaverContext involvement is desired, how was it envisioned to be integrated with the current pack signature, or would this require changes to the ObjectWeaver trait?
  • Alternatively, is the current IllegalArgumentException considered sufficient for this kind of precondition failure (element count mismatch)?

Addressing or clarifying this TODO would help ensure the error handling strategy is consistent and clear.

}
packer.packArrayHeader(elementWeavers.size)
product.productIterator.zip(elementWeavers).foreach { case (elem, weaver) =>
// This cast is generally safe due to how elementWeavers is constructed.
// The individual element's weaver will handle its specific packing.
(weaver.asInstanceOf[ObjectWeaver[Any]]).pack(packer, elem, config)
}
}

override def unpack(unpacker: Unpacker, context: WeaverContext): Unit = {
val numElements = unpacker.unpackArrayHeader()
if (numElements != elementWeavers.size) {
context.setError(new IllegalArgumentException(s"Element count mismatch. Expected: ${elementWeavers.size}, Got: ${numElements}"))
// TODO: Potentially consume unexpected fields from unpacker to allow recovery or partial unpack
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This TODO (// TODO: Potentially consume unexpected fields from unpacker to allow recovery or partial unpack) raises a good point about enhancing robustness, particularly for schema evolution.

Currently, if the packed data has a different number of elements than expected, unpack will set an error and return. Allowing the unpacker to optionally consume (skip) unexpected extra fields could make deserialization more lenient and improve forward compatibility (e.g., an older client reading data produced by a newer version of a case class with additional fields).

What are your thoughts on the priority of implementing this? If it's not planned for this PR, perhaps creating a follow-up issue to track this potential enhancement would be beneficial for future development.

return
}

val elements = new Array[Any](elementWeavers.size)
var i = 0
var failed = false
while (i < elementWeavers.size && !failed) {
val weaver = elementWeavers(i)
// Create a new context for each element to isolate errors and values
val elementContext = WeaverContext(context.config)
weaver.unpack(unpacker, elementContext)

if (elementContext.hasError) {
context.setError(new RuntimeException(s"Failed to unpack element $i: ${elementContext.getError.get.getMessage}", elementContext.getError.get))
failed = true
} else {
elements(i) = elementContext.getLastValue
}
i += 1
}

if (!failed) {
try {
val instance = m.fromProduct(new Product {
override def productArity: Int = elements.length
override def productElement(n: Int): Any = elements(n)
override def canEqual(that: Any): Boolean = that.isInstanceOf[Product] && that.asInstanceOf[Product].productArity == productArity
})
context.setLastValue(instance)
} catch {
case e: Throwable =>
context.setError(new RuntimeException("Failed to instantiate case class from product", e))
}
}
// If failed, context will already have an error set.
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,5 @@ object ObjectWeaver:
): A = weaver.fromJson(json, config)

export PrimitiveWeaver.given

inline given [A](using m: scala.deriving.Mirror.ProductOf[A]): ObjectWeaver[A] = CaseClassWeaver[A](using m)
98 changes: 98 additions & 0 deletions ai-core/src/test/scala/wvlet/ai/core/weaver/WeaverTest.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
package wvlet.ai.core.weaver

import wvlet.airspec.AirSpec
import wvlet.ai.core.weaver.ObjectWeaver // Ensure ObjectWeaver is imported if not already fully covered
import scala.jdk.CollectionConverters.*

// Define case classes for testing
case class SimpleCase(i: Int, s: String, b: Boolean)
case class NestedCase(name: String, simple: SimpleCase)
case class OptionCase(id: Int, opt: Option[String])
case class SeqCase(key: String, values: Seq[Int])

class WeaverTest extends AirSpec:

test("weave int") {
Expand Down Expand Up @@ -497,4 +504,95 @@ class WeaverTest extends AirSpec:
result.get.getMessage.contains("Cannot convert") shouldBe true
}

// Tests for SimpleCase
test("weave SimpleCase") {
val v = SimpleCase(10, "test case", true)
val msgpack = ObjectWeaver.weave(v)
val v2 = ObjectWeaver.unweave[SimpleCase](msgpack)
v shouldBe v2
}

test("SimpleCase toJson") {
val v = SimpleCase(20, "json test", false)
val json = ObjectWeaver.toJson(v)
val v2 = ObjectWeaver.fromJson[SimpleCase](json)
v shouldBe v2
}

// Tests for NestedCase
test("weave NestedCase") {
val v = NestedCase("nested", SimpleCase(30, "inner", true))
val msgpack = ObjectWeaver.weave(v)
val v2 = ObjectWeaver.unweave[NestedCase](msgpack)
v shouldBe v2
}

test("NestedCase toJson") {
val v = NestedCase("nested json", SimpleCase(40, "inner json", false))
val json = ObjectWeaver.toJson(v)
val v2 = ObjectWeaver.fromJson[NestedCase](json)
v shouldBe v2
}

// Tests for OptionCase
test("weave OptionCase with Some") {
val v = OptionCase(50, Some("option value"))
val msgpack = ObjectWeaver.weave(v)
val v2 = ObjectWeaver.unweave[OptionCase](msgpack)
v shouldBe v2
}

test("OptionCase toJson with Some") {
val v = OptionCase(60, Some("option json"))
val json = ObjectWeaver.toJson(v)
val v2 = ObjectWeaver.fromJson[OptionCase](json)
v shouldBe v2
}

test("weave OptionCase with None") {
val v = OptionCase(70, None)
val msgpack = ObjectWeaver.weave(v)
val v2 = ObjectWeaver.unweave[OptionCase](msgpack)
v shouldBe v2
}

test("OptionCase toJson with None") {
val v = OptionCase(80, None)
val json = ObjectWeaver.toJson(v)
// Check against expected JSON for None, as direct None might be ambiguous for fromJson
// Depending on JSON library, None might be represented as null or omitted
// Assuming it's represented as null or handled by the weaver
val v2 = ObjectWeaver.fromJson[OptionCase](json)
v shouldBe v2
}

// Tests for SeqCase
test("weave SeqCase with non-empty Seq") {
val v = SeqCase("seq test", Seq(1, 2, 3, 4))
val msgpack = ObjectWeaver.weave(v)
val v2 = ObjectWeaver.unweave[SeqCase](msgpack)
v shouldBe v2
}

test("SeqCase toJson with non-empty Seq") {
val v = SeqCase("seq json", Seq(5, 6, 7))
val json = ObjectWeaver.toJson(v)
val v2 = ObjectWeaver.fromJson[SeqCase](json)
v shouldBe v2
}

test("weave SeqCase with empty Seq") {
val v = SeqCase("empty seq", Seq.empty[Int])
val msgpack = ObjectWeaver.weave(v)
val v2 = ObjectWeaver.unweave[SeqCase](msgpack)
v shouldBe v2
}

test("SeqCase toJson with empty Seq") {
val v = SeqCase("empty seq json", Seq.empty[Int])
val json = ObjectWeaver.toJson(v)
val v2 = ObjectWeaver.fromJson[SeqCase](json)
v shouldBe v2
}

end WeaverTest
Loading