diff --git a/.travis.yml b/.travis.yml index 456414ebce..ea40887262 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,5 +6,5 @@ jdk: before_script: - mysql -e "create database IF NOT EXISTS cromwell_test;" -script: sbt -Dbackend.shared-filesystem.localization.0=copy clean coverage nodocker:test +script: sbt -Dbackend.shared-filesystem.localization.0=copy clean coverage nodocker:test assembly after_success: sbt coveralls diff --git a/build.sbt b/build.sbt index dc3b6d6e7a..db87b2def3 100644 --- a/build.sbt +++ b/build.sbt @@ -3,22 +3,32 @@ import sbtrelease.ReleasePlugin._ name := "cromwell" -version := "0.12" +version := "0.13" organization := "org.broadinstitute" scalaVersion := "2.11.7" +val lenthallV = "0.13" + val sprayV = "1.3.2" val DowngradedSprayV = "1.3.1" val akkaV = "2.3.12" +val slickV = "3.1.0" + val googleClientApiV = "1.20.0" +resolvers ++= Seq( + "Broad Artifactory Releases" at "https://artifactory.broadinstitute.org/artifactory/libs-release/", + "Broad Artifactory Snapshots" at "https://artifactory.broadinstitute.org/artifactory/libs-snapshot/") + libraryDependencies ++= Seq( - "com.gettyimages" %% "spray-swagger" % "0.5.1", + "org.broadinstitute" %% "lenthall" % lenthallV, + "com.typesafe.scala-logging" %% "scala-logging" % "3.1.0", + "org.joda" % "joda-convert" % "1.8.1", "org.webjars" % "swagger-ui" % "2.1.1", "io.spray" %% "spray-can" % sprayV, "io.spray" %% "spray-routing" % sprayV, @@ -33,14 +43,14 @@ libraryDependencies ++= Seq( "ch.qos.logback" % "logback-classic" % "1.1.3", "ch.qos.logback" % "logback-access" % "1.1.3", "org.codehaus.janino" % "janino" % "2.7.8", - "com.typesafe.slick" %% "slick" % "3.0.0", - "com.zaxxer" % "HikariCP" % "2.3.3", + "com.typesafe.slick" %% "slick" % slickV, + "com.typesafe.slick" %% "slick-hikaricp" % slickV, "org.hsqldb" % "hsqldb" % "2.3.2", "com.google.gcloud" % "gcloud-java" % "latest.integration", "com.google.api-client" % "google-api-client-java6" % googleClientApiV, "com.google.api-client" % "google-api-client-jackson2" % googleClientApiV, "com.google.oauth-client" % "google-oauth-client" % googleClientApiV, - "mysql" % "mysql-connector-java" % "5.1.35", + "mysql" % "mysql-connector-java" % "5.1.36", "org.scalaz" % "scalaz-core_2.11" % "7.1.3", //---------- Test libraries -------------------// "io.spray" %% "spray-testkit" % sprayV % Test, @@ -81,14 +91,20 @@ val customMergeStrategy: String => MergeStrategy = { MergeStrategy.filterDistinctLines case _ => MergeStrategy.deduplicate } - case "asm-license.txt" | "overview.html" => + case "asm-license.txt" | "overview.html" | "cobertura.properties" => MergeStrategy.discard case _ => MergeStrategy.deduplicate } assemblyMergeStrategy in assembly := customMergeStrategy -scalacOptions ++= Seq("-deprecation", "-unchecked", "-feature") +// The reason why -Xmax-classfile-name is set is because this will fail +// to build on Docker otherwise. The reason why it's 200 is because it +// fails if the value is too close to 256 (even 254 fails). For more info: +// +// https://github.com/sbt/sbt-assembly/issues/69 +// https://github.com/scala/pickling/issues/10 +scalacOptions ++= Seq("-deprecation", "-unchecked", "-feature", "-Xmax-classfile-name", "200") lazy val DockerTest = config("docker") extend Test @@ -107,3 +123,6 @@ testOptions in DockerTest += Tests.Argument("-n", "DockerTest") testOptions in NoDockerTest += Tests.Argument("-l", "DockerTest") test in assembly := {} + +parallelExecution := false + diff --git a/src/main/resources/application.conf b/src/main/resources/application.conf index 9797eb6605..148c78d2cb 100644 --- a/src/main/resources/application.conf +++ b/src/main/resources/application.conf @@ -8,6 +8,11 @@ akka { loggers = ["akka.event.slf4j.Slf4jLogger"] } +swagger { + docsPath = "swagger/cromwell.yaml" + uiVersion = "2.1.1" +} + spray.can { server { request-timeout = 40s @@ -94,18 +99,18 @@ database { main { hsqldb { - url = "jdbc:hsqldb:mem:${slick.uniqueSchema};shutdown=false;hsqldb.tx=mvcc" - driver = "org.hsqldb.jdbcDriver" - slick.driver = "slick.driver.HsqldbDriver" + db.url = "jdbc:hsqldb:mem:${slick.uniqueSchema};shutdown=false;hsqldb.tx=mvcc" + db.driver = "org.hsqldb.jdbcDriver" + driver = "slick.driver.HsqldbDriver$" slick.createSchema = true } } test { hsqldb { - url = "jdbc:hsqldb:mem:testdb;shutdown=false;hsqldb.tx=mvcc" - driver = "org.hsqldb.jdbcDriver" - slick.driver = "slick.driver.HsqldbDriver" + db.url = "jdbc:hsqldb:mem:testdb;shutdown=false;hsqldb.tx=mvcc" + db.driver = "org.hsqldb.jdbcDriver" + driver = "slick.driver.HsqldbDriver$" liquibase = { changelog = "src/main/migrations/changelog.xml" connection = "liquibase.database.jvm.HsqlConnection" @@ -113,11 +118,11 @@ database { } mysql { - url = "jdbc:mysql://localhost/cromwell_test" - user = "travis" - password = "" - driver = "com.mysql.jdbc.Driver" - slick.driver = "slick.driver.MySQLDriver" + db.url = "jdbc:mysql://localhost/cromwell_test" + db.user = "travis" + db.password = "" + db.driver = "com.mysql.jdbc.Driver" + driver = "slick.driver.MySQLDriver$" liquibase = { changelog = "src/main/migrations/changelog.xml" } diff --git a/src/main/scala/cromwell/binding/AstTools.scala b/src/main/scala/cromwell/binding/AstTools.scala index 9a77dc4c56..0425d4fcdd 100644 --- a/src/main/scala/cromwell/binding/AstTools.scala +++ b/src/main/scala/cromwell/binding/AstTools.scala @@ -52,6 +52,28 @@ object AstTools { } def wdlValue(wdlType: WdlType, wdlSyntaxErrorFormatter: WdlSyntaxErrorFormatter): WdlValue = { + + def astToMap(ast: Ast) = { + val mapType = wdlType.asInstanceOf[WdlMapType] + val elements = ast.getAttribute("map").asInstanceOf[AstList].asScala.toVector.map({ kvnode => + val k = kvnode.asInstanceOf[Ast].getAttribute("key").wdlValue(mapType.keyType, wdlSyntaxErrorFormatter) + val v = kvnode.asInstanceOf[Ast].getAttribute("value").wdlValue(mapType.valueType, wdlSyntaxErrorFormatter) + k -> v + }).toMap + + WdlMap(mapType, elements) + } + + def astToObject(ast: Ast) = { + val elements = ast.getAttribute("map").asInstanceOf[AstList].asScala.toVector.map({ kvnode => + val k = kvnode.asInstanceOf[Ast].getAttribute("key").sourceString + val v = kvnode.asInstanceOf[Ast].getAttribute("value").wdlValue(WdlStringType, wdlSyntaxErrorFormatter) + k -> v + }).toMap + + WdlObject(elements) + } + astNode match { case t: Terminal if t.getTerminalStr == "string" && wdlType == WdlStringType => WdlString(t.getSourceString) case t: Terminal if t.getTerminalStr == "string" && wdlType == WdlFileType => WdlFile(t.getSourceString) @@ -61,21 +83,15 @@ object AstTools { case "true" => WdlBoolean.True case "false" => WdlBoolean.False } - // TODO: The below cases, ArrayLiteral and MapLiteral are brittle. They recursively call this wdlValue(). + // TODO: The below cases, ArrayLiteral and MapLiteral, ObjectLiteral are brittle. They recursively call this wdlValue(). // However, those recursive calls might contain full-on expressions instead of just other literals. This // whole thing ought to be part of the regular expression evaluator, though I imagine that's non-trivial. case a: Ast if a.getName == "ArrayLiteral" && wdlType.isInstanceOf[WdlArrayType] => val arrType = wdlType.asInstanceOf[WdlArrayType] val elements = a.getAttribute("values").astListAsVector map {node => node.wdlValue(arrType.memberType, wdlSyntaxErrorFormatter)} WdlArray(arrType, elements) - case a: Ast if a.getName == "MapLiteral" && wdlType.isInstanceOf[WdlMapType] => - val mapType = wdlType.asInstanceOf[WdlMapType] - val elements = a.getAttribute("map").asInstanceOf[AstList].asScala.toVector.map({ kvnode => - val k = kvnode.asInstanceOf[Ast].getAttribute("key").wdlValue(mapType.keyType, wdlSyntaxErrorFormatter) - val v = kvnode.asInstanceOf[Ast].getAttribute("value").wdlValue(mapType.valueType, wdlSyntaxErrorFormatter) - k -> v - }).toMap - WdlMap(mapType, elements) + case a: Ast if a.getName == "MapLiteral" && wdlType.isInstanceOf[WdlMapType] => astToMap(a) + case a: Ast if a.getName == "ObjectLiteral" && wdlType == WdlObjectType => astToObject(a) case _ => throw new SyntaxError(s"Could not convert AST to a $wdlType (${Option(astNode).getOrElse("No AST").toString})") } } diff --git a/src/main/scala/cromwell/binding/WdlExpression.scala b/src/main/scala/cromwell/binding/WdlExpression.scala index ac941d5778..a751eb7505 100644 --- a/src/main/scala/cromwell/binding/WdlExpression.scala +++ b/src/main/scala/cromwell/binding/WdlExpression.scala @@ -26,6 +26,7 @@ object WdlExpression { def isMemberAccess: Boolean = ast.getName == "MemberAccess" def isArrayLiteral: Boolean = ast.getName == "ArrayLiteral" def isMapLiteral: Boolean = ast.getName == "MapLiteral" + def isObjectLiteral: Boolean = ast.getName == "ObjectLiteral" def isArrayOrMapLookup: Boolean = ast.getName == "ArrayOrMapLookup" def params = ast.getAttribute("params").asInstanceOf[AstList].asScala.toVector def name = ast.getAttribute("name").asInstanceOf[Terminal].getSourceString @@ -70,7 +71,8 @@ object WdlExpression { "read_float", "read_boolean", "read_lines", - "read_map" + "read_map", + "read_object" ) def evaluate(ast: AstNode, lookup: ScopedLookupFunction, functions: WdlFunctions[WdlValue]): Try[WdlValue] = diff --git a/src/main/scala/cromwell/binding/WdlNamespace.scala b/src/main/scala/cromwell/binding/WdlNamespace.scala index c1784f7d69..4fe6056f96 100644 --- a/src/main/scala/cromwell/binding/WdlNamespace.scala +++ b/src/main/scala/cromwell/binding/WdlNamespace.scala @@ -3,7 +3,7 @@ package cromwell.binding import java.io.File import cromwell.binding.AstTools.{AstNodeName, EnhancedAstNode, EnhancedAstSeq} -import cromwell.binding.expression.NoFunctions +import cromwell.binding.expression.WdlStandardLibraryFunctions import cromwell.binding.types._ import cromwell.binding.values._ import cromwell.parser.WdlParser._ @@ -107,13 +107,13 @@ case class NamespaceWithWorkflow(importedAs: Option[String], * For the declarations that have an expression attached to it already, evaluate the expression * and return the value for storage in the symbol store */ - def staticDeclarationsRecursive(userInputs: WorkflowCoercedInputs): Try[WorkflowCoercedInputs] = { + def staticDeclarationsRecursive(userInputs: WorkflowCoercedInputs, wdlFunctions: WdlStandardLibraryFunctions): Try[WorkflowCoercedInputs] = { import scala.collection.mutable val collected = mutable.Map[String, WdlValue]() val allDeclarations = workflow.declarations ++ workflow.calls.flatMap {_.task.declarations} val evaluatedDeclarations = allDeclarations.filter {_.expression.isDefined}.map {decl => - val value = decl.expression.get.evaluate(declarationLookupFunction(decl, collected.toMap ++ userInputs), new NoFunctions) + val value = decl.expression.get.evaluate(declarationLookupFunction(decl, collected.toMap ++ userInputs), wdlFunctions) collected += (decl.fullyQualifiedName -> value.get) val coercedValue = value match { case Success(s) => decl.wdlType.coerceRawValue(s) diff --git a/src/main/scala/cromwell/binding/command/ParameterCommandPart.scala b/src/main/scala/cromwell/binding/command/ParameterCommandPart.scala index 17b19d2059..72f21ac656 100644 --- a/src/main/scala/cromwell/binding/command/ParameterCommandPart.scala +++ b/src/main/scala/cromwell/binding/command/ParameterCommandPart.scala @@ -33,7 +33,7 @@ case class ParameterCommandPart(attributes: Map[String, String], expression: Wdl case Some(d) => throw new UnsupportedOperationException(s"Parameter ${v.variable} is required, but no value is specified") case None => throw new UnsupportedOperationException(s"This should not happen: could not find declaration for ${v.variable}") } - case _ => throw new UnsupportedOperationException(s"Could not evaluate expression: ${expression.toString}") + case e => throw new UnsupportedOperationException(s"Could not evaluate expression: ${expression.toWdlString}", e) } } diff --git a/src/main/scala/cromwell/binding/expression/WdlStandardLibraryFunctions.scala b/src/main/scala/cromwell/binding/expression/WdlStandardLibraryFunctions.scala index 20c522594a..da7fad6589 100644 --- a/src/main/scala/cromwell/binding/expression/WdlStandardLibraryFunctions.scala +++ b/src/main/scala/cromwell/binding/expression/WdlStandardLibraryFunctions.scala @@ -17,10 +17,22 @@ trait WdlStandardLibraryFunctions extends WdlFunctions[WdlValue] { protected def read_object(params: Seq[Try[WdlValue]]): Try[WdlObject] = fail("read_objects") protected def read_objects(params: Seq[Try[WdlValue]]): Try[WdlArray] = fail("read_objects") protected def read_json(params: Seq[Try[WdlValue]]): Try[WdlValue] = fail("read_json") - protected def read_int(params: Seq[Try[WdlValue]]): Try[WdlInteger] = fail("read_int") protected def read_string(params: Seq[Try[WdlValue]]): Try[WdlString] = fail("read_string") - protected def read_float(params: Seq[Try[WdlValue]]): Try[WdlFloat] = fail("read_float") - protected def read_boolean(params: Seq[Try[WdlValue]]): Try[WdlBoolean] = fail("read_boolean") + /** + * Try to read an integer from the file referenced by the specified `WdlValue`. + */ + protected def read_int(params: Seq[Try[WdlValue]]): Try[WdlInteger] = + read_string(params) map { s => WdlInteger(s.value.trim.toInt) } + /** + * Try to read a float from the file referenced by the specified `WdlValue`. + */ + protected def read_float(params: Seq[Try[WdlValue]]): Try[WdlFloat] = + read_string(params) map { s => WdlFloat(s.value.trim.toDouble) } + /** + * Try to read a boolean from the file referenced by the specified `WdlValue`. + */ + protected def read_boolean(params: Seq[Try[WdlValue]]): Try[WdlBoolean] = + read_string(params) map { s => WdlBoolean(java.lang.Boolean.parseBoolean(s.value.trim.toLowerCase)) } protected def write_lines(params: Seq[Try[WdlValue]]): Try[WdlFile] = fail("write_lines") protected def write_tsv(params: Seq[Try[WdlValue]]): Try[WdlFile] = fail("write_tsv") protected def write_map(params: Seq[Try[WdlValue]]): Try[WdlFile] = fail("write_map") diff --git a/src/main/scala/cromwell/binding/types/WdlArrayType.scala b/src/main/scala/cromwell/binding/types/WdlArrayType.scala index e0f8bd0582..5ad67dc715 100644 --- a/src/main/scala/cromwell/binding/types/WdlArrayType.scala +++ b/src/main/scala/cromwell/binding/types/WdlArrayType.scala @@ -1,6 +1,6 @@ package cromwell.binding.types -import cromwell.binding.values.{WdlArray, WdlFile, WdlString} +import cromwell.binding.values.{WdlValue, WdlArray, WdlFile, WdlString} import spray.json.JsArray case class WdlArrayType(memberType: WdlType) extends WdlType { @@ -27,3 +27,14 @@ case class WdlArrayType(memberType: WdlType) extends WdlType { case _ => false } } + +object WdlArrayType { + + implicit class WdlArrayEnhanced(wdlType: WdlType) extends WdlType { + + override protected def coercion: PartialFunction[Any, WdlValue] = wdlType.coercion + override def toWdlString: String = wdlType.toWdlString + + def isAnArrayOf(genericType: WdlType) = wdlType.isInstanceOf[WdlArrayType] && wdlType.asInstanceOf[WdlArrayType].memberType == genericType + } +} \ No newline at end of file diff --git a/src/main/scala/cromwell/binding/types/WdlMapType.scala b/src/main/scala/cromwell/binding/types/WdlMapType.scala index b07827c0dd..0fc50ceb69 100644 --- a/src/main/scala/cromwell/binding/types/WdlMapType.scala +++ b/src/main/scala/cromwell/binding/types/WdlMapType.scala @@ -11,10 +11,12 @@ case class WdlMapType(keyType: WdlType, valueType: WdlType) extends WdlType { case m: Map[_, _] if m.isEmpty => WdlMap(WdlMapType(keyType, valueType), Map()) case js: JsObject if js.fields.nonEmpty => WdlMap.coerceMap(js.fields, this) case wdlMap: WdlMap => WdlMap.coerceMap(wdlMap.value, this) + case o: WdlObject => WdlMap.coerceMap(o.value, this) } override def isCoerceableFrom(otherType: WdlType): Boolean = otherType match { case m: WdlMapType => keyType.isCoerceableFrom(m.keyType) && valueType.isCoerceableFrom(m.valueType) + case WdlObjectType => keyType.isCoerceableFrom(WdlStringType) && valueType.isCoerceableFrom(WdlStringType) case _ => false } } diff --git a/src/main/scala/cromwell/binding/types/WdlObjectType.scala b/src/main/scala/cromwell/binding/types/WdlObjectType.scala index 185ddbdb53..69e7dc4be6 100644 --- a/src/main/scala/cromwell/binding/types/WdlObjectType.scala +++ b/src/main/scala/cromwell/binding/types/WdlObjectType.scala @@ -1,16 +1,52 @@ package cromwell.binding.types import cromwell.binding.Call -import cromwell.binding.values.{WdlCallOutputsObject, WdlObject} +import cromwell.binding.values._ +import spray.json.JsObject + +import scala.util.{Try, Failure, Success} case object WdlObjectType extends WdlType { val toWdlString: String = "Object" + private def handleCoercionFailures(tries: Try[_]*) = { + val errorMessages = tries collect { + case Failure(f) => f.getMessage + } mkString "," + + throw new UnsupportedOperationException(s"Coercion failed: $errorMessages") + } + override protected def coercion = { case o: WdlObject => o + case m: WdlMap if isMapCoercable(m) => + val coercedMap = m.value map { + case (k, v) => toWdlString(k) -> toWdlString(v) + } collect { + case (Success(k), Success(v)) => k.value -> v + case (k, v) => handleCoercionFailures(k, v) + } + + WdlObject(coercedMap) + case js: JsObject => + val coercedMap = WdlMap.coerceMap(js.fields, WdlMapType(WdlStringType, WdlAnyType)).value map { + // get is safe because coerceMap above would have failed already if k was not coerceable to WdlString + case (k, v) => toWdlString(k).get.value -> v + } + + WdlObject(coercedMap) + } + + private def toWdlString(v: WdlValue) = WdlStringType.coerceRawValue(v).map(_.asInstanceOf[WdlString]) + + override def isCoerceableFrom(otherType: WdlType) = otherType match { + case WdlObjectType => true + case t: WdlMapType if isMapTypeCoercable(t) => true + case _ => false } - override def fromWdlString(rawString: String) = ??? + def isMapTypeCoercable(t: WdlMapType) = WdlStringType.isCoerceableFrom(t.keyType) && WdlStringType.isCoerceableFrom(t.valueType) + def isMapCoercable(m: WdlMap) = isMapTypeCoercable(m.wdlType) } case class WdlCallOutputsObjectType(call: Call) extends WdlType { diff --git a/src/main/scala/cromwell/binding/types/WdlType.scala b/src/main/scala/cromwell/binding/types/WdlType.scala index a306694f33..51885b617d 100644 --- a/src/main/scala/cromwell/binding/types/WdlType.scala +++ b/src/main/scala/cromwell/binding/types/WdlType.scala @@ -43,6 +43,7 @@ trait WdlType { * @param wdlSource source code representing the WdlValue * @return The WdlValue */ + //TODO: return a Try ? def fromWdlString(wdlSource: WdlSource): WdlValue = { val tokens = WdlType.parser.lex(wdlSource, "string") val terminalMap = tokens.asScala.toVector.map {(_, wdlSource)}.toMap diff --git a/src/main/scala/cromwell/binding/values/WdlArray.scala b/src/main/scala/cromwell/binding/values/WdlArray.scala index cb544caaa0..f9596cb3af 100644 --- a/src/main/scala/cromwell/binding/values/WdlArray.scala +++ b/src/main/scala/cromwell/binding/values/WdlArray.scala @@ -1,6 +1,6 @@ package cromwell.binding.values -import cromwell.binding.types.{WdlArrayType, WdlPrimitiveType} +import cromwell.binding.types.{WdlObjectType, WdlArrayType, WdlPrimitiveType} import scala.util.{Failure, Success, Try} @@ -25,6 +25,7 @@ case class WdlArray(wdlType: WdlArrayType, value: Seq[WdlValue]) extends WdlValu def tsvSerialize: Try[String] = { wdlType.memberType match { case t: WdlPrimitiveType => Success(value.map(_.valueString).mkString("\n")) + case WdlObjectType => WdlObject.tsvSerializeArray(value map { _.asInstanceOf[WdlObject] }) case _ => Failure(new UnsupportedOperationException("Can only TSV serialize an Array[Primitive]")) } } diff --git a/src/main/scala/cromwell/binding/values/WdlObject.scala b/src/main/scala/cromwell/binding/values/WdlObject.scala index cbc764b2e7..18a79da3f6 100644 --- a/src/main/scala/cromwell/binding/values/WdlObject.scala +++ b/src/main/scala/cromwell/binding/values/WdlObject.scala @@ -1,14 +1,85 @@ package cromwell.binding.values import cromwell.binding.Call -import cromwell.binding.types.{WdlCallOutputsObjectType, WdlObjectType} +import cromwell.binding.types._ +import cromwell.util.FileUtil + +import scala.util.{Failure, Success, Try} trait WdlObjectLike { def value: Map[String, WdlValue] } +object WdlObject { + + def coerceObject(m: Map[String, String]): WdlObject = { + val coerced = WdlMap.coerceMap(m, WdlMapType(WdlStringType, WdlAnyType)).value map { + case (k, v) => k.valueString -> v + } + + WdlObject(coerced) + } + + def fromTsv(tsv: String): Try[Array[WdlObject]] = { + FileUtil.parseTsv(tsv) match { + case Success(table) if table.isEmpty => Failure(new UnsupportedOperationException("TSV file was empty or could not be parsed.")) + case Success(table) if table.length < 2 => Failure(new UnsupportedOperationException("TSV must be 2 rows (or more) to convert to an Object (Array[Object])")) + case Success(table) => Try { + table.tail map { line => coerceObject((table.head zip line).toMap) } + } + case Failure(e) => Failure(e) + } + } + + //TODO: Try to stream this out to avoid memory overhead + def tsvSerializeArray(input: Seq[WdlObject]): Try[String] = { + + /** + * Validates that all objects have the same attributes. + */ + def attributesMatch(objects: Seq[WdlObject]): Boolean = { + val attributesSet = objects map { _.orderedAttributes } + val intersection = attributesSet reduce (_.intersect(_)) + attributesSet forall { _ == intersection } + } + + input match { + case Nil => Failure(new UnsupportedOperationException("Cannot write empty objects array.")) + case objects if attributesMatch(objects) => + /* Note: this is arbitrary as it takes the first object in the array as a reference for the attributes. + * It has no impact though as we first made sure that all objects have the same attributes. + */ + val attributes = objects.head.orderedAttributes + val attributesLine = attributes.mkString("\t") + val valuesLines = objects map { obj => + attributes map { obj.value(_).valueString } mkString "\t" + } mkString "\n" + + Success(s"$attributesLine\n$valuesLines") + case _ => Failure(new UnsupportedOperationException("Could not serialize array: Objects in the array have different attributes.")) + } + } + +} + case class WdlObject(value: Map[String, WdlValue]) extends WdlValue with WdlObjectLike { val wdlType = WdlObjectType + + override def toWdlString: String = + "object {" + value.map {case (k, v) => s"$k: ${v.toWdlString}"}.mkString(", ") + "}" + + lazy val orderedAttributes = value.keySet.toSeq + lazy val orderedValues = orderedAttributes map { value(_) } + + def tsvSerialize: Try[String] = Try { + val keysLine = orderedAttributes.mkString("\t") + val values = orderedValues map { + case v if v.isInstanceOf[WdlPrimitive] => v.valueString + case _ => throw new UnsupportedOperationException("Can only TSV serialize an Object with Primitive values.") + } + + s"$keysLine\n${values.mkString("\t")}" + } } case class WdlCallOutputsObject(call: Call, outputs: Map[String, WdlValue]) extends WdlValue with WdlObjectLike { diff --git a/src/main/scala/cromwell/binding/values/WdlValueJsonFormatter.scala b/src/main/scala/cromwell/binding/values/WdlValueJsonFormatter.scala index 5820207bf2..46ba7f28e8 100644 --- a/src/main/scala/cromwell/binding/values/WdlValueJsonFormatter.scala +++ b/src/main/scala/cromwell/binding/values/WdlValueJsonFormatter.scala @@ -11,7 +11,7 @@ object WdlValueJsonFormatter extends DefaultJsonProtocol { case f: WdlFloat => JsNumber(f.value) case b: WdlBoolean => JsBoolean(b.value) case f: WdlFile => JsString(f.value) - case o: WdlObject => JsObject() + case o: WdlObject => new JsObject(o.value map {case(k, v) => k -> write(v)}) case a: WdlArray => new JsArray(a.value.map(write).toVector) case m: WdlMap => new JsObject(m.value map {case(k,v) => k.valueString -> write(v)}) case e: WdlExpression => JsString(e.toWdlString) diff --git a/src/main/scala/cromwell/engine/CallActor.scala b/src/main/scala/cromwell/engine/CallActor.scala index aea555e7ad..548e331437 100644 --- a/src/main/scala/cromwell/engine/CallActor.scala +++ b/src/main/scala/cromwell/engine/CallActor.scala @@ -108,7 +108,7 @@ class CallActor(key: CallKey, locallyQualifiedInputs: CallInputs, backend: Backe private def handleFinished(call: Call, executionResult: ExecutionResult): CallActor.this.State = { executionResult match { - case SuccessfulExecution(outputs) => context.parent ! WorkflowActor.CallCompleted(key, outputs) + case SuccessfulExecution(outputs, returnCode) => context.parent ! WorkflowActor.CallCompleted(key, outputs, returnCode) case AbortedExecution => context.parent ! WorkflowActor.AbortComplete(key) case FailedExecution(e, returnCode) => context.parent ! WorkflowActor.CallFailed(key, returnCode, e.getMessage) } diff --git a/src/main/scala/cromwell/engine/CallExecutionActor.scala b/src/main/scala/cromwell/engine/CallExecutionActor.scala index bfa90fc844..04159bf1b3 100644 --- a/src/main/scala/cromwell/engine/CallExecutionActor.scala +++ b/src/main/scala/cromwell/engine/CallExecutionActor.scala @@ -25,7 +25,7 @@ class CallExecutionActor(backendCall: BackendCall) extends Actor with CromwellAc val result = backendCall.execute result match { - case SuccessfulExecution(_) => log.info(s"$tag: successful execution.") + case SuccessfulExecution(_, _) => log.info(s"$tag: successful execution.") case AbortedExecution => log.info(s"$tag: aborted.") case FailedExecution(e, returnCode) => log.error(e, s"$tag: failed execution, returnCode = $returnCode") } diff --git a/src/main/scala/cromwell/engine/backend/Backend.scala b/src/main/scala/cromwell/engine/backend/Backend.scala index 36c001e77c..d49e8c3da8 100644 --- a/src/main/scala/cromwell/engine/backend/Backend.scala +++ b/src/main/scala/cromwell/engine/backend/Backend.scala @@ -2,13 +2,13 @@ package cromwell.engine.backend import com.typesafe.config.Config import cromwell.binding._ +import cromwell.binding.expression.WdlStandardLibraryFunctions import cromwell.engine.ExecutionIndex.ExecutionIndex import cromwell.engine._ import cromwell.engine.backend.Backend.RestartableWorkflow import cromwell.engine.backend.jes.JesBackend import cromwell.engine.backend.local.LocalBackend import cromwell.engine.backend.sge.SgeBackend -import cromwell.engine.db.DataAccess import cromwell.engine.workflow.{CallKey, WorkflowOptions} import cromwell.parser.BackendType @@ -16,14 +16,20 @@ import scala.concurrent.{ExecutionContext, Future} import scala.util.Try object Backend { + lazy val LocalBackend = new LocalBackend + lazy val JesBackend = new JesBackend { JesConf } //forces configuration resolution to fail now if something is missing + lazy val SgeBackend = new SgeBackend + class StdoutStderrException(message: String) extends RuntimeException(message) + def from(backendConf: Config): Backend = from(backendConf.getString("backend")) def from(name: String) = name.toLowerCase match { - case "local" => new LocalBackend - case "jes" => new JesBackend { JesConf } //forces configuration resolution to fail now if something is missing - case "sge" => new SgeBackend + case "local" => LocalBackend + case "jes" => JesBackend + case "sge" => SgeBackend case doh => throw new IllegalArgumentException(s"$doh is not a recognized backend") } + case class RestartableWorkflow(id: WorkflowId, source: WorkflowSourceFiles) } @@ -67,6 +73,11 @@ trait Backend { locallyQualifiedInputs: CallInputs, abortRegistrationFunction: AbortRegistrationFunction): BackendCall + /** + * Engine functions that don't need a Call context (e.g. read_lines(), read_float(), etc) + */ + def engineFunctions: WdlStandardLibraryFunctions + /** * Do whatever is appropriate for this backend implementation to support restarting the specified workflows. */ @@ -85,6 +96,10 @@ trait Backend { @throws[IllegalArgumentException]("if a value is missing / incorrect") def assertWorkflowOptions(options: WorkflowOptions): Unit = {} - def makeTag(backendCall: BackendCall): String = - s"${this.getClass.getSimpleName} [UUID(${backendCall.workflowDescriptor.shortId}):${backendCall.call.name}]" + def makeTag(backendCall: BackendCall): String = { + // Sometimes the class name is `anon$1`. In cases like that, don't print it in the log because it's not adding value + val cls = this.getClass.getSimpleName + val clsString = if (cls.startsWith("anon")) "" else s"$cls " + s"$clsString[UUID(${backendCall.workflowDescriptor.shortId}):${backendCall.call.name}]" + } } diff --git a/src/main/scala/cromwell/engine/backend/ExecutionResult.scala b/src/main/scala/cromwell/engine/backend/ExecutionResult.scala index bdd2a1c19f..ee00252140 100644 --- a/src/main/scala/cromwell/engine/backend/ExecutionResult.scala +++ b/src/main/scala/cromwell/engine/backend/ExecutionResult.scala @@ -13,7 +13,7 @@ sealed trait ExecutionResult /** * A successful execution with resolved outputs. */ -final case class SuccessfulExecution(outputs: CallOutputs) extends ExecutionResult +final case class SuccessfulExecution(outputs: CallOutputs, returnCode: Int) extends ExecutionResult /** * A user-requested abort of the command. diff --git a/src/main/scala/cromwell/engine/backend/jes/JesAttributes.scala b/src/main/scala/cromwell/engine/backend/jes/JesAttributes.scala index 1121b35d80..e82f43b2da 100644 --- a/src/main/scala/cromwell/engine/backend/jes/JesAttributes.scala +++ b/src/main/scala/cromwell/engine/backend/jes/JesAttributes.scala @@ -2,8 +2,9 @@ package cromwell.engine.backend.jes import java.net.URL -import com.typesafe.config.{Config, ConfigFactory} -import cromwell.util.ConfigUtil._ +import com.typesafe.config.ConfigFactory +import lenthall.config.ScalaConfig._ +import lenthall.config.ValidatedConfig._ import scala.language.postfixOps import scalaz.Scalaz._ @@ -23,6 +24,7 @@ object JesAttributes { "baseExecutionBucket", "endpointUrl", "authenticationMode", + "maximumPollingInterval", "dockerAccount", "dockerToken" ) diff --git a/src/main/scala/cromwell/engine/backend/jes/JesBackend.scala b/src/main/scala/cromwell/engine/backend/jes/JesBackend.scala index 0672ef7e6e..a97fa7fbec 100644 --- a/src/main/scala/cromwell/engine/backend/jes/JesBackend.scala +++ b/src/main/scala/cromwell/engine/backend/jes/JesBackend.scala @@ -7,21 +7,20 @@ import java.nio.file.{Path, Paths} import com.google.api.services.genomics.model.Parameter import com.typesafe.scalalogging.LazyLogging import cromwell.binding._ -import cromwell.binding.expression.NoFunctions +import cromwell.binding.expression.{NoFunctions, WdlStandardLibraryFunctions} import cromwell.binding.types.{WdlFileType, WdlType} import cromwell.binding.values._ import cromwell.engine.ExecutionIndex.ExecutionIndex import cromwell.engine.backend.Backend.RestartableWorkflow import cromwell.engine.backend._ import cromwell.engine.backend.jes.JesBackend._ -import cromwell.engine.workflow.CallKey -import cromwell.engine.{AbortFunction, AbortRegistrationFunction} -import cromwell.engine.db.DataAccess import cromwell.engine.workflow.{CallKey, WorkflowOptions} import cromwell.engine.{AbortFunction, AbortRegistrationFunction, WorkflowDescriptor} import cromwell.parser.BackendType +import cromwell.util.TryUtil import cromwell.util.google.GoogleCloudStoragePath +import scala.concurrent.duration._ import scala.concurrent.{ExecutionContext, Future} import scala.language.postfixOps import scala.util.{Failure, Success, Try} @@ -32,8 +31,6 @@ object JesBackend { where stdout.txt is input and output. Redirect stdout/stderr to a different name, but it'll be localized back in GCS as stdout/stderr. Yes, it's hacky. */ - val LocalStdoutValue = "job.stdout.txt" - val LocalStderrValue = "job.stderr.txt" val LocalWorkingDiskValue = "disk://local-disk" val WorkingDiskParamName = "working_disk" val ExtraConfigParamName = "__extra_config_gcs_path" @@ -45,19 +42,11 @@ object JesBackend { val GcsRootOptionKey = "jes_gcs_root" val OptionKeys = Set(AccountOptionKey, RefreshTokenOptionKey, GcsRootOptionKey) - def stderrJesOutput(callGcsPath: String): JesOutput = JesOutput(LocalStderrValue, s"$callGcsPath/$LocalStderrValue", Paths.get(LocalStderrValue)) - def stdoutJesOutput(callGcsPath: String): JesOutput = JesOutput(LocalStdoutValue, s"$callGcsPath/$LocalStdoutValue", Paths.get(LocalStdoutValue)) def localizationDiskInput(): JesInput = JesInput(WorkingDiskParamName, LocalWorkingDiskValue, new File(JesCromwellRoot).toPath) + def authGcsCredentialsPath(gcsPath: Option[String]): Option[JesInput] = gcsPath.map(JesInput(ExtraConfigParamName, _, Paths.get(""), "LITERAL")) - // For now we want to always redirect stdout and stderr. This could be problematic if that's what the WDL calls stuff, but oh well - def standardParameters(callGcsPath: String): Seq[JesParameter] = Seq( - stdoutJesOutput(callGcsPath), - stderrJesOutput(callGcsPath), - localizationDiskInput() - ) - // Decoration around WorkflowDescriptor to generate bucket names and the like implicit class JesWorkflowDescriptor(val descriptor: WorkflowDescriptor) extends JesBackend { def callDir(key: CallKey) = callGcsPath(descriptor, key.scope.name, key.index) @@ -177,11 +166,7 @@ class JesBackend extends Backend with LazyLogging { } override def stdoutStderr(descriptor: WorkflowDescriptor, callName: String, index: ExecutionIndex): StdoutStderr = { - val base = callGcsPath(descriptor, callName, index) - StdoutStderr( - stdout = WdlFile(s"$base/$LocalStdoutValue"), - stderr = WdlFile(s"$base/$LocalStderrValue") - ) + JesBackendCall.stdoutStderr(callGcsPath(descriptor, callName, index)) } override def bindCall(workflowDescriptor: WorkflowDescriptor, @@ -191,6 +176,8 @@ class JesBackend extends Backend with LazyLogging { JesBackendCall(this, workflowDescriptor, key, locallyQualifiedInputs, abortRegistrationFunction) } + override def engineFunctions: WdlStandardLibraryFunctions = new JesEngineFunctionsWithoutCallContext(JesConnection.storage) + def execute(backendCall: BackendCall): ExecutionResult = { val tag = makeTag(backendCall) val cmdInput = JesInput("exec", backendCall.gcsExecPath.toString, Paths.get("exec.sh")) @@ -231,13 +218,47 @@ class JesBackend extends Backend with LazyLogging { } } - private def runWithJes(backendCall: BackendCall, command: String, jesInputs: Seq[JesInput], jesOutputs: Seq[JesOutput]): ExecutionResult = { - val tag = makeTag(backendCall) - val jesParameters = standardParameters(backendCall.callGcsPath) ++ gcsAuthParameter(backendCall.workflowDescriptor) ++ jesInputs ++ jesOutputs - logger.info(s"$tag `$command`") - JesConnection.storage.uploadObject(backendCall.gcsExecPath, command) + private def uploadCommandScript(backendCall: BackendCall, command: String): Try[Unit] = { + val fileContent = + s""" + |#!/bin/bash + |$command + |echo $$? > ${JesBackendCall.RcFilename} + """.stripMargin.trim + + def attemptToUploadObject(priorAttempt: Option[Unit]) = JesConnection.storage.uploadObject(backendCall.gcsExecPath, fileContent) + + TryUtil.retryBlock( + fn = attemptToUploadObject, + retries = Some(10), + pollingInterval = 5 seconds, + pollingBackOffFactor = 1, + maxPollingInterval = 10 seconds, + failMessage = Some(s"${makeTag(backendCall)} Exception occurred while uploading script to ${backendCall.gcsExecPath}") + ) + } + + private def createJesRun(backendCall: BackendCall, jesParameters: Seq[JesParameter]): Try[Run] = { + def attemptToCreateJesRun(priorAttempt: Option[Run]): Run = Pipeline( + backendCall.jesCommandLine, + backendCall.workflowDescriptor, + backendCall.key, + jesParameters, + googleProject(backendCall.workflowDescriptor), + JesConnection + ).run + + TryUtil.retryBlock( + fn = attemptToCreateJesRun, + retries = Some(10), + pollingInterval = 5 seconds, + pollingBackOffFactor = 1, + maxPollingInterval = 10 seconds, + failMessage = Some(s"${makeTag(backendCall)} Exception occurred while creating JES Run") + ) + } - val run = Pipeline(s"/bin/bash exec.sh > $LocalStdoutValue 2> $LocalStderrValue", backendCall.workflowDescriptor, backendCall.key, jesParameters, googleProject(backendCall.workflowDescriptor), JesConnection).run + private def pollJesRun(run: Run, backendCall: BackendCall, jesOutputs: Seq[JesOutput]): ExecutionResult = { // Wait until the job starts (or completes/fails) before registering the abort to avoid awkward cancel-during-initialization behavior. val initializedStatus = run.waitUntilRunningOrComplete backendCall.callAbortRegistrationFunction.register(AbortFunction(() => run.abort())) @@ -271,21 +292,24 @@ class JesBackend extends Backend with LazyLogging { taskOutput.name -> attemptedValue } toMap - lazy val stderrLength: BigInteger = JesConnection.storage.objectSize(GoogleCloudStoragePath(stderrJesOutput(backendCall.callGcsPath).gcs)) - - if (backendCall.call.failOnStderr && stderrLength.intValue > 0) { - FailedExecution(new Throwable(s"Workflow ${backendCall.workflowDescriptor.id}: stderr has length $stderrLength for command: $command")) - } else status match { + lazy val stderrLength: BigInteger = JesConnection.storage.objectSize(GoogleCloudStoragePath(backendCall.stderrJesOutput.gcs)) + lazy val returnCode = backendCall.downloadRcFile.map(_.trim.toInt) + lazy val continueOnReturnCode = backendCall.call.continueOnReturnCode + + status match { + case Run.Success if backendCall.call.failOnStderr && stderrLength.intValue > 0 => + FailedExecution(new Throwable(s"${makeTag(backendCall)} execution failed: stderr has length $stderrLength")) + case Run.Success if returnCode.isFailure => + FailedExecution(new Throwable(s"${makeTag(backendCall)} execution failed: could not download or parse return code file", returnCode.failed.get)) + case Run.Success if !continueOnReturnCode.continueFor(returnCode.get) => + FailedExecution(new Throwable(s"${makeTag(backendCall)} execution failed: disallowed command return code: " + returnCode.get)) case Run.Success => - unwrapOutputValues(outputMappings, backendCall.workflowDescriptor) match { - case Success(outputs) => SuccessfulExecution(outputs) - case Failure(e) => FailedExecution(e) - } + handleSuccess(outputMappings, backendCall.workflowDescriptor, returnCode.get) case Run.Failed(errorCode, errorMessage) => val throwable = if (errorMessage contains "Operation canceled at") { new TaskAbortedException() } else { - new Throwable(s"Workflow ${backendCall.workflowDescriptor.id}: errorCode $errorCode for command: $command. Message: $errorMessage") + new Throwable(s"Task ${backendCall.workflowDescriptor.id}:${backendCall.call.name} failed: error code $errorCode. Message: $errorMessage") } FailedExecution(throwable) } @@ -295,14 +319,37 @@ class JesBackend extends Backend with LazyLogging { } } - private def unwrapOutputValues(outputMappings: Map[String, Try[WdlValue]], workflowDescriptor: WorkflowDescriptor): Try[Map[String, WdlValue]] = { + private def runWithJes(backendCall: BackendCall, command: String, jesInputs: Seq[JesInput], jesOutputs: Seq[JesOutput]): ExecutionResult = { + val tag = makeTag(backendCall) + val jesParameters = backendCall.standardParameters ++ gcsAuthParameter(backendCall.workflowDescriptor) ++ jesInputs ++ jesOutputs + logger.info(s"$tag `$command`") + + val jesJobSetup = for { + _ <- uploadCommandScript(backendCall, command) + run <- createJesRun(backendCall, jesParameters) + } yield run + + jesJobSetup match { + case Failure(ex) => + logger.error(s"$tag Failed to create a JES run", ex) + FailedExecution(ex) + case Success(run) => pollJesRun(run, backendCall, jesOutputs) + } + } + + private def handleSuccess(outputMappings: Map[String, Try[WdlValue]], workflowDescriptor: WorkflowDescriptor, returnCode: Int): ExecutionResult = { val taskOutputEvaluationFailures = outputMappings filter { _._2.isFailure } - if (taskOutputEvaluationFailures.isEmpty) { + val outputValues = if (taskOutputEvaluationFailures.isEmpty) { Success(outputMappings collect { case (name, Success(wdlValue)) => name -> wdlValue }) } else { val message = taskOutputEvaluationFailures collect { case (name, Failure(e)) => s"$name: $e" } mkString "\n" Failure(new Throwable(s"Workflow ${workflowDescriptor.id}: $message")) } + + outputValues match { + case Success(outputs) => SuccessfulExecution(outputs, returnCode) + case Failure(e) => FailedExecution(e) + } } override def handleCallRestarts(restartableWorkflows: Seq[RestartableWorkflow])(implicit ec: ExecutionContext): Future[Any] = Future("FIXME") @@ -329,8 +376,6 @@ class JesBackend extends Backend with LazyLogging { descriptor.workflowOptions.getOrElse("google_project", JesConf.project) } - - // Create an input parameter containing the path to this authentication file def gcsAuthParameter(descriptor: WorkflowDescriptor): Option[JesInput] = { if (JesConf.authMode == RefreshTokenMode || JesConf.dockerCredentials.isDefined) { diff --git a/src/main/scala/cromwell/engine/backend/jes/JesBackendCall.scala b/src/main/scala/cromwell/engine/backend/jes/JesBackendCall.scala index c60d97188f..70d5568111 100644 --- a/src/main/scala/cromwell/engine/backend/jes/JesBackendCall.scala +++ b/src/main/scala/cromwell/engine/backend/jes/JesBackendCall.scala @@ -1,20 +1,55 @@ package cromwell.engine.backend.jes +import java.nio.file.Paths + import cromwell.binding._ -import cromwell.engine.backend.{BackendCall, ExecutionResult} +import cromwell.binding.values.WdlFile +import cromwell.engine.backend.jes.JesBackend.JesOutput +import cromwell.engine.backend.{StdoutStderr, BackendCall, ExecutionResult} import cromwell.engine.workflow.CallKey import cromwell.engine.{AbortRegistrationFunction, WorkflowDescriptor} import cromwell.util.google.GoogleCloudStoragePath +import scala.util.Try + + +object JesBackendCall { + + def stdoutStderr(callGcsPath: String): StdoutStderr = { + StdoutStderr( + stdout = WdlFile(s"$callGcsPath/$StdoutFilename"), + stderr = WdlFile(s"$callGcsPath/$StderrFilename") + ) + } + + val StdoutFilename = "job.stdout.txt" + val StderrFilename = "job.stderr.txt" + val RcFilename = "job.rc.txt" + private def jesOutput(callGcsPath: String, filename: String): JesOutput = JesOutput(filename, s"$callGcsPath/$filename", Paths.get(filename)) +} + case class JesBackendCall(backend: JesBackend, workflowDescriptor: WorkflowDescriptor, key: CallKey, locallyQualifiedInputs: CallInputs, callAbortRegistrationFunction: AbortRegistrationFunction) extends BackendCall { + + import JesBackendCall._ + + def jesCommandLine = s"/bin/bash exec.sh > $StdoutFilename 2> $StderrFilename" + val callGcsPath = backend.callGcsPath(workflowDescriptor, call.name, key.index) val callDir = GoogleCloudStoragePath(callGcsPath) val gcsExecPath = GoogleCloudStoragePath(callGcsPath + "/exec.sh") val jesConnection = backend.JesConnection val engineFunctions = new JesEngineFunctions(this) + + lazy val stderrJesOutput = jesOutput(callGcsPath, StderrFilename) + lazy val stdoutJesOutput = jesOutput(callGcsPath, StdoutFilename) + lazy val rcJesOutput = jesOutput(callGcsPath, RcFilename) + + def standardParameters = Seq(stderrJesOutput, stdoutJesOutput, rcJesOutput) + def rcGcsPath = rcJesOutput.gcs def execute: ExecutionResult = backend.execute(this) + def downloadRcFile: Try[String] = GoogleCloudStoragePath.parse(callGcsPath + "/" + RcFilename).map(jesConnection.storage.slurpFile) } diff --git a/src/main/scala/cromwell/engine/backend/jes/JesEngineFunctions.scala b/src/main/scala/cromwell/engine/backend/jes/JesEngineFunctions.scala index 58d072192e..3a3d51877d 100644 --- a/src/main/scala/cromwell/engine/backend/jes/JesEngineFunctions.scala +++ b/src/main/scala/cromwell/engine/backend/jes/JesEngineFunctions.scala @@ -1,25 +1,19 @@ package cromwell.engine.backend.jes import cromwell.binding.expression.WdlStandardLibraryFunctions -import cromwell.binding.types.{WdlArrayType, WdlStringType} +import cromwell.binding.types.{WdlObjectType, WdlArrayType, WdlStringType} import cromwell.binding.values._ -import cromwell.util.google.GoogleCloudStoragePath +import cromwell.util.google.{GoogleCloudStorage, GoogleCloudStoragePath} import scala.util.{Failure, Success, Try} -/** - * Implementation of WDL standard library functions for the JES backend. - */ -case class JesEngineFunctions(jesBackendCall: JesBackendCall) extends WdlStandardLibraryFunctions { +class JesEngineFunctionsWithoutCallContext(storage: GoogleCloudStorage) extends WdlStandardLibraryFunctions { - private def readFromPath(value: String): String = { - val gcsPath = GoogleCloudStoragePath.parse(value) match { - case Success(path) => path - case Failure(ex) => GoogleCloudStoragePath(jesBackendCall.callDir + s"/$value") - } - jesBackendCall.jesConnection.storage.slurpFile(gcsPath) + protected def readFromPath(value: String): String = { + // .get here because engine functions should throw exception if they fail. Evaluator will catch it + storage.slurpFile(GoogleCloudStoragePath.parse(value).get) } - + /** * Read the entire contents of a file from the specified `WdlValue`, where the file can be * specified either as a path via a `WdlString` (with magical handling of "stdout"), or @@ -36,16 +30,9 @@ case class JesEngineFunctions(jesBackendCall: JesBackendCall) extends WdlStandar } } - override protected def stdout(params: Seq[Try[WdlValue]]): Try[WdlFile] = { - val newPath = GoogleCloudStoragePath(jesBackendCall.callDir.bucket, jesBackendCall.callDir.objectName + "/" + JesBackend.LocalStdoutValue) - Success(WdlFile(newPath.toString)) - } - - override protected def stderr(params: Seq[Try[WdlValue]]): Try[WdlFile] = { - val newPath = GoogleCloudStoragePath(jesBackendCall.callDir.bucket, jesBackendCall.callDir.objectName + "/" + JesBackend.LocalStderrValue) - Success(WdlFile(newPath.toString)) - } - + /** + * Read all lines from the file referenced by the first parameter + */ override protected def read_lines(params: Seq[Try[WdlValue]]): Try[WdlArray] = { for { singleArgument <- extractSingleArgument(params) @@ -53,18 +40,29 @@ case class JesEngineFunctions(jesBackendCall: JesBackendCall) extends WdlStandar } yield WdlArray(WdlArrayType(WdlStringType), lines) } - /** - * Try to read an integer from the file referenced by the specified `WdlValue`. - */ - override protected def read_int(params: Seq[Try[WdlValue]]): Try[WdlInteger] = { - read_string(params) map { s => WdlInteger(s.value.trim.toInt) } + override protected def read_map(params: Seq[Try[WdlValue]]): Try[WdlMap] = { + for { + singleArgument <- extractSingleArgument(params) + contents <- Success(fileContentsToString(singleArgument)) + wdlMap <- WdlMap.fromTsv(contents) + } yield wdlMap } - /** - * Try to read an float from the file referenced by the specified `WdlValue`. - */ - override protected def read_float(params: Seq[Try[WdlValue]]): Try[WdlFloat] = { - read_string(params) map { s => WdlFloat(s.value.trim.toDouble) } + private def extractObjectOrArray(params: Seq[Try[WdlValue]]) = for { + singleArgument <- extractSingleArgument(params) + contents <- Success(fileContentsToString(singleArgument)) + wdlObjects <- WdlObject.fromTsv(contents) + } yield wdlObjects + + override protected def read_object(params: Seq[Try[WdlValue]]): Try[WdlObject] = { + extractObjectOrArray(params) map { + case array if array.length == 1 => array.head + case _ => throw new IllegalArgumentException("read_object yields an Object and thus can only read 2-rows TSV files. Try using read_objects instead.") + } + } + + override def read_objects(params: Seq[Try[WdlValue]]): Try[WdlArray] = { + extractObjectOrArray(params) map { WdlArray(WdlArrayType(WdlObjectType), _) } } /** @@ -76,4 +74,29 @@ case class JesEngineFunctions(jesBackendCall: JesBackendCall) extends WdlStandar string = fileContentsToString(singleArgument) } yield WdlString(string.trim) } + +} + +/** + * Implementation of WDL standard library functions for the JES backend. + */ +class JesEngineFunctions(jesBackendCall: JesBackendCall) extends JesEngineFunctionsWithoutCallContext(jesBackendCall.jesConnection.storage) { + + override def readFromPath(value: String): String = { + val gcsPath = GoogleCloudStoragePath.parse(value) match { + case Success(path) => path + case Failure(ex) => GoogleCloudStoragePath(jesBackendCall.callDir + s"/$value") + } + jesBackendCall.jesConnection.storage.slurpFile(gcsPath) + } + + override protected def stdout(params: Seq[Try[WdlValue]]): Try[WdlFile] = { + val newPath = GoogleCloudStoragePath(jesBackendCall.stdoutJesOutput.gcs) + Success(WdlFile(newPath.toString)) + } + + override protected def stderr(params: Seq[Try[WdlValue]]): Try[WdlFile] = { + val newPath = GoogleCloudStoragePath(jesBackendCall.stderrJesOutput.gcs) + Success(WdlFile(newPath.toString)) + } } diff --git a/src/main/scala/cromwell/engine/backend/jes/Pipeline.scala b/src/main/scala/cromwell/engine/backend/jes/Pipeline.scala index dfff6f7559..96233cecf6 100644 --- a/src/main/scala/cromwell/engine/backend/jes/Pipeline.scala +++ b/src/main/scala/cromwell/engine/backend/jes/Pipeline.scala @@ -3,7 +3,6 @@ package cromwell.engine.backend.jes import com.google.api.services.genomics.Genomics import com.google.api.services.genomics.model.CreatePipelineRequest import com.typesafe.scalalogging.LazyLogging -import cromwell.binding.Call import cromwell.engine.WorkflowDescriptor import cromwell.engine.backend.jes.JesBackend._ import cromwell.engine.workflow.CallKey @@ -35,8 +34,8 @@ object Pipeline extends LazyLogging { pipelineId, projectId, gcsPath, - workflow, - call, + workflow, + key, jesParameters, runtimeInfo, jesConnection.genomics) @@ -49,7 +48,7 @@ case class Pipeline(command: String, projectId: String, gcsPath: String, workflow: WorkflowDescriptor, - call: Call, + key: CallKey, jesParameters: Seq[JesParameter], runtimeInfo: JesRuntimeInfo, genomicsService: Genomics) { diff --git a/src/main/scala/cromwell/engine/backend/jes/Run.scala b/src/main/scala/cromwell/engine/backend/jes/Run.scala index 4c2a1835f1..9583a3236f 100644 --- a/src/main/scala/cromwell/engine/backend/jes/Run.scala +++ b/src/main/scala/cromwell/engine/backend/jes/Run.scala @@ -4,25 +4,28 @@ import com.google.api.services.genomics.model.{CancelOperationRequest, Logging, import com.typesafe.config.ConfigFactory import cromwell.engine.backend.jes.JesBackend.JesParameter import cromwell.engine.backend.jes.Run.{Failed, Running, Success, _} +import cromwell.engine.db.DataAccess._ import cromwell.engine.db.{JesCallBackendInfo, JesId, JesStatus} -import cromwell.engine.{ExecutionStatus, db} +import cromwell.engine.workflow.CallKey +import cromwell.util.TryUtil import cromwell.util.google.GoogleScopes import org.slf4j.LoggerFactory -import scala.annotation.tailrec import scala.collection.JavaConverters._ +import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.Failure object Run { val JesServiceAccount = new ServiceAccount().setEmail("default").setScopes(GoogleScopes.Scopes.asJava) lazy val Log = LoggerFactory.getLogger("main") - lazy val maximumPollingInterval = ConfigFactory.load.getConfig("backend").getConfig("jes").getInt("maximumPollingInterval") * 1000 - val initialPollingInterval = 500 - val pollingBackoffFactor = 1.1 + lazy val MaximumPollingInterval = Duration(ConfigFactory.load.getConfig("backend").getConfig("jes").getInt("maximumPollingInterval"), "seconds") + val InitialPollingInterval = 5 seconds + val PollingBackoffFactor = 1.1 def apply(pipeline: Pipeline): Run = { val rpr = new RunPipelineRequest().setPipelineId(pipeline.id).setProjectId(pipeline.projectId).setServiceAccount(JesServiceAccount) - val tag = s"JES Run [UUID(${pipeline.workflow.shortId}):${pipeline.call.name}]" + val tag = s"JES Run [UUID(${pipeline.workflow.shortId}):${pipeline.key.scope.name}]" rpr.setInputs(pipeline.jesParameters.filter(_.isInput).toRunMap) Log.info(s"$tag Inputs:\n${stringifyMap(rpr.getInputs.asScala.toMap)}") @@ -61,40 +64,12 @@ object Run { // Don't want to include errorMessage or code in the snappy status toString: override def toString = "Failed" } - - @tailrec - private final def waitForStatus(run: Run, previousStatus: Option[RunStatus], pollingInterval: Double, breakout: RunStatus => Boolean): RunStatus = { - val currentStatus = run.status() - - if (!(previousStatus contains currentStatus)) { - // If this is the first time checking the status, we log the transition as '-' to 'currentStatus'. Otherwise - // just use the state names. - val prevStateName = previousStatus map { _.toString } getOrElse "-" - Log.info(s"${run.tag}: Status change from $prevStateName to $currentStatus") - - // Update the database state: - val newBackendInfo = JesCallBackendInfo(db.CallStatus(ExecutionStatus.Running.toString), Option(JesId(run.jesId)), Option(JesStatus(currentStatus.toString))) - // TODO: Re-add - //globalDataAccess.updateExecutionBackendInfo(run.workflowId, run.call, newBackendInfo) - } - - if (breakout(currentStatus)) { - currentStatus - } else { - Thread.sleep(pollingInterval.toInt) - waitForStatus(run, Option(currentStatus), nextPollingInterval(pollingInterval, maximumPollingInterval), breakout) - } - } - - private final def nextPollingInterval(previousPollingInterval: Double, maximumPollingInterval: Int): Double = { - Math.min(previousPollingInterval * pollingBackoffFactor, maximumPollingInterval) - } } case class Run(jesId: String, pipeline: Pipeline, tag: String) { lazy val workflowId = pipeline.workflow.id - lazy val call = pipeline.call + lazy val call = pipeline.key.scope def status(): RunStatus = { val op = pipeline.genomicsService.operations().get(jesId).execute @@ -109,19 +84,53 @@ case class Run(jesId: String, pipeline: Pipeline, tag: String) { } } + private final def waitForStatus(previousStatus: Option[RunStatus], breakout: RunStatus => Boolean): RunStatus = { + + def checkStatus(previousStatus: Option[RunStatus]): RunStatus = { + val currentStatus = status() + + if (!(previousStatus contains currentStatus)) { + // If this is the first time checking the status, we log the transition as '-' to 'currentStatus'. Otherwise + // just use the state names. + val prevStateName = previousStatus map { _.toString } getOrElse "-" + Log.info(s"$tag: Status change from $prevStateName to $currentStatus") + + // Update the database state: + val newBackendInfo = JesCallBackendInfo(Option(JesId(jesId)), Option(JesStatus(currentStatus.toString))) + globalDataAccess.updateExecutionBackendInfo(workflowId, CallKey(call, pipeline.key.index), newBackendInfo) + } + + currentStatus + } + + val attemptedStatus = TryUtil.retryBlock( + fn = checkStatus, + isSuccess = breakout, + retries = None, + pollingInterval = InitialPollingInterval, + pollingBackOffFactor = PollingBackoffFactor, + maxPollingInterval = MaximumPollingInterval, + priorValue = previousStatus + ) + + attemptedStatus match { + case util.Success(x) => x + case Failure(_) => Failed(-1, "Unexpectedly stopped checking status.") // Assuming TryUtil.retryBlock works, this should not happen + } + } + final def waitUntilComplete(previousStatus: RunStatus): TerminalRunStatus = { - val terminalStatus = Run.waitForStatus(this, Option(previousStatus), initialPollingInterval, { + val terminalStatus = waitForStatus(Option(previousStatus), { case x: TerminalRunStatus => true case _ => false }) - terminalStatus match { case x: TerminalRunStatus => x - case _ => Failed(-1, "Unexpectedly stopped checking status.") // Assuming waitForStatus works, this never happens. + case _ => Failed(-1, "Unexpectedly stopped checking status") } } - final def waitUntilRunningOrComplete: RunStatus = Run.waitForStatus(this, None, initialPollingInterval, { + final def waitUntilRunningOrComplete: RunStatus = waitForStatus(None, { case Running => true case x: TerminalRunStatus => true case _ => false diff --git a/src/main/scala/cromwell/engine/backend/local/LocalBackend.scala b/src/main/scala/cromwell/engine/backend/local/LocalBackend.scala index 11c4de5208..a96b9e2ea3 100644 --- a/src/main/scala/cromwell/engine/backend/local/LocalBackend.scala +++ b/src/main/scala/cromwell/engine/backend/local/LocalBackend.scala @@ -1,6 +1,6 @@ package cromwell.engine.backend.local -import java.io.{FileWriter, BufferedWriter, Writer} +import java.io.Writer import java.nio.file.{Files, Path, Paths} import com.typesafe.scalalogging.LazyLogging @@ -10,13 +10,11 @@ import cromwell.engine._ import cromwell.engine.backend.Backend.RestartableWorkflow import cromwell.engine.backend._ import cromwell.engine.db.DataAccess._ -import cromwell.engine.db.{CallStatus, DataAccess, ExecutionDatabaseKey} -import cromwell.engine.workflow.{CallKey, WorkflowOptions} +import cromwell.engine.db.{CallStatus, ExecutionDatabaseKey} +import cromwell.engine.workflow.CallKey import cromwell.parser.BackendType import cromwell.util.FileUtil._ -import cromwell.util.TailedWriter -import scala.collection.immutable.Queue import scala.concurrent.{ExecutionContext, Future} import scala.language.postfixOps import scala.sys.process._ @@ -176,9 +174,9 @@ class LocalBackend extends Backend with SharedFileSystem with LazyLogging { s"Workflow ${backendCall.workflowDescriptor.id}: stderr has length $stderrFileLength")) } else { - def processSuccess() = { + def processSuccess(rc: Int) = { postProcess(backendCall) match { - case Success(outputs) => SuccessfulExecution(outputs) + case Success(outputs) => SuccessfulExecution(outputs, rc) case Failure(e) => val message = Option(e.getMessage) map { ": " + _ } getOrElse "" FailedExecution(new Throwable("Failed post processing of outputs" + message, e)) @@ -196,7 +194,7 @@ class LocalBackend extends Backend with SharedFileSystem with LazyLogging { val continueOnReturnCode = backendCall.call.continueOnReturnCode returnCode match { case Success(143) => AbortedExecution // Special case to check for SIGTERM exit code - implying abort - case Success(otherReturnCode) if continueOnReturnCode.continueFor(otherReturnCode) => processSuccess() + case Success(otherReturnCode) if continueOnReturnCode.continueFor(otherReturnCode) => processSuccess(otherReturnCode) case Success(badReturnCode) => FailedExecution(new Exception(badReturnCodeMessage), Option(badReturnCode)) case Failure(e) => FailedExecution(new Throwable(badReturnCodeMessage, e)) } diff --git a/src/main/scala/cromwell/engine/backend/local/LocalEngineFunctions.scala b/src/main/scala/cromwell/engine/backend/local/LocalEngineFunctions.scala index 1c775cf342..dbf181b2a6 100644 --- a/src/main/scala/cromwell/engine/backend/local/LocalEngineFunctions.scala +++ b/src/main/scala/cromwell/engine/backend/local/LocalEngineFunctions.scala @@ -2,15 +2,69 @@ package cromwell.engine.backend.local import java.io.File import java.nio.file.{Path, Paths} + import cromwell.binding.expression.WdlStandardLibraryFunctions -import cromwell.binding.types.{WdlArrayType, WdlFileType, WdlMapType, WdlStringType} +import cromwell.binding.types.WdlArrayType._ +import cromwell.binding.types._ import cromwell.binding.values._ import cromwell.util.FileUtil import cromwell.util.FileUtil.{EnhancedFile, EnhancedPath} import scala.util.{Failure, Success, Try} -class LocalEngineFunctions(cwd: Path, stdout: Path, stderr: Path) extends WdlStandardLibraryFunctions { +class LocalEngineFunctionsWithoutCallContext extends WdlStandardLibraryFunctions { + protected def fileContentsToString(value: WdlValue): String = { + value match { + case f: WdlFile => new File(f.value).slurp + case s: WdlString => Paths.get(s.value).slurp + case e => throw new UnsupportedOperationException("Unsupported argument " + e) + } + } + + override protected def read_lines(params: Seq[Try[WdlValue]]): Try[WdlArray] = { + for { + singleArgument <- extractSingleArgument(params) + lines = fileContentsToString(singleArgument).split("\n").map{WdlString} + } yield WdlArray(WdlArrayType(WdlStringType), lines) + } + + override protected def read_map(params: Seq[Try[WdlValue]]): Try[WdlMap] = { + for { + singleArgument <- extractSingleArgument(params) + contents <- Success(fileContentsToString(singleArgument)) + wdlMap <- WdlMap.fromTsv(contents) + } yield wdlMap + } + + private def extractObjectArray(params: Seq[Try[WdlValue]]): Try[Array[WdlObject]] = for { + singleArgument <- extractSingleArgument(params) + contents <- Success(fileContentsToString(singleArgument)) + wdlObjects <- WdlObject.fromTsv(contents) + } yield wdlObjects + + override protected def read_object(params: Seq[Try[WdlValue]]): Try[WdlObject] = { + extractObjectArray(params) map { + case array if array.length == 1 => array.head + case _ => throw new IllegalArgumentException("read_object yields an Object and thus can only read 2-rows TSV files. Try using read_objects instead.") + } + } + + override def read_objects(params: Seq[Try[WdlValue]]): Try[WdlArray] = { + extractObjectArray(params) map { WdlArray(WdlArrayType(WdlObjectType), _) } + } + + /** + * Try to read a string from the file referenced by the specified `WdlValue`. + */ + override protected def read_string(params: Seq[Try[WdlValue]]): Try[WdlString] = { + for { + singleArgument <- extractSingleArgument(params) + string = fileContentsToString(singleArgument) + } yield WdlString(string.stripSuffix("\n")) + } +} + +class LocalEngineFunctions(cwd: Path, stdout: Path, stderr: Path) extends LocalEngineFunctionsWithoutCallContext { /** * Read the entire contents of a file from the specified `WdlValue`, where the file can be @@ -20,7 +74,7 @@ class LocalEngineFunctions(cwd: Path, stdout: Path, stderr: Path) extends WdlSta * @throws UnsupportedOperationException for an unrecognized file reference, as this is intended * to be wrapped in a `Try`. */ - private def fileContentsToString(value: WdlValue): String = { + override def fileContentsToString(value: WdlValue): String = { value match { case f: WdlFile => new File(f.value).slurp case s: WdlString => cwd.resolve(s.value).slurp @@ -44,57 +98,51 @@ class LocalEngineFunctions(cwd: Path, stdout: Path, stderr: Path) extends WdlSta } } - override protected def read_lines(params: Seq[Try[WdlValue]]): Try[WdlArray] = { + override protected def write_lines(params: Seq[Try[WdlValue]]): Try[WdlFile] = { for { singleArgument <- extractSingleArgument(params) - lines = fileContentsToString(singleArgument).split("\n").map{WdlString} - } yield WdlArray(WdlArrayType(WdlStringType), lines) + if singleArgument.wdlType.isInstanceOf[WdlArrayType] + tsvSerialized <- singleArgument.asInstanceOf[WdlArray].tsvSerialize + file <- writeContent("array", tsvSerialized) + } yield file } - override protected def read_map(params: Seq[Try[WdlValue]]): Try[WdlMap] = { + override protected def write_map(params: Seq[Try[WdlValue]]): Try[WdlFile] = { for { singleArgument <- extractSingleArgument(params) - if singleArgument.wdlType == WdlFileType - contents <- Success(Paths.get(singleArgument.asInstanceOf[WdlFile].valueString).slurp) - wdlMap <- WdlMap.fromTsv(contents) - } yield wdlMap + if singleArgument.wdlType.isInstanceOf[WdlMapType] + tsvSerialized <- singleArgument.asInstanceOf[WdlMap].tsvSerialize + file <- writeContent("map", tsvSerialized) + } yield file } - /** - * Try to read an integer from the file referenced by the specified `WdlValue`. - */ - override protected def read_int(params: Seq[Try[WdlValue]]): Try[WdlInteger] = - read_string(params).map { s => WdlInteger(s.value.trim.toInt) } - - /** - * Try to read a string from the file referenced by the specified `WdlValue`. - */ - override protected def read_string(params: Seq[Try[WdlValue]]): Try[WdlString] = { + override protected def write_object(params: Seq[Try[WdlValue]]): Try[WdlFile] = { for { singleArgument <- extractSingleArgument(params) - string = fileContentsToString(singleArgument) - } yield WdlString(string.stripSuffix("\n")) + if singleArgument.wdlType == WdlObjectType + tsvSerialized <- singleArgument.asInstanceOf[WdlObject].tsvSerialize + file <- writeContent("object", tsvSerialized) + } yield file } - override protected def write_lines(params: Seq[Try[WdlValue]]): Try[WdlFile] = { + override protected def write_objects(params: Seq[Try[WdlValue]]): Try[WdlFile] = { for { singleArgument <- extractSingleArgument(params) - if singleArgument.wdlType.isInstanceOf[WdlArrayType] + if singleArgument.wdlType.isAnArrayOf(WdlObjectType) tsvSerialized <- singleArgument.asInstanceOf[WdlArray].tsvSerialize - (path, writer) = FileUtil.tempFileAndWriter("array", cwd.toFile) - _ <- Try(writer.write(tsvSerialized)) - _ <- Success(writer.close()) - } yield WdlFile(path.toAbsolutePath.toString) + file <- writeContent("array", tsvSerialized) + } yield file } - override protected def write_map(params: Seq[Try[WdlValue]]): Try[WdlFile] = { - for { - singleArgument <- extractSingleArgument(params) - if singleArgument.wdlType.isInstanceOf[WdlMapType] - tsvSerialized <- singleArgument.asInstanceOf[WdlMap].tsvSerialize - (path, writer) = FileUtil.tempFileAndWriter("map", cwd.toFile) - _ <- Try(writer.write(tsvSerialized)) - _ <- Success(writer.close()) - } yield WdlFile(path.toAbsolutePath.toString) + protected def writeContent(baseName: String, content: String) = { + val (path, writer) = FileUtil.tempFileAndWriter("array", cwd.toFile) + try { + writer.write(content) + Success(WdlFile(path.toAbsolutePath.toString)) + } catch { + case t: Throwable => Failure(t) + } finally { + writer.close() + } } } diff --git a/src/main/scala/cromwell/engine/backend/local/SharedFileSystem.scala b/src/main/scala/cromwell/engine/backend/local/SharedFileSystem.scala index 29a2b1d201..cdb85a85d7 100644 --- a/src/main/scala/cromwell/engine/backend/local/SharedFileSystem.scala +++ b/src/main/scala/cromwell/engine/backend/local/SharedFileSystem.scala @@ -6,12 +6,12 @@ import java.security.MessageDigest import com.typesafe.config.ConfigFactory import cromwell.binding._ -import cromwell.binding.types.{WdlArrayType, WdlFileType, WdlMapType} +import cromwell.binding.expression.WdlStandardLibraryFunctions +import cromwell.binding.types.WdlFileType import cromwell.binding.values.{WdlValue, _} import cromwell.engine.ExecutionIndex._ import cromwell.engine.WorkflowDescriptor import cromwell.engine.backend.{LocalFileSystemBackendCall, StdoutStderr} -import cromwell.engine.db.DataAccess import org.apache.commons.io.FileUtils import scala.collection.JavaConverters._ @@ -67,6 +67,8 @@ trait SharedFileSystem { import SharedFileSystem._ + val engineFunctions: WdlStandardLibraryFunctions = new LocalEngineFunctionsWithoutCallContext + def postProcess(backendCall: LocalFileSystemBackendCall): Try[CallOutputs] = { // Evaluate output expressions, performing conversions from String -> File where required. val outputMappings = backendCall.call.task.outputs map { taskOutput => @@ -232,15 +234,9 @@ trait SharedFileSystem { */ private def outputAutoConversion(backendCall: LocalFileSystemBackendCall, taskOutput: TaskOutput, rawOutputValue: WdlValue): Try[WdlValue] = { rawOutputValue match { - case v: WdlString if taskOutput.wdlType == WdlFileType => assertTaskOutputPathExists(hostAbsoluteFilePath(backendCall, v.value), taskOutput, backendCall.call.fullyQualifiedName) - case m: WdlMap if taskOutput.wdlType.isInstanceOf[WdlMapType] => taskOutput.wdlType.coerceRawValue(m) - case a: WdlArray if taskOutput.wdlType.isInstanceOf[WdlArrayType] => taskOutput.wdlType.coerceRawValue(a) - case v if v.wdlType == taskOutput.wdlType => Success(v) - case _ => Failure(new RuntimeException( - s"""Error processing '${backendCall.call.fullyQualifiedName}.${taskOutput.name}': - | - |Value $rawOutputValue cannot be converted to ${taskOutput.wdlType.toWdlString} - """.stripMargin)) + case rhs if rhs.wdlType == taskOutput.wdlType => Success(rhs) + case rhs: WdlString if taskOutput.wdlType == WdlFileType => assertTaskOutputPathExists(hostAbsoluteFilePath(backendCall, rhs.value), taskOutput, backendCall.call.fullyQualifiedName) + case rhs => taskOutput.wdlType.coerceRawValue(rhs) } } diff --git a/src/main/scala/cromwell/engine/backend/sge/SgeBackend.scala b/src/main/scala/cromwell/engine/backend/sge/SgeBackend.scala index c88b53df60..6dd22768ab 100644 --- a/src/main/scala/cromwell/engine/backend/sge/SgeBackend.scala +++ b/src/main/scala/cromwell/engine/backend/sge/SgeBackend.scala @@ -7,8 +7,9 @@ import cromwell.binding.{Call, CallInputs} import cromwell.engine.backend.Backend.RestartableWorkflow import cromwell.engine.backend._ import cromwell.engine.backend.local.{LocalBackend, SharedFileSystem} -import cromwell.engine.db.{CallStatus, SgeCallBackendInfo} -import cromwell.engine.workflow.{CallKey, WorkflowOptions} +import cromwell.engine.db.DataAccess._ +import cromwell.engine.db.SgeCallBackendInfo +import cromwell.engine.workflow.CallKey import cromwell.engine.{AbortRegistrationFunction, _} import cromwell.parser.BackendType import cromwell.util.FileUtil._ @@ -64,7 +65,7 @@ class SgeBackend extends Backend with SharedFileSystem with LazyLogging { private def statusString(result: ExecutionResult): String = (result match { case AbortedExecution => ExecutionStatus.Aborted case FailedExecution(_, _) => ExecutionStatus.Failed - case SuccessfulExecution(_) => ExecutionStatus.Done + case SuccessfulExecution(_, _) => ExecutionStatus.Done }).toString private def recordDatabaseFailure(call: Call, status: String, rc: Int): PartialFunction[Throwable, Unit] = { @@ -72,10 +73,8 @@ class SgeBackend extends Backend with SharedFileSystem with LazyLogging { } private def updateSgeJobTable(call: BackendCall, status: String, rc: Option[Int], sgeJobId: Option[Int]): Future[Unit] = { - val backendInfo = SgeCallBackendInfo(CallStatus(status, rc), sgeJobId) - // TODO: re-add - //globalDataAccess.updateExecutionBackendInfo(call.workflowDescriptor.id, call.call, backendInfo) - Future.successful(()) + val backendInfo = SgeCallBackendInfo(sgeJobId) + globalDataAccess.updateExecutionBackendInfo(call.workflowDescriptor.id, CallKey(call.call, call.key.index), backendInfo) } // TODO: Not much thought was given to this function @@ -175,9 +174,9 @@ class SgeBackend extends Backend with SharedFileSystem with LazyLogging { FailedExecution(new Exception(s"$tag SGE job failed because of return code: $r"), Option(r)) case (_, stderrLength) if stderrLength > 0 && backendCall.call.failOnStderr => FailedExecution(new Throwable(s"$tag SGE job failed because there were $stderrLength bytes on standard error"), Option(0)) - case (_, _) => + case (r, _) => postProcess(backendCall) match { - case Success(callOutputs) => SuccessfulExecution(callOutputs) + case Success(callOutputs) => SuccessfulExecution(callOutputs, r) case Failure(e) => FailedExecution(e) } } diff --git a/src/main/scala/cromwell/engine/db/CallBackendInfo.scala b/src/main/scala/cromwell/engine/db/CallBackendInfo.scala index dea0e323b2..c004c5992e 100644 --- a/src/main/scala/cromwell/engine/db/CallBackendInfo.scala +++ b/src/main/scala/cromwell/engine/db/CallBackendInfo.scala @@ -1,11 +1,9 @@ package cromwell.engine.db -sealed trait CallBackendInfo { - val status: CallStatus -} +sealed trait CallBackendInfo -final case class LocalCallBackendInfo(status: CallStatus, processId: Option[Int]) extends CallBackendInfo +final case class LocalCallBackendInfo(processId: Option[Int]) extends CallBackendInfo -final case class JesCallBackendInfo(status: CallStatus, jesId: Option[JesId], jesStatus: Option[JesStatus]) extends CallBackendInfo +final case class JesCallBackendInfo(jesId: Option[JesId], jesStatus: Option[JesStatus]) extends CallBackendInfo -final case class SgeCallBackendInfo(status: CallStatus, sgeJobNumber: Option[Int]) extends CallBackendInfo +final case class SgeCallBackendInfo(sgeJobNumber: Option[Int]) extends CallBackendInfo diff --git a/src/main/scala/cromwell/engine/db/DataAccess.scala b/src/main/scala/cromwell/engine/db/DataAccess.scala index a95458eab1..5364998bbf 100644 --- a/src/main/scala/cromwell/engine/db/DataAccess.scala +++ b/src/main/scala/cromwell/engine/db/DataAccess.scala @@ -5,7 +5,7 @@ import cromwell.binding.values.WdlValue import cromwell.engine.ExecutionStatus.ExecutionStatus import cromwell.engine.backend.Backend import cromwell.engine.db.slick._ -import cromwell.engine.workflow.{ExecutionStoreKey, OutputKey} +import cromwell.engine.workflow.{CallKey, ExecutionStoreKey, OutputKey} import cromwell.engine.{SymbolStoreEntry, WorkflowDescriptor, WorkflowId, WorkflowState} import scala.concurrent.Future @@ -32,7 +32,7 @@ trait DataAccess { def getExecutionBackendInfo(workflowId: WorkflowId, call: Call): Future[CallBackendInfo] - def updateExecutionBackendInfo(workflowId: WorkflowId, call: Call, backendInfo: CallBackendInfo): Future[Unit] + def updateExecutionBackendInfo(workflowId: WorkflowId, callKey: CallKey, backendInfo: CallBackendInfo): Future[Unit] def updateWorkflowState(workflowId: WorkflowId, workflowState: WorkflowState): Future[Unit] diff --git a/src/main/scala/cromwell/engine/db/slick/DataAccessComponent.scala b/src/main/scala/cromwell/engine/db/slick/DataAccessComponent.scala index 0f5fc89995..5d858effbb 100644 --- a/src/main/scala/cromwell/engine/db/slick/DataAccessComponent.scala +++ b/src/main/scala/cromwell/engine/db/slick/DataAccessComponent.scala @@ -1,10 +1,7 @@ package cromwell.engine.db.slick -import cromwell.engine.WorkflowId import slick.driver.JdbcProfile -import scala.reflect.runtime._ - class DataAccessComponent(val driver: JdbcProfile) extends DriverComponent with WorkflowExecutionComponent @@ -17,10 +14,6 @@ class DataAccessComponent(val driver: JdbcProfile) import driver.api._ - def this(driverName: String) { - this(DataAccessComponent.getObject[JdbcProfile](driverName)) - } - lazy val schema = workflowExecutions.schema ++ workflowExecutionAuxes.schema ++ @@ -29,44 +22,4 @@ class DataAccessComponent(val driver: JdbcProfile) localJobs.schema ++ jesJobs.schema ++ sgeJobs.schema - - def jesJobInfo(id: WorkflowId): Query[(Executions, JesJobs), (Execution, JesJob), Seq] = for { - workflowExecution <- workflowExecutions - execution <- executions - jesJob <- jesJobs - if workflowExecution.workflowExecutionUuid === id.toString - if execution.workflowExecutionId === workflowExecution.workflowExecutionId - if jesJob.executionId === execution.executionId - } yield (execution, jesJob) - - def localJobInfo(id: WorkflowId) = for { - workflowExecution <- workflowExecutions - execution <- executions - localJob <- localJobs - if workflowExecution.workflowExecutionUuid === id.toString - if execution.workflowExecutionId === workflowExecution.workflowExecutionId - if localJob.executionId === execution.executionId - } yield (execution, localJob) - - def sgeJobInfo(id: WorkflowId) = for { - workflowExecution <- workflowExecutions - execution <- executions - sgeJob <- sgeJobs - if workflowExecution.workflowExecutionUuid === id.toString - if execution.workflowExecutionId === workflowExecution.workflowExecutionId - if sgeJob.executionId === execution.executionId - } yield (execution, sgeJob) -} - -object DataAccessComponent { - // TODO: move to DSDE common util? - private def getObject[T](objectName: String): T = { - // via - // http://stackoverflow.com/questions/23466782/scala-object-get-reference-from-string-in-scala-2-10 - // https://github.com/anvie/slick-test/blob/045f4db610d3b91bf928a53f2bc7b6ae17c35985/slick-util/src/main/scala/scala/slick/codegen/ModelGenerator.scala - val staticModule = currentMirror.staticModule(objectName) - val reflectModule = currentMirror.reflectModule(staticModule) - val instance = reflectModule.instance - instance.asInstanceOf[T] - } } diff --git a/src/main/scala/cromwell/engine/db/slick/DatabaseConfig.scala b/src/main/scala/cromwell/engine/db/slick/DatabaseConfig.scala deleted file mode 100644 index a80623f927..0000000000 --- a/src/main/scala/cromwell/engine/db/slick/DatabaseConfig.scala +++ /dev/null @@ -1,11 +0,0 @@ -package cromwell.engine.db.slick - -import com.typesafe.config.ConfigFactory -import slick.util.ConfigExtensionMethods._ - -object DatabaseConfig { - private val config = ConfigFactory.load() - lazy val rootDatabaseConfig = config.getConfig("database") - private val databaseConfigName = rootDatabaseConfig.getStringOpt("config") - lazy val databaseConfig = databaseConfigName.map(rootDatabaseConfig.getConfig).getOrElse(rootDatabaseConfig) -} diff --git a/src/main/scala/cromwell/engine/db/slick/ExecutionComponent.scala b/src/main/scala/cromwell/engine/db/slick/ExecutionComponent.scala index b23cb0428b..bcb47a013b 100644 --- a/src/main/scala/cromwell/engine/db/slick/ExecutionComponent.scala +++ b/src/main/scala/cromwell/engine/db/slick/ExecutionComponent.scala @@ -66,6 +66,15 @@ trait ExecutionComponent { if execution.callFqn === callFqn } yield (execution.callFqn, execution.index, execution.status, execution.rc)) + val executionsByWorkflowExecutionUuidAndCallFqnAndShardIndex = Compiled( + (workflowExecutionUuid: Rep[String], callFqn: Rep[String], index: Rep[Int]) => for { + execution <- executions + if execution.callFqn === callFqn + if execution.index === index + workflowExecution <- execution.workflowExecution + if workflowExecution.workflowExecutionUuid === workflowExecutionUuid + } yield execution) + val executionsByWorkflowExecutionUuidAndCallFqn = Compiled( (workflowExecutionUuid: Rep[String], callFqn: Rep[String]) => for { execution <- executions diff --git a/src/main/scala/cromwell/engine/db/slick/JesJobComponent.scala b/src/main/scala/cromwell/engine/db/slick/JesJobComponent.scala index 20f767ba10..848e562555 100644 --- a/src/main/scala/cromwell/engine/db/slick/JesJobComponent.scala +++ b/src/main/scala/cromwell/engine/db/slick/JesJobComponent.scala @@ -9,7 +9,7 @@ case class JesJob ) trait JesJobComponent { - this: DriverComponent with ExecutionComponent => + this: DriverComponent with ExecutionComponent with WorkflowExecutionComponent => import driver.api._ @@ -45,4 +45,11 @@ trait JesJobComponent { jesJob <- jesJobs if jesJob.executionId === executionId } yield (jesJob.jesId, jesJob.jesStatus)) + + val jesJobsWithExecutionsByWorkflowExecutionUuid = Compiled( + (workflowExecutionUuid: Rep[String]) => for { + workflowExecution <- workflowExecutions if workflowExecution.workflowExecutionUuid === workflowExecutionUuid + execution <- executions if execution.workflowExecutionId === workflowExecution.workflowExecutionId + jesJob <- jesJobs if jesJob.executionId === execution.executionId + } yield (execution, jesJob)) } diff --git a/src/main/scala/cromwell/engine/db/slick/LocalJobComponent.scala b/src/main/scala/cromwell/engine/db/slick/LocalJobComponent.scala index f614c40a43..85eb1e5e60 100644 --- a/src/main/scala/cromwell/engine/db/slick/LocalJobComponent.scala +++ b/src/main/scala/cromwell/engine/db/slick/LocalJobComponent.scala @@ -8,7 +8,7 @@ case class LocalJob ) trait LocalJobComponent { - this: DriverComponent with ExecutionComponent => + this: DriverComponent with ExecutionComponent with WorkflowExecutionComponent => import driver.api._ @@ -40,4 +40,14 @@ trait LocalJobComponent { localJob <- localJobs if localJob.executionId === executionId } yield localJob.pid) + + val localJobsWithExecutionsByWorkflowExecutionUuid = Compiled( + (workflowExecutionUuid: Rep[String]) => for { + workflowExecution <- workflowExecutions + if workflowExecution.workflowExecutionUuid === workflowExecutionUuid + execution <- executions + if execution.workflowExecutionId === workflowExecution.workflowExecutionId + localJob <- localJobs + if localJob.executionId === execution.executionId + } yield (execution, localJob)) } diff --git a/src/main/scala/cromwell/engine/db/slick/RunMysql.scala b/src/main/scala/cromwell/engine/db/slick/RunMysql.scala deleted file mode 100644 index fe798d9615..0000000000 --- a/src/main/scala/cromwell/engine/db/slick/RunMysql.scala +++ /dev/null @@ -1,203 +0,0 @@ -package cromwell.engine.db.slick - -import java.sql.{Connection, DriverManager} - -import com.typesafe.config.{ConfigFactory, ConfigRenderOptions} -import slick.jdbc.SimpleJdbcAction - -import scala.collection.JavaConverters._ -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.language.postfixOps - -object RunMysql { - Class.forName("com.mysql.jdbc.Driver").newInstance() - - val user: String = getConfig("user") - val password: String = getConfig("password") - val url: String = getConfig("url") - val urlWithSsl: String = withSSL(url) - val urlWithoutSsl: String = stripSSL(url) - val configWithoutConnection = DatabaseConfig.databaseConfig. - withoutPath("url"). - withoutPath("driver"). - withoutPath("dataSourceClass"). - withoutPath("properties.url"). - withoutPath("properties.useSSL"). - withoutPath("properties.requireSSL") - - var successes: Seq[String] = Seq.empty - var failures: Seq[String] = Seq.empty - - // Hiding this main() from sbt, but leaving class for future debugging. - //def main(args: Array[String]) = runMain() - - def runMain() = { - println("user: " + user) - println("password: " + ("*" * password.length)) - println("url: " + url) - println("urlWithSsl: " + urlWithSsl) - println("urlWithoutSsl: " + urlWithoutSsl) - - tryRun("jdbcMain", jdbcMain()) - tryRun("slickMain", slickMain()) - tryRun("jdbcRequireSsl", jdbcRequireSSL()) - - tryRun("dataSourceRaw", dataSourceTest(urlWithoutSsl, Option(false))) - tryRun("dataSourceRawProp", dataSourceTest(urlWithoutSsl, Option(true))) - tryRun("dataSourceRawNoProp", dataSourceTest(urlWithoutSsl, None)) - - tryRun("dataSourceSslUrl", dataSourceTest(urlWithSsl, Option(false))) - tryRun("dataSourceSslUrlProp", dataSourceTest(urlWithSsl, Option(true))) - tryRun("dataSourceSslNoProp", dataSourceTest(urlWithSsl, None)) - - tryRun("slickRaw", slickTest(urlWithoutSsl, sslProp = Option(false))) - tryRun("slickRawProp", slickTest(urlWithoutSsl, sslProp = Option(true))) - tryRun("slickRawNoProp", slickTest(urlWithoutSsl, sslProp = None)) - tryRun("slickRawDriver", slickTest(urlWithoutSsl, dataSource = false)) - - tryRun("slickSslUrl", slickTest(urlWithSsl, sslProp = Option(false))) - tryRun("slickSslUrlProp", slickTest(urlWithSsl, sslProp = Option(true))) - tryRun("slickSslUrlNoProp", slickTest(urlWithSsl, sslProp = None)) - tryRun("slickSslDriver", slickTest(urlWithSsl, dataSource = false)) - - println("Successes: " + successes.size) - successes.foreach(s => println(" " + s)) - println("Failures: " + failures.size) - failures.foreach(f => println(" " + f)) - } - - // Tests - - def jdbcMain() = { - val connectionString = s"$url?useSSL=true" - println(s"Testing connection '$connectionString'") - val connection = DriverManager.getConnection(connectionString, user, password) - testConnection(connection) - } - - def jdbcRequireSSL() = { - val connectionString = s"$url?useSSL=true&requireSSL=true" - println(s"Testing connection '$connectionString'") - val connection = DriverManager.getConnection(connectionString, user, password) - testConnection(connection) - } - - def slickMain(): Unit = { - val dataAccess = new SlickDataAccess() - val connectionFuture = dataAccess.database.run(SimpleJdbcAction(context => testConnection(context.connection))) - Await.result(connectionFuture, 10 seconds) - } - - def dataSourceMain(): Unit = { - val ds = new com.mysql.jdbc.jdbc2.optional.MysqlDataSource - ds.setURL(url) - ds.setUseSSL(true) - ds.setRequireSSL(true) - testConnection(ds.getConnection(user, password)) - } - - // Datasource permutations - - def dataSourceTest(testUrl: String, useSsl: Option[Boolean]): Unit = { - val ds = new com.mysql.jdbc.jdbc2.optional.MysqlDataSource - ds.setURL(testUrl) - useSsl foreach ds.setUseSSL - useSsl foreach ds.setRequireSSL - testConnection(ds.getConnection(user, password)) - } - - // Slick tests - - def slickTest(urlVal: String, - dataSource: Boolean = true, - sslProp: Option[Boolean] = None): Unit = { - - var map: Map[String, AnyRef] = Map.empty - - if (dataSource) { - map += "dataSourceClass" -> "com.mysql.jdbc.jdbc2.optional.MysqlDataSource" - map += "properties.url" -> urlVal - sslProp foreach { useSsl => - map += "properties.useSSL" -> Boolean.box(useSsl) - map += "properties.requireSSL" -> Boolean.box(useSsl) - } - } else { - map += "driver" -> "com.mysql.jdbc.Driver" - map += "url" -> urlVal - } - - val config = ConfigFactory.parseMap(map.asJava, "debug slick options").withFallback(configWithoutConnection) - println("config:") - println(config.root().render(ConfigRenderOptions.defaults())) - val dataAccess = new SlickDataAccess(config) - val connectionFuture = dataAccess.database.run(SimpleJdbcAction(context => testConnection(context.connection))) - Await.result(connectionFuture, 10 seconds) - } - - // Utilities - - private def getConfig(key: String) = - getConfigOpt("properties." + key) orElse getConfigOpt(key) getOrElse "" - - private def getConfigOpt(key: String): Option[String] = { - if (DatabaseConfig.databaseConfig.hasPath(key)) { - Option(DatabaseConfig.databaseConfig.getString(key)) - } else { - None - } - } - - private def tryRun(label: String, test: => Unit): Unit = { - try { - println("Running: " + label) - test - println("Success: " + label) - successes :+= label - } catch { - case e: Exception => - failures :+= label - println("Failed: " + label) - e.printStackTrace() - } - println() - } - - private def testConnection(connection: Connection) = { - try { - println("Connection url: " + connection.getMetaData.getURL) - println("Connection valid? " + connection.isValid(5)) - val resultSet = connection.createStatement().executeQuery("select count(*) from WORKFLOW_EXECUTION") - val first = resultSet.next() - if (!first) throw new RuntimeException("No first row!") - val workflowExecutions = resultSet.getInt(1) - println(s"Found $workflowExecutions workflow executions.") - } finally{ - try { - connection.close() - } catch { - case e: Exception => - /* ignore */ - } - } - } - - private def stripSSL(url: String): String = { - var newUrl = url - newUrl = newUrl.replaceAll( """\?useSSL=true\&""", "?") - newUrl = newUrl.replaceAll( """\?useSSL=true$""", "") - newUrl = newUrl.replaceAll( """\&useSSL=true""", "") - newUrl = newUrl.replaceAll( """\?requireSSL=true\&""", "?") - newUrl = newUrl.replaceAll( """\?requireSSL=true$""", "") - newUrl = newUrl.replaceAll( """\&requireSSL=true""", "") - newUrl - } - - private def withSSL(url: String): String = { - val newUrl = stripSSL(url) - if (newUrl.contains("?")) - newUrl + "&useSSL=true&requireSSL=true" - else - newUrl + "?useSSL=true&requireSSL=true" - } -} diff --git a/src/main/scala/cromwell/engine/db/slick/SgeJobComponent.scala b/src/main/scala/cromwell/engine/db/slick/SgeJobComponent.scala index 07a7b4881e..190f4f39a5 100644 --- a/src/main/scala/cromwell/engine/db/slick/SgeJobComponent.scala +++ b/src/main/scala/cromwell/engine/db/slick/SgeJobComponent.scala @@ -5,7 +5,7 @@ case class SgeJob (executionId: Int, sgeJobId: Option[Int] = None) trait SgeJobComponent { - this: DriverComponent with ExecutionComponent => + this: DriverComponent with ExecutionComponent with WorkflowExecutionComponent => import driver.api._ @@ -33,4 +33,13 @@ trait SgeJobComponent { if sgeJob.executionId === executionId } yield sgeJob.sgeJobNumber) + val sgeJobsWithExecutionsByWorkflowExecutionUuid = Compiled( + (workflowExecutionUuid: Rep[String]) => for { + workflowExecution <- workflowExecutions + if workflowExecution.workflowExecutionUuid === workflowExecutionUuid + execution <- executions + if execution.workflowExecutionId === workflowExecution.workflowExecutionId + sgeJob <- sgeJobs + if sgeJob.executionId === execution.executionId + } yield (execution, sgeJob)) } diff --git a/src/main/scala/cromwell/engine/db/slick/SlickDataAccess.scala b/src/main/scala/cromwell/engine/db/slick/SlickDataAccess.scala index 32670a1cff..1a3712fa8f 100644 --- a/src/main/scala/cromwell/engine/db/slick/SlickDataAccess.scala +++ b/src/main/scala/cromwell/engine/db/slick/SlickDataAccess.scala @@ -4,7 +4,8 @@ import java.sql.{Clob, Timestamp} import java.util.{Date, UUID} import javax.sql.rowset.serial.SerialClob -import _root_.slick.util.ConfigExtensionMethods._ +import _root_.slick.backend.DatabaseConfig +import _root_.slick.driver.JdbcProfile import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} import cromwell.binding._ import cromwell.binding.types.{WdlPrimitiveType, WdlType} @@ -17,6 +18,7 @@ import cromwell.engine.backend.local.LocalBackend import cromwell.engine.backend.sge.SgeBackend import cromwell.engine.db._ import cromwell.engine.workflow.{CallKey, ExecutionStoreKey, OutputKey, ScatterKey} +import lenthall.config.ScalaConfig._ import org.slf4j.LoggerFactory import scala.concurrent.duration.Duration @@ -28,6 +30,28 @@ object SlickDataAccess { val IoInput = "INPUT" val IoOutput = "OUTPUT" + lazy val rootConfig = ConfigFactory.load() + /* + VERY TEMPORARY! + Turns out, Slick has a way to load databases from configs: DatabaseConfig + http://slick.typesafe.com/doc/3.0.0/database.html?highlight=databaseconfig#databaseconfig + + To switch over to this config format, we need to rename: + + 1. Property name "driver" renamed to "db.driver" + 2. Property name "slick.driver" renamed to "driver" + 3. Property value with the slick driver needs to append "$" + + To make sure the code continues to run during this switch, the application.conf's have been updated ahead of time + with the temporary "databaseSlickDriverConfigSwitch" configuration. + */ + private lazy val rootDatabaseConfig = rootConfig.getConfig( + if (rootConfig.hasPath("databaseSlickDriverConfigSwitch")) "databaseSlickDriverConfigSwitch" else "database") + private lazy val databaseConfigName = rootDatabaseConfig.getStringOption("config") + lazy val defaultDatabaseConfig = databaseConfigName.map(getDatabaseConfig).getOrElse(rootDatabaseConfig) + + def getDatabaseConfig(path: String) = rootDatabaseConfig.getConfig(path) + implicit class DateToTimestamp(val date: Date) extends AnyVal { def toTimestamp = new Timestamp(date.getTime) } @@ -44,7 +68,7 @@ object SlickDataAccess { /** * Returns either the "url" or "properties.url" */ - def urlKey = if (config.hasPath("url")) "url" else "properties.url" + def urlKey = if (config.hasPath("db.url")) "db.url" else "db.properties.url" /** * Returns the value of either the "url" or "properties.url" @@ -85,25 +109,25 @@ object SlickDataAccess { * if the list is empty. In every use case as of the writing of this comment, * those exceptions would have been wrapped in a failed Future and returned. */ -class SlickDataAccess(databaseConfig: Config, val dataAccess: DataAccessComponent) extends DataAccess { +class SlickDataAccess(databaseConfig: Config) extends DataAccess { - def this(databaseConfig: Config) = this( - databaseConfig, - new DataAccessComponent(databaseConfig.getString("slick.driver"))) + import SlickDataAccess._ + + def this() = this(SlickDataAccess.defaultDatabaseConfig) + + private val configWithUniqueSchema = this.databaseConfig.withUniqueSchema - def this() = this(DatabaseConfig.databaseConfig) + val slickConfig = DatabaseConfig.forConfig[JdbcProfile]("", configWithUniqueSchema) + val dataAccess = new DataAccessComponent(slickConfig.driver) // NOTE: Used for slick flatMap. May switch to custom ExecutionContext the future private implicit val executionContext = ExecutionContext.global - import SlickDataAccess._ - // Allows creation of a Database, plus implicits for running transactions import dataAccess.driver.api._ // NOTE: if you want to refactor database is inner-class type: this.dataAccess.driver.backend.DatabaseFactory - private val configWithUniqueSchema = databaseConfig.withUniqueSchema - val database = Database.forConfig("", configWithUniqueSchema) + val database = slickConfig.db // Possibly create the database { @@ -124,7 +148,7 @@ class SlickDataAccess(databaseConfig: Config, val dataAccess: DataAccessComponen // generate unique schema instances that don't conflict. // // Otherwise, create one DataAccess and hold on to the reference. - if (databaseConfig.getBooleanOr("slick.createSchema")) { + if (this.databaseConfig.getBooleanOr("slick.createSchema")) { Await.result(database.run(dataAccess.schema.create), Duration.Inf) } } @@ -157,8 +181,8 @@ class SlickDataAccess(databaseConfig: Config, val dataAccess: DataAccessComponen backend: Backend): Future[Unit] = { val scopeKeys: Traversable[ExecutionStoreKey] = scopes collect { - case call: Call => CallKey(call, None, None) - case scatter: Scatter => ScatterKey(scatter, None, None) + case call: Call => CallKey(call, None) + case scatter: Scatter => ScatterKey(scatter, None) } val action = for { @@ -378,18 +402,12 @@ class SlickDataAccess(databaseConfig: Config, val dataAccess: DataAccessComponen jobResultOption = localJobResultOption orElse jesJobResultOption orElse sgeJobResultOption backendInfo = jobResultOption match { case Some(localJobResult: LocalJob) => - new LocalCallBackendInfo( - CallStatus(executionResult.status, executionResult.rc), - localJobResult.pid) + new LocalCallBackendInfo(localJobResult.pid) case Some(jesJobResult: JesJob) => - new JesCallBackendInfo( - CallStatus(executionResult.status, executionResult.rc), - jesJobResult.jesId map JesId, + new JesCallBackendInfo(jesJobResult.jesId map JesId, jesJobResult.jesStatus map JesStatus) case Some(sgeJobResult: SgeJob) => - new SgeCallBackendInfo( - CallStatus(executionResult.status, executionResult.rc), - sgeJobResult.sgeJobNumber) + new SgeCallBackendInfo(sgeJobResult.sgeJobNumber) case _ => throw new IllegalArgumentException( s"Unknown backend from db for (uuid, fqn): " + @@ -401,25 +419,15 @@ class SlickDataAccess(databaseConfig: Config, val dataAccess: DataAccessComponen runTransaction(action) } - // TODO it's confusing that CallBackendInfo has a CallStatus in it when that information doesn't go to the - // backend info tables. But this method does use the CallStatus data from the CallBackendInfo to update the - // Execution table. override def updateExecutionBackendInfo(workflowId: WorkflowId, - call: Call, + callKey: CallKey, backendInfo: CallBackendInfo): Future[Unit] = { require(backendInfo != null, "backend info is null") - val callStatus = backendInfo.status + import ExecutionIndex._ val action = for { - executionResult <- dataAccess.executionsByWorkflowExecutionUuidAndCallFqn( - workflowId.toString, call.fullyQualifiedName).result.head - - executionStatusQuery = dataAccess.executionStatusesAndReturnCodesByExecutionId( - executionResult.executionId.get) - - executionUpdate <- executionStatusQuery.update(callStatus.executionStatus.toString, callStatus.returnCode) - - _ = require(executionUpdate == 1, s"Unexpected execution update count $executionUpdate") + executionResult <- dataAccess.executionsByWorkflowExecutionUuidAndCallFqnAndShardIndex( + workflowId.toString, callKey.scope.fullyQualifiedName, callKey.index.fromIndex).result.head backendUpdate <- backendInfo match { case localBackendInfo: LocalCallBackendInfo => @@ -447,6 +455,9 @@ class SlickDataAccess(databaseConfig: Config, val dataAccess: DataAccessComponen runTransaction(action) } + private def toSymbolStoreEntries(symbolResults: Traversable[Symbol]) = + symbolResults map toSymbolStoreEntry + private def toSymbolStoreEntry(symbolResult: Symbol) = { val wdlType = WdlType.fromWdlString(symbolResult.wdlType) new SymbolStoreEntry( @@ -463,21 +474,19 @@ class SlickDataAccess(databaseConfig: Config, val dataAccess: DataAccessComponen override def getFullyQualifiedName(workflowId: WorkflowId, fqn: FullyQualifiedName): Future[Traversable[SymbolStoreEntry]] = { val Array(scope, varName) = fqn.split("\\.(?=[^\\.]+$)") // e.g. "a.b.c.d" => Seq("a.b.c", "d") - val action = for { - symbolResults <- dataAccess.symbolsByScopeAndName(workflowId.toString, scope, varName).result - symbolStoreEntries = symbolResults map toSymbolStoreEntry - } yield symbolStoreEntries + val action = dataAccess.symbolsByScopeAndName(workflowId.toString, scope, varName).result - runTransaction(action) + val futureResults = runTransaction(action) + + futureResults map toSymbolStoreEntries } override def getAll(workflowId: WorkflowId): Future[Traversable[SymbolStoreEntry]] = { - val action = for { - symbolResults <- dataAccess.allSymbols(workflowId.toString).result - symbolStoreEntries = symbolResults map toSymbolStoreEntry - } yield symbolStoreEntries + val action = dataAccess.allSymbols(workflowId.toString).result - runTransaction(action) + val futureResults = runTransaction(action) + + futureResults map toSymbolStoreEntries } /** Get all inputs for the scope of this key. */ @@ -499,14 +508,13 @@ class SlickDataAccess(databaseConfig: Config, val dataAccess: DataAccessComponen private def getSymbols(workflowId: WorkflowId, ioValue: IoValue, callFqnOption: Option[FullyQualifiedName] = None, callIndexOption: Option[Int] = None): Future[Traversable[SymbolStoreEntry]] = { - val action = for { - symbolResults <- dataAccess.symbolsByWorkflowExecutionUuidAndIoAndMaybeScope( - workflowId.toString, ioValue, callFqnOption, callIndexOption - ).result - symbolStoreEntries = symbolResults map toSymbolStoreEntry - } yield symbolStoreEntries + val action = dataAccess.symbolsByWorkflowExecutionUuidAndIoAndMaybeScope( + workflowId.toString, ioValue, callFqnOption, callIndexOption + ).result - runTransaction(action) + val futureResults = runTransaction(action) + + futureResults map toSymbolStoreEntries } /** Should fail if a value is already set. The keys in the Map are locally qualified names. */ @@ -550,50 +558,42 @@ class SlickDataAccess(databaseConfig: Config, val dataAccess: DataAccessComponen } override def getExecutions(id: WorkflowId): Future[Traversable[Execution]] = { - val action = for { - executions <- dataAccess.executionsByWorkflowExecutionUuid(id.toString).result - } yield executions + val action = dataAccess.executionsByWorkflowExecutionUuid(id.toString).result runTransaction(action) } override def getWorkflowExecution(workflowId: WorkflowId): Future[WorkflowExecution] = { - val action = for { - workflowExecution <- dataAccess.workflowExecutionsByWorkflowExecutionUuid(workflowId.toString).result.headOption - } yield workflowExecution + val action = dataAccess.workflowExecutionsByWorkflowExecutionUuid(workflowId.toString).result.headOption runTransaction(action) map { _.getOrElse(throw new NoSuchElementException(s"Workflow $workflowId not found.")) } } override def getWorkflowExecutionAux(id: WorkflowId): Future[WorkflowExecutionAux] = { - val action = for { - workflowExecutionAux <- dataAccess.workflowExecutionAuxesByWorkflowExecutionUuid(id.toString).result.headOption - } yield workflowExecutionAux + val action = dataAccess.workflowExecutionAuxesByWorkflowExecutionUuid(id.toString).result.headOption runTransaction(action) map { _.getOrElse(throw new NoSuchElementException(s"No workflow execution aux found for ID '$id'.")) } } override def getAllInputs(workflowId: WorkflowId): Future[Traversable[SymbolStoreEntry]] = { - val action = for { - inputs <- dataAccess.symbolsByWorkflowExecutionUuidAndIo(workflowId.toString, IoInput).result - symbolStoreInputs = inputs map toSymbolStoreEntry - } yield symbolStoreInputs + val action = dataAccess.symbolsByWorkflowExecutionUuidAndIo(workflowId.toString, IoInput).result - runTransaction(action) + val futureResults = runTransaction(action) + + futureResults map toSymbolStoreEntries } override def getAllOutputs(workflowId: WorkflowId): Future[Traversable[SymbolStoreEntry]] = { - val action = for { - outputs <- dataAccess.symbolsByWorkflowExecutionUuidAndIo(workflowId.toString, IoOutput).result - symbolStoreOutputs = outputs map toSymbolStoreEntry - } yield symbolStoreOutputs + val action = dataAccess.symbolsByWorkflowExecutionUuidAndIo(workflowId.toString, IoOutput).result - runTransaction(action) + val futureResults = runTransaction(action) + + futureResults map toSymbolStoreEntries } override def jesJobInfo(id: WorkflowId): Future[Map[ExecutionDatabaseKey, JesJob]] = { val action = for { - executionAndJob <- dataAccess.jesJobInfo(id).result + executionAndJob <- dataAccess.jesJobsWithExecutionsByWorkflowExecutionUuid(id.toString).result } yield executionAndJob val futureResults = runTransaction(action) @@ -606,7 +606,7 @@ class SlickDataAccess(databaseConfig: Config, val dataAccess: DataAccessComponen override def localJobInfo(id: WorkflowId): Future[Map[ExecutionDatabaseKey, LocalJob]] = { val action = for { - executionAndJob <- dataAccess.localJobInfo(id).result + executionAndJob <- dataAccess.localJobsWithExecutionsByWorkflowExecutionUuid(id.toString).result } yield executionAndJob val futureResults = runTransaction(action) @@ -619,7 +619,7 @@ class SlickDataAccess(databaseConfig: Config, val dataAccess: DataAccessComponen override def sgeJobInfo(id: WorkflowId): Future[Map[ExecutionDatabaseKey, SgeJob]] = { val action = for { - executionAndJob <- dataAccess.sgeJobInfo(id).result + executionAndJob <- dataAccess.sgeJobsWithExecutionsByWorkflowExecutionUuid(id.toString).result } yield executionAndJob val futureResults = runTransaction(action) @@ -639,5 +639,4 @@ class SlickDataAccess(databaseConfig: Config, val dataAccess: DataAccessComponen runTransaction(action) } - } diff --git a/src/main/scala/cromwell/engine/package.scala b/src/main/scala/cromwell/engine/package.scala index e1be7c480b..fc567b0f76 100644 --- a/src/main/scala/cromwell/engine/package.scala +++ b/src/main/scala/cromwell/engine/package.scala @@ -11,7 +11,7 @@ import cromwell.engine.workflow.WorkflowOptions import spray.json._ import scala.language.implicitConversions -import scala.util.{Success, Try, Failure} +import scala.util.{Failure, Success, Try} /** * ==Cromwell Execution Engine== @@ -44,12 +44,12 @@ package object engine { case Failure(ex) => throw ex } - val backendType = Backend.from(workflowOptions.getOrElse("default_backend", ConfigFactory.load.getConfig("backend").getString("backend"))) - val namespace = NamespaceWithWorkflow.load(sourceFiles.wdlSource, backendType.backendType) + val backend = Backend.from(workflowOptions.getOrElse("default_backend", ConfigFactory.load.getConfig("backend").getString("backend"))) + val namespace = NamespaceWithWorkflow.load(sourceFiles.wdlSource, backend.backendType) val name = namespace.workflow.name val shortId = id.toString.split("-")(0) - backendType.assertWorkflowOptions(workflowOptions) + backend.assertWorkflowOptions(workflowOptions) val rawInputs = Try(sourceFiles.inputsJson.parseJson) match { case Success(JsObject(inputs)) => inputs @@ -58,7 +58,7 @@ package object engine { // Currently we are throwing an exception if construction of the workflow descriptor fails, hence .get on the Trys val coercedInputs = namespace.coerceRawInputs(rawInputs).get - val declarations = namespace.staticDeclarationsRecursive(coercedInputs).get + val declarations = namespace.staticDeclarationsRecursive(coercedInputs, backend.engineFunctions).get val actualInputs: WorkflowCoercedInputs = coercedInputs ++ declarations } diff --git a/src/main/scala/cromwell/engine/workflow/CallMetadataBuilder.scala b/src/main/scala/cromwell/engine/workflow/CallMetadataBuilder.scala index 2002e25458..bf30def24d 100644 --- a/src/main/scala/cromwell/engine/workflow/CallMetadataBuilder.scala +++ b/src/main/scala/cromwell/engine/workflow/CallMetadataBuilder.scala @@ -37,8 +37,8 @@ object CallMetadataBuilder { def extract(job: Any): BackendValues = { job match { case ji: LocalJob => BackendValues("Local") - case ji: JesJob => BackendValues("JES", jobId = Option(ji.jesId.toString), status = ji.jesStatus) - case ji: SgeJob => BackendValues("SGE", jobId = Option(ji.sgeJobNumber.toString)) + case ji: JesJob => BackendValues("JES", jobId = ji.jesId, status = ji.jesStatus) + case ji: SgeJob => BackendValues("SGE", jobId = ji.sgeJobNumber map { _.toString }) } } } @@ -163,11 +163,6 @@ object CallMetadataBuilder { val executionMap = executionMapTransformers.foldLeft(Map.empty: ExecutionMap) { case (map, transformer) => map ++ transformer(map) } - def symbolToMapEntry(symbol: Symbol) = { - val clob = symbol.wdlValue.get - symbol.name -> clob.getSubString(1, clob.length().toInt) - } - // Convert from the convenience AssembledCallMetadata format to the CallMetadata format // that the endpoint needs to serve up. def constructCallMetadata(metadata: AssembledCallMetadata): CallMetadata = { diff --git a/src/main/scala/cromwell/engine/workflow/ExecutionStoreKey.scala b/src/main/scala/cromwell/engine/workflow/ExecutionStoreKey.scala index abf1fb2858..1bcf7977f9 100644 --- a/src/main/scala/cromwell/engine/workflow/ExecutionStoreKey.scala +++ b/src/main/scala/cromwell/engine/workflow/ExecutionStoreKey.scala @@ -9,17 +9,16 @@ import scala.language.postfixOps sealed trait ExecutionStoreKey { def scope: Scope def index: Option[Int] - def parent: Option[ExecutionStoreKey] } trait OutputKey extends ExecutionStoreKey -case class CallKey(scope: Call, index: Option[Int], parent: Option[ExecutionStoreKey]) extends OutputKey -case class CollectorKey(scope: Call, parent: Option[ExecutionStoreKey]) extends OutputKey { +case class CallKey(scope: Call, index: Option[Int]) extends OutputKey +case class CollectorKey(scope: Call) extends OutputKey { override val index: Option[Int] = None } -case class ScatterKey(scope: Scatter, index: Option[Int], parent: Option[ExecutionStoreKey]) extends ExecutionStoreKey { +case class ScatterKey(scope: Scatter, index: Option[Int]) extends ExecutionStoreKey { /** * Creates a sub-ExecutionStore with Starting entries for each of the scoped children. @@ -35,8 +34,8 @@ case class ScatterKey(scope: Scatter, index: Option[Int], parent: Option[Executi val parent = Option(this) scope match { case call: Call => - val shards = (0 until count) map { i => CallKey(call, Option(i), parent) } - shards :+ CollectorKey(call, parent) + val shards = (0 until count) map { i => CallKey(call, Option(i)) } + shards :+ CollectorKey(call) case scatter: Scatter => throw new UnsupportedOperationException("Nested Scatters are not supported (yet).") case e => diff --git a/src/main/scala/cromwell/engine/workflow/ValidateActor.scala b/src/main/scala/cromwell/engine/workflow/ValidateActor.scala index a513d0dab1..ee70449869 100644 --- a/src/main/scala/cromwell/engine/workflow/ValidateActor.scala +++ b/src/main/scala/cromwell/engine/workflow/ValidateActor.scala @@ -43,7 +43,6 @@ class ValidateActor(wdlSource: WdlSource, wdlJson: WdlJson) namespaceWithWorkflow <- Future(NamespaceWithWorkflow.load(wdlSource, WorkflowManagerActor.BackendType)) inputs <- Future(wdlJson.parseJson).map(_.asJsObject.fields) coercedInputs <- Future.fromTry(namespaceWithWorkflow.coerceRawInputs(inputs)) - _ <- Future.fromTry(namespaceWithWorkflow.staticDeclarationsRecursive(coercedInputs)) } yield () // Validate that the future run and return `Success[Unit]` aka (), or `Failure[Exception]` futureValidation onComplete { diff --git a/src/main/scala/cromwell/engine/workflow/WorkflowActor.scala b/src/main/scala/cromwell/engine/workflow/WorkflowActor.scala index 916a34d1ab..cb0c957ea2 100644 --- a/src/main/scala/cromwell/engine/workflow/WorkflowActor.scala +++ b/src/main/scala/cromwell/engine/workflow/WorkflowActor.scala @@ -23,16 +23,35 @@ import scala.util.{Failure, Success, Try} object WorkflowActor { sealed trait WorkflowActorMessage - case object Start extends WorkflowActorMessage - case object Restart extends WorkflowActorMessage case object Complete extends WorkflowActorMessage case object GetFailureMessage extends WorkflowActorMessage case object AbortWorkflow extends WorkflowActorMessage case class AbortComplete(call: OutputKey) extends WorkflowActorMessage case class CallStarted(call: OutputKey) extends WorkflowActorMessage - case class CallCompleted(call: OutputKey, callOutputs: CallOutputs) extends WorkflowActorMessage + case class CallCompleted(call: OutputKey, callOutputs: CallOutputs, returnCode: Int) extends WorkflowActorMessage case class CallFailed(call: OutputKey, returnCode: Option[Int], failure: String) extends WorkflowActorMessage case object Terminate extends WorkflowActorMessage + case class ExecutionStoreCreated(startMode: StartMode) extends WorkflowActorMessage + case class AsyncFailure(t: Throwable) extends WorkflowActorMessage + + sealed trait StartMode { + def runInitialization(actor: WorkflowActor): Future[Unit] + } + + case object Start extends WorkflowActorMessage with StartMode { + override def runInitialization(actor: WorkflowActor): Future[Unit] = { + // This only does the initialization for a newly created workflow. For a restarted workflow we should be able + // to assume the adjusted symbols already exist in the DB, but is it safe to assume the staged files are in place? + actor.initializeWorkflow match { + case Success(inputs) => actor.createWorkflow(inputs) + case Failure(ex) => Future.failed(ex) + } + } + } + + case object Restart extends WorkflowActorMessage with StartMode { + override def runInitialization(actor: WorkflowActor): Future[Unit] = Future.successful(()) + } def props(descriptor: WorkflowDescriptor, backend: Backend): Props = { Props(WorkflowActor(descriptor, backend)) @@ -61,20 +80,18 @@ object WorkflowActor { case class WorkflowActor(workflow: WorkflowDescriptor, backend: Backend) extends LoggingFSM[WorkflowState, WorkflowFailure] with CromwellActor { + + def createWorkflow(inputs: HostInputs): Future[Unit] = { + globalDataAccess.createWorkflow( + workflow, buildSymbolStoreEntries(workflow.namespace, inputs), workflow.namespace.workflow.children, backend) + } + private var executionStore: ExecutionStore = _ val tag: String = s"WorkflowActor [UUID(${workflow.shortId})]" override val log = Logging(context.system, classOf[WorkflowActor]) startWith(WorkflowSubmitted, NoFailureMessage) - def initWorkflow(initialization: Future[Unit] = Future.successful(())): ExecutionStore = { - val futureStore = for { - _ <- initialization - store <- createStore - } yield store - Await.result(futureStore, AkkaTimeout) - } - /** * Try to generate output for a collector call, by collecting outputs for all of its shards. * It's fail-fast on shard output retrieval @@ -108,24 +125,45 @@ case class WorkflowActor(workflow: WorkflowDescriptor, } } + private def initializeExecutionStore(initializationCode: Future[Unit], startMode: StartMode): Unit = { + val futureStore = for { + _ <- initializationCode + store <- createStore + } yield store + + futureStore onComplete { + case Success(store) => + executionStore = store + self ! ExecutionStoreCreated(startMode) + case Failure(t) => + self ! AsyncFailure(t) + } + } + + private def initializeWorkflow: Try[HostInputs] = backend.initializeForWorkflow(workflow) + when(WorkflowSubmitted) { - case Event(Restart, NoFailureMessage) => - executionStore = initWorkflow() - startRunnableCalls() - case Event(Start, NoFailureMessage) => - log.info(s"$tag Start message received") - executionStore = initWorkflow(createWorkflow()) - symbolsMarkdownTable foreach { table => log.info(s"Initial symbols:\n\n$table") } - executionsMarkdownTable foreach { table => log.info(s"Initial executions:\n\n$table") } + case Event(startMode: StartMode, NoFailureMessage) => + log.info(s"$tag $startMode message received") + initializeExecutionStore(startMode.runInitialization(this), startMode) + stay() + case Event(ExecutionStoreCreated(startMode), NoFailureMessage) => + if (startMode == Restart) { + symbolsMarkdownTable foreach { table => log.info(s"Initial symbols:\n\n$table") } + executionsMarkdownTable foreach { table => log.info(s"Initial executions:\n\n$table") } + } startRunnableCalls() + case Event(AsyncFailure(t), NoFailureMessage) => + log.error(t.getMessage, t) + goto(WorkflowFailed) } when(WorkflowRunning) { case Event(CallStarted(callKey), NoFailureMessage) => persistStatus(callKey, ExecutionStatus.Running) stay() - case Event(CallCompleted(callKey, outputs), NoFailureMessage) => - awaitCallComplete(callKey, outputs) match { + case Event(CallCompleted(callKey, outputs, returnCode), NoFailureMessage) => + awaitCallComplete(callKey, outputs, returnCode) match { case Success(_) => if (isWorkflowDone) goto(WorkflowSucceeded) else startRunnableCalls() case Failure(e) => @@ -165,8 +203,8 @@ case class WorkflowActor(workflow: WorkflowDescriptor, case Event(CallFailed(callKey, returnCode, failure), NoFailureMessage) => persistStatus(callKey, ExecutionStatus.Failed, returnCode) if (isWorkflowAborted) goto(WorkflowAborted) using NoFailureMessage else stay() - case Event(CallCompleted(callKey, outputs), NoFailureMessage) => - awaitCallComplete(callKey, outputs) + case Event(CallCompleted(callKey, outputs, returnCode), NoFailureMessage) => + awaitCallComplete(callKey, outputs, returnCode) if (isWorkflowAborted) goto(WorkflowAborted) using NoFailureMessage else stay() case m => log.error("Unexpected message in Aborting state: " + m.getClass.getSimpleName) @@ -190,50 +228,59 @@ case class WorkflowActor(workflow: WorkflowDescriptor, */ onTransition { case fromState -> toState => - log.info(s"$tag transitioning from $fromState to $toState.") - globalDataAccess.updateWorkflowState(workflow.id, toState) - /* - Send a message to self to trigger an actor shutdown. Run on a short timer to help enable some - unit test instrumentation - */ - if (toState.isTerminal) { - backend.cleanUpForWorkflow(workflow) - globalDataAccess.updateWorkflowOptions(workflow.id, workflow.workflowOptions.clearEncryptedValues) - setTimer(s"WorkflowActor termination message: $tag", Terminate, AkkaTimeout, DontRepeatTimer) + def handleTerminalWorkflow: Future[Unit] = { + for { + _ <- backend.cleanUpForWorkflow(workflow) + _ <- globalDataAccess.updateWorkflowOptions(workflow.id, workflow.workflowOptions.clearEncryptedValues) + // Send a message to self to trigger an actor shutdown. Run on a short timer to help enable some + // unit test instrumentation + _ = setTimer(s"WorkflowActor termination message: $tag", Terminate, AkkaTimeout, DontRepeatTimer) + } yield () } + + for { + // Write the new workflow state before logging the change, tests assume the change is in effect when + // the message is logged. + _ <- globalDataAccess.updateWorkflowState(workflow.id, toState) + _ = log.info(s"$tag transitioning from $fromState to $toState.") + _ <- if (toState.isTerminal) handleTerminalWorkflow else Future.successful({}) + } yield () } private def persistStatus(key: ExecutionStoreKey, status: ExecutionStatus, - returnCode: Option[Int] = None): Future[Unit] = { - persistStatuses(Iterable(key), status, returnCode) + returnCode: Option[Int] = None): Unit = { + Await.result(persistStatuses(Iterable(key), status, returnCode), Duration.Inf) } private def persistStatuses(key: Traversable[ExecutionStoreKey], executionStatus: ExecutionStatus, returnCode: Option[Int] = None): Future[Unit] = { - executionStore ++= key map { _ -> executionStatus } - key foreach { k => + val databaseKeys = key map { k => val indexLog = k.index.map(i => s" (shard $i)").getOrElse("") log.info(s"$tag persisting status of ${k.scope.fullyQualifiedName}$indexLog to $executionStatus.") + ExecutionDatabaseKey(k.scope.fullyQualifiedName, k.index) } - globalDataAccess.setStatus(workflow.id, key map { k => - ExecutionDatabaseKey(k.scope.fullyQualifiedName, k.index) - }, CallStatus(executionStatus, returnCode)) - globalDataAccess.setStatus(workflow.id, key map { k => ExecutionDatabaseKey(k.scope.fullyQualifiedName, k.index) }, CallStatus(executionStatus, returnCode)) + for { + // Write the status to the database before updating the store, the store is what is examined to + // determine workflow doneness and if that persisted workflow representation is not consistent, + // tests may see unexpected values. + _ <- globalDataAccess.setStatus(workflow.id, databaseKeys, CallStatus(executionStatus, returnCode)) + _ = executionStore ++= key map { _ -> executionStatus } + } yield () } - private def awaitCallComplete(key: OutputKey, outputs: CallOutputs): Try[Unit] = { - val callFuture = handleCallCompleted(key, outputs) + private def awaitCallComplete(key: OutputKey, outputs: CallOutputs, returnCode: Int): Try[Unit] = { + val callFuture = handleCallCompleted(key, outputs, returnCode) Await.ready(callFuture, AkkaTimeout) callFuture.value.get } - private def handleCallCompleted(key: OutputKey, outputs: CallOutputs): Future[Unit] = { + private def handleCallCompleted(key: OutputKey, outputs: CallOutputs, returnCode: Int): Future[Unit] = { log.info(s"$tag handling completion of call '${key.scope.fullyQualifiedName}'.") for { _ <- globalDataAccess.setOutputs(workflow.id, key, outputs) - _ <- persistStatus(key, ExecutionStatus.Done, Option(0)) + _ = persistStatus(key, ExecutionStatus.Done, Option(returnCode)) } yield() } @@ -489,10 +536,10 @@ case class WorkflowActor(workflow: WorkflowDescriptor, globalDataAccess.getExecutionStatuses(workflow.id) map { statuses => statuses map { case (k, v) => val key: ExecutionStoreKey = (workflow.namespace.resolve(k.fqn), k.index) match { - case (Some(c: Call), Some(i)) => CallKey(c, Some(i), None) - case (Some(c: Call), None) if isInScatterBlock(c) => CollectorKey(c, None) - case (Some(c: Call), None) => CallKey(c, None, None) - case (Some(s: Scatter), None) => ScatterKey(s, None, None) + case (Some(c: Call), Some(i)) => CallKey(c, Some(i)) + case (Some(c: Call), None) if isInScatterBlock(c) => CollectorKey(c) + case (Some(c: Call), None) => CallKey(c, None) + case (Some(s: Scatter), None) => ScatterKey(s, None) case _ => throw new UnsupportedOperationException(s"Execution entry invalid: $k -> $v") } key -> v.executionStatus @@ -511,17 +558,6 @@ case class WorkflowActor(workflow: WorkflowDescriptor, inputSymbols.toSet ++ callSymbols.toSet } - def createWorkflow(): Future[Unit] = { - val workflowDescriptor = WorkflowDescriptor(workflow.id, workflow.sourceFiles) - // This only does the initialization for a newly created workflow. For a restarted workflow we should be able - // to assume the adjusted symbols already exist in the DB, but is it safe to assume the staged files are in place? - backend.initializeForWorkflow(workflow) match { - case Success(inputs) => - globalDataAccess.createWorkflow(workflowDescriptor, buildSymbolStoreEntries(workflow.namespace, inputs), workflow.namespace.workflow.children, backend) - case Failure(ex) => Future.failed(ex) - } - } - /** * This is the lookup function used to evaluate scatter collection expressions. * @@ -554,11 +590,11 @@ case class WorkflowActor(workflow: WorkflowDescriptor, collection match { case Success(a: WdlArray) => Try { val newEntries = scatterKey.populate(a.value.size) + persistStatus(scatterKey, ExecutionStatus.Starting, None) val createScatter = for { - _ <- persistStatus(scatterKey, ExecutionStatus.Starting, None) _ <- globalDataAccess.insertCalls(workflow.id, newEntries.keys, backend) _ <- persistStatuses(newEntries.keys, ExecutionStatus.NotStarted, None) - _ <- persistStatus(scatterKey, ExecutionStatus.Done, Some(0)) + _ = persistStatus(scatterKey, ExecutionStatus.Done, Some(0)) } yield () Await.result(createScatter, AkkaTimeout) newEntries.keys @@ -577,7 +613,7 @@ case class WorkflowActor(workflow: WorkflowDescriptor, self ! CallFailed(collector, None, e.getMessage) case Success(outputs) => log.info(s"Collection complete for Scattered Call ${collector.scope.fullyQualifiedName}.") - self ! CallCompleted(collector, outputs) + self ! CallCompleted(collector, outputs, 0) } Success(Seq.empty[ExecutionStoreKey]) @@ -625,14 +661,14 @@ case class WorkflowActor(workflow: WorkflowDescriptor, private def executionsAsTable: Seq[Seq[String]] = { val futureRows = globalDataAccess.getExecutionStatuses(workflow.id) map { entries => entries.map({ case(k, v) => - Seq(k.fqn.toString, k.index.getOrElse("").toString, v.toString) + Seq(k.fqn.toString, k.index.getOrElse("").toString, v.executionStatus.toString, v.returnCode.getOrElse("").toString) }) } Await.result(futureRows, AkkaTimeout).toSeq } private def executionsMarkdownTable: Option[String] = { - val header = Seq("SCOPE", "INDEX", "STATUS") + val header = Seq("SCOPE", "INDEX", "STATUS", "RETURN CODE") executionsAsTable match { case rows: Seq[Seq[String]] if rows.isEmpty => None case rows => Some(TerminalUtil.mdTable(rows.toSeq, header)) diff --git a/src/main/scala/cromwell/server/CromwellServer.scala b/src/main/scala/cromwell/server/CromwellServer.scala index 76caa049b4..8a6fa58e74 100644 --- a/src/main/scala/cromwell/server/CromwellServer.scala +++ b/src/main/scala/cromwell/server/CromwellServer.scala @@ -1,38 +1,28 @@ package cromwell.server -import akka.io.IO -import akka.pattern.ask import akka.util.Timeout import com.typesafe.config.ConfigFactory -import cromwell.webservice.CromwellApiServiceActor -import spray.can.Http +import cromwell.webservice.{CromwellApiServiceActor, SwaggerService} +import lenthall.spray.SprayCanHttpService._ import scala.concurrent.duration._ -import scala.util.{Failure, Success} // Note that as per the language specification, this is instantiated lazily and only used when necessary (i.e. server mode) object CromwellServer extends DefaultWorkflowManagerSystem { val conf = ConfigFactory.load() - // NOTE: Currently the this.dataAccess is passed in to this.workflowManagerActor. - // The actor could technically restart with the same instance of the dataAccess, - // So, we're not shutting down dataAccess during this.workflowManagerActor.postStop() nor this.service.postStop(). - // Not sure otherwise when this server is really shutting down, so this.dataAccess currently never explicitly closed. - // Shouldn't be an issue unless perhaps test code tries to launch multiple servers and leaves dangling connections. - val service = actorSystem.actorOf(CromwellApiServiceActor.props(workflowManagerActor), "cromwell-service") + val service = actorSystem.actorOf(CromwellApiServiceActor.props(workflowManagerActor, SwaggerService.from(conf)), "cromwell-service") implicit val timeout = Timeout(5.seconds) val webserviceConf = conf.getConfig("webservice") import scala.concurrent.ExecutionContext.Implicits.global - (IO(Http) ? Http.Bind(service, interface = webserviceConf.getString("interface"), port = webserviceConf.getInt("port"))).onComplete { - case Success(Http.CommandFailed(failure)) => - actorSystem.log.error("could not bind to port: " + failure.toString) - actorSystem.shutdown() - case Failure(t) => - actorSystem.log.error(t, "could not bind to port") - actorSystem.shutdown() + + service.bindOrShutdown( + interface = webserviceConf.getString("interface"), + port = webserviceConf.getInt("port") + ) onSuccess { case _ => actorSystem.log.info("Cromwell service started...") } diff --git a/src/main/scala/cromwell/util/ConfigUtil.scala b/src/main/scala/cromwell/util/ConfigUtil.scala deleted file mode 100644 index 6d82dd839c..0000000000 --- a/src/main/scala/cromwell/util/ConfigUtil.scala +++ /dev/null @@ -1,78 +0,0 @@ -package cromwell.util - -import java.net.{MalformedURLException, URL} - -import com.typesafe.config.{ConfigValue, Config, ConfigException, ConfigFactory} -import org.slf4j.{LoggerFactory, Logger} -import scala.reflect.{ClassTag, classTag} -import scala.collection.JavaConversions._ -import scala.util.Try -import scalaz._ -import Scalaz._ - -object ConfigUtil { - - val validationLogger = LoggerFactory.getLogger("ConfigurationValidation") - - class ConfigValidationException(context: String, validationException: ConfigException.ValidationFailed) - extends ConfigException.ValidationFailed(validationException.problems()) { - override def getMessage: String = { - val problems = validationException.problems().map(_.problem()).mkString(", ") - s"$context configuration validation failed : $problems" - } - } - - implicit class EnhancedConfig(val config: Config) extends AnyVal { - - def keys = config.entrySet().toSet map { v: java.util.Map.Entry[String, ConfigValue] => v.getKey } - - def getStringOption(key: String): Option[String] = { - Try(config.getString(key)) match { - case scala.util.Success(value) => Option(value) - case scala.util.Failure(e: ConfigException.Missing) => None - case scala.util.Failure(e) => throw e - } - } - - /** - * For keys that are in the configuration but not in the reference keySet, log a warning. - */ - def warnNotRecognized(keySet: Set[String], context: String) = { - keys.diff(keySet) match { - case warnings if warnings.nonEmpty => validationLogger.warn(s"Unrecognized configuration key(s) for $context: ${warnings.mkString(", ")}") - case _ => - } - } - - /** - * Validates that the value for this key is a well formed URL. - */ - def validateURL(key: String): ValidationNel[String, URL] = key.validateAny { url => - new URL(config.getString(url)) - } - - def validateString(key: String): ValidationNel[String, String] = try { - config.getString(key).successNel - } catch { - case e: ConfigException.Missing => "Could not find key: $key".failureNel - } - - } - - implicit class EnhancedValidation[I <: AnyRef](val value: I) extends AnyVal { - /** - * Validates this value by applying validationFunction to it and returning a Validation: - * Returns successNel upon success. - * If an exception is thrown and is a subtype of E, return failureNel with the exception message. - * @param validationFunction function that should throw an exception if this value is found not to be valid - * @tparam O return type of validationFunction - * @tparam E Restricts the subtype of Exception that should be caught during validation - */ - def validateAny[O, E <: Exception: ClassTag](validationFunction: I => O): ValidationNel[String, O] = try { - validationFunction(value).successNel - } catch { - case e if classTag[E].runtimeClass.isInstance(e) => e.getMessage.failureNel - } - } - -} diff --git a/src/main/scala/cromwell/util/TryUtil.scala b/src/main/scala/cromwell/util/TryUtil.scala index 05c1234b5d..c6bab1fcae 100644 --- a/src/main/scala/cromwell/util/TryUtil.scala +++ b/src/main/scala/cromwell/util/TryUtil.scala @@ -2,9 +2,12 @@ package cromwell.util import java.io.{PrintWriter, StringWriter} -import scala.util.{Failure, Try} +import com.typesafe.scalalogging.LazyLogging -object TryUtil { +import scala.concurrent.duration.Duration +import scala.util.{Success, Failure, Try} + +object TryUtil extends LazyLogging { private def stringifyFailure[T](failure: Try[T]): String = { val stringWriter = new StringWriter() val writer = new PrintWriter(stringWriter) @@ -16,4 +19,54 @@ object TryUtil { def stringifyFailures[T](possibleFailures: Traversable[Try[T]]): Traversable[String] = possibleFailures.collect { case failure: Failure[T] => stringifyFailure(failure) } + + private def defaultSuccessFunction(a: Any): Boolean = true + + /** + * Runs a block of code (`fn`) `retries` number of times until it succeeds. + * It will wait `pollingInterval` amount of time between retry attempts and + * The `pollingBackOffFactor` is for exponentially backing off the `pollingInterval` + * on subsequent retries. The `pollingInterval` shall not exceed `maxPollingInterval` + * + * Returns a Try[T] where T is the return value of `fn`, the function to be retried. + * If the return value is Success[T] then at least one retry succeeded. + * + * The isSuccess function is optional but if provided, then isSuccess(fn) must be true + * or it will trigger another retry. if isSuccess is omitted, the only way the fn can + * fail is if it throws an exception. + * + * Use `retries` value of None indicates to retry indefinitely. + */ + @annotation.tailrec + def retryBlock[T](fn: Option[T] => T, + isSuccess: T => Boolean = defaultSuccessFunction _, + retries: Option[Int], + pollingInterval: Duration, + pollingBackOffFactor: Double, + maxPollingInterval: Duration, + failMessage: Option[String] = None, + priorValue: Option[T] = None): Try[T] = { + Try { fn(priorValue) } match { + case Success(x) if isSuccess(x) => Success(x) + case value if (retries.isDefined && retries.get > 1) || retries.isEmpty => + + val retryCountMessage = if (retries.getOrElse(0) > 0) s" (${retries.getOrElse(0) - 1} more retries) " else "" + val retryMessage = s"Retrying in $pollingInterval$retryCountMessage..." + failMessage foreach { m => logger.warn(s"$m. $retryMessage") } + + Thread.sleep(pollingInterval.toMillis) + + retryBlock( + fn, + isSuccess, + retries.map(_ - 1), + Duration(Math.min((pollingInterval.toMillis * pollingBackOffFactor).toLong, maxPollingInterval.toMillis), "milliseconds"), + pollingBackOffFactor, + maxPollingInterval, + failMessage, + value.toOption + ) + case f => f + } + } } diff --git a/src/main/scala/cromwell/util/google/GoogleCloudStoragePath.scala b/src/main/scala/cromwell/util/google/GoogleCloudStoragePath.scala index 184d18e77b..ba303d2810 100644 --- a/src/main/scala/cromwell/util/google/GoogleCloudStoragePath.scala +++ b/src/main/scala/cromwell/util/google/GoogleCloudStoragePath.scala @@ -20,7 +20,7 @@ object GoogleCloudStoragePath { val gsUriRegex = """gs://([^/]*)/(.*)""".r value match { case gsUriRegex(bucket, objectName) => Success(GoogleCloudStoragePath(bucket, objectName)) - case _ => Failure(new IllegalArgumentException()) + case _ => Failure(new IllegalArgumentException(s"Not a valid Google Cloud Storage URI: ${value}")) } } } \ No newline at end of file diff --git a/src/main/scala/cromwell/webservice/ApiDataModels.scala b/src/main/scala/cromwell/webservice/ApiDataModels.scala index f5673b53b8..3502caf5a1 100644 --- a/src/main/scala/cromwell/webservice/ApiDataModels.scala +++ b/src/main/scala/cromwell/webservice/ApiDataModels.scala @@ -1,97 +1,25 @@ package cromwell.webservice -import com.wordnik.swagger.annotations.{ApiModel, ApiModelProperty} import cromwell.binding.FullyQualifiedName import cromwell.binding.values.WdlValue import cromwell.engine.backend.{CallMetadata, StdoutStderr} import org.joda.time.DateTime import spray.json.JsObject -import scala.annotation.meta.field +case class WorkflowValidateResponse(valid: Boolean, error: Option[String]) -@ApiModel(value = "WorkflowValidate") -case class WorkflowValidateResponse -( - @(ApiModelProperty@field)(required = true, value = "The validation of the workflow") - valid: Boolean, - @(ApiModelProperty@field)(required = false, value = "The validation error of the workflow") - error: Option[String] - ) +case class WorkflowStatusResponse(id: String, status: String) -@ApiModel(value = "WorkflowStatus") -case class WorkflowStatusResponse -( - @(ApiModelProperty@field)(required = true, value = "The identifier of the workflow") - id: String, - @(ApiModelProperty@field)(required = true, value = "The status of the workflow") - status: String - ) +case class WorkflowSubmitResponse(id: String, status: String) -@ApiModel(value = "WorkflowSubmit") -case class WorkflowSubmitResponse -( - @(ApiModelProperty@field)(required = true, value = "The identifier of the workflow") - id: String, - @(ApiModelProperty@field)(required = true, value = "The status of the workflow") - status: String - ) +case class WorkflowOutputResponse(id: String, outputs: Map[FullyQualifiedName, WdlValue]) -@ApiModel(value = "WorkflowOutputs") -case class WorkflowOutputResponse -( - @(ApiModelProperty@field)(required = true, value = "The identifier of the workflow") - id: String, - @(ApiModelProperty@field)(required = true, value = "The outputs of the workflow") - outputs: Map[FullyQualifiedName, WdlValue] - ) +case class WorkflowAbortResponse(id: String, status: String) -@ApiModel(value = "WorkflowAbort") -case class WorkflowAbortResponse -( - @(ApiModelProperty@field)(required = true, value = "The identifier of the workflow") - id: String, - @(ApiModelProperty@field)(required = true, value = "The status of the workflow") - status: String - ) +case class CallOutputResponse(id: String, callFqn: String, outputs: Map[FullyQualifiedName, WdlValue]) -@ApiModel(value = "CallOutputs") -case class CallOutputResponse -( - @(ApiModelProperty@field)(required = true, value = "The identifier of the workflow") - id: String, - @(ApiModelProperty@field)(required = true, value = "The fully qualified name of the call") - callFqn: String, - @(ApiModelProperty@field)(required = true, value = "The outputs of the workflow") - outputs: Map[FullyQualifiedName, WdlValue] - ) - -@ApiModel(value = "CallStdoutStderr") -case class CallStdoutStderrResponse -( - @(ApiModelProperty@field)(required = true, value = "The identifier of the workflow") - id: String, - @(ApiModelProperty@field)(required = true, value = "The fully qualified name of the call") - logs: Map[String, Seq[StdoutStderr]] -) - -@ApiModel(value = "WorkflowMetadata") -case class WorkflowMetadataResponse -( - @(ApiModelProperty@field)(required = true, value = "The identifier of the workflow") - id: String, - @(ApiModelProperty@field)(required = true, value = "Workflow status") - status: String, - @(ApiModelProperty@field)(required = true, value = "Date and time the workflow was submitted") - submission: DateTime, - @(ApiModelProperty@field)(required = true, value = "Date and time the workflow started execution") - start: Option[DateTime], - @(ApiModelProperty@field)(required = true, value = "Date and time the workflow ended execution") - end: Option[DateTime], - @(ApiModelProperty@field)(required = true, value = "Workflow inputs") - inputs: JsObject, - @(ApiModelProperty@field)(required = true, value = "Workflow outputs") - outputs: Option[Map[String, WdlValue]], - @(ApiModelProperty@field)(required = true, value = "The fully qualified name of the call") - calls: Map[String, Seq[CallMetadata]] -) +case class CallStdoutStderrResponse(id: String, logs: Map[String, Seq[StdoutStderr]]) +case class WorkflowMetadataResponse(id: String, status: String, submission: DateTime, start: Option[DateTime], + end: Option[DateTime], inputs: JsObject, outputs: Option[Map[String, WdlValue]], + calls: Map[String, Seq[CallMetadata]]) diff --git a/src/main/scala/cromwell/webservice/CromwellApiService.scala b/src/main/scala/cromwell/webservice/CromwellApiService.scala index 93a7e5e8c9..db4c6cdbb8 100644 --- a/src/main/scala/cromwell/webservice/CromwellApiService.scala +++ b/src/main/scala/cromwell/webservice/CromwellApiService.scala @@ -1,26 +1,46 @@ package cromwell.webservice -import akka.actor.{Actor, ActorRef, Props} +import akka.actor.{Actor, ActorRef, ActorRefFactory, Props} +import com.typesafe.config.Config import cromwell.engine.workflow.{ValidateActor, WorkflowOptions} import cromwell.engine.{WorkflowId, WorkflowSourceFiles} +import lenthall.spray.{SwaggerUiInfo, ConfigSwaggerUiHttpService} import spray.http.StatusCodes import spray.json._ import spray.routing.Directive.pimpApply import spray.routing._ +import scala.reflect.runtime.universe._ import scala.util.{Failure, Success, Try} + +object SwaggerService { + /* + Because of the implicit arg requirement apply() doesn't work here, so falling back to the less + idiomatic (but not unheard of) from(). + */ + def from(conf: Config)(implicit actorRefFactory: ActorRefFactory): SwaggerService = { + new SwaggerService(conf.getConfig("swagger")) + } +} + +class SwaggerService(override val swaggerUiConfig: Config) + (implicit val actorRefFactory: ActorRefFactory) + extends ConfigSwaggerUiHttpService { +} + + object CromwellApiServiceActor { - def props(workflowManagerActorRef: ActorRef): Props = { - Props(new CromwellApiServiceActor(workflowManagerActorRef)) + def props(workflowManagerActorRef: ActorRef, swaggerService: SwaggerService): Props = { + Props(new CromwellApiServiceActor(workflowManagerActorRef, swaggerService)) } } -class CromwellApiServiceActor(val workflowManager: ActorRef) extends Actor with CromwellApiService { +class CromwellApiServiceActor(val workflowManager: ActorRef, swaggerService: SwaggerService) extends Actor with CromwellApiService { implicit def executionContext = actorRefFactory.dispatcher def actorRefFactory = context - def possibleRoutes = options { complete(StatusCodes.OK) } ~ docsRoute ~ workflowRoutes + def possibleRoutes = options { complete(StatusCodes.OK) } ~ docsRoute ~ swaggerService.swaggerUiRoutes ~ workflowRoutes def receive = runRoute(possibleRoutes) } diff --git a/src/test/scala/cromwell/CromwellTestkitSpec.scala b/src/test/scala/cromwell/CromwellTestkitSpec.scala index 23192662f0..c18a9ecdf5 100644 --- a/src/test/scala/cromwell/CromwellTestkitSpec.scala +++ b/src/test/scala/cromwell/CromwellTestkitSpec.scala @@ -108,12 +108,7 @@ with DefaultTimeout with ImplicitSender with WordSpecLike with Matchers with Bef } def buildWorkflowDescriptor(sampleWdl: SampleWdl, runtime: String, uuid: UUID): WorkflowDescriptor = { - val source = sampleWdl.wdlSource(runtime) - val namespace = NamespaceWithWorkflow.load(source, BackendType.LOCAL) - val coercedInputs = namespace.coerceRawInputs(sampleWdl.rawInputs).get - val declarations = namespace.staticDeclarationsRecursive(coercedInputs).get - val inputs = coercedInputs ++ declarations - val workflowSources = WorkflowSourceFiles(source, sampleWdl.wdlJson, "{}") + val workflowSources = WorkflowSourceFiles(sampleWdl.wdlSource(runtime), sampleWdl.wdlJson, "{}") WorkflowDescriptor(WorkflowId(uuid), workflowSources) } diff --git a/src/test/scala/cromwell/SimpleWorkflowActorSpec.scala b/src/test/scala/cromwell/SimpleWorkflowActorSpec.scala index 1479dfb31f..4944b7ec50 100644 --- a/src/test/scala/cromwell/SimpleWorkflowActorSpec.scala +++ b/src/test/scala/cromwell/SimpleWorkflowActorSpec.scala @@ -2,10 +2,10 @@ package cromwell import java.util.UUID -import akka.pattern.ask +import akka.actor +import akka.actor.ActorInitializationException import akka.testkit._ import cromwell.binding._ -import cromwell.binding.values.WdlString import cromwell.engine._ import cromwell.engine.backend.local.LocalBackend import cromwell.engine.workflow.WorkflowActor @@ -14,65 +14,58 @@ import cromwell.parser.BackendType import cromwell.util.SampleWdl import cromwell.util.SampleWdl.HelloWorld.Addressee -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps class SimpleWorkflowActorSpec extends CromwellTestkitSpec("SimpleWorkflowActorSpec") { - private def buildWorkflowFSMRef(sampleWdl: SampleWdl, rawInputsOverride: Option[WorkflowRawInputs] = None): + private def buildWorkflowFSMRef(sampleWdl: SampleWdl, rawInputsOverride: String): TestFSMRef[WorkflowState, WorkflowFailure, WorkflowActor] = { - - val namespace = NamespaceWithWorkflow.load(sampleWdl.wdlSource(), BackendType.LOCAL) - val rawInputs = rawInputsOverride.getOrElse(sampleWdl.rawInputs) - val coercedInputs = namespace.coerceRawInputs(rawInputs).get - val workflowSources = WorkflowSourceFiles(sampleWdl.wdlSource(), sampleWdl.wdlJson, "{}") + val workflowSources = WorkflowSourceFiles(sampleWdl.wdlSource(), rawInputsOverride, "{}") val descriptor = WorkflowDescriptor(WorkflowId(UUID.randomUUID()), workflowSources) TestFSMRef(new WorkflowActor(descriptor, new LocalBackend)) } - val TestExecutionTimeout = 5000 milliseconds + val TestExecutionTimeout = 5.seconds.dilated "A WorkflowActor" should { "start, run, succeed and die" in { - val fsm = buildWorkflowFSMRef(SampleWdl.HelloWorld) - val probe = TestProbe() - probe watch fsm - assert(fsm.stateName == WorkflowSubmitted) startingCallsFilter("hello.hello") { - fsm ! Start + val fsm = buildWorkflowFSMRef(SampleWdl.HelloWorld, SampleWdl.HelloWorld.wdlJson) + val probe = TestProbe() + probe watch fsm within(TestExecutionTimeout) { - awaitCond(fsm.stateName == WorkflowRunning) - awaitCond(fsm.stateName == WorkflowSucceeded) + waitForPattern("transitioning from Submitted to Running") { + waitForPattern("transitioning from Running to Succeeded") { + fsm ! Start + } + } } - probe.expectTerminated(fsm, 10 seconds) + probe.expectTerminated(fsm, 10.seconds.dilated) } } "fail to construct with missing inputs" in { - intercept[UnsatisfiedInputsException] { - buildWorkflowFSMRef(SampleWdl.HelloWorld, rawInputsOverride = Some(Map.empty)) + intercept[UnsatisfiedInputsException] { + buildWorkflowFSMRef(SampleWdl.HelloWorld, rawInputsOverride = "{}") } } "fail to construct with inputs of the wrong type" in { intercept[UnsatisfiedInputsException] { - buildWorkflowFSMRef(SampleWdl.HelloWorld, rawInputsOverride = Some(Map(Addressee -> 3))) + buildWorkflowFSMRef(SampleWdl.HelloWorld, rawInputsOverride = s""" { "$Addressee" : 3} """) } } "fail when a call fails" in { - val fsm = buildWorkflowFSMRef(SampleWdl.GoodbyeWorld) - assert(fsm.stateName == WorkflowSubmitted) startingCallsFilter("goodbye.goodbye") { - waitForPattern("persisting status of goodbye.goodbye to Starting.") { - waitForPattern("persisting status of goodbye.goodbye to Running.") { - waitForPattern("persisting status of goodbye.goodbye to Failed.") { - waitForPattern("WorkflowActor .+ transitioning from Running to Failed\\.") { + waitForPattern("WorkflowActor .+ transitioning from Submitted to Running\\.") { + waitForPattern("persisting status of goodbye.goodbye to Starting.") { + waitForPattern("persisting status of goodbye.goodbye to Running.") { + waitForPattern("persisting status of goodbye.goodbye to Failed.") { + val fsm = buildWorkflowFSMRef(SampleWdl.GoodbyeWorld, SampleWdl.GoodbyeWorld.wdlJson) fsm ! Start - awaitCond(fsm.stateName == WorkflowRunning) - awaitCond(fsm.stateName == WorkflowFailed) } } } @@ -81,10 +74,10 @@ class SimpleWorkflowActorSpec extends CromwellTestkitSpec("SimpleWorkflowActorSp } "gracefully handle malformed WDL" in { - val fsm = buildWorkflowFSMRef(SampleWdl.CoercionNotDefined) - assert(fsm.stateName == WorkflowSubmitted) - fsm ! Start within(TestExecutionTimeout) { + val fsm = buildWorkflowFSMRef(SampleWdl.CoercionNotDefined, SampleWdl.CoercionNotDefined.wdlJson) + fsm ! Start + awaitCond(fsm.stateName == WorkflowSubmitted) awaitCond(fsm.stateName == WorkflowFailed) } } diff --git a/src/test/scala/cromwell/WdlFunctionsAtWorkflowLevelSpec.scala b/src/test/scala/cromwell/WdlFunctionsAtWorkflowLevelSpec.scala new file mode 100644 index 0000000000..4be4fd8900 --- /dev/null +++ b/src/test/scala/cromwell/WdlFunctionsAtWorkflowLevelSpec.scala @@ -0,0 +1,30 @@ +package cromwell + +import akka.testkit._ +import cromwell.binding.types.{WdlMapType, WdlStringType, WdlArrayType} +import cromwell.binding.values.{WdlMap, WdlArray, WdlString} +import cromwell.CromwellSpec.DockerTest +import cromwell.util.SampleWdl + +import scala.language.postfixOps + +class WdlFunctionsAtWorkflowLevelSpec extends CromwellTestkitSpec("WdlFunctionsAtWorkflowLevelSpec") { + val outputMap = WdlMap(WdlMapType(WdlStringType, WdlStringType), Map( + WdlString("k1") -> WdlString("v1"), + WdlString("k2") -> WdlString("v2"), + WdlString("k3") -> WdlString("v3") + )) + + "A workflow with a read_lines() and read_map() at the workflow level" should { + "execute those functions properly" in { + runWdlAndAssertOutputs( + sampleWdl = SampleWdl.WdlFunctionsAtWorkflowLevel, + eventFilter = EventFilter.info(pattern = s"starting calls: w.a", occurrences = 1), + expectedOutputs = Map( + "w.a.x" -> WdlString("one two three four five"), + "w.a.y" -> outputMap + ) + ) + } + } +} diff --git a/src/test/scala/cromwell/binding/types/WdlMapTypeSpec.scala b/src/test/scala/cromwell/binding/types/WdlMapTypeSpec.scala index 9ed89b985e..3d10776835 100644 --- a/src/test/scala/cromwell/binding/types/WdlMapTypeSpec.scala +++ b/src/test/scala/cromwell/binding/types/WdlMapTypeSpec.scala @@ -1,6 +1,6 @@ package cromwell.binding.types -import cromwell.binding.values.{WdlMap, WdlInteger, WdlString} +import cromwell.binding.values.{WdlObject, WdlMap, WdlInteger, WdlString} import cromwell.parser.WdlParser.SyntaxError import org.scalatest.{FlatSpec, Matchers} import spray.json.{JsObject, JsArray, JsNumber} @@ -13,6 +13,12 @@ class WdlMapTypeSpec extends FlatSpec with Matchers { WdlString("b") -> WdlInteger(2), WdlString("c") -> WdlInteger(3) )) + val coerceableObject = WdlObject(Map( + "a" -> WdlString("1"), + "b" -> WdlString("2"), + "c" -> WdlString("3") + )) + "WdlMap" should "stringify its value" in { stringIntMap.toWdlString shouldEqual "{\"a\": 1, \"b\": 2, \"c\": 3}" } @@ -34,6 +40,14 @@ class WdlMapTypeSpec extends FlatSpec with Matchers { it should "convert WDL source code to WdlMap" in { WdlMapType(WdlStringType, WdlIntegerType).fromWdlString("{\"a\": 1, \"b\": 2, \"c\": 3}") shouldEqual stringIntMap } + it should "coerce a coerceable object into a WdlMap" in { + WdlMapType(WdlStringType, WdlIntegerType).coerceRawValue(coerceableObject) match { + case Success(v) => + v.wdlType shouldEqual WdlMapType(WdlStringType, WdlIntegerType) + v.toWdlString shouldEqual stringIntMap.toWdlString + case Failure(f) => fail("Failed to coerce a map to an object") + } + } it should "NOT successfully convert WDL source code to WdlMap if passed a bogus AST" in { try { WdlMapType(WdlStringType, WdlIntegerType).fromWdlString("workflow wf{}") diff --git a/src/test/scala/cromwell/binding/types/WdlObjectTypeSpec.scala b/src/test/scala/cromwell/binding/types/WdlObjectTypeSpec.scala new file mode 100644 index 0000000000..dfe628c2d1 --- /dev/null +++ b/src/test/scala/cromwell/binding/types/WdlObjectTypeSpec.scala @@ -0,0 +1,106 @@ +package cromwell.binding.types + +import cromwell.binding.values.{WdlInteger, WdlString, WdlMap, WdlObject} +import cromwell.parser.WdlParser.SyntaxError +import org.scalatest.{FlatSpec, Matchers} + +import scala.util.{Failure, Success} + +class WdlObjectTypeSpec extends FlatSpec with Matchers { + val abcObject = WdlObject(Map( + "a" -> WdlString("one"), + "b" -> WdlString("two"), + "c" -> WdlString("three") + )) + + val coerceableMap = WdlMap(WdlMapType(WdlStringType, WdlStringType), Map( + WdlString("a") -> WdlString("one"), + WdlString("b") -> WdlString("two"), + WdlString("c") -> WdlString("three")) + ) + + val nonCoerceableMap1 = WdlMap(WdlMapType(WdlIntegerType, WdlStringType), Map( + WdlInteger(1) -> WdlString("one"), + WdlInteger(2) -> WdlString("two"), + WdlInteger(3) -> WdlString("three")) + ) + + val nonCoerceableMap2 = WdlMap(WdlMapType(WdlStringType, WdlObjectType), Map( + WdlString("a") -> WdlObject(Map.empty), + WdlString("b") -> WdlObject(Map.empty), + WdlString("c") -> WdlObject(Map.empty)) + ) + + "WdlObject" should "stringify its value" in { + abcObject.toWdlString shouldEqual "object {a: \"one\", b: \"two\", c: \"three\"}" + } + + it should "stringify its type" in { + abcObject.wdlType.toWdlString shouldEqual "Object" + } + + it should "convert WDL source code to WdlMap" in { + WdlObjectType.fromWdlString("object {a: \"one\", b: \"two\", c: \"three\"}") shouldEqual abcObject + } + + it should "coerce a coerceable map into a WdlObject" in { + WdlObjectType.coerceRawValue(coerceableMap) match { + case Success(v) => + v.wdlType shouldEqual WdlObjectType + v.toWdlString shouldEqual abcObject.toWdlString + case Failure(f) => fail("Failed to coerce a map to an object") + } + } + + it should "NOT successfully coerce a NON coerceable map into a WdlObject" in { + WdlObjectType.coerceRawValue(nonCoerceableMap1) match { + case Success(v) => fail("should not have succeeded") + case Failure(f) => // expected + } + } + + it should "NOT successfully coerce a NON coerceable map into a WdlObject (2)" in { + WdlObjectType.coerceRawValue(nonCoerceableMap2) match { + case Success(v) => fail("should not have succeeded") + case Failure(f) => // expected + } + } + + it should "NOT successfully convert WDL source code to WdlMap if passed a bogus AST" in { + try { + WdlObjectType.fromWdlString("workflow wf{}") + fail("should not have succeeded") + } catch { + case _: SyntaxError => // expected + } + } + + it should "NOT successfully convert WDL source code to WdlMap if passed a bogus AST (2)" in { + try { + WdlObjectType.fromWdlString("100") + fail("should not have succeeded") + } catch { + case _: SyntaxError => // expected + } + } + + it should "NOT successfully convert WDL source code to WdlMap if passed a bogus AST (3)" in { + try { + WdlObjectType.fromWdlString("{1:x(),2:stdout()}") + fail("should not have succeeded") + } catch { + case _: SyntaxError => // expected + } + } + + it should "NOT successfully convert WDL source code to WdlMap if passed a bogus AST (4)" in { + try { + WdlObjectType.fromWdlString("{1:var,2:var}") + fail("should not have succeeded") + } catch { + case _: SyntaxError => // expected + } + } + + +} diff --git a/src/test/scala/cromwell/binding/values/WdlObjectSpec.scala b/src/test/scala/cromwell/binding/values/WdlObjectSpec.scala new file mode 100644 index 0000000000..632af9af5b --- /dev/null +++ b/src/test/scala/cromwell/binding/values/WdlObjectSpec.scala @@ -0,0 +1,81 @@ +package cromwell.binding.values + +import cromwell.binding.types.{WdlObjectType, WdlArrayType} +import org.scalatest.{FlatSpec, Matchers, TryValues} + +class WdlObjectSpec extends FlatSpec with Matchers with TryValues { + + val correctTSV = "one\ttwo\tthree\tfour\none\tfour\tnine\tsixteen" + val emptyTSV = "" + val oneRowTSV = "one\ttwo\tthree\tfour" + val nonHomogeneousTS = "onet\ttwo\tthree\none\ttwo" + val arrayTSV = correctTSV + "\none\teight\ttwentyseven\tsixtyfour" + + it should "read an Object from a correct TSV file" in { + val parsed = WdlObject.fromTsv(correctTSV) + parsed should be a 'success + val array: Array[WdlObject] = parsed.success.value + array should have size 1 + + //Attributes + array.head.value should contain key "one" + array.head.value should contain key "two" + array.head.value should contain key "three" + array.head.value should contain key "four" + + //Values + array.head.value.get("one") shouldBe Some(WdlString("one")) + array.head.value.get("two") shouldBe Some(WdlString("four")) + array.head.value.get("three") shouldBe Some(WdlString("nine")) + array.head.value.get("four") shouldBe Some(WdlString("sixteen")) + } + + it should "NOT read from a TSV file with less than 2 rows" in { + WdlObject.fromTsv(emptyTSV) should be a 'failure + WdlObject.fromTsv(oneRowTSV) should be a 'failure + } + + it should "NOT read from a non homogeneous TSV file" in { + WdlObject.fromTsv(nonHomogeneousTS) should be a 'failure + } + + it should "serialize to TSV" in { + val obj = WdlObject.fromTsv(correctTSV).get.head + val serialized = obj.tsvSerialize + serialized should be a 'success + serialized.success.value shouldEqual correctTSV + } + + it should "read a WdlArray[WdlObject] from a correct TSV file" in { + val parsed = WdlObject.fromTsv(arrayTSV) + parsed should be a 'success + val array: Array[WdlObject] = parsed.success.value + array should have size 2 + + //Attributes + array foreach { _.value should contain key "one" } + array foreach { _.value should contain key "two" } + array foreach { _.value should contain key "three" } + array foreach { _.value should contain key "four" } + + //Values + array.head.value.get("one") shouldBe Some(WdlString("one")) + array.head.value.get("two") shouldBe Some(WdlString("four")) + array.head.value.get("three") shouldBe Some(WdlString("nine")) + array.head.value.get("four") shouldBe Some(WdlString("sixteen")) + + array(1).value.get("one") shouldBe Some(WdlString("one")) + array(1).value.get("two") shouldBe Some(WdlString("eight")) + array(1).value.get("three") shouldBe Some(WdlString("twentyseven")) + array(1).value.get("four") shouldBe Some(WdlString("sixtyfour")) + } + + it should "serialize a WdlArray[WdlObject] to TSV" in { + val array = WdlArray(WdlArrayType(WdlObjectType), WdlObject.fromTsv(arrayTSV).get) + val serialized = array.tsvSerialize + serialized should be a 'success + serialized.success.value shouldEqual arrayTSV + + } + +} diff --git a/src/test/scala/cromwell/binding/values/WdlValueSpec.scala b/src/test/scala/cromwell/binding/values/WdlValueSpec.scala index 627aeb0ff7..cb1720f686 100644 --- a/src/test/scala/cromwell/binding/values/WdlValueSpec.scala +++ b/src/test/scala/cromwell/binding/values/WdlValueSpec.scala @@ -1,8 +1,7 @@ package cromwell.binding.values -import cromwell.binding.{WdlExpression, WdlNamespace} -import cromwell.parser.BackendType -import cromwell.util.SampleWdl +import cromwell.binding.WdlExpression +import cromwell.binding.types.{WdlMapType, WdlStringType} import org.scalatest.prop.TableDrivenPropertyChecks import org.scalatest.{FlatSpec, Matchers} @@ -25,7 +24,9 @@ class WdlValueSpec extends FlatSpec with Matchers { (WdlInteger(0), "0"), (WdlInteger(Int.MaxValue), "2147483647"), (WdlInteger(Int.MinValue), "-2147483648"), - (WdlString(""), "\"\"") + (WdlString(""), "\"\""), + (WdlObject(Map("one" -> WdlString("two"))), "object {one: \"two\"}"), + (WdlMap(WdlMapType(WdlStringType, WdlStringType), Map(WdlString("one") -> WdlString("two"))), "{\"one\": \"two\"}") ) forAll(wdlValueRawStrings) { (wdlValue, rawString) => @@ -86,17 +87,4 @@ class WdlValueSpec extends FlatSpec with Matchers { fromRawString.wdlType should be(wdlType) } } - - val notImplementRawString = Table( - "wdlValue", - WdlObject(Map("key" -> WdlString("value"))), - WdlNamespace.load(SampleWdl.HelloWorld.wdlSource(), BackendType.LOCAL)) - - forAll(notImplementRawString) { wdlValue => - it should s"not implement a ${wdlValue.typeName} raw string" in { - a [NotImplementedError] should be thrownBy wdlValue.toWdlString - val wdlType = wdlValue.wdlType - a [NotImplementedError] should be thrownBy wdlType.fromWdlString("") - } - } } diff --git a/src/test/scala/cromwell/engine/WorkflowManagerActorSpec.scala b/src/test/scala/cromwell/engine/WorkflowManagerActorSpec.scala index 7317fe6a12..256d35368f 100644 --- a/src/test/scala/cromwell/engine/WorkflowManagerActorSpec.scala +++ b/src/test/scala/cromwell/engine/WorkflowManagerActorSpec.scala @@ -3,7 +3,7 @@ package cromwell.engine import java.util.UUID import akka.pattern.ask -import akka.testkit.{EventFilter, TestActorRef} +import akka.testkit.{EventFilter, TestActorRef, _} import cromwell.binding._ import cromwell.binding.command.CommandPart import cromwell.binding.types.{WdlArrayType, WdlStringType} @@ -30,7 +30,7 @@ class WorkflowManagerActorSpec extends CromwellTestkitSpec("WorkflowManagerActor "A WorkflowManagerActor" should { - val TestExecutionTimeout = 5000 milliseconds + val TestExecutionTimeout = 5.seconds.dilated "run the Hello World workflow" in { diff --git a/src/test/scala/cromwell/engine/backend/local/LocalBackendSpec.scala b/src/test/scala/cromwell/engine/backend/local/LocalBackendSpec.scala index 3f835d148a..c9fb0dc19c 100644 --- a/src/test/scala/cromwell/engine/backend/local/LocalBackendSpec.scala +++ b/src/test/scala/cromwell/engine/backend/local/LocalBackendSpec.scala @@ -37,10 +37,10 @@ class LocalBackendSpec extends CromwellTestkitSpec("LocalBackendSpec") { def testFailOnStderr(descriptor: WorkflowDescriptor, expectSuccess: Boolean): Unit = { val call = descriptor.namespace.workflow.calls.head val backend = new LocalBackend() - val backendCall = backend.bindCall(descriptor, CallKey(call, None, None), Map.empty[String, WdlValue], AbortRegistrationFunction(_ => ())) + val backendCall = backend.bindCall(descriptor, CallKey(call, None), Map.empty[String, WdlValue], AbortRegistrationFunction(_ => ())) backendCall.execute match { case FailedExecution(e, _) => if (expectSuccess) fail("A call in a failOnStderr test which should have succeeded has failed ", e) - case SuccessfulExecution(_) => if (!expectSuccess) fail("A call in a failOnStderr test which should have failed has succeeded") + case SuccessfulExecution(_, _) => if (!expectSuccess) fail("A call in a failOnStderr test which should have failed has succeeded") case AbortedExecution => fail("Not expecting this at all") } } diff --git a/src/test/scala/cromwell/engine/db/slick/SlickDataAccessSpec.scala b/src/test/scala/cromwell/engine/db/slick/SlickDataAccessSpec.scala index 72cf62806d..86082b504f 100644 --- a/src/test/scala/cromwell/engine/db/slick/SlickDataAccessSpec.scala +++ b/src/test/scala/cromwell/engine/db/slick/SlickDataAccessSpec.scala @@ -4,6 +4,7 @@ import java.sql.SQLException import java.util.UUID import cromwell.binding._ +import cromwell.binding.expression.WdlStandardLibraryFunctions import cromwell.binding.types.{WdlArrayType, WdlStringType} import cromwell.binding.values.{WdlArray, WdlString} import cromwell.engine.ExecutionIndex.ExecutionIndex @@ -59,6 +60,9 @@ class SlickDataAccessSpec extends FlatSpec with Matchers with ScalaFutures { abortRegistrationFunction: AbortRegistrationFunction): BackendCall = throw new NotImplementedError + override def engineFunctions: WdlStandardLibraryFunctions = + throw new NotImplementedError + override def execute(bc: BackendCall): ExecutionResult = throw new NotImplementedError @@ -308,7 +312,7 @@ class SlickDataAccessSpec extends FlatSpec with Matchers with ScalaFutures { status.executionStatus should be(if (updateStatus) ExecutionStatus.Running else ExecutionStatus.NotStarted) status.returnCode should be(None) } - _ <- dataAccess.insertCalls(workflowId, Seq(CallKey(call, Option(0), None)), localBackend) + _ <- dataAccess.insertCalls(workflowId, Seq(CallKey(call, Option(0))), localBackend) _ <- dataAccess.setStatus(workflowId, Seq(shardKey), CallStatus(ExecutionStatus.Done, Option(0))) _ <- dataAccess.getExecutionStatuses(workflowId) map { result => result.size should be(2) @@ -337,7 +341,7 @@ class SlickDataAccessSpec extends FlatSpec with Matchers with ScalaFutures { val workflowInfo = new WorkflowDescriptor(workflowId, testSources) val task = new Task("taskName", Nil, Nil, Nil, null, BackendType.LOCAL) val call = new Call(None, callFqn, task, Set.empty[FullyQualifiedName], Map.empty, None) - val callKey = CallKey(call, None, None) + val callKey = CallKey(call, None) (for { _ <- dataAccess.createWorkflow(workflowInfo, Nil, Nil, localBackend) _ <- dataAccess.updateWorkflowState(workflowId, WorkflowRunning) @@ -443,8 +447,8 @@ class SlickDataAccessSpec extends FlatSpec with Matchers with ScalaFutures { (for { _ <- dataAccess.createWorkflow(workflowInfo, Nil, Nil, localBackend) _ <- dataAccess.updateWorkflowState(workflowId, WorkflowRunning) - _ <- dataAccess.setOutputs(workflowId, CallKey(call, None, None), Map(symbolLqn -> new WdlString("testStringValue"))) - _ <- dataAccess.setOutputs(workflowId, CallKey(call, Option(0), None), Map(symbolLqn -> new WdlString("testStringValueShard"))) + _ <- dataAccess.setOutputs(workflowId, CallKey(call, None), Map(symbolLqn -> new WdlString("testStringValue"))) + _ <- dataAccess.setOutputs(workflowId, CallKey(call, Option(0)), Map(symbolLqn -> new WdlString("testStringValueShard"))) _ <- dataAccess.getOutputs(workflowId) map { results => results.size should be(1) //getOutputs on a workflowId does NOT return shards outputs @@ -460,6 +464,40 @@ class SlickDataAccessSpec extends FlatSpec with Matchers with ScalaFutures { } yield ()).futureValue } + it should "set and get shard statuses" in { + assume(canConnect || testRequired) + val callFqn1 = "call.fully.qualified.scope$s1" + val callFqn2 = "call.fully.qualified.scope$s2" + val symbolLqn = "symbol" + val workflowId = WorkflowId(UUID.randomUUID()) + val workflowInfo = new WorkflowDescriptor(workflowId, testSources) + val task = new Task("taskName", Nil, Nil, Nil, null, BackendType.LOCAL) + val call1 = new Call(None, callFqn1, task, Set.empty[FullyQualifiedName], Map.empty, None) + val shardIndex1 = Option(0) + val pid1 = Option(123) + val backendInfo1 = new LocalCallBackendInfo(pid1) + val call2 = new Call(None, callFqn2, task, Set.empty[FullyQualifiedName], Map.empty, None) + val shardIndex2 = Option(1) + val pid2 = Option(987) + val backendInfo2 = new LocalCallBackendInfo(pid2) + + (for { + _ <- dataAccess.createWorkflow(workflowInfo, Nil, Nil, localBackend) + _ <- dataAccess.insertCalls(workflowId, Seq(CallKey(call1, shardIndex1)), localBackend) + _ <- dataAccess.insertCalls(workflowId, Seq(CallKey(call2, shardIndex2)), localBackend) + _ <- dataAccess.updateExecutionBackendInfo(workflowId, CallKey(call1, shardIndex1), backendInfo1) + _ <- dataAccess.updateExecutionBackendInfo(workflowId, CallKey(call2, shardIndex2), backendInfo2) + _ <- dataAccess.getExecutionBackendInfo(workflowId, call1) map { + case LocalCallBackendInfo(processId: Option[Int]) => assertResult(pid1) { processId } + case _ => fail("Unexpected backend info type returned") + } + _ <- dataAccess.getExecutionBackendInfo(workflowId, call2) map { + case LocalCallBackendInfo(processId: Option[Int]) => assertResult(pid2) { processId } + case _ => fail("Unexpected backend info type returned") + } + } yield ()).futureValue + } + it should "set and get an output by call" in { assume(canConnect || testRequired) val callFqn = "call.fully.qualified.scope" @@ -472,8 +510,8 @@ class SlickDataAccessSpec extends FlatSpec with Matchers with ScalaFutures { (for { _ <- dataAccess.createWorkflow(workflowInfo, Nil, Nil, localBackend) _ <- dataAccess.updateWorkflowState(workflowId, WorkflowRunning) - _ <- dataAccess.setOutputs(workflowId, CallKey(call, None, None), Map(symbolLqn -> new WdlString("testStringValue"))) - _ <- dataAccess.setOutputs(workflowId, CallKey(call, Option(0), None), Map(symbolLqn -> new WdlString("testStringValueShard"))) + _ <- dataAccess.setOutputs(workflowId, CallKey(call, None), Map(symbolLqn -> new WdlString("testStringValue"))) + _ <- dataAccess.setOutputs(workflowId, CallKey(call, Option(0)), Map(symbolLqn -> new WdlString("testStringValueShard"))) callOutput <- dataAccess.getOutputs(workflowId, ExecutionDatabaseKey(call.fullyQualifiedName, None)) map { results => results.head.key.index should be(None) results.head.wdlValue.get should be(new WdlString("testStringValue")) @@ -549,7 +587,7 @@ class SlickDataAccessSpec extends FlatSpec with Matchers with ScalaFutures { (for { _ <- dataAccess.createWorkflow(workflowInfo, Seq(entry), Nil, localBackend) _ <- dataAccess.updateWorkflowState(workflowId, WorkflowRunning) - _ <- dataAccess.setOutputs(workflowId, CallKey(call, None, None), Map(symbolLqn -> new WdlString("testStringValue"))) + _ <- dataAccess.setOutputs(workflowId, CallKey(call, None), Map(symbolLqn -> new WdlString("testStringValue"))) _ <- dataAccess.getOutputs(workflowId, ExecutionDatabaseKey(call.fullyQualifiedName, None)) map { results => results.size should be(1) val resultSymbol = results.head @@ -580,7 +618,7 @@ class SlickDataAccessSpec extends FlatSpec with Matchers with ScalaFutures { (for { _ <- dataAccess.createWorkflow(workflowInfo, Seq(entry), Nil, localBackend) _ <- dataAccess.updateWorkflowState(workflowId, WorkflowRunning) - _ <- dataAccess.setOutputs(workflowId, CallKey(call, None, None), Map(symbolFqn -> new WdlString("testStringValue"))) + _ <- dataAccess.setOutputs(workflowId, CallKey(call, None), Map(symbolFqn -> new WdlString("testStringValue"))) } yield ()).failed.futureValue should be(a[SQLException]) } @@ -590,17 +628,14 @@ class SlickDataAccessSpec extends FlatSpec with Matchers with ScalaFutures { val workflowInfo = new WorkflowDescriptor(workflowId, testSources) val task = new Task("taskName", Nil, Nil, Nil, null, BackendType.LOCAL) val call = new Call(None, "fully.qualified.name", task, Set.empty[FullyQualifiedName], Map.empty, None) - val backendInfo = new LocalCallBackendInfo(CallStatus(ExecutionStatus.Running, None), Option(123)) + val backendInfo = new LocalCallBackendInfo(Option(123)) (for { _ <- dataAccess.createWorkflow(workflowInfo, Nil, Seq(call), localBackend) - _ <- dataAccess.updateWorkflowState(workflowId, WorkflowRunning) - _ <- dataAccess.updateExecutionBackendInfo(workflowId, call, backendInfo) + _ <- dataAccess.updateExecutionBackendInfo(workflowId, CallKey(call, None), backendInfo) _ <- dataAccess.getExecutionBackendInfo(workflowId, call) map { insertResultCall => insertResultCall should be(a[LocalCallBackendInfo]) val insertResultLocalCall = insertResultCall.asInstanceOf[LocalCallBackendInfo] - insertResultLocalCall.status.executionStatus should be(ExecutionStatus.Running) - insertResultLocalCall.status.returnCode shouldBe None insertResultLocalCall.processId shouldNot be(empty) insertResultLocalCall.processId.get should be(123) } @@ -616,29 +651,23 @@ class SlickDataAccessSpec extends FlatSpec with Matchers with ScalaFutures { val workflowInfo2 = new WorkflowDescriptor(workflowId2, testSources) val task = new Task("taskName", Nil, Nil, Nil, null, BackendType.LOCAL) val call = new Call(None, "fully.qualified.name", task, Set.empty[FullyQualifiedName], Map.empty, None) - val backendInfo1 = new LocalCallBackendInfo(CallStatus(ExecutionStatus.Running, None), Option(123)) - val backendInfo2 = new LocalCallBackendInfo(CallStatus(ExecutionStatus.Failed, Option(1)), Option(321)) + val backendInfo1 = new LocalCallBackendInfo(Option(123)) + val backendInfo2 = new LocalCallBackendInfo(Option(321)) (for { _ <- dataAccess.createWorkflow(workflowInfo1, Nil, Seq(call), localBackend) _ <- dataAccess.createWorkflow(workflowInfo2, Nil, Seq(call), localBackend) - _ <- dataAccess.updateWorkflowState(workflowId1, WorkflowRunning) - _ <- dataAccess.updateWorkflowState(workflowId2, WorkflowRunning) - _ <- dataAccess.updateExecutionBackendInfo(workflowId1, call, backendInfo1) - _ <- dataAccess.updateExecutionBackendInfo(workflowId2, call, backendInfo2) + _ <- dataAccess.updateExecutionBackendInfo(workflowId1, CallKey(call, None), backendInfo1) + _ <- dataAccess.updateExecutionBackendInfo(workflowId2, CallKey(call, None), backendInfo2) _ <- dataAccess.getExecutionBackendInfo(workflowId1, call) map { insertResultCall => insertResultCall should be(a[LocalCallBackendInfo]) val insertResultLocalCall = insertResultCall.asInstanceOf[LocalCallBackendInfo] - insertResultLocalCall.status.executionStatus should be(ExecutionStatus.Running) - insertResultLocalCall.status.returnCode should be(None) insertResultLocalCall.processId shouldNot be(empty) insertResultLocalCall.processId.get should be(123) } _ <- dataAccess.getExecutionBackendInfo(workflowId2, call) map { insertResultCall => insertResultCall should be(a[LocalCallBackendInfo]) val insertResultLocalCall = insertResultCall.asInstanceOf[LocalCallBackendInfo] - insertResultLocalCall.status.executionStatus should be(ExecutionStatus.Failed) - insertResultLocalCall.status.returnCode should be(Some(1)) insertResultLocalCall.processId shouldNot be(empty) insertResultLocalCall.processId.get should be(321) } @@ -655,7 +684,7 @@ class SlickDataAccessSpec extends FlatSpec with Matchers with ScalaFutures { (for { _ <- dataAccess.createWorkflow(workflowInfo, Nil, Seq(call), localBackend) _ <- dataAccess.updateWorkflowState(workflowId, WorkflowRunning) - _ <- dataAccess.updateExecutionBackendInfo(workflowId, call, null) + _ <- dataAccess.updateExecutionBackendInfo(workflowId, CallKey(call, None), null) } yield ()).failed.futureValue should be(an[IllegalArgumentException]) } diff --git a/src/test/scala/cromwell/engine/db/slick/TestSlickDatabase.scala b/src/test/scala/cromwell/engine/db/slick/TestSlickDatabase.scala index 78c6ffb7b1..87258055b5 100644 --- a/src/test/scala/cromwell/engine/db/slick/TestSlickDatabase.scala +++ b/src/test/scala/cromwell/engine/db/slick/TestSlickDatabase.scala @@ -7,16 +7,20 @@ import liquibase.Liquibase import liquibase.database.DatabaseConnection import liquibase.resource.{FileSystemResourceAccessor, ResourceAccessor} import org.slf4j.LoggerFactory +import slick.backend.DatabaseConfig +import slick.driver.JdbcProfile import scala.concurrent.{ExecutionContext, Future} class TestSlickDatabase(configPath: String) { - private lazy val databaseConfig = DatabaseConfig.rootDatabaseConfig.getConfig(configPath) + private lazy val databaseConfig = SlickDataAccess.getDatabaseConfig(configPath) private lazy val log = LoggerFactory.getLogger(classOf[TestSlickDatabase]) + private lazy val slickConfig = DatabaseConfig.forConfig[JdbcProfile]("", databaseConfig) + // NOTE: Using the import below for isValidConnection, but maybe not the lazy instance if the check fails. - lazy val dataAccessComponent: DataAccessComponent = new DataAccessComponent(databaseConfig.getString("slick.driver")) + lazy val dataAccessComponent: DataAccessComponent = new DataAccessComponent(slickConfig.driver) import dataAccessComponent.driver.api._ @@ -32,7 +36,7 @@ class TestSlickDatabase(configPath: String) { implicit val executionContext = ExecutionContext.global Future { log.debug("Opening test connection setup for " + configPath) - Database.forConfig("", databaseConfig) + slickConfig.db } flatMap { database => database.run(SimpleDBIO(_.connection.isValid(1))) andThen { case _ => @@ -47,10 +51,10 @@ class TestSlickDatabase(configPath: String) { } lazy val slickDataAccess = - if (this.databaseConfig == DatabaseConfig.databaseConfig) + if (this.databaseConfig == SlickDataAccess.defaultDatabaseConfig) new SlickDataAccess() // Test the no-args constructor else - new SlickDataAccess(databaseConfig, dataAccessComponent) + new SlickDataAccess(databaseConfig) def useLiquibase = databaseConfig.hasPath("liquibase") diff --git a/src/test/scala/cromwell/util/SampleWdl.scala b/src/test/scala/cromwell/util/SampleWdl.scala index 69a522ef44..9d97488166 100644 --- a/src/test/scala/cromwell/util/SampleWdl.scala +++ b/src/test/scala/cromwell/util/SampleWdl.scala @@ -1237,4 +1237,54 @@ object SampleWdl { "file_passing.f" -> createCannedFile("canned", fileContents).getAbsolutePath ) } + + object WdlFunctionsAtWorkflowLevel extends SampleWdl { + val CannedArray = + """one + |two + |three + |four + |five + """.stripMargin.trim + + val CannedMap = + s"""k1\tv1 + |k2\tv2 + |k3\tv3 + """.stripMargin.trim + + override def wdlSource(runtime: String): WdlSource = + """ + |task a { + | Array[String] array + | Map[String, String] map + | + | command { + | echo ${sep=' ' array} > concat + | } + | output { + | String x = read_string("concat") + | Map[String, String] y = map + | } + |} + | + |workflow w { + | File array_file + | File map_file + | Array[String] in_array = read_lines(array_file) + | Map[String, String] in_map = read_map(map_file) + | call a {input: + | array=in_array, + | map=in_map + | } + |} + """.stripMargin.replaceAll("RUNTIME", runtime) + + override val rawInputs = { + Map( + "w.array_file" -> createCannedFile("array.txt", CannedArray).getAbsolutePath, + "w.map_file" -> createCannedFile("map.txt", CannedMap).getAbsolutePath + ) + } + } } diff --git a/src/test/scala/cromwell/util/TryUtilSpec.scala b/src/test/scala/cromwell/util/TryUtilSpec.scala new file mode 100644 index 0000000000..2b749b1376 --- /dev/null +++ b/src/test/scala/cromwell/util/TryUtilSpec.scala @@ -0,0 +1,45 @@ +package cromwell.util + +import org.scalatest.{FlatSpec, Matchers} + +import scala.language.postfixOps +import scala.concurrent.duration._ +import scala.util.Success + +class TryUtilSpec extends FlatSpec with Matchers { + def failNTimes(n: Int): Option[Int] => Int = { + var counter = n + def func(prior: Option[Int]): Int = { + if (counter > 0) { + counter -= 1 + throw new Exception("Failed") + } + 9 + } + func + } + + it should "Retry a function until it works" in { + val value = TryUtil.retryBlock( + fn = failNTimes(4), + retries = Some(5), + pollingInterval = 50 milliseconds, + pollingBackOffFactor = 1, + maxPollingInterval = 10 seconds, + failMessage = Some(s"failed attempt (on purpose)") + ) + value shouldEqual Success(9) + } + + it should "Fail if it hits the max retry count" in { + val value = TryUtil.retryBlock( + fn = failNTimes(4), + retries = Some(4), + pollingInterval = 50 milliseconds, + pollingBackOffFactor = 1, + maxPollingInterval = 10 seconds, + failMessage = Some(s"failed attempt (on purpose)") + ) + value.isFailure shouldEqual true + } +}