diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala index a36a6a3c..949550c6 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala @@ -95,7 +95,9 @@ class SdkScalaTypeTest { datetime: SdkBindingData[Instant], duration: SdkBindingData[Duration], blob: SdkBindingData[Blob], - generic: SdkBindingData[ScalarNested] + generic: SdkBindingData[ScalarNested], + none: SdkBindingData[Option[String]], + some: SdkBindingData[Option[String]] ) case class CollectionInput( @@ -105,7 +107,8 @@ class SdkScalaTypeTest { booleans: SdkBindingData[List[Boolean]], datetimes: SdkBindingData[List[Instant]], durations: SdkBindingData[List[Duration]], - generics: SdkBindingData[List[ScalarNested]] + generics: SdkBindingData[List[ScalarNested]], + options: SdkBindingData[List[Option[String]]] ) case class MapInput( @@ -115,7 +118,8 @@ class SdkScalaTypeTest { booleanMap: SdkBindingData[Map[String, Boolean]], datetimeMap: SdkBindingData[Map[String, Instant]], durationMap: SdkBindingData[Map[String, Duration]], - genericMap: SdkBindingData[Map[String, ScalarNested]] + genericMap: SdkBindingData[Map[String, ScalarNested]], + optionMap: SdkBindingData[Map[String, Option[String]]] ) case class ComplexInput( @@ -196,7 +200,9 @@ class SdkScalaTypeTest { .literalType(LiteralType.ofBlobType(BlobType.DEFAULT)) .description("") .build(), - "generic" -> createVar(SimpleType.STRUCT) + "generic" -> createVar(SimpleType.STRUCT), + "none" -> createVar(SimpleType.STRUCT), + "some" -> createVar(SimpleType.STRUCT) ).asJava val output = SdkScalaType[ScalarInput].getVariableMap @@ -274,6 +280,16 @@ class SdkScalaTypeTest { ).asJava ) ) + ), + "none" -> Literal.ofScalar( + Scalar.ofGeneric( + Struct.of(Map.empty[String, Struct.Value].asJava) + ) + ), + "some" -> Literal.ofScalar( + Scalar.ofGeneric( + Struct.of(Map("value" -> Struct.Value.ofStringValue("hello")).asJava) + ) ) ).asJava @@ -295,6 +311,14 @@ class SdkScalaTypeTest { List(ScalarNestedNested("foo", Some("bar"))), Map("foo" -> ScalarNestedNested("foo", Some("bar"))) ) + ), + none = SdkBindingDataFactory.of( + SdkLiteralTypes.generics[Option[String]](), + Option(null) + ), + some = SdkBindingDataFactory.of( + SdkLiteralTypes.generics[Option[String]](), + Option("hello") ) ) @@ -323,7 +347,11 @@ class SdkScalaTypeTest { List(ScalarNestedNested("foo", Some("bar"))), Map("foo" -> ScalarNestedNested("foo", Some("bar"))) ) - ) + ), + none = + SdkBindingDataFactory.of(SdkLiteralTypes.generics(), Option(null)), + some = + SdkBindingDataFactory.of(SdkLiteralTypes.generics(), Option("hello")) ) val expected = Map( @@ -399,6 +427,23 @@ class SdkScalaTypeTest { ).asJava ) ) + ), + "none" -> Literal.ofScalar( + Scalar.ofGeneric( + Struct.of( + Map(__TYPE -> Struct.Value.ofStringValue("scala.None$")).asJava + ) + ) + ), + "some" -> Literal.ofScalar( + Scalar.ofGeneric( + Struct.of( + Map( + "value" -> Struct.Value.ofStringValue("hello"), + __TYPE -> Struct.Value.ofStringValue("scala.Some") + ).asJava + ) + ) ) ).asJava @@ -416,7 +461,8 @@ class SdkScalaTypeTest { "booleans" -> createCollectionVar(SimpleType.BOOLEAN), "datetimes" -> createCollectionVar(SimpleType.DATETIME), "durations" -> createCollectionVar(SimpleType.DURATION), - "generics" -> createCollectionVar(SimpleType.STRUCT) + "generics" -> createCollectionVar(SimpleType.STRUCT), + "options" -> createCollectionVar(SimpleType.STRUCT) ).asJava val output = SdkScalaType[CollectionInput].getVariableMap @@ -443,6 +489,14 @@ class SdkScalaTypeTest { List(ScalarNestedNested("foo", Some("bar"))), Map("foo" -> ScalarNestedNested("foo", Some("bar"))) ) + ), + none = SdkBindingDataFactory.of( + SdkLiteralTypes.generics[Option[String]](), + Option(null) + ), + some = SdkBindingDataFactory.of( + SdkLiteralTypes.generics[Option[String]](), + Option("hello") ) ) @@ -465,6 +519,14 @@ class SdkScalaTypeTest { List(ScalarNestedNested("foo", Some("bar"))), Map("foo" -> ScalarNestedNested("foo", Some("bar"))) ) + ), + "none" -> SdkBindingDataFactory.of( + SdkLiteralTypes.generics[Option[String]](), + Option(null) + ), + "some" -> SdkBindingDataFactory.of( + SdkLiteralTypes.generics[Option[String]](), + Option("hello") ) ).asJava @@ -531,6 +593,10 @@ class SdkScalaTypeTest { Map("foo2" -> ScalarNestedNested("foo2", Some("bar2"))) ) ) + ), + options = SdkBindingDataFactory.of( + SdkLiteralTypes.generics[Option[String]](), + List(Option("hello"), Option(null)) ) ) @@ -550,7 +616,8 @@ class SdkScalaTypeTest { "booleanMap" -> createMapVar(SimpleType.BOOLEAN), "datetimeMap" -> createMapVar(SimpleType.DATETIME), "durationMap" -> createMapVar(SimpleType.DURATION), - "genericMap" -> createMapVar(SimpleType.STRUCT) + "genericMap" -> createMapVar(SimpleType.STRUCT), + "optionMap" -> createMapVar(SimpleType.STRUCT) ).asJava val output = SdkScalaType[MapInput].getVariableMap @@ -598,6 +665,10 @@ class SdkScalaTypeTest { Map("foo2" -> ScalarNestedNested("foo2", Some("bar2"))) ) ) + ), + optionMap = SdkBindingDataFactory.of( + SdkLiteralTypes.generics[Option[String]](), + Map("none" -> Option(null), "some" -> Option("hello")) ) ) diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala index fb128dd5..ef5cc625 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala @@ -28,6 +28,7 @@ import scala.reflect.api.{Mirror, TypeCreator, Universe} import scala.reflect.runtime.universe import scala.reflect.{ClassTag, classTag} import scala.reflect.runtime.universe.{ + ClassSymbol, NoPrefix, Symbol, Type, @@ -72,7 +73,7 @@ object SdkLiteralTypes { blobs(BlobType.DEFAULT).asInstanceOf[SdkLiteralType[T]] case t if t =:= typeOf[Binary] => binary().asInstanceOf[SdkLiteralType[T]] - case t if t <:< typeOf[Product] && !(t =:= typeOf[Option[_]]) => + case t if t <:< typeOf[Product] => generics().asInstanceOf[SdkLiteralType[T]] case t if t =:= typeOf[List[Long]] => @@ -391,24 +392,37 @@ object SdkLiteralTypes { ) } - val clazz = typeOf[S].typeSymbol.asClass - val classMirror = mirror.reflectClass(clazz) - val constructor = typeOf[S].decl(termNames.CONSTRUCTOR).asMethod - val constructorMirror = classMirror.reflectConstructor(constructor) - - val constructorArgs = - constructor.paramLists.flatten.map((param: Symbol) => { - val paramName = param.name.toString - val value = map.getOrElse( - paramName, - throw new IllegalArgumentException( - s"Map is missing required parameter named $paramName" + def instantiateViaConstructor(cls: ClassSymbol): S = { + val classMirror = mirror.reflectClass(cls) + val constructor = typeOf[S].decl(termNames.CONSTRUCTOR).asMethod + val constructorMirror = classMirror.reflectConstructor(constructor) + + val constructorArgs = + constructor.paramLists.flatten.map((param: Symbol) => { + val paramName = param.name.toString + val value = map.getOrElse( + paramName, + throw new IllegalArgumentException( + s"Map is missing required parameter named $paramName" + ) ) - ) - valueToParamValue(value, param.typeSignature.dealias) - }) + valueToParamValue(value, param.typeSignature.dealias) + }) + + constructorMirror(constructorArgs: _*).asInstanceOf[S] + } + + val clazz = typeOf[S].typeSymbol.asClass + // special handling of scala.Option as it is a Product, but can't be instantiated like common + // case classes + if (clazz.name.toString == "Option") + map + .get("value") + .map(valueToParamValue(_, typeOf[S].typeArgs.head)) + .asInstanceOf[S] + else + instantiateViaConstructor(clazz) - constructorMirror(constructorArgs: _*).asInstanceOf[S] } def structValueToAny(value: Struct.Value): Any = { diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala index 00cbdea5..c4868094 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala @@ -232,11 +232,8 @@ object SdkScalaType { implicit def durationLiteralType: SdkScalaLiteralType[Duration] = DelegateLiteralType(SdkLiteralTypes.durations()) - // more specific matching to fail the usage of SdkBindingData[Option[_]] - implicit def optionLiteralType: SdkScalaLiteralType[Option[_]] = ??? - // fixme: using Product is just an approximation for case class because Product - // is also super class of, for example, Option and Tuple + // is also super class of, for example, Either or Try implicit def productLiteralType[T <: Product: TypeTag: ClassTag] : SdkScalaLiteralType[T] = DelegateLiteralType(SdkLiteralTypes.generics()) diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/package.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/package.scala index 47c6332b..b5bcc208 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/package.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/package.scala @@ -30,7 +30,11 @@ package object flytekitscala { } catch { case _: Throwable => // fall back to java's way, less reliable and with limitations - product.getClass.getDeclaredFields.map(_.getName).toList + val methodNames = product.getClass.getDeclaredMethods.map(_.getName) + product.getClass.getDeclaredFields + .map(_.getName) + .filter(methodNames.contains) + .toList } } }