Skip to content

Commit

Permalink
Added ability to check job status of individual job, made associated …
Browse files Browse the repository at this point in the history
…tests
  • Loading branch information
grohli committed Oct 31, 2024
1 parent 9f6d876 commit c572812
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 8 deletions.
15 changes: 10 additions & 5 deletions hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@ import is.hail.annotations._
import is.hail.asm4s._
import is.hail.backend._
import is.hail.expr.Validate
import is.hail.expr.ir.{
Compile, IR, IRParser, IRParserEnvironment, IRSize, LoweringAnalyses, MakeTuple, SortField,
TableIR, TableReader, TypeCheck,
}
import is.hail.expr.ir.{Compile, IR, IRParser, IRParserEnvironment, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck}
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering._
Expand Down Expand Up @@ -189,7 +186,10 @@ class ServiceBackend(
val jobs =
collection.indices.map { i =>
defaultJob.copy(
attributes = Map("name" -> s"${name}_stage${stageCount}_${stageIdentifier}_job$i"),
attributes = Map(
"name" -> s"${name}_stage${stageCount}_${stageIdentifier}_job$i",
"idx" -> i.toString,
),
process = defaultProcess.copy(
command = Array(Main.WORKER, root, s"$i", s"${collection.length}")
),
Expand Down Expand Up @@ -279,6 +279,11 @@ class ServiceBackend(

val jobGroup = submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier)

// case match on jobGroup
// success => read files
// failure => read failure only
// cancelled => propagate failure message

log.info(s"parallelizeAndComputeWithIndex: $token: reading results")
val startTime = System.nanoTime()
var r @ (err, results) = runAll[Option, Array[Byte]](executor) {
Expand Down
88 changes: 86 additions & 2 deletions hail/src/main/scala/is/hail/services/BatchClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import is.hail.utils._

import scala.util.Random

import java.net.URL
import java.net.{URL, URLEncoder}
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.Path

Expand Down Expand Up @@ -87,6 +87,28 @@ object JobGroupStates {
case object Running extends JobGroupState
}

sealed trait JobState extends Product with Serializable

object JobStates {
case object Pending extends JobState
case object Ready extends JobState
case object Creating extends JobState
case object Running extends JobState
case object Cancelled extends JobState
case object Error extends JobState
case object Failed extends JobState
case object Success extends JobState
}

case class JobListEntry(
batch_id: Int,
job_id: Int,
state: JobState,
exit_code: Int,
)

case class JobResponse(job_id: Int, state: JobState, attributes: Option[Map[String, String]])

object BatchClient {

private[this] def BatchServiceScopes(env: Map[String, String]): Array[String] =
Expand Down Expand Up @@ -122,7 +144,10 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
JobProcessRequestSerializer +
JobGroupStateDeserializer +
JobGroupResponseDeserializer +
JarSpecSerializer
JarSpecSerializer +
JobStateDeserializer +
JobListEntryDeserializer +
JobResponseDeserializer

def newBatch(createRequest: BatchRequest): Int = {
val response = req.post("/api/v1alpha/batches/create", Extraction.decompose(createRequest))
Expand Down Expand Up @@ -153,6 +178,17 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
.get(s"/api/v1alpha/batches/$batchId/job-groups/$jobGroupId")
.extract[JobGroupResponse]

def getJobGroupJobs(batchId: Int, jobGroupId: Int, status: Option[JobState] = None)
: IndexedSeq[JobListEntry] = {
val q = status.map(s => s"state=${s.toString.toLowerCase}").getOrElse("")
req.get(
s"/api/v2alpha/batches/$batchId/job-groups/$jobGroupId/jobs?q=${URLEncoder.encode(q, UTF_8)}"
).as { case obj: JObject => (obj \ "jobs").extract[IndexedSeq[JobListEntry]] }
}

def getJob(batchId: Int, jobId: Int): JobResponse =
req.get(s"/api/v1alpha/batches/$batchId/jobs/$jobId").extract[JobResponse]

def waitForJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse = {
val start = System.nanoTime()

Expand Down Expand Up @@ -298,6 +334,23 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
)
)

private[this] object JobStateDeserializer
extends CustomSerializer[JobState](_ =>
(
{
case JString("Pending") => JobStates.Pending
case JString("Ready") => JobStates.Ready
case JString("Creating") => JobStates.Creating
case JString("Running") => JobStates.Running
case JString("Cancelled") => JobStates.Cancelled
case JString("Error") => JobStates.Error
case JString("Failed") => JobStates.Failed
case JString("Success") => JobStates.Success
},
PartialFunction.empty,
)
)

private[this] object JobGroupResponseDeserializer
extends CustomSerializer[JobGroupResponse](implicit fmts =>
(
Expand All @@ -319,6 +372,37 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
)
)

private[this] object JobListEntryDeserializer
extends CustomSerializer[JobListEntry](implicit fmts =>
(
{
case o: JObject =>
JobListEntry(
batch_id = (o \ "batch_id").extract[Int],
job_id = (o \ "job_id").extract[Int],
state = (o \ "state").extract[JobState],
exit_code = (o \ "exit_code").extract[Int],
)
},
PartialFunction.empty,
)
)

private[this] object JobResponseDeserializer
extends CustomSerializer[JobResponse](implicit fmts =>
(
{
case o: JObject =>
JobResponse(
job_id = (o \ "job_id").extract[Int],
state = (o \ "state").extract[JobState],
attributes = (o \ "attributes").extract[Option[Map[String, String]]],
)
},
PartialFunction.empty,
)
)

private[this] object JarSpecSerializer
extends CustomSerializer[JarSpec](_ =>
(
Expand Down
42 changes: 41 additions & 1 deletion hail/src/test/scala/is/hail/services/BatchClientSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class BatchClientSuite extends TestNGSuite {

@BeforeClass
def createClientAndBatch(): Unit = {
client = BatchClient(DeployConfig.get(), Path.of("/test-gsa-key/key.json"))
client = BatchClient(DeployConfig.get(), Path.of("/tmp/test-gsa-key/key.json"))
batchId = client.newBatch(
BatchRequest(
billing_project = "test",
Expand Down Expand Up @@ -79,6 +79,46 @@ class BatchClientSuite extends TestNGSuite {
assert(result.n_cancelled == 1)
}

@Test
def testGetJobGroupJobsByState(): Unit = {
val jobGroup = client.getJobGroup(8218901, 2)
assert(jobGroup.n_jobs == 2)
assert(jobGroup.n_failed == 1)
assert(client.getJobGroupJobs(8218901, 2).length == 2)
for (state <- Array(JobStates.Failed, JobStates.Success)) {
val jobs = client.getJobGroupJobs(8218901, 2, Some(state))
assert(jobs.length == 1)
assert(jobs(0).state == state)
}
}

@Test
def testGetJobs(): Unit = {
val jobGroupId = client.newJobGroup(
req = JobGroupRequest(
batch_id = batchId,
absolute_parent_id = parentJobGroupId,
token = tokenUrlSafe,
jobs = IndexedSeq(
JobRequest(
always_run = false,
attributes = Map("foo" -> "bar"),
process = BashJob(
image = "ubuntu:22.04",
command = Array("/bin/bash", "-c", s"exit 0"),
),
)
),
)
)
val jobGroupJobs = client.getJobGroupJobs(batchId, jobGroupId)
for (entry <- jobGroupJobs) {
val job = client.getJob(batchId, entry.job_id)
assert(job.attributes.isDefined)
assert(job.attributes.get == Map("foo" -> "bar"))
}
}

@Test
def testNewJobGroup(): Unit =
// The query driver submits a job group per stage with one job per partition
Expand Down

0 comments on commit c572812

Please sign in to comment.