From bfb27baa603395a084f140a098c3947d924fe046 Mon Sep 17 00:00:00 2001 From: grohli <22306963+grohli@users.noreply.github.com> Date: Tue, 29 Oct 2024 16:54:53 -0400 Subject: [PATCH] Query batch for failed jobGroupJobs in query stage --- .../hail/backend/service/ServiceBackend.scala | 122 +++++++++++------- 1 file changed, 76 insertions(+), 46 deletions(-) diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 40e5c283aaf..a914315d224 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -5,14 +5,17 @@ import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ import is.hail.expr.Validate -import is.hail.expr.ir.{Compile, IR, IRParser, IRParserEnvironment, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck} +import is.hail.expr.ir.{ + Compile, IR, IRParser, IRParserEnvironment, IRSize, LoweringAnalyses, MakeTuple, SortField, + TableIR, TableReader, TypeCheck, +} import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ import is.hail.io.fs._ import is.hail.linalg.BlockMatrix import is.hail.services.{BatchClient, JobGroupRequest, _} -import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success} +import is.hail.services.JobGroupStates.{Cancelled, Failure, Success} import is.hail.types._ import is.hail.types.physical._ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType @@ -214,18 +217,33 @@ class ServiceBackend( batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId) } - private[this] def readResult(root: String, i: Int): Array[Byte] = { - val bytes = fs.readNoCompression(s"$root/result.$i") - if (bytes(0) != 0) { - bytes.slice(1, bytes.length) - } else { - val errorInformationBytes = bytes.slice(1, bytes.length) - val is = new DataInputStream(new ByteArrayInputStream(errorInformationBytes)) - val shortMessage = readString(is) - val expandedMessage = readString(is) - val errorId = is.readInt() - throw new HailWorkerException(i, shortMessage, expandedMessage, errorId) - } + private[this] def readPartitionResult(root: String, i: Int): Array[Byte] = { + val file = s"$root/result.$i" + val bytes = fs.readNoCompression(file) + assert(bytes(0) != 0, s"$file is not a valid result.") + bytes.slice(1, bytes.length) + } + + private[this] def readPartitionError(root: String, i: Int): HailWorkerException = { + val file = s"$root/result.$i" + val bytes = fs.readNoCompression(file) + assert(bytes(0) == 0, s"$file did not contain an error") + val errorInformationBytes = bytes.slice(1, bytes.length) + val is = new DataInputStream(new ByteArrayInputStream(errorInformationBytes)) + val shortMessage = readString(is) + val expandedMessage = readString(is) + val errorId = is.readInt() + new HailWorkerException(i, shortMessage, expandedMessage, errorId) + } + + private[this] def getPartitionIndex(batchId: Int, jobId: Int): Int = { + val job = batchClient.getJob(batchId, jobId) + val attr = job.attributes.getOrElse( + throw new HailBatchFailure( + s"Job $jobId in batch $batchId did not have attributes." + ) + ) + attr("idx").toInt } override def parallelizeAndComputeWithIndex( @@ -278,41 +296,53 @@ class ServiceBackend( uploadContexts.get() val jobGroup = submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier) - - // case match on jobGroup - // success => read files - // failure => read failure only - // cancelled => propagate failure message - log.info(s"parallelizeAndComputeWithIndex: $token: reading results") val startTime = System.nanoTime() - var r @ (err, results) = runAll[Option, Array[Byte]](executor) { - /* A missing file means the job was cancelled because another job failed. Assumes that if any - * job was cancelled, then at least one job failed. We want to ignore the missing file - * exceptions and return one of the actual failure exceptions. */ - case (opt, _: FileNotFoundException) => opt - case (opt, e) => opt.orElse(Some(e)) - }(None) { - (partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) => - (() => readResult(root, jobIndex), partIdx) - } - } - if (jobGroup.state != Success && err.isEmpty) { - assert(jobGroup.state != Running) - val error = - jobGroup.state match { - case Failure => - new HailBatchFailure( - s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} failed with an unknown error" - ) - case Cancelled => - new CancellationException( - s"Job group ${jobGroup.job_group_id} for batch ${batchConfig.batchId} was cancelled" - ) - } - r = (Some(error), results) - } + val r @ (_, results) = + jobGroup.state match { + case Success => + runAllKeepFirstError(executor) { + (partIdxs, parts.indices).zipped.map { (partIdx, jobIndex) => + (() => readPartitionResult(root, jobIndex), partIdx) + } + } + case Failure => + val failedEntries = batchClient.getJobGroupJobs( + jobGroup.batch_id, + jobGroup.job_group_id, + Some(JobStates.Failed), + ) + assert( + failedEntries.nonEmpty, + s"Job group ${jobGroup.job_group_id} failed, but no failed jobs found.", + ) + val partId = getPartitionIndex(jobGroup.batch_id, failedEntries.head.job_id) + val error = readPartitionError(root, partId) + val successes = batchClient.getJobGroupJobs( + jobGroup.batch_id, + jobGroup.job_group_id, + Some(JobStates.Success), + ) + val results = successes.map { job => + val partIdx = getPartitionIndex(jobGroup.batch_id, job.job_id) + (readPartitionResult(root, partIdx), partIdx) + } + (Some(error), results) + case Cancelled => + val successes = batchClient.getJobGroupJobs( + jobGroup.batch_id, + jobGroup.job_group_id, + Some(JobStates.Success), + ) + val results = successes.map { job => + val partIdx = getPartitionIndex(jobGroup.batch_id, job.job_id) + (readPartitionResult(root, partIdx), partIdx) + } + val error = + new CancellationException(s"Job Group ${jobGroup.job_group_id} was cancelled.") + (Some(error), results) + } val resultsReadingSeconds = (System.nanoTime() - startTime) / 1000000000.0 val rate = results.length / resultsReadingSeconds