Skip to content

Commit

Permalink
Query batch for failed jobGroupJobs in query stage
Browse files Browse the repository at this point in the history
  • Loading branch information
grohli committed Nov 5, 2024
1 parent c1385b5 commit bfb27ba
Showing 1 changed file with 76 additions and 46 deletions.
122 changes: 76 additions & 46 deletions hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ import is.hail.annotations._
import is.hail.asm4s._
import is.hail.backend._
import is.hail.expr.Validate
import is.hail.expr.ir.{Compile, IR, IRParser, IRParserEnvironment, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck}
import is.hail.expr.ir.{
Compile, IR, IRParser, IRParserEnvironment, IRSize, LoweringAnalyses, MakeTuple, SortField,
TableIR, TableReader, TypeCheck,
}
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
import is.hail.linalg.BlockMatrix
import is.hail.services.{BatchClient, JobGroupRequest, _}
import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success}
import is.hail.services.JobGroupStates.{Cancelled, Failure, Success}
import is.hail.types._
import is.hail.types.physical._
import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType
Expand Down Expand Up @@ -214,18 +217,33 @@ class ServiceBackend(
batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId)
}

private[this] def readResult(root: String, i: Int): Array[Byte] = {
val bytes = fs.readNoCompression(s"$root/result.$i")
if (bytes(0) != 0) {
bytes.slice(1, bytes.length)
} else {
val errorInformationBytes = bytes.slice(1, bytes.length)
val is = new DataInputStream(new ByteArrayInputStream(errorInformationBytes))
val shortMessage = readString(is)
val expandedMessage = readString(is)
val errorId = is.readInt()
throw new HailWorkerException(i, shortMessage, expandedMessage, errorId)
}
private[this] def readPartitionResult(root: String, i: Int): Array[Byte] = {
val file = s"$root/result.$i"
val bytes = fs.readNoCompression(file)
assert(bytes(0) != 0, s"$file is not a valid result.")
bytes.slice(1, bytes.length)
}

private[this] def readPartitionError(root: String, i: Int): HailWorkerException = {
val file = s"$root/result.$i"
val bytes = fs.readNoCompression(file)
assert(bytes(0) == 0, s"$file did not contain an error")
val errorInformationBytes = bytes.slice(1, bytes.length)
val is = new DataInputStream(new ByteArrayInputStream(errorInformationBytes))
val shortMessage = readString(is)
val expandedMessage = readString(is)
val errorId = is.readInt()
new HailWorkerException(i, shortMessage, expandedMessage, errorId)
}

private[this] def getPartitionIndex(batchId: Int, jobId: Int): Int = {
val job = batchClient.getJob(batchId, jobId)
val attr = job.attributes.getOrElse(
throw new HailBatchFailure(
s"Job $jobId in batch $batchId did not have attributes."
)
)
attr("idx").toInt
}

override def parallelizeAndComputeWithIndex(
Expand Down Expand Up @@ -278,41 +296,53 @@ class ServiceBackend(
uploadContexts.get()

val jobGroup = submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier)

// case match on jobGroup
// success => read files
// failure => read failure only
// cancelled => propagate failure message

log.info(s"parallelizeAndComputeWithIndex: $token: reading results")
val startTime = System.nanoTime()
var r @ (err, results) = runAll[Option, Array[Byte]](executor) {
/* A missing file means the job was cancelled because another job failed. Assumes that if any
* job was cancelled, then at least one job failed. We want to ignore the missing file
* exceptions and return one of the actual failure exceptions. */
case (opt, _: FileNotFoundException) => opt
case (opt, e) => opt.orElse(Some(e))
}(None) {
(partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) =>
(() => readResult(root, jobIndex), partIdx)
}
}
if (jobGroup.state != Success && err.isEmpty) {
assert(jobGroup.state != Running)
val error =
jobGroup.state match {
case Failure =>
new HailBatchFailure(
s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} failed with an unknown error"
)
case Cancelled =>
new CancellationException(
s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} was cancelled"
)
}

r = (Some(error), results)
}
val r @ (_, results) =
jobGroup.state match {
case Success =>
runAllKeepFirstError(executor) {
(partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) =>
(() => readPartitionResult(root, jobIndex), partIdx)
}
}
case Failure =>
val failedEntries = batchClient.getJobGroupJobs(
jobGroup.batch_id,
jobGroup.job_group_id,
Some(JobStates.Failed),
)
assert(
failedEntries.nonEmpty,
s"Job group ${jobGroup.job_group_id} failed, but no failed jobs found.",
)
val partId = getPartitionIndex(jobGroup.batch_id, failedEntries.head.job_id)
val error = readPartitionError(root, partId)
val successes = batchClient.getJobGroupJobs(
jobGroup.batch_id,
jobGroup.job_group_id,
Some(JobStates.Success),
)
val results = successes.map { job =>
val partIdx = getPartitionIndex(jobGroup.batch_id, job.job_id)
(readPartitionResult(root, partIdx), partIdx)
}
(Some(error), results)
case Cancelled =>
val successes = batchClient.getJobGroupJobs(
jobGroup.batch_id,
jobGroup.job_group_id,
Some(JobStates.Success),
)
val results = successes.map { job =>
val partIdx = getPartitionIndex(jobGroup.batch_id, job.job_id)
(readPartitionResult(root, partIdx), partIdx)
}
val error =
new CancellationException(s"Job Group ${jobGroup.job_group_id} was cancelled.")
(Some(error), results)
}

val resultsReadingSeconds = (System.nanoTime() - startTime) / 1000000000.0
val rate = results.length / resultsReadingSeconds
Expand Down

0 comments on commit bfb27ba

Please sign in to comment.