Skip to content

Commit

Permalink
Added ability to check job status of individual job, made associated …
Browse files Browse the repository at this point in the history
…tests
  • Loading branch information
grohli committed Jan 28, 2025
1 parent aa8a2e8 commit d149f11
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 66 deletions.
118 changes: 76 additions & 42 deletions hail/hail/src/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import is.hail.expr.ir.lowering._
import is.hail.io.fs._
import is.hail.linalg.BlockMatrix
import is.hail.services.{BatchClient, JobGroupRequest, _}
import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success}
import is.hail.services.JobGroupStates.{Cancelled, Failure, Success}
import is.hail.types._
import is.hail.types.physical._
import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType
Expand Down Expand Up @@ -172,7 +172,7 @@ class ServiceBackend(
token: String,
root: String,
stageIdentifier: String,
): JobGroupResponse = {
): (JobGroupResponse, Int) = {
val defaultProcess =
JvmJob(
command = null,
Expand All @@ -199,14 +199,28 @@ class ServiceBackend(
val jobs =
collection.indices.map { i =>
defaultJob.copy(
attributes = Map("name" -> s"${name}_stage${stageCount}_${stageIdentifier}_job$i"),
attributes = Map(
"name" -> s"${name}_stage${stageCount}_${stageIdentifier}_job$i",
"idx" -> i.toString,
),
process = defaultProcess.copy(
command = Array(Main.WORKER, root, s"$i", s"${collection.length}")
),
)
}

val jobGroupId =
// When we create a JobGroup with n jobs, Batch gives us the absolute JobGroupId,
// and the startJobId for the first job.
/* This means that all JobId's in the JobGroup will have values in range (startJobId, startJobId
* + n). */
// Therefore, we know the partition index for a given job by using this startJobId offset.

// Why do we do this?
// Consider a situation where we're submitting thousands of jobs in a job group.
/* If one of those jobs fails, we don't want to make thousands of requests to batch to get a
* partition index */
// that that job corresponds to.
val (jobGroupId, startJobId) =
batchClient.newJobGroup(
JobGroupRequest(
batch_id = batchConfig.batchId,
Expand All @@ -221,21 +235,27 @@ class ServiceBackend(
stageCount += 1

Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms
batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId)
val response = batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId)
(response, startJobId)
}

private[this] def readResult(root: String, i: Int): Array[Byte] = {
val bytes = fs.readNoCompression(s"$root/result.$i")
if (bytes(0) != 0) {
bytes.slice(1, bytes.length)
} else {
val errorInformationBytes = bytes.slice(1, bytes.length)
val is = new DataInputStream(new ByteArrayInputStream(errorInformationBytes))
val shortMessage = readString(is)
val expandedMessage = readString(is)
val errorId = is.readInt()
throw new HailWorkerException(i, shortMessage, expandedMessage, errorId)
}
private[this] def readPartitionResult(root: String, i: Int): Array[Byte] = {
val file = s"$root/result.$i"
val bytes = fs.readNoCompression(file)
assert(bytes(0) != 0, s"$file is not a valid result.")
bytes.slice(1, bytes.length)
}

private[this] def readPartitionError(root: String, i: Int): HailWorkerException = {
val file = s"$root/result.$i"
val bytes = fs.readNoCompression(file)
assert(bytes(0) == 0, s"$file did not contain an error")
val errorInformationBytes = bytes.slice(1, bytes.length)
val is = new DataInputStream(new ByteArrayInputStream(errorInformationBytes))
val shortMessage = readString(is)
val expandedMessage = readString(is)
val errorId = is.readInt()
new HailWorkerException(i, shortMessage, expandedMessage, errorId)
}

override def parallelizeAndComputeWithIndex(
Expand Down Expand Up @@ -288,37 +308,51 @@ class ServiceBackend(
uploadFunction.get()
uploadContexts.get()

val jobGroup = submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier)

val (jobGroup, startJobId) =
submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier)
log.info(s"parallelizeAndComputeWithIndex: $token: reading results")
val startTime = System.nanoTime()
var r @ (err, results) = runAll[Option, Array[Byte]](executor) {
/* A missing file means the job was cancelled because another job failed. Assumes that if any
* job was cancelled, then at least one job failed. We want to ignore the missing file
* exceptions and return one of the actual failure exceptions. */
case (opt, _: FileNotFoundException) => opt
case (opt, e) => opt.orElse(Some(e))
}(None) {
(partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) =>
(() => readResult(root, jobIndex), partIdx)
}
}
if (jobGroup.state != Success && err.isEmpty) {
assert(jobGroup.state != Running)
val error =
jobGroup.state match {
case Failure =>
new HailBatchFailure(
s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} failed with an unknown error"
)
case Cancelled =>

def streamSuccessfulJobResults: Stream[(Array[Byte], Int)] =
for {
successes <- batchClient.getJobGroupJobs(
jobGroup.batch_id,
jobGroup.job_group_id,
Some(JobStates.Success),
)
job <- successes
partIdx = job.job_id - startJobId
} yield (readPartitionResult(root, partIdx), partIdx)

val r @ (_, results) =
jobGroup.state match {
case Success =>
runAllKeepFirstError(executor) {
(partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) =>
(() => readPartitionResult(root, jobIndex), partIdx)
}
}
case Failure =>
val failedEntries = batchClient.getJobGroupJobs(
jobGroup.batch_id,
jobGroup.job_group_id,
Some(JobStates.Failed),
)
assert(
failedEntries.nonEmpty,
s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} failed, but no failed jobs found.",
)
val error = readPartitionError(root, failedEntries.head.head.job_id - startJobId)

(Some(error), streamSuccessfulJobResults.toIndexedSeq)
case Cancelled =>
val error =
new CancellationException(
s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} was cancelled"
)
}

r = (Some(error), results)
}
(Some(error), streamSuccessfulJobResults.toIndexedSeq)
}

