From 03a6c9360c04d842970388293d189fa2b8fa965e Mon Sep 17 00:00:00 2001 From: grohli <22306963+grohli@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:53:16 -0400 Subject: [PATCH] Added ability to check job status of individual job, made associated tests --- .../hail/backend/service/ServiceBackend.scala | 122 ++++++++++------ .../src/is/hail/backend/service/Worker.scala | 15 +- .../src/is/hail/services/BatchClient.scala | 97 ++++++++++++- .../src/is/hail/utils/ErrorHandling.scala | 2 +- .../is/hail/backend/ServiceBackendSuite.scala | 135 +++++++++++++++++- .../is/hail/services/BatchClientSuite.scala | 21 ++- 6 files changed, 323 insertions(+), 69 deletions(-) diff --git a/hail/hail/src/is/hail/backend/service/ServiceBackend.scala b/hail/hail/src/is/hail/backend/service/ServiceBackend.scala index c2cf1f215e0..e3693642a54 100644 --- a/hail/hail/src/is/hail/backend/service/ServiceBackend.scala +++ b/hail/hail/src/is/hail/backend/service/ServiceBackend.scala @@ -15,7 +15,7 @@ import is.hail.expr.ir.lowering._ import is.hail.io.fs._ import is.hail.linalg.BlockMatrix import is.hail.services.{BatchClient, JobGroupRequest, _} -import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success} +import is.hail.services.JobGroupStates.{Cancelled, Failure, Success} import is.hail.types._ import is.hail.types.physical._ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType @@ -172,7 +172,7 @@ class ServiceBackend( token: String, root: String, stageIdentifier: String, - ): JobGroupResponse = { + ): (JobGroupResponse, Int) = { val defaultProcess = JvmJob( command = null, @@ -199,14 +199,28 @@ class ServiceBackend( val jobs = collection.indices.map { i => defaultJob.copy( - attributes = Map("name" -> s"${name}_stage${stageCount}_${stageIdentifier}_job$i"), + attributes = Map( + "name" -> s"${name}_stage${stageCount}_${stageIdentifier}_job$i", + "idx" -> i.toString, + ), process = defaultProcess.copy( command = Array(Main.WORKER, root, s"$i", s"${collection.length}") ), ) } - val jobGroupId = + // When we create a JobGroup with n jobs, Batch gives us the absolute JobGroupId, + // and the startJobId for the first job. + /* This means that all JobId's in the JobGroup will have values in range (startJobId, startJobId + * + n). */ + // Therefore, we know the partition index for a given job by using this startJobId offset. + + // Why do we do this? + // Consider a situation where we're submitting thousands of jobs in a job group. + /* If one of those jobs fails, we don't want to make thousands of requests to batch to get a + * partition index */ + // that that job corresponds to. + val (jobGroupId, startJobId) = batchClient.newJobGroup( JobGroupRequest( batch_id = batchConfig.batchId, @@ -221,21 +235,27 @@ class ServiceBackend( stageCount += 1 Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms - batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId) + val response = batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId) + (response, startJobId) } - private[this] def readResult(root: String, i: Int): Array[Byte] = { - val bytes = fs.readNoCompression(s"$root/result.$i") - if (bytes(0) != 0) { - bytes.slice(1, bytes.length) - } else { - val errorInformationBytes = bytes.slice(1, bytes.length) - val is = new DataInputStream(new ByteArrayInputStream(errorInformationBytes)) - val shortMessage = readString(is) - val expandedMessage = readString(is) - val errorId = is.readInt() - throw new HailWorkerException(i, shortMessage, expandedMessage, errorId) - } + private[this] def readPartitionResult(root: String, i: Int): Array[Byte] = { + val file = s"$root/result.$i" + val bytes = fs.readNoCompression(file) + assert(bytes(0) != 0, s"$file is not a valid result.") + bytes.slice(1, bytes.length) + } + + private[this] def readPartitionError(root: String, i: Int): HailWorkerException = { + val file = s"$root/result.$i" + val bytes = fs.readNoCompression(file) + assert(bytes(0) == 0, s"$file did not contain an error") + val errorInformationBytes = bytes.slice(1, bytes.length) + val is = new DataInputStream(new ByteArrayInputStream(errorInformationBytes)) + val shortMessage = readString(is) + val expandedMessage = readString(is) + val errorId = is.readInt() + new HailWorkerException(i, shortMessage, expandedMessage, errorId) } override def parallelizeAndComputeWithIndex( @@ -288,37 +308,49 @@ class ServiceBackend( uploadFunction.get() uploadContexts.get() - val jobGroup = submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier) - + val (jobGroup, startJobId) = + submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier) log.info(s"parallelizeAndComputeWithIndex: $token: reading results") val startTime = System.nanoTime() - var r @ (err, results) = runAll[Option, Array[Byte]](executor) { - /* A missing file means the job was cancelled because another job failed. Assumes that if any - * job was cancelled, then at least one job failed. We want to ignore the missing file - * exceptions and return one of the actual failure exceptions. */ - case (opt, _: FileNotFoundException) => opt - case (opt, e) => opt.orElse(Some(e)) - }(None) { - (partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) => - (() => readResult(root, jobIndex), partIdx) - } - } - if (jobGroup.state != Success && err.isEmpty) { - assert(jobGroup.state != Running) - val error = - jobGroup.state match { - case Failure => - new HailBatchFailure( - s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} failed with an unknown error" - ) - case Cancelled => - new CancellationException( - s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} was cancelled" - ) - } - r = (Some(error), results) - } + def streamSuccessfulJobResults: Stream[(Array[Byte], Int)] = + for { + successes <- batchClient.getJobGroupJobs( + jobGroup.batch_id, + jobGroup.job_group_id, + Some(JobStates.Success), + ) + job <- successes + partIdx = job.job_id - startJobId + } yield (readPartitionResult(root, partIdx), partIdx) + + val r @ (_, results) = + jobGroup.state match { + case Success => + runAllKeepFirstError(executor) { + (partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) => + (() => readPartitionResult(root, jobIndex), partIdx) + } + } + case Failure => + val failedEntries = batchClient.getJobGroupJobs( + jobGroup.batch_id, + jobGroup.job_group_id, + Some(JobStates.Failed), + ) + assert( + failedEntries.nonEmpty, + s"Job group ${jobGroup.job_group_id} failed, but no failed jobs found.", + ) + val error = readPartitionError(root, failedEntries.head.head.job_id - startJobId) + + (Some(error), streamSuccessfulJobResults.toIndexedSeq) + case Cancelled => + val error = + new CancellationException(s"Job Group ${jobGroup.job_group_id} was cancelled.") + + (Some(error), streamSuccessfulJobResults.toIndexedSeq) + } val resultsReadingSeconds = (System.nanoTime() - startTime) / 1000000000.0 val rate = results.length / resultsReadingSeconds diff --git a/hail/hail/src/is/hail/backend/service/Worker.scala b/hail/hail/src/is/hail/backend/service/Worker.scala index 4f596de3b4a..27be34613b6 100644 --- a/hail/hail/src/is/hail/backend/service/Worker.scala +++ b/hail/hail/src/is/hail/backend/service/Worker.scala @@ -98,6 +98,14 @@ object Worker { out.write(bytes) } + def writeException(out: DataOutputStream, e: Throwable): Unit = { + val (shortMessage, expandedMessage, errorId) = handleForPython(e) + out.writeBoolean(false) + writeString(out, shortMessage) + writeString(out, expandedMessage) + out.writeInt(errorId) + } + def main(argv: Array[String]): Unit = { val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) @@ -219,12 +227,7 @@ object Worker { dos.writeBoolean(true) dos.write(bytes) case Left(throwableWhileExecutingUserCode) => - val (shortMessage, expandedMessage, errorId) = - handleForPython(throwableWhileExecutingUserCode) - dos.writeBoolean(false) - writeString(dos, shortMessage) - writeString(dos, expandedMessage) - dos.writeInt(errorId) + writeException(dos, throwableWhileExecutingUserCode) } } } diff --git a/hail/hail/src/is/hail/services/BatchClient.scala b/hail/hail/src/is/hail/services/BatchClient.scala index 4e4f9125282..b6605deb6eb 100644 --- a/hail/hail/src/is/hail/services/BatchClient.scala +++ b/hail/hail/src/is/hail/services/BatchClient.scala @@ -3,15 +3,16 @@ package is.hail.services import is.hail.expr.ir.ByteArrayBuilder import is.hail.services.BatchClient.{ BunchMaxSizeBytes, JarSpecSerializer, JobGroupResponseDeserializer, JobGroupStateDeserializer, - JobProcessRequestSerializer, + JobListEntryDeserializer, JobProcessRequestSerializer, JobStateDeserializer, } import is.hail.services.oauth2.CloudCredentials import is.hail.services.requests.Requester import is.hail.utils._ +import scala.collection.immutable.Stream.cons import scala.util.Random -import java.net.URL +import java.net.{URL, URLEncoder} import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.Path @@ -90,6 +91,26 @@ object JobGroupStates { case object Running extends JobGroupState } +sealed trait JobState extends Product with Serializable + +object JobStates { + case object Pending extends JobState + case object Ready extends JobState + case object Creating extends JobState + case object Running extends JobState + case object Cancelled extends JobState + case object Error extends JobState + case object Failed extends JobState + case object Success extends JobState +} + +case class JobListEntry( + batch_id: Int, + job_id: Int, + state: JobState, + exit_code: Int, +) + object BatchClient { val BunchMaxSizeBytes: Int = 1024 * 1024 @@ -181,6 +202,39 @@ object BatchClient { }, ) ) + + object JobStateDeserializer + extends CustomSerializer[JobState](_ => + ( + { + case JString("Pending") => JobStates.Pending + case JString("Ready") => JobStates.Ready + case JString("Creating") => JobStates.Creating + case JString("Running") => JobStates.Running + case JString("Cancelled") => JobStates.Cancelled + case JString("Error") => JobStates.Error + case JString("Failed") => JobStates.Failed + case JString("Success") => JobStates.Success + }, + PartialFunction.empty, + ) + ) + + object JobListEntryDeserializer + extends CustomSerializer[JobListEntry](implicit fmts => + ( + { + case o: JObject => + JobListEntry( + batch_id = (o \ "batch_id").extract[Int], + job_id = (o \ "job_id").extract[Int], + state = (o \ "state").extract[JobState], + exit_code = (o \ "exit_code").extract[Int], + ) + }, + PartialFunction.empty, + ) + ) } case class BatchClient private (req: Requester) extends Logging with AutoCloseable { @@ -190,7 +244,14 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab JobProcessRequestSerializer + JobGroupStateDeserializer + JobGroupResponseDeserializer + - JarSpecSerializer + JarSpecSerializer + + JobStateDeserializer + + JobListEntryDeserializer + + private[this] def paginated[S, A](s0: S)(f: S => (A, S)): Stream[A] = { + val (a, s1) = f(s0) + cons(a, paginated(s1)(f)) + } def newBatch(createRequest: BatchRequest): Int = { val response = req.post("/api/v1alpha/batches/create", Extraction.decompose(createRequest)) @@ -199,9 +260,9 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab batchId } - def newJobGroup(req: JobGroupRequest): Int = { + def newJobGroup(req: JobGroupRequest): (Int, Int) = { val nJobs = req.jobs.length - val (updateId, startJobGroupId) = beginUpdate(req.batch_id, req.token, nJobs) + val (updateId, startJobGroupId, startJobId) = beginUpdate(req.batch_id, req.token, nJobs) log.info(s"Began update '$updateId' for batch '${req.batch_id}'.") createJobGroup(updateId, req) @@ -213,7 +274,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab commitUpdate(req.batch_id, updateId) log.info(s"Committed update $updateId for batch ${req.batch_id}.") - startJobGroupId + (startJobGroupId, startJobId) } def getJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse = @@ -221,6 +282,26 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab .get(s"/api/v1alpha/batches/$batchId/job-groups/$jobGroupId") .extract[JobGroupResponse] + def getJobGroupJobs(batchId: Int, jobGroupId: Int, status: Option[JobState] = None) + : Stream[IndexedSeq[JobListEntry]] = { + val q = status.map(s => s"state=${s.toString.toLowerCase}").getOrElse("") + paginated(Some(0): Option[Int]) { + case Some(jobId) => + req.get( + s"/api/v2alpha/batches/$batchId/job-groups/$jobGroupId/jobs?q=${URLEncoder.encode(q, UTF_8)}&last_job_id=$jobId" + ) + .as { case obj: JObject => + ( + (obj \ "jobs").extract[IndexedSeq[JobListEntry]], + (obj \ "last_job_id").extract[Option[Int]], + ) + } + case None => + (IndexedSeq.empty, None) + } + .takeWhile(_.nonEmpty) + } + def waitForJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse = { val start = System.nanoTime() @@ -301,7 +382,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab ) } - private[this] def beginUpdate(batchId: Int, token: String, nJobs: Int): (Int, Int) = + private[this] def beginUpdate(batchId: Int, token: String, nJobs: Int): (Int, Int, Int) = req .post( s"/api/v1alpha/batches/$batchId/updates/create", @@ -315,6 +396,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab ( (obj \ "update_id").extract[Int], (obj \ "start_job_group_id").extract[Int], + (obj \ "start_job_id").extract[Int], ) } @@ -335,4 +417,5 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab ) )), ) + } diff --git a/hail/hail/src/is/hail/utils/ErrorHandling.scala b/hail/hail/src/is/hail/utils/ErrorHandling.scala index 176df006080..beb3ce0edc5 100644 --- a/hail/hail/src/is/hail/utils/ErrorHandling.scala +++ b/hail/hail/src/is/hail/utils/ErrorHandling.scala @@ -8,7 +8,7 @@ class HailException(val msg: String, val logMsg: Option[String], cause: Throwabl def this(msg: String, errorId: Int) = this(msg, None, null, errorId) } -class HailWorkerException( +case class HailWorkerException( val partitionId: Int, val shortMessage: String, val expandedMessage: String, diff --git a/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala b/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala index 461eed812f4..63a161c146f 100644 --- a/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala @@ -2,13 +2,16 @@ package is.hail.backend import is.hail.HailFeatureFlags import is.hail.asm4s.HailClassLoader -import is.hail.backend.service.{ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload} +import is.hail.backend.service.{ + ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload, Worker, +} import is.hail.io.fs.{CloudStorageFSConfig, RouterFS} import is.hail.services._ -import is.hail.services.JobGroupStates.Success -import is.hail.utils.{tokenUrlSafe, using} +import is.hail.services.JobGroupStates.{Cancelled, Failure, Success} +import is.hail.utils.{handleForPython, tokenUrlSafe, using, HailWorkerException} import scala.collection.mutable +import scala.concurrent.CancellationException import scala.reflect.io.{Directory, Path} import scala.util.Random @@ -18,7 +21,7 @@ import org.mockito.ArgumentMatchersSugar.any import org.mockito.IdiomaticMockito import org.mockito.MockitoSugar.when import org.scalatest.OptionValues -import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.scalatest.matchers.should.Matchers.{a, convertToAnyShouldWrapper} import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test @@ -50,8 +53,7 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV storage = Some(rpcConfig.storage), ) } - - backend.batchConfig.jobGroupId + 1 + (backend.batchConfig.jobGroupId + 1, 1) } // the service backend expects that each job write its output to a well-known @@ -96,6 +98,127 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV } } + @Test def testFailedJobGroup(): Unit = + withMockDriverContext { rpcConfig => + val batchClient = mock[BatchClient] + using(ServiceBackend(batchClient, rpcConfig)) { backend => + val contexts = Array.tabulate(100)(_.toString.getBytes) + val startJobGroupId = 2356 + when(batchClient.newJobGroup(any[JobGroupRequest])) thenAnswer { + jobGroup: JobGroupRequest => (backend.batchConfig.jobGroupId + 1, startJobGroupId) + } + val successes = Array(13, 34, 65, 81) // arbitrary indices + val failures = Array(21, 44) + val expectedCause = new NoSuchMethodError("") + when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { + (id: Int, jobGroupId: Int) => + val resultsDir = + Path(backend.serviceBackendContext.remoteTmpDir) / + "parallelizeAndComputeWithIndex" / + tokenUrlSafe + + resultsDir.createDirectory() + for (i <- successes) (resultsDir / f"result.$i").toFile.writeAll("11") + + for (i <- failures) + backend.fs.writePDOS((resultsDir / f"result.$i").toString()) { + os => Worker.writeException(os, expectedCause) + } + JobGroupResponse( + batch_id = id, + job_group_id = jobGroupId, + state = Failure, + complete = true, + n_jobs = contexts.length, + n_completed = contexts.length, + n_succeeded = successes.length, + n_failed = failures.length, + n_cancelled = contexts.length - failures.length - successes.length, + ) + } + when(batchClient.getJobGroupJobs(any[Int], any[Int], any[Option[JobState]])) thenAnswer { + (batchId: Int, _: Int, s: Option[JobState]) => + s match { + case Some(JobStates.Failed) => + Stream(failures.map(i => + JobListEntry(batchId, i + startJobGroupId, JobStates.Failed, 1) + ).toIndexedSeq) + + case Some(JobStates.Success) => + Stream(successes.map(i => + JobListEntry(batchId, i + startJobGroupId, JobStates.Success, 1) + ).toIndexedSeq) + } + + } + + val (failure, result) = + backend.parallelizeAndComputeWithIndex( + backend.serviceBackendContext, + backend.fs, + contexts, + "stage1", + )((bytes, _, _, _) => bytes) + val (shortMessage, expanded, id) = handleForPython(expectedCause) + failure.value shouldBe new HailWorkerException(failures.head, shortMessage, expanded, id) + result.map(_._2) shouldBe successes + } + } + + @Test def testCancelledJobGroup(): Unit = + withMockDriverContext { rpcConfig => + val batchClient = mock[BatchClient] + using(ServiceBackend(batchClient, rpcConfig)) { backend => + val contexts = Array.tabulate(100)(_.toString.getBytes) + val startJobGroupId = 2356 + when(batchClient.newJobGroup(any[JobGroupRequest])) thenAnswer { + jobGroup: JobGroupRequest => (backend.batchConfig.jobGroupId + 1, startJobGroupId) + } + val successes = Array(13, 34, 65, 81) // arbitrary indices + when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { + (id: Int, jobGroupId: Int) => + val resultsDir = + Path(backend.serviceBackendContext.remoteTmpDir) / + "parallelizeAndComputeWithIndex" / + tokenUrlSafe + + resultsDir.createDirectory() + for (i <- successes) (resultsDir / f"result.$i").toFile.writeAll("11") + + JobGroupResponse( + batch_id = id, + job_group_id = jobGroupId, + state = Cancelled, + complete = true, + n_jobs = contexts.length, + n_completed = contexts.length, + n_succeeded = successes.length, + n_failed = 0, + n_cancelled = contexts.length - successes.length, + ) + } + when(batchClient.getJobGroupJobs(any[Int], any[Int], any[Option[JobState]])) thenAnswer { + (batchId: Int, _: Int, s: Option[JobState]) => + s match { + case Some(JobStates.Success) => + Stream(successes.map(i => + JobListEntry(batchId, i + startJobGroupId, JobStates.Success, 1) + ).toIndexedSeq) + } + } + + val (failure, result) = + backend.parallelizeAndComputeWithIndex( + backend.serviceBackendContext, + backend.fs, + contexts, + "stage1", + )((bytes, _, _, _) => bytes) + failure.value shouldBe a[CancellationException] + result.map(_._2) shouldBe successes + } + } + def ServiceBackend(client: BatchClient, rpcConfig: ServiceBackendRPCPayload): ServiceBackend = { val flags = HailFeatureFlags.fromEnv() val fs = RouterFS.buildRoutes(CloudStorageFSConfig()) diff --git a/hail/hail/test/src/is/hail/services/BatchClientSuite.scala b/hail/hail/test/src/is/hail/services/BatchClientSuite.scala index 079870efce2..50bb2557912 100644 --- a/hail/hail/test/src/is/hail/services/BatchClientSuite.scala +++ b/hail/hail/test/src/is/hail/services/BatchClientSuite.scala @@ -40,7 +40,7 @@ class BatchClientSuite extends TestNGSuite { attributes = Map("name" -> m.getName), jobs = FastSeq(), ) - ) + )._1 } @AfterClass @@ -49,7 +49,7 @@ class BatchClientSuite extends TestNGSuite { @Test def testCancelAfterNFailures(): Unit = { - val jobGroupId = client.newJobGroup( + val (jobGroupId, _) = client.newJobGroup( req = JobGroupRequest( batch_id = batchId, absolute_parent_id = parentJobGroupId, @@ -79,11 +79,24 @@ class BatchClientSuite extends TestNGSuite { assert(result.n_cancelled == 1) } + @Test + def testGetJobGroupJobsByState(): Unit = { + val jobGroup = client.getJobGroup(8218901, 2) + assert(jobGroup.n_jobs == 2) + assert(jobGroup.n_failed == 1) + assert(client.getJobGroupJobs(8218901, 2).head.length == 2) + for (state <- Array(JobStates.Failed, JobStates.Success)) + for (jobs <- client.getJobGroupJobs(8218901, 2, Some(state))) { + assert(jobs.length == 1) + assert(jobs(0).state == state) + } + } + @Test def testNewJobGroup(): Unit = // The query driver submits a job group per stage with one job per partition for (i <- 1 to 2) { - val jobGroupId = client.newJobGroup( + val (jobGroupId, _) = client.newJobGroup( req = JobGroupRequest( batch_id = batchId, absolute_parent_id = parentJobGroupId, @@ -107,7 +120,7 @@ class BatchClientSuite extends TestNGSuite { @Test def testJvmJob(): Unit = { - val jobGroupId = client.newJobGroup( + val (jobGroupId, _) = client.newJobGroup( req = JobGroupRequest( batch_id = batchId, absolute_parent_id = parentJobGroupId,