Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Improve error handling and job tracking in ServiceBackend #14751

Open
wants to merge 1 commit into
base: qob-fast-cancel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 76 additions & 42 deletions hail/hail/src/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import is.hail.expr.ir.lowering._
import is.hail.io.fs._
import is.hail.linalg.BlockMatrix
import is.hail.services.{BatchClient, JobGroupRequest, _}
import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success}
import is.hail.services.JobGroupStates.{Cancelled, Failure, Success}
import is.hail.types._
import is.hail.types.physical._
import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType
Expand Down Expand Up @@ -172,7 +172,7 @@ class ServiceBackend(
token: String,
root: String,
stageIdentifier: String,
): JobGroupResponse = {
): (JobGroupResponse, Int) = {
val defaultProcess =
JvmJob(
command = null,
Expand All @@ -199,14 +199,28 @@ class ServiceBackend(
val jobs =
collection.indices.map { i =>
defaultJob.copy(
attributes = Map("name" -> s"${name}_stage${stageCount}_${stageIdentifier}_job$i"),
attributes = Map(
"name" -> s"${name}_stage${stageCount}_${stageIdentifier}_job$i",
"idx" -> i.toString,
),
process = defaultProcess.copy(
command = Array(Main.WORKER, root, s"$i", s"${collection.length}")
),
)
}

val jobGroupId =
/* When we create a JobGroup with n jobs, Batch gives us the absolute JobGroupId,
* and the startJobId for the first job.
* This means that all JobId's in the JobGroup will have values in range (startJobId, startJobId + n).
* Therefore, we know the partition index for a given job by using this startJobId offset.

* Why do we do this?
* Consider a situation where we're submitting thousands of jobs in a job group.
* If one of those jobs fails, we don't want to make thousands of requests to batch to get a
* partition index
* that that job corresponds to. */

val (jobGroupId, startJobId) =
batchClient.newJobGroup(
JobGroupRequest(
batch_id = batchConfig.batchId,
Expand All @@ -221,21 +235,27 @@ class ServiceBackend(
stageCount += 1

Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms
batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId)
val response = batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId)
(response, startJobId)
}

private[this] def readResult(root: String, i: Int): Array[Byte] = {
val bytes = fs.readNoCompression(s"$root/result.$i")
if (bytes(0) != 0) {
bytes.slice(1, bytes.length)
} else {
val errorInformationBytes = bytes.slice(1, bytes.length)
val is = new DataInputStream(new ByteArrayInputStream(errorInformationBytes))
val shortMessage = readString(is)
val expandedMessage = readString(is)
val errorId = is.readInt()
throw new HailWorkerException(i, shortMessage, expandedMessage, errorId)
}
private[this] def readPartitionResult(root: String, i: Int): Array[Byte] = {
val file = s"$root/result.$i"
val bytes = fs.readNoCompression(file)
assert(bytes(0) != 0, s"$file is not a valid result.")
bytes.slice(1, bytes.length)
}

private[this] def readPartitionError(root: String, i: Int): HailWorkerException = {
val file = s"$root/result.$i"
val bytes = fs.readNoCompression(file)
assert(bytes(0) == 0, s"$file did not contain an error")
val errorInformationBytes = bytes.slice(1, bytes.length)
val is = new DataInputStream(new ByteArrayInputStream(errorInformationBytes))
val shortMessage = readString(is)
val expandedMessage = readString(is)
val errorId = is.readInt()
new HailWorkerException(i, shortMessage, expandedMessage, errorId)
}

override def parallelizeAndComputeWithIndex(
Expand Down Expand Up @@ -288,37 +308,51 @@ class ServiceBackend(
uploadFunction.get()
uploadContexts.get()

val jobGroup = submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier)

val (jobGroup, startJobId) =
submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier)
log.info(s"parallelizeAndComputeWithIndex: $token: reading results")
val startTime = System.nanoTime()
var r @ (err, results) = runAll[Option, Array[Byte]](executor) {
/* A missing file means the job was cancelled because another job failed. Assumes that if any
* job was cancelled, then at least one job failed. We want to ignore the missing file
* exceptions and return one of the actual failure exceptions. */
case (opt, _: FileNotFoundException) => opt
case (opt, e) => opt.orElse(Some(e))
}(None) {
(partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) =>
(() => readResult(root, jobIndex), partIdx)
}
}
if (jobGroup.state != Success && err.isEmpty) {
assert(jobGroup.state != Running)
val error =
jobGroup.state match {
case Failure =>
new HailBatchFailure(
s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} failed with an unknown error"
)
case Cancelled =>

def streamSuccessfulJobResults: Stream[(Array[Byte], Int)] =
for {
successes <- batchClient.getJobGroupJobs(
jobGroup.batch_id,
jobGroup.job_group_id,
Some(JobStates.Success),
)
job <- successes
partIdx = job.job_id - startJobId
} yield (readPartitionResult(root, partIdx), partIdx)

val r @ (_, results) =
jobGroup.state match {
case Success =>
runAllKeepFirstError(executor) {
(partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) =>
(() => readPartitionResult(root, jobIndex), partIdx)
}
}
case Failure =>
val failedEntries = batchClient.getJobGroupJobs(
jobGroup.batch_id,
jobGroup.job_group_id,
Some(JobStates.Failed),
)
assert(
failedEntries.nonEmpty,
s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} failed, but no failed jobs found.",
)
val error = readPartitionError(root, failedEntries.head.head.job_id - startJobId)

(Some(error), streamSuccessfulJobResults.toIndexedSeq)
case Cancelled =>
val error =
new CancellationException(
s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} was cancelled"
)
}

r = (Some(error), results)
}
(Some(error), streamSuccessfulJobResults.toIndexedSeq)
}

