diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index e0b448bef2d1..8f04b754f856 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -4,8 +4,7 @@ import is.hail.asm4s._ import is.hail.backend.Backend.jsonToBytes import is.hail.backend.spark.SparkBackend import is.hail.expr.ir.{ - BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses, - SortField, TableIR, TableReader, + BaseIR, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader, } import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.io.{BufferSpec, TypedCodecSpec} @@ -107,9 +106,6 @@ abstract class Backend extends Closeable { def shouldCacheQueryInfo: Boolean = true - def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) - : CompiledFunction[T] - def lowerDistributedSort( ctx: ExecuteContext, stage: TableStage, @@ -208,23 +204,3 @@ abstract class Backend extends Closeable { def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] } - -trait BackendWithCodeCache { - private[this] val codeCache: Cache[CodeCacheKey, CompiledFunction[_]] = new Cache(50) - - def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) - : CompiledFunction[T] = { - codeCache.get(k) match { - case Some(v) => v.asInstanceOf[CompiledFunction[T]] - case None => - val compiledFunction = f - codeCache += ((k, compiledFunction)) - compiledFunction - } - } -} - -trait BackendWithNoCodeCache { - def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) - : CompiledFunction[T] = f -} diff --git a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala index 7012d3db2e58..42112522ffa0 100644 --- a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala +++ b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala @@ -4,6 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations.{Region, RegionPool} import is.hail.asm4s.HailClassLoader import is.hail.backend.local.LocalTaskContext +import is.hail.expr.ir.{CodeCacheKey, CompiledFunction} import is.hail.expr.ir.lowering.IrMetadata import is.hail.io.fs.FS import is.hail.linalg.BlockMatrix @@ -73,6 +74,7 @@ object ExecuteContext { backendContext: BackendContext, irMetadata: IrMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix], + codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], )( f: ExecuteContext => T ): T = { @@ -92,6 +94,7 @@ object ExecuteContext { backendContext, irMetadata, blockMatrixCache, + codeCache, ))(f(_)) } } @@ -122,6 +125,7 @@ class ExecuteContext( val backendContext: BackendContext, val irMetadata: IrMetadata, val BlockMatrixCache: mutable.Map[String, BlockMatrix], + val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], ) extends Closeable { val rngNonce: Long = @@ -191,6 +195,7 @@ class ExecuteContext( backendContext: BackendContext = this.backendContext, irMetadata: IrMetadata = this.irMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache, + codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache, )( f: ExecuteContext => A ): A = @@ -208,5 +213,6 @@ class ExecuteContext( backendContext, irMetadata, blockMatrixCache, + codeCache, ))(f) } diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index 7a34954d243b..09983e0f6982 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -70,13 +70,14 @@ object LocalBackend { class LocalBackend( val tmpdir: String, override val references: mutable.Map[String, ReferenceGenome], -) extends Backend with BackendWithCodeCache with Py4JBackendExtensions { +) extends Backend with Py4JBackendExtensions { override def backend: Backend = this override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv() override def longLifeTempFileManager: TempFileManager = null - private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) + private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader) + private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) // flags can be set after construction from python def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) @@ -100,6 +101,7 @@ class LocalBackend( }, new IrMetadata(), ImmutableMap.empty, + codeCache, )(f) } diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 62c6099dc45c..6a34503fdd08 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -51,7 +51,6 @@ class ServiceBackendContext( ) extends BackendContext with Serializable {} object ServiceBackend { - private val log = Logger.getLogger(getClass.getName()) def apply( jarLocation: String, @@ -132,8 +131,7 @@ class ServiceBackend( val fs: FS, val serviceBackendContext: ServiceBackendContext, val scratchDir: String, -) extends Backend with BackendWithNoCodeCache { - import ServiceBackend.log +) extends Backend with Logging { private[this] var stageCount = 0 private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000 @@ -388,6 +386,7 @@ class ServiceBackend( serviceBackendContext, new IrMetadata(), ImmutableMap.empty, + mutable.Map.empty, )(f) } diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index f392b809ced9..3380e76a4cda 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -307,7 +307,7 @@ class SparkBackend( override val references: mutable.Map[String, ReferenceGenome], gcsRequesterPaysProject: String, gcsRequesterPaysBuckets: String, -) extends Backend with BackendWithCodeCache with Py4JBackendExtensions { +) extends Backend with Py4JBackendExtensions { assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null) lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate() @@ -338,8 +338,8 @@ class SparkBackend( override val longLifeTempFileManager: TempFileManager = new OwningTempFileManager(fs) - private[this] val bmCache: BlockMatrixCache = - new BlockMatrixCache() + private[this] val bmCache = new BlockMatrixCache() + private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) def createExecuteContextForTests( timer: ExecutionTimer, @@ -363,6 +363,7 @@ class SparkBackend( }, new IrMetadata(), ImmutableMap.empty, + mutable.Map.empty, ) override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = @@ -383,6 +384,7 @@ class SparkBackend( }, new IrMetadata(), bmCache, + codeCache, )(f) } diff --git a/hail/src/main/scala/is/hail/expr/ir/Compile.scala b/hail/src/main/scala/is/hail/expr/ir/Compile.scala index d6cff956a383..b54537947b12 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Compile.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Compile.scala @@ -44,53 +44,49 @@ object Compile { ): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) = ctx.time { val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true) - val k = - CodeCacheKey(FastSeq[AggStateSig](), params.map { case (n, pt) => (n, pt) }, normalizedBody) - (ctx.backend.lookupOrCompileCachedFunction[F](k) { - - var ir = body - ir = Subst( - ir, - BindingEnv(params - .zipWithIndex - .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }), - ) - ir = - LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx) - - TypeCheck(ctx, ir, BindingEnv.empty) - - val returnParam = CodeParamType(SingleCodeType.typeInfoFromType(ir.typ)) - - val fb = EmitFunctionBuilder[F]( - ctx, - "Compiled", - CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => - pt - }, - returnParam, - Some("Emit.scala"), - ) - - /* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${ - * x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) } - * } - * - * visit(ir) } */ - - assert( - fb.mb.parameterTypeInfo == expectedCodeParamTypes, - s"expected $expectedCodeParamTypes, got ${fb.mb.parameterTypeInfo}", - ) - assert( - fb.mb.returnTypeInfo == expectedCodeReturnType, - s"expected $expectedCodeReturnType, got ${fb.mb.returnTypeInfo}", - ) - - val emitContext = EmitContext.analyze(ctx, ir) - val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length) - CompiledFunction(rt, fb.resultWithIndex(print)) - }).tuple + ctx.CodeCache.getOrElseUpdate( + CodeCacheKey(FastSeq(), params.map { case (n, pt) => (n, pt) }, normalizedBody), { + var ir = body + ir = Subst( + ir, + BindingEnv(params + .zipWithIndex + .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }), + ) + ir = LoweringPipeline.compileLowerer(optimize)(ctx, ir).asInstanceOf[IR].noSharing(ctx) + + TypeCheck(ctx, ir) + + val fb = EmitFunctionBuilder[F]( + ctx, + "Compiled", + CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => + pt + }, + CodeParamType(SingleCodeType.typeInfoFromType(ir.typ)), + Some("Emit.scala"), + ) + + /* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${ + * x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) + * } } + * + * visit(ir) } */ + + assert( + fb.mb.parameterTypeInfo == expectedCodeParamTypes, + s"expected $expectedCodeParamTypes, got ${fb.mb.parameterTypeInfo}", + ) + assert( + fb.mb.returnTypeInfo == expectedCodeReturnType, + s"expected $expectedCodeReturnType, got ${fb.mb.returnTypeInfo}", + ) + + val emitContext = EmitContext.analyze(ctx, ir) + val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length) + CompiledFunction(rt, fb.resultWithIndex(print)) + }, + ).asInstanceOf[CompiledFunction[F]].tuple } } @@ -108,55 +104,44 @@ object CompileWithAggregators { (HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion), ) = ctx.time { - val normalizedBody = - NormalizeNames(ctx, body, allowFreeVariables = true) - val k = CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody) - (ctx.backend.lookupOrCompileCachedFunction[F with FunctionWithAggRegion](k) { - - var ir = body - ir = Subst( - ir, - BindingEnv(params - .zipWithIndex - .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }), - ) - ir = - LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx) - - TypeCheck( - ctx, - ir, - BindingEnv(Env.fromSeq[Type](params.map { case (name, t) => name -> t.virtualType })), - ) - - val fb = EmitFunctionBuilder[F]( - ctx, - "CompiledWithAggs", - CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => pt }, - SingleCodeType.typeInfoFromType(ir.typ), - Some("Emit.scala"), - ) - - /* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${ - * x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) } - * } - * - * visit(ir) } */ - - val emitContext = EmitContext.analyze(ctx, ir) - val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, Some(aggSigs)) - - val f = fb.resultWithIndex() - CompiledFunction( - rt, - f.asInstanceOf[( - HailClassLoader, - FS, - HailTaskContext, - Region, - ) => (F with FunctionWithAggRegion)], - ) - }).tuple + val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true) + ctx.CodeCache.getOrElseUpdate( + CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody), { + var ir = body + ir = Subst( + ir, + BindingEnv(params + .zipWithIndex + .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }), + ) + ir = + LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx) + + TypeCheck( + ctx, + ir, + BindingEnv(Env.fromSeq[Type](params.map { case (name, t) => name -> t.virtualType })), + ) + + val fb = EmitFunctionBuilder[F with FunctionWithAggRegion]( + ctx, + "CompiledWithAggs", + CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => pt }, + SingleCodeType.typeInfoFromType(ir.typ), + Some("Emit.scala"), + ) + + /* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${ + * x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) + * } } + * + * visit(ir) } */ + + val emitContext = EmitContext.analyze(ctx, ir) + val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, Some(aggSigs)) + CompiledFunction(rt, fb.resultWithIndex()) + }, + ).asInstanceOf[CompiledFunction[F with FunctionWithAggRegion]].tuple } } diff --git a/hail/src/main/scala/is/hail/utils/Cache.scala b/hail/src/main/scala/is/hail/utils/Cache.scala index 3aa40a6c5473..8e924e7ed02c 100644 --- a/hail/src/main/scala/is/hail/utils/Cache.scala +++ b/hail/src/main/scala/is/hail/utils/Cache.scala @@ -2,20 +2,29 @@ package is.hail.utils import is.hail.annotations.{Region, RegionMemory} +import scala.collection.mutable +import scala.jdk.CollectionConverters.asScalaIteratorConverter + import java.io.Closeable import java.util import java.util.Map.Entry -class Cache[K, V](capacity: Int) { +class Cache[K, V](capacity: Int) extends mutable.AbstractMap[K, V] { private[this] val m = new util.LinkedHashMap[K, V](capacity, 0.75f, true) { override def removeEldestEntry(eldest: Entry[K, V]): Boolean = size() > capacity } - def get(k: K): Option[V] = synchronized(Option(m.get(k))) + override def +=(kv: (K, V)): Cache.this.type = + synchronized { m.put(kv._1, kv._2); this } + + override def -=(key: K): Cache.this.type = + synchronized { m.remove(key); this } - def +=(p: (K, V)): Unit = synchronized(m.put(p._1, p._2)) + override def get(key: K): Option[V] = + synchronized(Option(m.get(key))) - def size: Int = synchronized(m.size()) + override def iterator: Iterator[(K, V)] = + for { e <- m.entrySet().iterator().asScala } yield (e.getKey, e.getValue) } class LongToRegionValueCache(capacity: Int) extends Closeable {