Skip to content

Commit

Permalink
Stream failed jobGroupJobs from batch
Browse files Browse the repository at this point in the history
  • Loading branch information
grohli committed Nov 6, 2024
1 parent bfb27ba commit 429847a
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 117 deletions.
69 changes: 33 additions & 36 deletions hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class ServiceBackend(
token: String,
root: String,
stageIdentifier: String,
): JobGroupResponse = {
): (JobGroupResponse, Int) = {
val defaultProcess =
JvmJob(
command = null,
Expand Down Expand Up @@ -199,7 +199,18 @@ class ServiceBackend(
)
}

val jobGroupId =
// When we create a JobGroup with n jobs, Batch gives us the absolute JobGroupId,
// and the startJobId for the first job.
/* This means that all JobId's in the JobGroup will have values in range (startJobId, startJobId
* + n). */
// Therefore, we know the partition index for a given job by using this startJobId offset.

// Why do we do this?
// Consider a situation where we're submitting thousands of jobs in a job group.
/* If one of those jobs fails, we don't want to make thousands of requests to batch to get a
* partition index */
// that that job corresponds to.
val (jobGroupId, startJobId) =
batchClient.newJobGroup(
JobGroupRequest(
batch_id = batchConfig.batchId,
Expand All @@ -214,7 +225,8 @@ class ServiceBackend(
stageCount += 1

Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms
batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId)
val response = batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId)
(response, startJobId)
}

private[this] def readPartitionResult(root: String, i: Int): Array[Byte] = {
Expand All @@ -236,16 +248,6 @@ class ServiceBackend(
new HailWorkerException(i, shortMessage, expandedMessage, errorId)
}

private[this] def getPartitionIndex(batchId: Int, jobId: Int): Int = {
val job = batchClient.getJob(batchId, jobId)
val attr = job.attributes.getOrElse(
throw new HailBatchFailure(
s"Job $jobId in batch $batchId did not have attributes."
)
)
attr("idx").toInt
}

override def parallelizeAndComputeWithIndex(
_backendContext: BackendContext,
fs: FS,
Expand Down Expand Up @@ -295,10 +297,22 @@ class ServiceBackend(
uploadFunction.get()
uploadContexts.get()

val jobGroup = submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier)
val (jobGroup, startJobId) =
submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier)
log.info(s"parallelizeAndComputeWithIndex: $token: reading results")
val startTime = System.nanoTime()

def streamSuccessfulJobResults: Stream[(Array[Byte], Int)] =
for {
successes <- batchClient.getJobGroupJobs(
jobGroup.batch_id,
jobGroup.job_group_id,
Some(JobStates.Success),
)
job <- successes
partIdx = job.job_id - startJobId
} yield (readPartitionResult(root, partIdx), partIdx)

val r @ (_, results) =
jobGroup.state match {
case Success =>
Expand All @@ -317,31 +331,14 @@ class ServiceBackend(
failedEntries.nonEmpty,
s"Job group ${jobGroup.job_group_id} failed, but no failed jobs found.",
)
val partId = getPartitionIndex(jobGroup.batch_id, failedEntries.head.job_id)
val error = readPartitionError(root, partId)
val successes = batchClient.getJobGroupJobs(
jobGroup.batch_id,
jobGroup.job_group_id,
Some(JobStates.Success),
)
val results = successes.map { job =>
val partIdx = getPartitionIndex(jobGroup.batch_id, job.job_id)
(readPartitionResult(root, partIdx), partIdx)
}
(Some(error), results)
val error = readPartitionError(root, failedEntries.head.head.job_id - startJobId)

(Some(error), streamSuccessfulJobResults.toIndexedSeq)
case Cancelled =>
val successes = batchClient.getJobGroupJobs(
jobGroup.batch_id,
jobGroup.job_group_id,
Some(JobStates.Success),
)
val results = successes.map { job =>
val partIdx = getPartitionIndex(jobGroup.batch_id, job.job_id)
(readPartitionResult(root, partIdx), partIdx)
}
val error =
new CancellationException(s"Job Group ${jobGroup.job_group_id} was cancelled.")
(Some(error), results)

(Some(error), streamSuccessfulJobResults.toIndexedSeq)
}

val resultsReadingSeconds = (System.nanoTime() - startTime) / 1000000000.0
Expand Down
15 changes: 9 additions & 6 deletions hail/src/main/scala/is/hail/backend/service/Worker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ object Worker {
out.write(bytes)
}

def writeException(out: DataOutputStream, e: Throwable): Unit = {
val (shortMessage, expandedMessage, errorId) = handleForPython(e)
out.writeBoolean(false)
writeString(out, shortMessage)
writeString(out, expandedMessage)
out.writeInt(errorId)
}

def main(argv: Array[String]): Unit = {
val theHailClassLoader = new HailClassLoader(getClass().getClassLoader())

Expand Down Expand Up @@ -217,12 +225,7 @@ object Worker {
dos.writeBoolean(true)
dos.write(bytes)
case Left(throwableWhileExecutingUserCode) =>
val (shortMessage, expandedMessage, errorId) =
handleForPython(throwableWhileExecutingUserCode)
dos.writeBoolean(false)
writeString(dos, shortMessage)
writeString(dos, expandedMessage)
dos.writeInt(errorId)
writeException(dos, throwableWhileExecutingUserCode)
}
}
}
Expand Down
58 changes: 28 additions & 30 deletions hail/src/main/scala/is/hail/services/BatchClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import is.hail.services.oauth2.CloudCredentials
import is.hail.services.requests.Requester
import is.hail.utils._

import scala.collection.immutable.Stream.cons
import scala.util.Random

import java.net.{URL, URLEncoder}
Expand Down Expand Up @@ -107,8 +108,6 @@ case class JobListEntry(
exit_code: Int,
)

case class JobResponse(job_id: Int, state: JobState, attributes: Option[Map[String, String]])

object BatchClient {

private[this] def BatchServiceScopes(env: Map[String, String]): Array[String] =
Expand Down Expand Up @@ -146,8 +145,12 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
JobGroupResponseDeserializer +
JarSpecSerializer +
JobStateDeserializer +
JobListEntryDeserializer +
JobResponseDeserializer
JobListEntryDeserializer

private[this] def paginated[S, A](s0: S)(f: S => (A, S)): Stream[A] = {
val (a, s1) = f(s0)
cons(a, paginated(s1)(f))
}

def newBatch(createRequest: BatchRequest): Int = {
val response = req.post("/api/v1alpha/batches/create", Extraction.decompose(createRequest))
Expand All @@ -156,9 +159,9 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
batchId
}

def newJobGroup(req: JobGroupRequest): Int = {
def newJobGroup(req: JobGroupRequest): (Int, Int) = {
val nJobs = req.jobs.length
val (updateId, startJobGroupId) = beginUpdate(req.batch_id, req.token, nJobs)
val (updateId, startJobGroupId, startJobId) = beginUpdate(req.batch_id, req.token, nJobs)
log.info(s"Began update '$updateId' for batch '${req.batch_id}'.")

createJobGroup(updateId, req)
Expand All @@ -170,7 +173,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
commitUpdate(req.batch_id, updateId)
log.info(s"Committed update $updateId for batch ${req.batch_id}.")

startJobGroupId
(startJobGroupId, startJobId)
}

def getJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse =
Expand All @@ -179,16 +182,25 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
.extract[JobGroupResponse]

def getJobGroupJobs(batchId: Int, jobGroupId: Int, status: Option[JobState] = None)
: IndexedSeq[JobListEntry] = {
: Stream[IndexedSeq[JobListEntry]] = {
val q = status.map(s => s"state=${s.toString.toLowerCase}").getOrElse("")
req.get(
s"/api/v2alpha/batches/$batchId/job-groups/$jobGroupId/jobs?q=${URLEncoder.encode(q, UTF_8)}"
).as { case obj: JObject => (obj \ "jobs").extract[IndexedSeq[JobListEntry]] }
paginated(Some(0): Option[Int]) {
case Some(jobId) =>
req.get(
s"/api/v2alpha/batches/$batchId/job-groups/$jobGroupId/jobs?q=${URLEncoder.encode(q, UTF_8)}&last_job_id=$jobId"
)
.as { case obj: JObject =>
(
(obj \ "jobs").extract[IndexedSeq[JobListEntry]],
(obj \ "last_job_id").extract[Option[Int]],
)
}
case None =>
(IndexedSeq.empty, None)
}
.takeWhile(_.nonEmpty)
}

def getJob(batchId: Int, jobId: Int): JobResponse =
req.get(s"/api/v1alpha/batches/$batchId/jobs/$jobId").extract[JobResponse]

def waitForJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse = {
val start = System.nanoTime()

Expand Down Expand Up @@ -264,7 +276,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
)
}

private[this] def beginUpdate(batchId: Int, token: String, nJobs: Int): (Int, Int) =
private[this] def beginUpdate(batchId: Int, token: String, nJobs: Int): (Int, Int, Int) =
req
.post(
s"/api/v1alpha/batches/$batchId/updates/create",
Expand All @@ -278,6 +290,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
(
(obj \ "update_id").extract[Int],
(obj \ "start_job_group_id").extract[Int],
(obj \ "start_job_id").extract[Int],
)
}

Expand Down Expand Up @@ -388,21 +401,6 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab
)
)

private[this] object JobResponseDeserializer
extends CustomSerializer[JobResponse](implicit fmts =>
(
{
case o: JObject =>
JobResponse(
job_id = (o \ "job_id").extract[Int],
state = (o \ "state").extract[JobState],
attributes = (o \ "attributes").extract[Option[Map[String, String]]],
)
},
PartialFunction.empty,
)
)

private[this] object JarSpecSerializer
extends CustomSerializer[JarSpec](_ =>
(
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/utils/ErrorHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class HailException(val msg: String, val logMsg: Option[String], cause: Throwabl
def this(msg: String, errorId: Int) = this(msg, None, null, errorId)
}

class HailWorkerException(
case class HailWorkerException(
val partitionId: Int,
val shortMessage: String,
val expandedMessage: String,
Expand Down
Loading

0 comments on commit 429847a

Please sign in to comment.