val resultsReadingSeconds = (System.nanoTime() - startTime) / 1000000000.0
val rate = results.length / resultsReadingSeconds
Expand Down
15 changes: 9 additions & 6 deletions hail/hail/src/is/hail/backend/service/Worker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ object Worker {
out.write(bytes)
}

def writeException(out: DataOutputStream, e: Throwable): Unit = {
val (shortMessage, expandedMessage, errorId) = handleForPython(e)
out.writeBoolean(false)
writeString(out, shortMessage)
writeString(out, expandedMessage)
out.writeInt(errorId)
}

def main(argv: Array[String]): Unit = {
val theHailClassLoader = new HailClassLoader(getClass().getClassLoader())

Expand Down Expand Up @@ -219,12 +227,7 @@ object Worker {
dos.writeBoolean(true)
dos.write(bytes)
case Left(throwableWhileExecutingUserCode) =>
val (shortMessage, expandedMessage, errorId) =
handleForPython(throwableWhileExecutingUserCode)
dos.writeBoolean(false)
writeString(dos, shortMessage)
writeString(dos, expandedMessage)
dos.writeInt(errorId)
writeException(dos, throwableWhileExecutingUserCode)
}
}
}
Expand Down
97 changes: 90 additions & 7 deletions hail/hail/src/is/hail/services/BatchClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ package is.hail.services
import is.hail.expr.ir.ByteArrayBuilder
import is.hail.services.BatchClient.{
BunchMaxSizeBytes, JarSpecSerializer, JobGroupResponseDeserializer, JobGroupStateDeserializer,
JobProcessRequestSerializer,
JobListEntryDeserializer, JobProcessRequestSerializer, JobStateDeserializer,
}
import is.hail.services.oauth2.CloudCredentials
import is.hail.services.requests.Requester
import is.hail.utils._

import scala.collection.immutable.Stream.cons
import scala.util.Random

import java.net.URL
import java.net.{URL, URLEncoder}
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.Path

Expand Down Expand Up @@ -90,6 +91,26 @@ object JobGroupStates {
case object Running extends JobGroupState
}

sealed trait JobState extends Product with Serializable

object JobStates {
case object Pending extends JobState
case object Ready extends JobState
case object Creating extends JobState
case object Running extends JobState
case object Cancelled extends JobState
case object Error extends JobState
case object Failed extends JobState
case object Success extends JobState
}

case class JobListEntry(
batch_id: Int,
job_id: Int,
state: JobState,
exit_code: Int,
)

object BatchClient {

val BunchMaxSizeBytes: Int = 1024 * 1024
Expand Down Expand Up @@ -181,6 +202,39 @@ object BatchClient {
},
)
)