val resultsReadingSeconds = (System.nanoTime() - startTime) / 1000000000.0
val rate = results.length / resultsReadingSeconds
Expand Down
15 changes: 9 additions & 6 deletions hail/hail/src/is/hail/backend/service/Worker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ object Worker {
out.write(bytes)
}

def writeException(out: DataOutputStream, e: Throwable): Unit = {
val (shortMessage, expandedMessage, errorId) = handleForPython(e)
out.writeBoolean(false)
writeString(out, shortMessage)
writeString(out, expandedMessage)
out.writeInt(errorId)
}

def main(argv: Array[String]): Unit = {
val theHailClassLoader = new HailClassLoader(getClass().getClassLoader())

Expand Down Expand Up @@ -219,12 +227,7 @@ object Worker {
dos.writeBoolean(true)
dos.write(bytes)
case Left(throwableWhileExecutingUserCode) =>
val (shortMessage, expandedMessage, errorId) =
handleForPython(throwableWhileExecutingUserCode)
dos.writeBoolean(false)
writeString(dos, shortMessage)
writeString(dos, expandedMessage)
dos.writeInt(errorId)
writeException(dos, throwableWhileExecutingUserCode)
}
}
}
Expand Down
97 changes: 90 additions & 7 deletions hail/hail/src/is/hail/services/BatchClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ package is.hail.services
import is.hail.expr.ir.ByteArrayBuilder
import is.hail.services.BatchClient.{
BunchMaxSizeBytes, JarSpecSerializer, JobGroupResponseDeserializer, JobGroupStateDeserializer,
JobProcessRequestSerializer,
JobListEntryDeserializer, JobProcessRequestSerializer, JobStateDeserializer,
}
import is.hail.services.oauth2.CloudCredentials
import is.hail.services.requests.Requester
import is.hail.utils._

import scala.collection.immutable.Stream.cons
import scala.util.Random

import java.net.URL
import java.net.{URL, URLEncoder}
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.Path

Expand Down Expand Up @@ -90,6 +91,26 @@ object JobGroupStates {
case object Running extends JobGroupState
}

sealed trait JobState extends Product with Serializable

object JobStates {
case object Pending extends JobState
case object Ready extends JobState
case object Creating extends JobState
case object Running extends JobState
case object Cancelled extends JobState
case object Error extends JobState
case object Failed extends JobState
case object Success extends JobState
}

case class JobListEntry(
batch_id: Int,
job_id: Int,
state: JobState,
exit_code: Int,
)

object BatchClient {

val BunchMaxSizeBytes: Int = 1024 * 1024
Expand Down Expand Up @@ -181,6 +202,39 @@ object BatchClient {
},
)
)

