Skip to content

Commit

Permalink
[qob] cancel stage if any partitions fail.
Browse files Browse the repository at this point in the history
  • Loading branch information
grohli committed Oct 25, 2024
1 parent 71b00f3 commit a0b5577
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 6 deletions.
11 changes: 9 additions & 2 deletions hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package is.hail.backend.service

import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags}
import is.hail.{HailContext, HailFeatureFlags}
import is.hail.annotations._
import is.hail.asm4s._
import is.hail.backend._
Expand Down Expand Up @@ -202,6 +202,7 @@ class ServiceBackend(
batch_id = batchConfig.batchId,
absolute_parent_id = batchConfig.jobGroupId,
token = token,
cancel_after_n_failures = Some(1),
attributes = Map("name" -> stageIdentifier),
jobs = jobs,
)
Expand Down Expand Up @@ -280,7 +281,13 @@ class ServiceBackend(

log.info(s"parallelizeAndComputeWithIndex: $token: reading results")
val startTime = System.nanoTime()
val r @ (error, results) = runAllKeepFirstError(new CancellingExecutorService(executor)) {
val r @ (error, results) = runAll[Option, Array[Byte]](executor) {
/* A missing file means the job was cancelled because another job failed. Assumes that if any
* job was cancelled, then at least one job failed. We want to ignore the missing file
* exceptions and return one of the actual failure exceptions. */
case (opt, _: FileNotFoundException) => opt
case (opt, e) => opt.orElse(Some(e))
}(None) {
(partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) =>
(() => readResult(root, jobIndex), partIdx)
}
Expand Down
14 changes: 10 additions & 4 deletions hail/src/main/scala/is/hail/services/BatchClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ import java.nio.file.Path

import org.apache.http.entity.ByteArrayEntity
import org.apache.http.entity.ContentType.APPLICATION_JSON
import org.json4s.{CustomSerializer, DefaultFormats, Extraction, Formats, JInt, JObject, JString}
import org.json4s.{
CustomSerializer, DefaultFormats, Extraction, Formats, JInt, JNull, JObject, JString,
}
import org.json4s.JsonAST.{JArray, JBool}
import org.json4s.jackson.JsonMethods

Expand All @@ -29,6 +31,7 @@ case class JobGroupRequest(
batch_id: Int,
absolute_parent_id: Int,
token: String,
cancel_after_n_failures: Option[Int] = None,
attributes: Map[String, String] = Map.empty,
jobs: IndexedSeq[JobRequest] = FastSeq(),
)
Expand All @@ -52,9 +55,9 @@ case class JarUrl(url: String) extends JarSpec

case class JobResources(
preemptible: Boolean,
cpu: Option[String],
memory: Option[String],
storage: Option[String],
cpu: Option[String] = None,
memory: Option[String] = None,
storage: Option[String] = None,
)

case class CloudfuseConfig(
Expand Down Expand Up @@ -252,6 +255,9 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
JObject(
"job_group_id" -> JInt(1), // job group id relative to the update
"absolute_parent_id" -> JInt(jobGroup.absolute_parent_id),
"cancel_after_n_failures" -> jobGroup.cancel_after_n_failures.map(JInt(_)).getOrElse(
JNull
),
"attributes" -> Extraction.decompose(jobGroup.attributes),
)
)),
Expand Down
33 changes: 33 additions & 0 deletions hail/src/test/scala/is/hail/services/BatchClientSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package is.hail.services

import is.hail.HAIL_REVISION
import is.hail.backend.service.Main
import is.hail.services.JobGroupStates.Failure
import is.hail.utils._

import java.lang.reflect.Method
Expand Down Expand Up @@ -46,6 +47,38 @@ class BatchClientSuite extends TestNGSuite {
def closeClient(): Unit =
client.close()

@Test
def testCancelAfterNFailures(): Unit = {
val jobGroupId = client.newJobGroup(
req = JobGroupRequest(
batch_id = batchId,
absolute_parent_id = parentJobGroupId,
cancel_after_n_failures = Some(1),
token = tokenUrlSafe,
jobs = FastSeq(
JobRequest(
always_run = false,
process = BashJob(
image = "ubuntu:22.04",
command = Array("/bin/bash", "-c", "sleep 1d"),
),
resources = Some(JobResources(preemptible = true)),
),
JobRequest(
always_run = false,
process = BashJob(
image = "ubuntu:22.04",
command = Array("/bin/bash", "-c", "exit 1"),
),
),
),
)
)
val result = client.waitForJobGroup(batchId, jobGroupId)
assert(result.state == Failure)
assert(result.n_cancelled == 1)
}

@Test
def testNewJobGroup(): Unit =
// The query driver submits a job group per stage with one job per partition
Expand Down

0 comments on commit a0b5577

Please sign in to comment.