diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index 618fafd82a53..eb5fbbcb5735 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -12,11 +12,10 @@ import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs._ import is.hail.io.plink.LoadPlink import is.hail.io.vcf.LoadVCF -import is.hail.linalg.BlockMatrix import is.hail.types._ import is.hail.types.encoded.EType import is.hail.types.physical.PTuple -import is.hail.types.virtual.{BlockMatrixType, TFloat64} +import is.hail.types.virtual.TFloat64 import is.hail.utils._ import is.hail.variant.ReferenceGenome @@ -90,15 +89,6 @@ abstract class Backend extends Closeable { def broadcast[T: ClassTag](value: T): BroadcastValue[T] - def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String) - : Unit - - def unpersist(backendContext: BackendContext, id: String): Unit - - def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix - - def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType - def parallelizeAndComputeWithIndex( backendContext: BackendContext, fs: FS, diff --git a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala index 1babbe897fc8..7012d3db2e58 100644 --- a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala +++ b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala @@ -6,6 +6,7 @@ import is.hail.asm4s.HailClassLoader import is.hail.backend.local.LocalTaskContext import is.hail.expr.ir.lowering.IrMetadata import is.hail.io.fs.FS +import is.hail.linalg.BlockMatrix import is.hail.utils._ import is.hail.variant.ReferenceGenome @@ -71,6 +72,7 @@ object ExecuteContext { flags: HailFeatureFlags, backendContext: BackendContext, irMetadata: IrMetadata, + blockMatrixCache: mutable.Map[String, BlockMatrix], )( f: ExecuteContext => T ): T = { @@ -89,6 +91,7 @@ object ExecuteContext { flags, backendContext, irMetadata, + blockMatrixCache, ))(f(_)) } } @@ -118,6 +121,7 @@ class ExecuteContext( val flags: HailFeatureFlags, val backendContext: BackendContext, val irMetadata: IrMetadata, + val BlockMatrixCache: mutable.Map[String, BlockMatrix], ) extends Closeable { val rngNonce: Long = @@ -186,6 +190,7 @@ class ExecuteContext( flags: HailFeatureFlags = this.flags, backendContext: BackendContext = this.backendContext, irMetadata: IrMetadata = this.irMetadata, + blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache, )( f: ExecuteContext => A ): A = @@ -202,5 +207,6 @@ class ExecuteContext( flags, backendContext, irMetadata, + blockMatrixCache, ))(f) } diff --git a/hail/src/main/scala/is/hail/backend/caching/BlockMatrixCache.scala b/hail/src/main/scala/is/hail/backend/caching/BlockMatrixCache.scala new file mode 100644 index 000000000000..db505a29239f --- /dev/null +++ b/hail/src/main/scala/is/hail/backend/caching/BlockMatrixCache.scala @@ -0,0 +1,30 @@ +package is.hail.backend.caching + +import is.hail.linalg.BlockMatrix + +import scala.collection.mutable + +class BlockMatrixCache extends mutable.AbstractMap[String, BlockMatrix] with AutoCloseable { + + private[this] val blockmatrices: mutable.Map[String, BlockMatrix] = + mutable.LinkedHashMap.empty + + override def +=(kv: (String, BlockMatrix)): BlockMatrixCache.this.type = { + blockmatrices += kv; this + } + + override def -=(key: String): BlockMatrixCache.this.type = { + get(key).foreach { bm => bm.unpersist(); blockmatrices -= key }; this + } + + override def get(key: String): Option[BlockMatrix] = + blockmatrices.get(key) + + override def iterator: Iterator[(String, BlockMatrix)] = + blockmatrices.iterator + + override def close(): Unit = { + blockmatrices.values.foreach(_.unpersist()) + blockmatrices.clear() + } +} diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index bea8191bd907..9d666ae62ef7 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -10,11 +10,10 @@ import is.hail.expr.ir._ import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.lowering._ import is.hail.io.fs._ -import is.hail.linalg.BlockMatrix import is.hail.types._ import is.hail.types.physical.PTuple import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType -import is.hail.types.virtual.{BlockMatrixType, TVoid} +import is.hail.types.virtual.TVoid import is.hail.utils._ import is.hail.variant.ReferenceGenome @@ -100,6 +99,7 @@ class LocalBackend( ExecutionCache.fromFlags(flags, fs, tmpdir) }, new IrMetadata(), + mutable.Map.empty, )(f) } @@ -202,15 +202,6 @@ class LocalBackend( ): TableReader = LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt, nPartitions) - def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String) - : Unit = ??? - - def unpersist(backendContext: BackendContext, id: String): Unit = ??? - - def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix = ??? - - def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType = ??? - def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) : TableStage = LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses) diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 8b58067b61be..78a54c0d15d5 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -14,7 +14,6 @@ import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ import is.hail.io.fs._ import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} -import is.hail.linalg.BlockMatrix import is.hail.services.{BatchClient, JobGroupRequest, _} import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success} import is.hail.types._ @@ -374,15 +373,6 @@ class ServiceBackend( ): TableReader = LowerDistributedSort.distributedSort(ctx, inputStage, sortFields, rt, nPartitions) - def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String) - : Unit = ??? - - def unpersist(backendContext: BackendContext, id: String): Unit = ??? - - def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix = ??? - - def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType = ??? - def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) : TableStage = LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses) @@ -401,6 +391,7 @@ class ServiceBackend( flags, serviceBackendContext, new IrMetadata(), + mutable.Map.empty, )(f) } diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index 94a1020734a1..fd1db82c26cc 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -4,6 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ +import is.hail.backend.caching.BlockMatrixCache import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ @@ -11,7 +12,6 @@ import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.lowering._ import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs._ -import is.hail.linalg.BlockMatrix import is.hail.rvd.RVD import is.hail.types._ import is.hail.types.physical.{PStruct, PTuple} @@ -338,20 +338,8 @@ class SparkBackend( override val longLifeTempFileManager: TempFileManager = new OwningTempFileManager(fs) - val bmCache: SparkBlockMatrixCache = SparkBlockMatrixCache() - - def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String) - : Unit = bmCache.persistBlockMatrix(id, value, storageLevel) - - def unpersist(backendContext: BackendContext, id: String): Unit = unpersist(id) - - def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix = - bmCache.getPersistedBlockMatrix(id) - - def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType = - bmCache.getPersistedBlockMatrixType(id) - - def unpersist(id: String): Unit = bmCache.unpersistBlockMatrix(id) + private[this] val bmCache: BlockMatrixCache = + new BlockMatrixCache() def createExecuteContextForTests( timer: ExecutionTimer, @@ -374,6 +362,7 @@ class SparkBackend( ExecutionCache.forTesting }, new IrMetadata(), + null, ) override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = @@ -393,6 +382,7 @@ class SparkBackend( ExecutionCache.fromFlags(flags, fs, tmpdir) }, new IrMetadata(), + bmCache, )(f) } @@ -457,6 +447,7 @@ class SparkBackend( override def asSpark(op: String): SparkBackend = this def close(): Unit = { + bmCache.close() SparkBackend.stop() longLifeTempFileManager.close() } diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBlockMatrixCache.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBlockMatrixCache.scala deleted file mode 100644 index b512902e37d2..000000000000 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBlockMatrixCache.scala +++ /dev/null @@ -1,25 +0,0 @@ -package is.hail.backend.spark - -import is.hail.linalg.BlockMatrix -import is.hail.types.virtual.BlockMatrixType -import is.hail.utils._ - -import scala.collection.mutable - -case class SparkBlockMatrixCache() { - private[this] val blockmatrices: mutable.Map[String, BlockMatrix] = new mutable.HashMap() - - def persistBlockMatrix(id: String, value: BlockMatrix, storageLevel: String): Unit = - blockmatrices.update(id, value.persist(storageLevel)) - - def getPersistedBlockMatrix(id: String): BlockMatrix = - blockmatrices.getOrElse(id, fatal(s"Persisted BlockMatrix with id $id does not exist.")) - - def getPersistedBlockMatrixType(id: String): BlockMatrixType = - BlockMatrixType.fromBlockMatrix(getPersistedBlockMatrix(id)) - - def unpersistBlockMatrix(id: String): Unit = { - getPersistedBlockMatrix(id).unpersist() - blockmatrices.remove(id) - } -} diff --git a/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala index b7cbe20dcbbe..f08c36cadb2e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala @@ -1,8 +1,7 @@ package is.hail.expr.ir -import is.hail.HailContext import is.hail.annotations.NDArray -import is.hail.backend.{BackendContext, ExecuteContext} +import is.hail.backend.ExecuteContext import is.hail.expr.Nat import is.hail.expr.ir.lowering.{BMSContexts, BlockMatrixStage2, LowererUnsupportedOperation} import is.hail.io.{StreamBufferSpec, TypedCodecSpec} @@ -106,7 +105,7 @@ object BlockMatrixReader { def fromJValue(ctx: ExecuteContext, jv: JValue): BlockMatrixReader = (jv \ "name").extract[String] match { case "BlockMatrixNativeReader" => BlockMatrixNativeReader.fromJValue(ctx.fs, jv) - case "BlockMatrixPersistReader" => BlockMatrixPersistReader.fromJValue(ctx.backendContext, jv) + case "BlockMatrixPersistReader" => BlockMatrixPersistReader.fromJValue(ctx, jv) case _ => jv.extract[BlockMatrixReader] } } @@ -274,12 +273,12 @@ case class BlockMatrixBinaryReader(path: String, shape: IndexedSeq[Long], blockS case class BlockMatrixNativePersistParameters(id: String) object BlockMatrixPersistReader { - def fromJValue(ctx: BackendContext, jv: JValue): BlockMatrixPersistReader = { + def fromJValue(ctx: ExecuteContext, jv: JValue): BlockMatrixPersistReader = { implicit val formats: Formats = BlockMatrixReader.formats val params = jv.extract[BlockMatrixNativePersistParameters] BlockMatrixPersistReader( params.id, - HailContext.backend.getPersistedBlockMatrixType(ctx, params.id), + BlockMatrixType.fromBlockMatrix(ctx.BlockMatrixCache(params.id)), ) } } @@ -287,9 +286,7 @@ object BlockMatrixPersistReader { case class BlockMatrixPersistReader(id: String, typ: BlockMatrixType) extends BlockMatrixReader { def pathsUsed: Seq[String] = FastSeq() lazy val fullType: BlockMatrixType = typ - - def apply(ctx: ExecuteContext): BlockMatrix = - HailContext.backend.getPersistedBlockMatrix(ctx.backendContext, id) + def apply(ctx: ExecuteContext): BlockMatrix = ctx.BlockMatrixCache(id) } case class BlockMatrixMap(child: BlockMatrixIR, eltName: Name, f: IR, needsDense: Boolean) diff --git a/hail/src/main/scala/is/hail/expr/ir/BlockMatrixWriter.scala b/hail/src/main/scala/is/hail/expr/ir/BlockMatrixWriter.scala index b6d36fcee658..6342f7fef6f5 100644 --- a/hail/src/main/scala/is/hail/expr/ir/BlockMatrixWriter.scala +++ b/hail/src/main/scala/is/hail/expr/ir/BlockMatrixWriter.scala @@ -1,6 +1,5 @@ package is.hail.expr.ir -import is.hail.HailContext import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext @@ -190,7 +189,7 @@ case class BlockMatrixPersistWriter(id: String, storageLevel: String) extends Bl def pathOpt: Option[String] = None def apply(ctx: ExecuteContext, bm: BlockMatrix): Unit = - HailContext.backend.persist(ctx.backendContext, id, bm, storageLevel) + ctx.BlockMatrixCache += id -> bm.persist(storageLevel) def loweredTyp: Type = TVoid } diff --git a/hail/src/main/scala/is/hail/utils/ErrorHandling.scala b/hail/src/main/scala/is/hail/utils/ErrorHandling.scala index 176df006080a..59878c0be579 100644 --- a/hail/src/main/scala/is/hail/utils/ErrorHandling.scala +++ b/hail/src/main/scala/is/hail/utils/ErrorHandling.scala @@ -18,7 +18,7 @@ class HailWorkerException( trait ErrorHandling { def fatal(msg: String): Nothing = throw new HailException(msg) - def fatal(msg: String, errorId: Int) = throw new HailException(msg, errorId) + def fatal(msg: String, errorId: Int): Nothing = throw new HailException(msg, errorId) def fatal(msg: String, cause: Throwable): Nothing = throw new HailException(msg, None, cause) diff --git a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala index 75a64588c487..e7855248d146 100644 --- a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala @@ -5,13 +5,13 @@ import is.hail.ExecStrategy.ExecStrategy import is.hail.TestUtils._ import is.hail.annotations.{BroadcastRow, ExtendedOrdering, SafeNDArray} import is.hail.backend.ExecuteContext +import is.hail.backend.caching.BlockMatrixCache import is.hail.expr.Nat import is.hail.expr.ir.ArrayZipBehavior.ArrayZipBehavior import is.hail.expr.ir.agg._ import is.hail.expr.ir.functions._ import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.bgen.MatrixBGENReader -import is.hail.linalg.BlockMatrix import is.hail.methods._ import is.hail.rvd.{PartitionBoundOrdering, RVD, RVDPartitioner} import is.hail.types.{tcoerce, VirtualTypeWithReq} @@ -3906,17 +3906,26 @@ class IRSuite extends HailSuite { assert(x2 == x) } - def testBlockMatrixIRParserPersist(): Unit = { - val bm = BlockMatrix.fill(1, 1, 0.0, 5) - backend.persist(ctx.backendContext, "x", bm, "MEMORY_ONLY") - val persist = - BlockMatrixRead(BlockMatrixPersistReader("x", BlockMatrixType.fromBlockMatrix(bm))) + @Test def testBlockMatrixIRParserPersist(): Unit = + using(new BlockMatrixCache()) { cache => + val bm = BlockMatrixRandom(0, gaussian = true, shape = Array(5L, 6L), blockSize = 3) - val s = Pretty.sexprStyle(persist, elideLiterals = false) - val x2 = IRParser.parse_blockmatrix_ir(ctx, s) - assert(x2 == persist) - backend.unpersist(ctx.backendContext, "x") - } + backend.withExecuteContext { ctx => + ctx.local(blockMatrixCache = cache) { ctx => + backend.execute(ctx, BlockMatrixWrite(bm, BlockMatrixPersistWriter("x", "MEMORY_ONLY"))) + } + } + + backend.withExecuteContext { ctx => + ctx.local(blockMatrixCache = cache) { ctx => + val persist = BlockMatrixRead(BlockMatrixPersistReader("x", bm.typ)) + + val s = Pretty.sexprStyle(persist, elideLiterals = false) + val x2 = IRParser.parse_blockmatrix_ir(ctx, s) + assert(x2 == persist) + } + } + } @Test def testCachedIR(): Unit = { val cached = Literal(TSet(TInt32), Set(1))