Skip to content

Commit

Permalink
Handle nonclass Scala types and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jarrod Young committed Jul 27, 2021
1 parent 80962c7 commit 5a5e88c
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 48 deletions.
101 changes: 56 additions & 45 deletions swagger/src/main/scala/org/http4s/rho/swagger/TypeBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ object TypeBuilder {
private[this] val logger = getLogger

def collectModels(t: Type, alreadyKnown: Set[Model], sfs: SwaggerFormats, et: Type)(implicit
st: ShowType): Set[Model] =
st: ShowType): Set[Model] =
try collectModels(t.dealias, alreadyKnown, ListSet.empty, sfs, et)
catch { case NonFatal(_) => Set.empty }

private def collectModels(
t: Type,
alreadyKnown: Set[Model],
known: TypeSet,
sfs: SwaggerFormats,
et: Type)(implicit st: ShowType): Set[Model] = {
t: Type,
alreadyKnown: Set[Model],
known: TypeSet,
sfs: SwaggerFormats,
et: Type)(implicit st: ShowType): Set[Model] = {

def go(t: Type, alreadyKnown: Set[Model], known: TypeSet): Set[Model] =
t.dealias match {
Expand All @@ -50,7 +50,7 @@ object TypeBuilder {
go(tpe.typeArgs.last, alreadyKnown, known + tpe)

case tpe
if (tpe.isCollection || tpe.isOption || tpe.isEffect(et)) && tpe.typeArgs.nonEmpty =>
if (tpe.isCollection || tpe.isOption || tpe.isEffect(et)) && tpe.typeArgs.nonEmpty =>
go(tpe.typeArgs.head, alreadyKnown, known + tpe)

case tpe if tpe.isStream =>
Expand All @@ -75,7 +75,7 @@ object TypeBuilder {
}.toSet

case tpe @ TypeRef(_, sym: Symbol, tpeArgs: List[Type])
if isCaseClass(sym) || isSumType(sym) =>
if isCaseClass(sym) || isSumType(sym) =>
val symIsSumType = isSumType(sym)
val maybeParentSumType = sym.asClass.baseClasses.drop(1).find(isSumType)

Expand Down Expand Up @@ -157,7 +157,7 @@ object TypeBuilder {
}

private def modelToSwagger(tpe: Type, sfs: SwaggerFormats)(implicit
st: ShowType): Option[ModelImpl] =
st: ShowType): Option[ModelImpl] =
try {
val TypeRef(_, sym: Symbol, tpeArgs: List[Type]) = tpe
val constructor = tpe.member(termNames.CONSTRUCTOR)
Expand Down Expand Up @@ -185,7 +185,7 @@ object TypeBuilder {
}

private def paramSymToProp(sym: Symbol, tpeArgs: List[Type], sfs: SwaggerFormats)(pSym: Symbol)(
implicit st: ShowType): (String, Property) = {
implicit st: ShowType): (String, Property) = {
val pType = pSym.typeSignature.substituteTypes(sym.asClass.typeParams, tpeArgs)
val required = !(pSym.asTerm.isParamWithDefault || pType.isOption)
val prop = typeToProperty(pType, sfs)
Expand All @@ -210,43 +210,50 @@ object TypeBuilder {
ArrayProperty(items = itemProperty)
} else if (tpe.isOption)
typeToProperty(tpe.typeArgs.head, sfs).withRequired(false)
else if (tpe.isAnyVal && !tpe.isPrimitive)
typeToProperty(
ptSym.asClass.primaryConstructor.asMethod.paramLists.flatten.head.typeSignature,
sfs
)
else if (isCaseClass(ptSym) || (isSumType(ptSym) && !isObjectEnum(ptSym)))
RefProperty(tpe.simpleName)
else
DataType.fromType(tpe) match {
case DataType.ValueDataType(name, format, qName) =>
AbstractProperty(`type` = name, description = qName, format = format)
case DataType.ComplexDataType(name, qName) =>
AbstractProperty(`type` = name, description = qName)
case DataType.ContainerDataType(name, _, _) =>
AbstractProperty(`type` = name)
case DataType.EnumDataType(enums) =>
StringProperty(enums = enums)
else if (tpe.isAnyVal && !tpe.isPrimitive && ptSym.isClass) {
val symbolOption = ptSym.asClass.primaryConstructor.asMethod.paramLists.flatten.headOption
symbolOption match {
case Some(symbol) =>
typeToProperty(
symbol.typeSignature,
sfs
)
case None => dataTypeFromType(tpe)
}
} else if (isCaseClass(ptSym) || (isSumType(ptSym) && !isObjectEnum(ptSym)))
RefProperty(tpe.simpleName)
else dataTypeFromType(tpe)
}
)

private def dataTypeFromType(tpe: Type)(implicit showType: ShowType): Property =
DataType.fromType(tpe) match {
case DataType.ValueDataType(name, format, qName) =>
AbstractProperty(`type` = name, description = qName, format = format)
case DataType.ComplexDataType(name, qName) =>
AbstractProperty(`type` = name, description = qName)
case DataType.ContainerDataType(name, _, _) =>
AbstractProperty(`type` = name)
case DataType.EnumDataType(enums) =>
StringProperty(enums = enums)
}

sealed trait DataType {
def name: String
}

object DataType {

case class ValueDataType(
name: String,
format: Option[String] = None,
qualifiedName: Option[String] = None)
extends DataType
name: String,
format: Option[String] = None,
qualifiedName: Option[String] = None)
extends DataType
case class ContainerDataType(
name: String,
typeArg: Option[DataType] = None,
uniqueItems: Boolean = false)
extends DataType
name: String,
typeArg: Option[DataType] = None,
uniqueItems: Boolean = false)
extends DataType
case class ComplexDataType(name: String, qualifiedName: Option[String] = None) extends DataType
case class EnumDataType(enums: Set[String]) extends DataType { val name = "string" }

Expand Down Expand Up @@ -288,7 +295,6 @@ object TypeBuilder {

private[swagger] def fromType(t: Type)(implicit st: ShowType): DataType = {
val klass = if (t.isOption && t.typeArgs.nonEmpty) t.typeArgs.head else t

if (klass.isNothingOrNull || klass.isUnitOrVoid)
ComplexDataType("string", qualifiedName = Option(klass.fullName))
else if (isString(klass)) this.String
Expand All @@ -312,21 +318,26 @@ object TypeBuilder {
if (t.typeArgs.nonEmpty) GenArray(fromType(t.typeArgs(1)))
else GenArray()
} else if (klass <:< typeOf[AnyVal]) {
fromType(
klass.members
.filter(_.isConstructor)
.flatMap(_.asMethod.paramLists.flatten)
.head
.typeSignature
)
val klassSymbolOption = klass.members
.filter(_.isConstructor)
.flatMap(_.asMethod.paramLists.flatten)
.headOption
klassSymbolOption match {
case Some(symbol) => fromType(symbol.typeSignature)
case None => fallBackDataTypeFromName(t)
}
} else if (isObjectEnum(klass.typeSymbol)) {
EnumDataType(klass.typeSymbol.asClass.knownDirectSubclasses.map(_.name.toString))
} else {
val stt = if (t.isOption) t.typeArgs.head else t
ComplexDataType("string", qualifiedName = Option(stt.fullName))
fallBackDataTypeFromName(t)
}
}

private def fallBackDataTypeFromName(t: Type)(implicit st: ShowType): DataType = {
val stt = if (t.isOption) t.typeArgs.head else t
ComplexDataType("string", qualifiedName = Option(stt.fullName))
}

private[this] val IntTypes =
Set[Type](
typeOf[Int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package org.http4s.rho.swagger

import java.sql.Timestamp
import java.util.Date

import cats.effect.IO
import cats.syntax.all._
import fs2.Stream
Expand All @@ -29,13 +28,15 @@ package object model {
case class FooWithMap(l: Map[String, Int])
case class FooVal(foo: Foo) extends AnyVal
case class BarWithFooVal(fooVal: FooVal)
case class AnyValClass(anyVal: AnyVal)
type AnyValType = AnyVal
case class ClassWithAnyValType(anyVal: AnyValType)
@DiscriminatorField("foobar")
sealed trait Sealed {
def foo: String
}
case class FooSealed(a: Int, foo: String, foo2: Foo) extends Sealed
case class BarSealed(str: String, foo: String) extends Sealed

sealed trait SealedEnum
case object FooEnum extends SealedEnum
case object BarEnum extends SealedEnum
Expand Down Expand Up @@ -71,7 +72,7 @@ package object model {
modelOfWithFormats(DefaultSwaggerFormats)

def modelOfWithFormats[T](
formats: SwaggerFormats)(implicit t: TypeTag[T], st: ShowType): Set[Model] =
formats: SwaggerFormats)(implicit t: TypeTag[T], st: ShowType): Set[Model] =
TypeBuilder.collectModels(t.tpe, Set.empty, formats, typeOf[IO[_]])
}

Expand Down Expand Up @@ -413,6 +414,16 @@ class TypeBuilderSuite extends FunSuite {
assertEquals(model, modelOf[Sealed])
}

test("A TypeBuilder should fall back to the class name for a class containing an AnyVal") {
val m = modelOf[AnyValClass].head
assertEquals(m.description, "AnyValClass".some)
}

test("A TypeBuilder should fall back to the class name for a custom type containing an AnyVal") {
val m = modelOf[ClassWithAnyValType].head
assertEquals(m.description, "ClassWithAnyValType".some)
}

test("A TypeBuilder should build a model for two-level sealed trait hierarchy") {
val ms = modelOf[TopLevelSealedTrait]
assertEquals(ms.size, 5)
Expand Down

0 comments on commit 5a5e88c

Please sign in to comment.