Skip to content

Commit

Permalink
feat: support Array types (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
mscheong01 authored Jun 16, 2023
1 parent f7ad339 commit 9ee64ae
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
package io.github.mscheong01.interfaice

import io.github.mscheong01.interfaice.util.isSuspendingFunction
import kotlinx.coroutines.flow.Flow
import reactor.core.publisher.Flux
import reactor.core.publisher.Mono
import java.lang.reflect.Method
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
Expand Down Expand Up @@ -110,6 +107,28 @@ open class TypeSpecification<T : Any>(
)
}

val isArray: Boolean
get() {
return Array<Any>::class.java.isAssignableFrom(this.klazz.java)
}

val arrayTypeArgument: TypeSpecification<*>
get() {
if (!isArray) {
throw IllegalStateException("Type: ${this.javaType.typeName} is not an array")
}
val entryType = (javaType as Class<*>).componentType
val klazz = if (entryType is ParameterizedType) {
(entryType.rawType as Class<*>).kotlin
} else {
(entryType as Class<*>).kotlin
}
return TypeSpecification(
klazz = klazz,
javaType = entryType
)
}

val isReactiveWrapper: Boolean
get() {
val qualifiedName = klazz.qualifiedName
Expand All @@ -127,12 +146,6 @@ open class TypeSpecification<T : Any>(
javaType = obj::class.java
)
}

val REACTIVE_WRAPPER_TYPES = listOf(
Mono::class,
Flux::class,
Flow::class
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ class TextObjectTranscoder(
val rule = match(TypeSpecification.from(obj))
return rule.encode(this, obj)
}
fun <T : Any> encode(obj: T?, type: TypeSpecification<T>): String {
if (obj == null) return "NULL"
val rule = match(type)
return rule.encode(this, obj)
}

fun <T : Any> decode(str: String, type: TypeSpecification<T>): T {
val rule = match(type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ object TranscodingRules {
type.klazz == Instant::class -> INSTANT
type.klazz == Duration::class -> DURATION
type.klazz == kotlin.time.Duration::class -> KOTLIN_DURATION
type.isArray -> {
val entryType = type.arrayTypeArgument
ArrayRule(entryType)
}
type.klazz.isSubclassOf(Collection::class) -> {
val entryType = type.typeArguments.first()
when {
Expand All @@ -41,11 +45,13 @@ object TranscodingRules {
else -> throw IllegalArgumentException("unsupported type: $type")
}
}

type.klazz.isSubclassOf(Map::class) -> {
val keyType = type.typeArguments[0]
val valueType = type.typeArguments[1]
MapRule(keyType, valueType)
}

type.klazz.isSubclassOf(Enum::class) -> EnumRule(type as TypeSpecification<out Enum<*>>)
else -> ObjectRule(type)
} as Rule<T>
Expand Down Expand Up @@ -186,6 +192,39 @@ object TranscodingRules {
decoder = { kotlin.time.Duration.parse(it) }
)

class ArrayRule<T : Any>(
val entryType: TypeSpecification<T>
) : Rule<Array<T>> {
override fun encodeDescription(transcoder: TextObjectTranscoder): String {
return """
Json array with the following entry format:
%s
""".trimIndent().format(
transcoder.match(entryType).encodeDescription(transcoder)
)
}

override fun encode(transcoder: TextObjectTranscoder, value: Array<T>): String {
return ObjectRule.mapper.writeValueAsString(value)
}

override fun decode(transcoder: TextObjectTranscoder, value: String): Array<T> {
val node = ObjectRule.mapper.readTree(value)
require(node.isArray) { "expected json array. actual: $value" }
val arrayNode = node as ArrayNode
val list = arrayNode.map {
if (it.isValueNode) {
transcoder.decode(it.asText(), entryType)
} else {
transcoder.decode(it.toString(), entryType)
}
}
val array = java.lang.reflect.Array.newInstance(entryType.klazz.java, list.size)
list.forEachIndexed { index, entry -> java.lang.reflect.Array.set(array, index, entry) }
return array as Array<T>
}
}

class ListRule<T : Any>(
override val entryType: TypeSpecification<T>
) : CollectionRule<T>(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.github.mscheong01.interfaice;

import org.junit.jupiter.api.Test;

import io.github.mscheong01.interfaice.openai.OpenAiProxyFactory;

public class JavaPrimitiveTest {
TextObjectTranscoder transcoder = new TextObjectTranscoder();

interface TestInterface {
int sum(int a, int b);
}

TestInterface proxy = OpenAiProxyFactory.of(System.getenv("OPENAI_API_KEY"))
.create(TestInterface.class);

@Test
public void test() {
int result = proxy.sum(1, 2);
System.out.println(result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class ProxyTest {
println(result2)
}

@Test
fun arrayTest() {
val result = proxy.arrayOfMountainNamesInCountry("South Korea")
println(result.contentToString())
}

@Test
fun collectionTest() {
val result = proxy.listOfMountainNamesInCountry("Nepal")
Expand Down Expand Up @@ -131,6 +137,9 @@ class ProxyTest {
@OpenAiChat
fun greetingMessage(name: String?): String

@OpenAiChat
fun arrayOfMountainNamesInCountry(countryName: String): Array<String>

@OpenAiChat
fun listOfMountainNamesInCountry(countryName: String): List<String>

Expand Down

0 comments on commit 9ee64ae

Please sign in to comment.