object JobStateDeserializer
extends CustomSerializer[JobState](_ =>
(
{
case JString("Pending") => JobStates.Pending
case JString("Ready") => JobStates.Ready
case JString("Creating") => JobStates.Creating
case JString("Running") => JobStates.Running
case JString("Cancelled") => JobStates.Cancelled
case JString("Error") => JobStates.Error
case JString("Failed") => JobStates.Failed
case JString("Success") => JobStates.Success
},
PartialFunction.empty,
)
)

object JobListEntryDeserializer
extends CustomSerializer[JobListEntry](implicit fmts =>
(
{
case o: JObject =>
JobListEntry(
batch_id = (o \ "batch_id").extract[Int],
job_id = (o \ "job_id").extract[Int],
state = (o \ "state").extract[JobState],
exit_code = (o \ "exit_code").extract[Int],
)
},
PartialFunction.empty,
)
)
}

case class BatchClient private (req: Requester) extends Logging with AutoCloseable {
Expand All @@ -190,7 +244,14 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
JobProcessRequestSerializer +
JobGroupStateDeserializer +
JobGroupResponseDeserializer +
JarSpecSerializer
JarSpecSerializer +
JobStateDeserializer +
JobListEntryDeserializer

private[this] def paginated[S, A](s0: S)(f: S => (A, S)): Stream[A] = {
val (a, s1) = f(s0)
cons(a, paginated(s1)(f))
}

def newBatch(createRequest: BatchRequest): Int = {
val response = req.post("/api/v1alpha/batches/create", Extraction.decompose(createRequest))
Expand All @@ -199,9 +260,9 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
batchId
}

def newJobGroup(req: JobGroupRequest): Int = {
def newJobGroup(req: JobGroupRequest): (Int, Int) = {
val nJobs = req.jobs.length
val (updateId, startJobGroupId) = beginUpdate(req.batch_id, req.token, nJobs)
val (updateId, startJobGroupId, startJobId) = beginUpdate(req.batch_id, req.token, nJobs)
log.info(s"Began update '$updateId' for batch '${req.batch_id}'.")

createJobGroup(updateId, req)
Expand All @@ -213,14 +274,34 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
commitUpdate(req.batch_id, updateId)
log.info(s"Committed update $updateId for batch ${req.batch_id}.")

startJobGroupId
(startJobGroupId, startJobId)
}

def getJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse =
req
.get(s"/api/v1alpha/batches/$batchId/job-groups/$jobGroupId")
.extract[JobGroupResponse]

def getJobGroupJobs(batchId: Int, jobGroupId: Int, status: Option[JobState] = None)
: Stream[IndexedSeq[JobListEntry]] = {
val q = status.map(s => s"state=${s.toString.toLowerCase}").getOrElse("")
paginated(Some(0): Option[Int]) {
case Some(jobId) =>
req.get(
s"/api/v2alpha/batches/$batchId/job-groups/$jobGroupId/jobs?q=${URLEncoder.encode(q, UTF_8)}&last_job_id=$jobId"
)
.as { case obj: JObject =>
(
(obj \ "jobs").extract[IndexedSeq[JobListEntry]],
(obj \ "last_job_id").extract[Option[Int]],
)
}
case None =>
(IndexedSeq.empty, None)
}
.takeWhile(_.nonEmpty)
}

def waitForJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse = {
val start = System.nanoTime()

Expand Down Expand Up @@ -301,7 +382,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
)
}

private[this] def beginUpdate(batchId: Int, token: String, nJobs: Int): (Int, Int) =
private[this] def beginUpdate(batchId: Int, token: String, nJobs: Int): (Int, Int, Int) =
req
.post(
s"/api/v1alpha/batches/$batchId/updates/create",
Expand All @@ -315,6 +396,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
(
(obj \ "update_id").extract[Int],
(obj \ "start_job_group_id").extract[Int],
(obj \ "start_job_id").extract[Int],
)
}

Expand All @@ -335,4 +417,5 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
)
)),
)

}
2 changes: 1 addition & 1 deletion hail/hail/src/is/hail/utils/ErrorHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class HailException(val msg: String, val logMsg: Option[String], cause: Throwabl
def this(msg: String, errorId: Int) = this(msg, None, null, errorId)
}

class HailWorkerException(
case class HailWorkerException(
val partitionId: Int,
val shortMessage: String,
val expandedMessage: String,
Expand Down
Loading