From 5a5e88c0e2eb56bb46c21471dc0f6fdd074da650 Mon Sep 17 00:00:00 2001 From: Jarrod Young Date: Mon, 26 Jul 2021 21:07:09 -0400 Subject: [PATCH 1/2] Handle nonclass Scala types and add tests --- .../org/http4s/rho/swagger/TypeBuilder.scala | 101 ++++++++++-------- .../http4s/rho/swagger/TypeBuilderSuite.scala | 17 ++- 2 files changed, 70 insertions(+), 48 deletions(-) diff --git a/swagger/src/main/scala/org/http4s/rho/swagger/TypeBuilder.scala b/swagger/src/main/scala/org/http4s/rho/swagger/TypeBuilder.scala index eb82558b2..806cc9a9a 100644 --- a/swagger/src/main/scala/org/http4s/rho/swagger/TypeBuilder.scala +++ b/swagger/src/main/scala/org/http4s/rho/swagger/TypeBuilder.scala @@ -19,16 +19,16 @@ object TypeBuilder { private[this] val logger = getLogger def collectModels(t: Type, alreadyKnown: Set[Model], sfs: SwaggerFormats, et: Type)(implicit - st: ShowType): Set[Model] = + st: ShowType): Set[Model] = try collectModels(t.dealias, alreadyKnown, ListSet.empty, sfs, et) catch { case NonFatal(_) => Set.empty } private def collectModels( - t: Type, - alreadyKnown: Set[Model], - known: TypeSet, - sfs: SwaggerFormats, - et: Type)(implicit st: ShowType): Set[Model] = { + t: Type, + alreadyKnown: Set[Model], + known: TypeSet, + sfs: SwaggerFormats, + et: Type)(implicit st: ShowType): Set[Model] = { def go(t: Type, alreadyKnown: Set[Model], known: TypeSet): Set[Model] = t.dealias match { @@ -50,7 +50,7 @@ object TypeBuilder { go(tpe.typeArgs.last, alreadyKnown, known + tpe) case tpe - if (tpe.isCollection || tpe.isOption || tpe.isEffect(et)) && tpe.typeArgs.nonEmpty => + if (tpe.isCollection || tpe.isOption || tpe.isEffect(et)) && tpe.typeArgs.nonEmpty => go(tpe.typeArgs.head, alreadyKnown, known + tpe) case tpe if tpe.isStream => @@ -75,7 +75,7 @@ object TypeBuilder { }.toSet case tpe @ TypeRef(_, sym: Symbol, tpeArgs: List[Type]) - if isCaseClass(sym) || isSumType(sym) => + if isCaseClass(sym) || isSumType(sym) => val symIsSumType = isSumType(sym) val maybeParentSumType = sym.asClass.baseClasses.drop(1).find(isSumType) @@ -157,7 +157,7 @@ object TypeBuilder { } private def modelToSwagger(tpe: Type, sfs: SwaggerFormats)(implicit - st: ShowType): Option[ModelImpl] = + st: ShowType): Option[ModelImpl] = try { val TypeRef(_, sym: Symbol, tpeArgs: List[Type]) = tpe val constructor = tpe.member(termNames.CONSTRUCTOR) @@ -185,7 +185,7 @@ object TypeBuilder { } private def paramSymToProp(sym: Symbol, tpeArgs: List[Type], sfs: SwaggerFormats)(pSym: Symbol)( - implicit st: ShowType): (String, Property) = { + implicit st: ShowType): (String, Property) = { val pType = pSym.typeSignature.substituteTypes(sym.asClass.typeParams, tpeArgs) val required = !(pSym.asTerm.isParamWithDefault || pType.isOption) val prop = typeToProperty(pType, sfs) @@ -210,27 +210,34 @@ object TypeBuilder { ArrayProperty(items = itemProperty) } else if (tpe.isOption) typeToProperty(tpe.typeArgs.head, sfs).withRequired(false) - else if (tpe.isAnyVal && !tpe.isPrimitive) - typeToProperty( - ptSym.asClass.primaryConstructor.asMethod.paramLists.flatten.head.typeSignature, - sfs - ) - else if (isCaseClass(ptSym) || (isSumType(ptSym) && !isObjectEnum(ptSym))) - RefProperty(tpe.simpleName) - else - DataType.fromType(tpe) match { - case DataType.ValueDataType(name, format, qName) => - AbstractProperty(`type` = name, description = qName, format = format) - case DataType.ComplexDataType(name, qName) => - AbstractProperty(`type` = name, description = qName) - case DataType.ContainerDataType(name, _, _) => - AbstractProperty(`type` = name) - case DataType.EnumDataType(enums) => - StringProperty(enums = enums) + else if (tpe.isAnyVal && !tpe.isPrimitive && ptSym.isClass) { + val symbolOption = ptSym.asClass.primaryConstructor.asMethod.paramLists.flatten.headOption + symbolOption match { + case Some(symbol) => + typeToProperty( + symbol.typeSignature, + sfs + ) + case None => dataTypeFromType(tpe) } + } else if (isCaseClass(ptSym) || (isSumType(ptSym) && !isObjectEnum(ptSym))) + RefProperty(tpe.simpleName) + else dataTypeFromType(tpe) } ) + private def dataTypeFromType(tpe: Type)(implicit showType: ShowType): Property = + DataType.fromType(tpe) match { + case DataType.ValueDataType(name, format, qName) => + AbstractProperty(`type` = name, description = qName, format = format) + case DataType.ComplexDataType(name, qName) => + AbstractProperty(`type` = name, description = qName) + case DataType.ContainerDataType(name, _, _) => + AbstractProperty(`type` = name) + case DataType.EnumDataType(enums) => + StringProperty(enums = enums) + } + sealed trait DataType { def name: String } @@ -238,15 +245,15 @@ object TypeBuilder { object DataType { case class ValueDataType( - name: String, - format: Option[String] = None, - qualifiedName: Option[String] = None) - extends DataType + name: String, + format: Option[String] = None, + qualifiedName: Option[String] = None) + extends DataType case class ContainerDataType( - name: String, - typeArg: Option[DataType] = None, - uniqueItems: Boolean = false) - extends DataType + name: String, + typeArg: Option[DataType] = None, + uniqueItems: Boolean = false) + extends DataType case class ComplexDataType(name: String, qualifiedName: Option[String] = None) extends DataType case class EnumDataType(enums: Set[String]) extends DataType { val name = "string" } @@ -288,7 +295,6 @@ object TypeBuilder { private[swagger] def fromType(t: Type)(implicit st: ShowType): DataType = { val klass = if (t.isOption && t.typeArgs.nonEmpty) t.typeArgs.head else t - if (klass.isNothingOrNull || klass.isUnitOrVoid) ComplexDataType("string", qualifiedName = Option(klass.fullName)) else if (isString(klass)) this.String @@ -312,21 +318,26 @@ object TypeBuilder { if (t.typeArgs.nonEmpty) GenArray(fromType(t.typeArgs(1))) else GenArray() } else if (klass <:< typeOf[AnyVal]) { - fromType( - klass.members - .filter(_.isConstructor) - .flatMap(_.asMethod.paramLists.flatten) - .head - .typeSignature - ) + val klassSymbolOption = klass.members + .filter(_.isConstructor) + .flatMap(_.asMethod.paramLists.flatten) + .headOption + klassSymbolOption match { + case Some(symbol) => fromType(symbol.typeSignature) + case None => fallBackDataTypeFromName(t) + } } else if (isObjectEnum(klass.typeSymbol)) { EnumDataType(klass.typeSymbol.asClass.knownDirectSubclasses.map(_.name.toString)) } else { - val stt = if (t.isOption) t.typeArgs.head else t - ComplexDataType("string", qualifiedName = Option(stt.fullName)) + fallBackDataTypeFromName(t) } } + private def fallBackDataTypeFromName(t: Type)(implicit st: ShowType): DataType = { + val stt = if (t.isOption) t.typeArgs.head else t + ComplexDataType("string", qualifiedName = Option(stt.fullName)) + } + private[this] val IntTypes = Set[Type]( typeOf[Int], diff --git a/swagger/src/test/scala/org/http4s/rho/swagger/TypeBuilderSuite.scala b/swagger/src/test/scala/org/http4s/rho/swagger/TypeBuilderSuite.scala index ff6eadfdf..462505c66 100644 --- a/swagger/src/test/scala/org/http4s/rho/swagger/TypeBuilderSuite.scala +++ b/swagger/src/test/scala/org/http4s/rho/swagger/TypeBuilderSuite.scala @@ -2,7 +2,6 @@ package org.http4s.rho.swagger import java.sql.Timestamp import java.util.Date - import cats.effect.IO import cats.syntax.all._ import fs2.Stream @@ -29,13 +28,15 @@ package object model { case class FooWithMap(l: Map[String, Int]) case class FooVal(foo: Foo) extends AnyVal case class BarWithFooVal(fooVal: FooVal) + case class AnyValClass(anyVal: AnyVal) + type AnyValType = AnyVal + case class ClassWithAnyValType(anyVal: AnyValType) @DiscriminatorField("foobar") sealed trait Sealed { def foo: String } case class FooSealed(a: Int, foo: String, foo2: Foo) extends Sealed case class BarSealed(str: String, foo: String) extends Sealed - sealed trait SealedEnum case object FooEnum extends SealedEnum case object BarEnum extends SealedEnum @@ -71,7 +72,7 @@ package object model { modelOfWithFormats(DefaultSwaggerFormats) def modelOfWithFormats[T]( - formats: SwaggerFormats)(implicit t: TypeTag[T], st: ShowType): Set[Model] = + formats: SwaggerFormats)(implicit t: TypeTag[T], st: ShowType): Set[Model] = TypeBuilder.collectModels(t.tpe, Set.empty, formats, typeOf[IO[_]]) } @@ -413,6 +414,16 @@ class TypeBuilderSuite extends FunSuite { assertEquals(model, modelOf[Sealed]) } + test("A TypeBuilder should fall back to the class name for a class containing an AnyVal") { + val m = modelOf[AnyValClass].head + assertEquals(m.description, "AnyValClass".some) + } + + test("A TypeBuilder should fall back to the class name for a custom type containing an AnyVal") { + val m = modelOf[ClassWithAnyValType].head + assertEquals(m.description, "ClassWithAnyValType".some) + } + test("A TypeBuilder should build a model for two-level sealed trait hierarchy") { val ms = modelOf[TopLevelSealedTrait] assertEquals(ms.size, 5) From 480eae148e71abc4a8113d51b2ecf0d09c7c8d90 Mon Sep 17 00:00:00 2001 From: Jarrod Young Date: Mon, 26 Jul 2021 21:09:22 -0400 Subject: [PATCH 2/2] Formatting --- .../org/http4s/rho/swagger/TypeBuilder.scala | 36 +++++++++---------- .../http4s/rho/swagger/TypeBuilderSuite.scala | 2 +- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/swagger/src/main/scala/org/http4s/rho/swagger/TypeBuilder.scala b/swagger/src/main/scala/org/http4s/rho/swagger/TypeBuilder.scala index 806cc9a9a..bcd02fe23 100644 --- a/swagger/src/main/scala/org/http4s/rho/swagger/TypeBuilder.scala +++ b/swagger/src/main/scala/org/http4s/rho/swagger/TypeBuilder.scala @@ -19,16 +19,16 @@ object TypeBuilder { private[this] val logger = getLogger def collectModels(t: Type, alreadyKnown: Set[Model], sfs: SwaggerFormats, et: Type)(implicit - st: ShowType): Set[Model] = + st: ShowType): Set[Model] = try collectModels(t.dealias, alreadyKnown, ListSet.empty, sfs, et) catch { case NonFatal(_) => Set.empty } private def collectModels( - t: Type, - alreadyKnown: Set[Model], - known: TypeSet, - sfs: SwaggerFormats, - et: Type)(implicit st: ShowType): Set[Model] = { + t: Type, + alreadyKnown: Set[Model], + known: TypeSet, + sfs: SwaggerFormats, + et: Type)(implicit st: ShowType): Set[Model] = { def go(t: Type, alreadyKnown: Set[Model], known: TypeSet): Set[Model] = t.dealias match { @@ -50,7 +50,7 @@ object TypeBuilder { go(tpe.typeArgs.last, alreadyKnown, known + tpe) case tpe - if (tpe.isCollection || tpe.isOption || tpe.isEffect(et)) && tpe.typeArgs.nonEmpty => + if (tpe.isCollection || tpe.isOption || tpe.isEffect(et)) && tpe.typeArgs.nonEmpty => go(tpe.typeArgs.head, alreadyKnown, known + tpe) case tpe if tpe.isStream => @@ -75,7 +75,7 @@ object TypeBuilder { }.toSet case tpe @ TypeRef(_, sym: Symbol, tpeArgs: List[Type]) - if isCaseClass(sym) || isSumType(sym) => + if isCaseClass(sym) || isSumType(sym) => val symIsSumType = isSumType(sym) val maybeParentSumType = sym.asClass.baseClasses.drop(1).find(isSumType) @@ -157,7 +157,7 @@ object TypeBuilder { } private def modelToSwagger(tpe: Type, sfs: SwaggerFormats)(implicit - st: ShowType): Option[ModelImpl] = + st: ShowType): Option[ModelImpl] = try { val TypeRef(_, sym: Symbol, tpeArgs: List[Type]) = tpe val constructor = tpe.member(termNames.CONSTRUCTOR) @@ -185,7 +185,7 @@ object TypeBuilder { } private def paramSymToProp(sym: Symbol, tpeArgs: List[Type], sfs: SwaggerFormats)(pSym: Symbol)( - implicit st: ShowType): (String, Property) = { + implicit st: ShowType): (String, Property) = { val pType = pSym.typeSignature.substituteTypes(sym.asClass.typeParams, tpeArgs) val required = !(pSym.asTerm.isParamWithDefault || pType.isOption) val prop = typeToProperty(pType, sfs) @@ -245,15 +245,15 @@ object TypeBuilder { object DataType { case class ValueDataType( - name: String, - format: Option[String] = None, - qualifiedName: Option[String] = None) - extends DataType + name: String, + format: Option[String] = None, + qualifiedName: Option[String] = None) + extends DataType case class ContainerDataType( - name: String, - typeArg: Option[DataType] = None, - uniqueItems: Boolean = false) - extends DataType + name: String, + typeArg: Option[DataType] = None, + uniqueItems: Boolean = false) + extends DataType case class ComplexDataType(name: String, qualifiedName: Option[String] = None) extends DataType case class EnumDataType(enums: Set[String]) extends DataType { val name = "string" } diff --git a/swagger/src/test/scala/org/http4s/rho/swagger/TypeBuilderSuite.scala b/swagger/src/test/scala/org/http4s/rho/swagger/TypeBuilderSuite.scala index 462505c66..bef45736b 100644 --- a/swagger/src/test/scala/org/http4s/rho/swagger/TypeBuilderSuite.scala +++ b/swagger/src/test/scala/org/http4s/rho/swagger/TypeBuilderSuite.scala @@ -72,7 +72,7 @@ package object model { modelOfWithFormats(DefaultSwaggerFormats) def modelOfWithFormats[T]( - formats: SwaggerFormats)(implicit t: TypeTag[T], st: ShowType): Set[Model] = + formats: SwaggerFormats)(implicit t: TypeTag[T], st: ShowType): Set[Model] = TypeBuilder.collectModels(t.tpe, Set.empty, formats, typeOf[IO[_]]) }