object JobStateDeserializer
extends CustomSerializer[JobState](_ =>
(
{
case JString("Pending") => JobStates.Pending
case JString("Ready") => JobStates.Ready
case JString("Creating") => JobStates.Creating
case JString("Running") => JobStates.Running
case JString("Cancelled") => JobStates.Cancelled
case JString("Error") => JobStates.Error
case JString("Failed") => JobStates.Failed
case JString("Success") => JobStates.Success
},
PartialFunction.empty,
)
)

object JobListEntryDeserializer
extends CustomSerializer[JobListEntry](implicit fmts =>
(
{
case o: JObject =>
JobListEntry(
batch_id = (o \ "batch_id").extract[Int],
job_id = (o \ "job_id").extract[Int],
state = (o \ "state").extract[JobState],
exit_code = (o \ "exit_code").extract[Int],
)
},
PartialFunction.empty,
)
)
}

case class BatchClient private (req: Requester) extends Logging with AutoCloseable {
Expand All @@ -190,7 +244,14 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
JobProcessRequestSerializer +
JobGroupStateDeserializer +
JobGroupResponseDeserializer +
JarSpecSerializer
JarSpecSerializer +
JobStateDeserializer +
JobListEntryDeserializer

private[this] def paginated[S, A](s0: S)(f: S => (A, S)): Stream[A] = {
val (a, s1) = f(s0)
cons(a, paginated(s1)(f))
}

def newBatch(createRequest: BatchRequest): Int = {
val response = req.post("/api/v1alpha/batches/create", Extraction.decompose(createRequest))
Expand All @@ -199,9 +260,9 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
batchId
}

def newJobGroup(req: JobGroupRequest): Int = {
def newJobGroup(req: JobGroupRequest): (Int, Int) = {
val nJobs = req.jobs.length
val (updateId, startJobGroupId) = beginUpdate(req.batch_id, req.token, nJobs)
val (updateId, startJobGroupId, startJobId) = beginUpdate(req.batch_id, req.token, nJobs)
log.info(s"Began update '$updateId' for batch '${req.batch_id}'.")

createJobGroup(updateId, req)
Expand All @@ -213,14 +274,34 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
commitUpdate(req.batch_id, updateId)
log.info(s"Committed update $updateId for batch ${req.batch_id}.")

startJobGroupId
(startJobGroupId, startJobId)
}

def getJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse =
req
.get(s"/api/v1alpha/batches/$batchId/job-groups/$jobGroupId")
.extract[JobGroupResponse]

def getJobGroupJobs(batchId: Int, jobGroupId: Int, status: Option[JobState] = None)
: Stream[IndexedSeq[JobListEntry]] = {
val q = status.map(s => s"state=${s.toString.toLowerCase}").getOrElse("")
paginated(Some(0): Option[Int]) {
case Some(jobId) =>
req.get(
s"/api/v2alpha/batches/$batchId/job-groups/$jobGroupId/jobs?q=${URLEncoder.encode(q, UTF_8)}&last_job_id=$jobId"
)
.as { case obj: JObject =>
(
(obj \ "jobs").extract[IndexedSeq[JobListEntry]],
(obj \ "last_job_id").extract[Option[Int]],
)
}
case None =>
(IndexedSeq.empty, None)
}
.takeWhile(_.nonEmpty)
}

def waitForJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse = {
val start = System.nanoTime()

Expand Down Expand Up @@ -301,7 +382,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
)
}

private[this] def beginUpdate(batchId: Int, token: String, nJobs: Int): (Int, Int) =
private[this] def beginUpdate(batchId: Int, token: String, nJobs: Int): (Int, Int, Int) =
req
.post(
s"/api/v1alpha/batches/$batchId/updates/create",
Expand All @@ -315,6 +396,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
(
(obj \ "update_id").extract[Int],
(obj \ "start_job_group_id").extract[Int],
(obj \ "start_job_id").extract[Int],
)
}

Expand All @@ -335,4 +417,5 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
)
)),
)

}
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/utils/ErrorHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class HailException(val msg: String, val logMsg: Option[String], cause: Throwabl
def this(msg: String, errorId: Int) = this(msg, None, null, errorId)
}

class HailWorkerException(
case class HailWorkerException(
val partitionId: Int,
val shortMessage: String,
val expandedMessage: String,
Expand Down
Loading

0 comments on commit d149f11

Please sign in to comment.