From 429847ad2b141109c0bc91118c9091178f5013df Mon Sep 17 00:00:00 2001 From: grohli <22306963+grohli@users.noreply.github.com> Date: Tue, 5 Nov 2024 14:18:10 -0500 Subject: [PATCH] Stream failed jobGroupJobs from batch --- .../hail/backend/service/ServiceBackend.scala | 69 +++++---- .../is/hail/backend/service/Worker.scala | 15 +- .../scala/is/hail/services/BatchClient.scala | 58 ++++---- .../scala/is/hail/utils/ErrorHandling.scala | 2 +- .../is/hail/backend/ServiceBackendSuite.scala | 135 +++++++++++++++++- .../is/hail/services/BatchClientSuite.scala | 49 ++----- 6 files changed, 211 insertions(+), 117 deletions(-) diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index a914315d224..fd252b15f4a 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -162,7 +162,7 @@ class ServiceBackend( token: String, root: String, stageIdentifier: String, - ): JobGroupResponse = { + ): (JobGroupResponse, Int) = { val defaultProcess = JvmJob( command = null, @@ -199,7 +199,18 @@ class ServiceBackend( ) } - val jobGroupId = + // When we create a JobGroup with n jobs, Batch gives us the absolute JobGroupId, + // and the startJobId for the first job. + /* This means that all JobId's in the JobGroup will have values in range (startJobId, startJobId + * + n). */ + // Therefore, we know the partition index for a given job by using this startJobId offset. + + // Why do we do this? + // Consider a situation where we're submitting thousands of jobs in a job group. + /* If one of those jobs fails, we don't want to make thousands of requests to batch to get a + * partition index */ + // that that job corresponds to. + val (jobGroupId, startJobId) = batchClient.newJobGroup( JobGroupRequest( batch_id = batchConfig.batchId, @@ -214,7 +225,8 @@ class ServiceBackend( stageCount += 1 Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms - batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId) + val response = batchClient.waitForJobGroup(batchConfig.batchId, jobGroupId) + (response, startJobId) } private[this] def readPartitionResult(root: String, i: Int): Array[Byte] = { @@ -236,16 +248,6 @@ class ServiceBackend( new HailWorkerException(i, shortMessage, expandedMessage, errorId) } - private[this] def getPartitionIndex(batchId: Int, jobId: Int): Int = { - val job = batchClient.getJob(batchId, jobId) - val attr = job.attributes.getOrElse( - throw new HailBatchFailure( - s"Job $jobId in batch $batchId did not have attributes." - ) - ) - attr("idx").toInt - } - override def parallelizeAndComputeWithIndex( _backendContext: BackendContext, fs: FS, @@ -295,10 +297,22 @@ class ServiceBackend( uploadFunction.get() uploadContexts.get() - val jobGroup = submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier) + val (jobGroup, startJobId) = + submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier) log.info(s"parallelizeAndComputeWithIndex: $token: reading results") val startTime = System.nanoTime() + def streamSuccessfulJobResults: Stream[(Array[Byte], Int)] = + for { + successes <- batchClient.getJobGroupJobs( + jobGroup.batch_id, + jobGroup.job_group_id, + Some(JobStates.Success), + ) + job <- successes + partIdx = job.job_id - startJobId + } yield (readPartitionResult(root, partIdx), partIdx) + val r @ (_, results) = jobGroup.state match { case Success => @@ -317,31 +331,14 @@ class ServiceBackend( failedEntries.nonEmpty, s"Job group ${jobGroup.job_group_id} failed, but no failed jobs found.", ) - val partId = getPartitionIndex(jobGroup.batch_id, failedEntries.head.job_id) - val error = readPartitionError(root, partId) - val successes = batchClient.getJobGroupJobs( - jobGroup.batch_id, - jobGroup.job_group_id, - Some(JobStates.Success), - ) - val results = successes.map { job => - val partIdx = getPartitionIndex(jobGroup.batch_id, job.job_id) - (readPartitionResult(root, partIdx), partIdx) - } - (Some(error), results) + val error = readPartitionError(root, failedEntries.head.head.job_id - startJobId) + + (Some(error), streamSuccessfulJobResults.toIndexedSeq) case Cancelled => - val successes = batchClient.getJobGroupJobs( - jobGroup.batch_id, - jobGroup.job_group_id, - Some(JobStates.Success), - ) - val results = successes.map { job => - val partIdx = getPartitionIndex(jobGroup.batch_id, job.job_id) - (readPartitionResult(root, partIdx), partIdx) - } val error = new CancellationException(s"Job Group ${jobGroup.job_group_id} was cancelled.") - (Some(error), results) + + (Some(error), streamSuccessfulJobResults.toIndexedSeq) } val resultsReadingSeconds = (System.nanoTime() - startTime) / 1000000000.0 diff --git a/hail/src/main/scala/is/hail/backend/service/Worker.scala b/hail/src/main/scala/is/hail/backend/service/Worker.scala index 5722e3bc435..48eeeb525c8 100644 --- a/hail/src/main/scala/is/hail/backend/service/Worker.scala +++ b/hail/src/main/scala/is/hail/backend/service/Worker.scala @@ -98,6 +98,14 @@ object Worker { out.write(bytes) } + def writeException(out: DataOutputStream, e: Throwable): Unit = { + val (shortMessage, expandedMessage, errorId) = handleForPython(e) + out.writeBoolean(false) + writeString(out, shortMessage) + writeString(out, expandedMessage) + out.writeInt(errorId) + } + def main(argv: Array[String]): Unit = { val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) @@ -217,12 +225,7 @@ object Worker { dos.writeBoolean(true) dos.write(bytes) case Left(throwableWhileExecutingUserCode) => - val (shortMessage, expandedMessage, errorId) = - handleForPython(throwableWhileExecutingUserCode) - dos.writeBoolean(false) - writeString(dos, shortMessage) - writeString(dos, expandedMessage) - dos.writeInt(errorId) + writeException(dos, throwableWhileExecutingUserCode) } } } diff --git a/hail/src/main/scala/is/hail/services/BatchClient.scala b/hail/src/main/scala/is/hail/services/BatchClient.scala index 5660b3db16d..df7038494b7 100644 --- a/hail/src/main/scala/is/hail/services/BatchClient.scala +++ b/hail/src/main/scala/is/hail/services/BatchClient.scala @@ -6,6 +6,7 @@ import is.hail.services.oauth2.CloudCredentials import is.hail.services.requests.Requester import is.hail.utils._ +import scala.collection.immutable.Stream.cons import scala.util.Random import java.net.{URL, URLEncoder} @@ -107,8 +108,6 @@ case class JobListEntry( exit_code: Int, ) -case class JobResponse(job_id: Int, state: JobState, attributes: Option[Map[String, String]]) - object BatchClient { private[this] def BatchServiceScopes(env: Map[String, String]): Array[String] = @@ -146,8 +145,12 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab JobGroupResponseDeserializer + JarSpecSerializer + JobStateDeserializer + - JobListEntryDeserializer + - JobResponseDeserializer + JobListEntryDeserializer + + private[this] def paginated[S, A](s0: S)(f: S => (A, S)): Stream[A] = { + val (a, s1) = f(s0) + cons(a, paginated(s1)(f)) + } def newBatch(createRequest: BatchRequest): Int = { val response = req.post("/api/v1alpha/batches/create", Extraction.decompose(createRequest)) @@ -156,9 +159,9 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab batchId } - def newJobGroup(req: JobGroupRequest): Int = { + def newJobGroup(req: JobGroupRequest): (Int, Int) = { val nJobs = req.jobs.length - val (updateId, startJobGroupId) = beginUpdate(req.batch_id, req.token, nJobs) + val (updateId, startJobGroupId, startJobId) = beginUpdate(req.batch_id, req.token, nJobs) log.info(s"Began update '$updateId' for batch '${req.batch_id}'.") createJobGroup(updateId, req) @@ -170,7 +173,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab commitUpdate(req.batch_id, updateId) log.info(s"Committed update $updateId for batch ${req.batch_id}.") - startJobGroupId + (startJobGroupId, startJobId) } def getJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse = @@ -179,16 +182,25 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab .extract[JobGroupResponse] def getJobGroupJobs(batchId: Int, jobGroupId: Int, status: Option[JobState] = None) - : IndexedSeq[JobListEntry] = { + : Stream[IndexedSeq[JobListEntry]] = { val q = status.map(s => s"state=${s.toString.toLowerCase}").getOrElse("") - req.get( - s"/api/v2alpha/batches/$batchId/job-groups/$jobGroupId/jobs?q=${URLEncoder.encode(q, UTF_8)}" - ).as { case obj: JObject => (obj \ "jobs").extract[IndexedSeq[JobListEntry]] } + paginated(Some(0): Option[Int]) { + case Some(jobId) => + req.get( + s"/api/v2alpha/batches/$batchId/job-groups/$jobGroupId/jobs?q=${URLEncoder.encode(q, UTF_8)}&last_job_id=$jobId" + ) + .as { case obj: JObject => + ( + (obj \ "jobs").extract[IndexedSeq[JobListEntry]], + (obj \ "last_job_id").extract[Option[Int]], + ) + } + case None => + (IndexedSeq.empty, None) + } + .takeWhile(_.nonEmpty) } - def getJob(batchId: Int, jobId: Int): JobResponse = - req.get(s"/api/v1alpha/batches/$batchId/jobs/$jobId").extract[JobResponse] - def waitForJobGroup(batchId: Int, jobGroupId: Int): JobGroupResponse = { val start = System.nanoTime() @@ -264,7 +276,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab ) } - private[this] def beginUpdate(batchId: Int, token: String, nJobs: Int): (Int, Int) = + private[this] def beginUpdate(batchId: Int, token: String, nJobs: Int): (Int, Int, Int) = req .post( s"/api/v1alpha/batches/$batchId/updates/create", @@ -278,6 +290,7 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab ( (obj \ "update_id").extract[Int], (obj \ "start_job_group_id").extract[Int], + (obj \ "start_job_id").extract[Int], ) } @@ -388,21 +401,6 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab ) ) - private[this] object JobResponseDeserializer - extends CustomSerializer[JobResponse](implicit fmts => - ( - { - case o: JObject => - JobResponse( - job_id = (o \ "job_id").extract[Int], - state = (o \ "state").extract[JobState], - attributes = (o \ "attributes").extract[Option[Map[String, String]]], - ) - }, - PartialFunction.empty, - ) - ) - private[this] object JarSpecSerializer extends CustomSerializer[JarSpec](_ => ( diff --git a/hail/src/main/scala/is/hail/utils/ErrorHandling.scala b/hail/src/main/scala/is/hail/utils/ErrorHandling.scala index 176df006080..beb3ce0edc5 100644 --- a/hail/src/main/scala/is/hail/utils/ErrorHandling.scala +++ b/hail/src/main/scala/is/hail/utils/ErrorHandling.scala @@ -8,7 +8,7 @@ class HailException(val msg: String, val logMsg: Option[String], cause: Throwabl def this(msg: String, errorId: Int) = this(msg, None, null, errorId) } -class HailWorkerException( +case class HailWorkerException( val partitionId: Int, val shortMessage: String, val expandedMessage: String, diff --git a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala index c66e1b0fcbd..74f60cbb5d6 100644 --- a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala @@ -2,12 +2,15 @@ package is.hail.backend import is.hail.HailFeatureFlags import is.hail.asm4s.HailClassLoader -import is.hail.backend.service.{ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload} +import is.hail.backend.service.{ + ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload, Worker, +} import is.hail.io.fs.{CloudStorageFSConfig, RouterFS} import is.hail.services._ -import is.hail.services.JobGroupStates.Success -import is.hail.utils.{tokenUrlSafe, using} +import is.hail.services.JobGroupStates.{Cancelled, Failure, Success} +import is.hail.utils.{handleForPython, tokenUrlSafe, using, HailWorkerException} +import scala.concurrent.CancellationException import scala.reflect.io.{Directory, Path} import scala.util.Random @@ -17,7 +20,7 @@ import org.mockito.ArgumentMatchersSugar.any import org.mockito.IdiomaticMockito import org.mockito.MockitoSugar.when import org.scalatest.OptionValues -import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.scalatest.matchers.should.Matchers.{a, convertToAnyShouldWrapper} import org.scalatestplus.testng.TestNGSuite import org.testng.annotations.Test @@ -49,8 +52,7 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV storage = Some(rpcConfig.storage), ) } - - backend.batchConfig.jobGroupId + 1 + (backend.batchConfig.jobGroupId + 1, 1) } // the service backend expects that each job write its output to a well-known @@ -95,6 +97,127 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV } } + @Test def testFailedJobGroup(): Unit = + withMockDriverContext { rpcConfig => + val batchClient = mock[BatchClient] + using(ServiceBackend(batchClient, rpcConfig)) { backend => + val contexts = Array.tabulate(100)(_.toString.getBytes) + val startJobGroupId = 2356 + when(batchClient.newJobGroup(any[JobGroupRequest])) thenAnswer { + jobGroup: JobGroupRequest => (backend.batchConfig.jobGroupId + 1, startJobGroupId) + } + val successes = Array(13, 34, 65, 81) // arbitrary indices + val failures = Array(21, 44) + val expectedCause = new NoSuchMethodError("") + when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { + (id: Int, jobGroupId: Int) => + val resultsDir = + Path(backend.serviceBackendContext.remoteTmpDir) / + "parallelizeAndComputeWithIndex" / + tokenUrlSafe + + resultsDir.createDirectory() + for (i <- successes) (resultsDir / f"result.$i").toFile.writeAll("11") + + for (i <- failures) + backend.fs.writePDOS((resultsDir / f"result.$i").toString()) { + os => Worker.writeException(os, expectedCause) + } + JobGroupResponse( + batch_id = id, + job_group_id = jobGroupId, + state = Failure, + complete = true, + n_jobs = contexts.length, + n_completed = contexts.length, + n_succeeded = successes.length, + n_failed = failures.length, + n_cancelled = contexts.length - failures.length - successes.length, + ) + } + when(batchClient.getJobGroupJobs(any[Int], any[Int], any[Option[JobState]])) thenAnswer { + (batchId: Int, _: Int, s: Option[JobState]) => + s match { + case Some(JobStates.Failed) => + Stream(failures.map(i => + JobListEntry(batchId, i + startJobGroupId, JobStates.Failed, 1) + ).toIndexedSeq) + + case Some(JobStates.Success) => + Stream(successes.map(i => + JobListEntry(batchId, i + startJobGroupId, JobStates.Success, 1) + ).toIndexedSeq) + } + + } + + val (failure, result) = + backend.parallelizeAndComputeWithIndex( + backend.serviceBackendContext, + backend.fs, + contexts, + "stage1", + )((bytes, _, _, _) => bytes) + val (shortMessage, expanded, id) = handleForPython(expectedCause) + failure.value shouldBe new HailWorkerException(failures.head, shortMessage, expanded, id) + result.map(_._2) shouldBe successes + } + } + + @Test def testCancelledJobGroup(): Unit = + withMockDriverContext { rpcConfig => + val batchClient = mock[BatchClient] + using(ServiceBackend(batchClient, rpcConfig)) { backend => + val contexts = Array.tabulate(100)(_.toString.getBytes) + val startJobGroupId = 2356 + when(batchClient.newJobGroup(any[JobGroupRequest])) thenAnswer { + jobGroup: JobGroupRequest => (backend.batchConfig.jobGroupId + 1, startJobGroupId) + } + val successes = Array(13, 34, 65, 81) // arbitrary indices + when(batchClient.waitForJobGroup(any[Int], any[Int])) thenAnswer { + (id: Int, jobGroupId: Int) => + val resultsDir = + Path(backend.serviceBackendContext.remoteTmpDir) / + "parallelizeAndComputeWithIndex" / + tokenUrlSafe + + resultsDir.createDirectory() + for (i <- successes) (resultsDir / f"result.$i").toFile.writeAll("11") + + JobGroupResponse( + batch_id = id, + job_group_id = jobGroupId, + state = Cancelled, + complete = true, + n_jobs = contexts.length, + n_completed = contexts.length, + n_succeeded = successes.length, + n_failed = 0, + n_cancelled = contexts.length - successes.length, + ) + } + when(batchClient.getJobGroupJobs(any[Int], any[Int], any[Option[JobState]])) thenAnswer { + (batchId: Int, _: Int, s: Option[JobState]) => + s match { + case Some(JobStates.Success) => + Stream(successes.map(i => + JobListEntry(batchId, i + startJobGroupId, JobStates.Success, 1) + ).toIndexedSeq) + } + } + + val (failure, result) = + backend.parallelizeAndComputeWithIndex( + backend.serviceBackendContext, + backend.fs, + contexts, + "stage1", + )((bytes, _, _, _) => bytes) + failure.value shouldBe a[CancellationException] + result.map(_._2) shouldBe successes + } + } + def ServiceBackend(client: BatchClient, rpcConfig: ServiceBackendRPCPayload): ServiceBackend = { val flags = HailFeatureFlags.fromEnv() val fs = RouterFS.buildRoutes(CloudStorageFSConfig()) diff --git a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala index 7da061855a5..50bb2557912 100644 --- a/hail/src/test/scala/is/hail/services/BatchClientSuite.scala +++ b/hail/src/test/scala/is/hail/services/BatchClientSuite.scala @@ -19,7 +19,7 @@ class BatchClientSuite extends TestNGSuite { @BeforeClass def createClientAndBatch(): Unit = { - client = BatchClient(DeployConfig.get(), Path.of("/tmp/test-gsa-key/key.json")) + client = BatchClient(DeployConfig.get(), Path.of("/test-gsa-key/key.json")) batchId = client.newBatch( BatchRequest( billing_project = "test", @@ -40,7 +40,7 @@ class BatchClientSuite extends TestNGSuite { attributes = Map("name" -> m.getName), jobs = FastSeq(), ) - ) + )._1 } @AfterClass @@ -49,7 +49,7 @@ class BatchClientSuite extends TestNGSuite { @Test def testCancelAfterNFailures(): Unit = { - val jobGroupId = client.newJobGroup( + val (jobGroupId, _) = client.newJobGroup( req = JobGroupRequest( batch_id = batchId, absolute_parent_id = parentJobGroupId, @@ -84,46 +84,19 @@ class BatchClientSuite extends TestNGSuite { val jobGroup = client.getJobGroup(8218901, 2) assert(jobGroup.n_jobs == 2) assert(jobGroup.n_failed == 1) - assert(client.getJobGroupJobs(8218901, 2).length == 2) - for (state <- Array(JobStates.Failed, JobStates.Success)) { - val jobs = client.getJobGroupJobs(8218901, 2, Some(state)) - assert(jobs.length == 1) - assert(jobs(0).state == state) - } - } - - @Test - def testGetJobs(): Unit = { - val jobGroupId = client.newJobGroup( - req = JobGroupRequest( - batch_id = batchId, - absolute_parent_id = parentJobGroupId, - token = tokenUrlSafe, - jobs = IndexedSeq( - JobRequest( - always_run = false, - attributes = Map("foo" -> "bar"), - process = BashJob( - image = "ubuntu:22.04", - command = Array("/bin/bash", "-c", s"exit 0"), - ), - ) - ), - ) - ) - val jobGroupJobs = client.getJobGroupJobs(batchId, jobGroupId) - for (entry <- jobGroupJobs) { - val job = client.getJob(batchId, entry.job_id) - assert(job.attributes.isDefined) - assert(job.attributes.get == Map("foo" -> "bar")) - } + assert(client.getJobGroupJobs(8218901, 2).head.length == 2) + for (state <- Array(JobStates.Failed, JobStates.Success)) + for (jobs <- client.getJobGroupJobs(8218901, 2, Some(state))) { + assert(jobs.length == 1) + assert(jobs(0).state == state) + } } @Test def testNewJobGroup(): Unit = // The query driver submits a job group per stage with one job per partition for (i <- 1 to 2) { - val jobGroupId = client.newJobGroup( + val (jobGroupId, _) = client.newJobGroup( req = JobGroupRequest( batch_id = batchId, absolute_parent_id = parentJobGroupId, @@ -147,7 +120,7 @@ class BatchClientSuite extends TestNGSuite { @Test def testJvmJob(): Unit = { - val jobGroupId = client.newJobGroup( + val (jobGroupId, _) = client.newJobGroup( req = JobGroupRequest( batch_id = batchId, absolute_parent_id = parentJobGroupId,