From 5c9e4ae1279960fb8062cd0ade61d977bce82fdb Mon Sep 17 00:00:00 2001 From: grohli <22306963+grohli@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:26:27 -0400 Subject: [PATCH] [qob] cancel stage if any partitions fail. --- .../hail/backend/service/ServiceBackend.scala | 11 ++++-- .../scala/is/hail/services/BatchClient.scala | 14 +++++--- .../is/hail/services/BatchClientSuite.scala | 35 ++++++++++++++++++- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index b0d06b8b0dc4..cd177c656ac4 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -1,6 +1,6 @@ package is.hail.backend.service -import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} +import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ @@ -202,6 +202,7 @@ class ServiceBackend( batch_id = batchConfig.batchId, absolute_parent_id = batchConfig.jobGroupId, token = token, + cancel_after_n_failures = Some(1), attributes = Map("name" -> stageIdentifier), jobs = jobs, ) @@ -280,7 +281,13 @@ class ServiceBackend( log.info(s"parallelizeAndComputeWithIndex: $token: reading results") val startTime = System.nanoTime() - val r @ (error, results) = runAllKeepFirstError(new CancellingExecutorService(executor)) { + val r @ (error, results) = runAll[Option, Array[Byte]](executor) { + /* A missing file means the job was cancelled because another job failed. Assumes that if any + * job was cancelled, then at least one job failed. We want to ignore the missing file + * exceptions and return one of the actual failure exceptions. */ + case (opt, _: FileNotFoundException) => opt + case (opt, e) => opt.orElse(Some(e)) + }(None) { (partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) => (() => readResult(root, jobIndex), partIdx) } diff --git a/hail/src/main/scala/is/hail/services/BatchClient.scala b/hail/src/main/scala/is/hail/services/BatchClient.scala index 571c03279425..8b956b7dc6ed 100644 --- a/hail/src/main/scala/is/hail/services/BatchClient.scala +++ b/hail/src/main/scala/is/hail/services/BatchClient.scala @@ -14,7 +14,9 @@ import java.nio.file.Path import org.apache.http.entity.ByteArrayEntity import org.apache.http.entity.ContentType.APPLICATION_JSON -import org.json4s.{CustomSerializer, DefaultFormats, Extraction, Formats, JInt, JObject, JString} +import org.json4s.{ + CustomSerializer, DefaultFormats, Extraction, Formats, JInt, JNull, JObject, JString, +} import org.json4s.JsonAST.{JArray, JBool} import org.json4s.jackson.JsonMethods @@ -29,6 +31,7 @@ case class JobGroupRequest( batch_id: Int, absolute_parent_id: Int, token: String, + cancel_after_n_failures: Option[Int] = None, attributes: Map[String, String] = Map.empty, jobs: IndexedSeq[JobRequest] = FastSeq(), ) @@ -52,9 +55,9 @@ case class JarUrl(url: String) extends JarSpec case class JobResources( preemptible: Boolean, - cpu: Option[String], - memory: Option[String], - storage: Option[String], + cpu: Option[String] = None, + memory: Option[String] = None, + storage: Option[String] = None, ) case class CloudfuseConfig( @@ -252,6 +255,9 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab JObject( "job_group_id" -> JInt(1), // job group id relative to the update "absolute_parent_id" -> JInt(jobGroup.absolute_parent_id), + "cancel_after_n_failures" -> jobGroup.cancel_after_n_failures.map(JInt(_)).getOrElse( + JNull + ), "attributes" -> Extraction.decompose(jobGroup.attributes), ) )), diff --git a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala index 26671c4b3ff0..529116e9455f 100644 --- a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala +++ b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala @@ -2,6 +2,7 @@ package is.hail.services import is.hail.HAIL_REVISION import is.hail.backend.service.Main +import is.hail.services.JobGroupStates.Failure import is.hail.utils._ import java.lang.reflect.Method @@ -18,7 +19,7 @@ class BatchClientSuite extends TestNGSuite { @BeforeClass def createClientAndBatch(): Unit = { - client = BatchClient(DeployConfig.get(), Path.of("/test-gsa-key/key.json")) + client = BatchClient(DeployConfig.get(), Path.of("/tmp/test-gsa-key/key.json")) batchId = client.newBatch( BatchRequest( billing_project = "test", @@ -46,6 +47,38 @@ class BatchClientSuite extends TestNGSuite { def closeClient(): Unit = client.close() + @Test + def testCancelAfterNFailures(): Unit = { + val jobGroupId = client.newJobGroup( + req = JobGroupRequest( + batch_id = batchId, + absolute_parent_id = parentJobGroupId, + cancel_after_n_failures = Some(1), + token = tokenUrlSafe, + jobs = FastSeq( + JobRequest( + always_run = false, + process = BashJob( + image = "ubuntu:22.04", + command = Array("/bin/bash", "-c", "sleep 1d"), + ), + resources = Some(JobResources(preemptible = true)), + ), + JobRequest( + always_run = false, + process = BashJob( + image = "ubuntu:22.04", + command = Array("/bin/bash", "-c", "exit 1"), + ), + ), + ), + ) + ) + val result = client.waitForJobGroup(batchId, jobGroupId) + assert(result.state == Failure) + assert(result.n_cancelled == 1) + } + @Test def testNewJobGroup(): Unit = // The query driver submits a job group per stage with one job per partition