From 5763812744f6a2be2a1740ac8c25a7d433a66cd5 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Thu, 19 Sep 2024 20:03:43 -0400 Subject: [PATCH] [query] Move LoweredTableReaderCoercer into ExecuteContext --- .../main/scala/is/hail/backend/Backend.scala | 2 - .../is/hail/backend/ExecuteContext.scala | 6 + .../is/hail/backend/caching/package.scala | 16 + .../is/hail/backend/local/LocalBackend.scala | 3 + .../hail/backend/service/ServiceBackend.scala | 10 +- .../is/hail/backend/spark/SparkBackend.scala | 8 +- .../is/hail/expr/ir/GenericTableValue.scala | 35 +- .../main/scala/is/hail/expr/ir/TableIR.scala | 678 +++++++++--------- 8 files changed, 379 insertions(+), 379 deletions(-) create mode 100644 hail/src/main/scala/is/hail/backend/caching/package.scala diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index ecdd74e7d0ae..7c53294e9417 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -86,8 +86,6 @@ abstract class Backend extends Closeable { def asSpark(implicit E: Enclosing): SparkBackend = fatal(s"${getClass.getSimpleName}: ${E.value} requires SparkBackend") - def shouldCacheQueryInfo: Boolean = true - def lowerDistributedSort( ctx: ExecuteContext, stage: TableStage, diff --git a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala index 383b6c296b3f..0b570ab6a61a 100644 --- a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala +++ b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala @@ -5,6 +5,7 @@ import is.hail.annotations.{Region, RegionPool} import is.hail.asm4s.HailClassLoader import is.hail.backend.local.LocalTaskContext import is.hail.expr.ir.{BaseIR, CodeCacheKey, CompiledFunction} +import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.lowering.IrMetadata import is.hail.io.fs.FS import is.hail.linalg.BlockMatrix @@ -74,6 +75,7 @@ object ExecuteContext { blockMatrixCache: mutable.Map[String, BlockMatrix], codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], irCache: mutable.Map[Int, BaseIR], + coercerCache: mutable.Map[Any, LoweredTableReaderCoercer], )( f: ExecuteContext => T ): T = { @@ -95,6 +97,7 @@ object ExecuteContext { blockMatrixCache, codeCache, irCache, + coercerCache, ))(f(_)) } } @@ -127,6 +130,7 @@ class ExecuteContext( val BlockMatrixCache: mutable.Map[String, BlockMatrix], val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], val IrCache: mutable.Map[Int, BaseIR], + val CoercerCache: mutable.Map[Any, LoweredTableReaderCoercer], ) extends Closeable { val rngNonce: Long = @@ -199,6 +203,7 @@ class ExecuteContext( blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache, codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache, irCache: mutable.Map[Int, BaseIR] = this.IrCache, + coercerCache: mutable.Map[Any, LoweredTableReaderCoercer] = this.CoercerCache, )( f: ExecuteContext => A ): A = @@ -218,5 +223,6 @@ class ExecuteContext( blockMatrixCache, codeCache, irCache, + coercerCache, ))(f) } diff --git a/hail/src/main/scala/is/hail/backend/caching/package.scala b/hail/src/main/scala/is/hail/backend/caching/package.scala new file mode 100644 index 000000000000..316708421439 --- /dev/null +++ b/hail/src/main/scala/is/hail/backend/caching/package.scala @@ -0,0 +1,16 @@ +package is.hail.backend + +import scala.collection.mutable + +package object caching { + private[this] object NoCachingInstance extends mutable.AbstractMap[Any, Any] { + override def +=(kv: (Any, Any)): NoCachingInstance.this.type = this + override def -=(key: Any): NoCachingInstance.this.type = this + override def get(key: Any): Option[Any] = None + override def iterator: Iterator[(Any, Any)] = Iterator.empty + override def getOrElseUpdate(key: Any, op: => Any): Any = op + } + + def NoCaching[K, V]: mutable.Map[K, V] = + NoCachingInstance.asInstanceOf[mutable.Map[K, V]] +} diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index 1781dab949c4..e8da0ad9c95b 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -7,6 +7,7 @@ import is.hail.backend._ import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ +import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering._ @@ -81,6 +82,7 @@ class LocalBackend( private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader) private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map() + private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32) // flags can be set after construction from python def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) @@ -106,6 +108,7 @@ class LocalBackend( ImmutableMap.empty, codeCache, persistedIR, + coercerCache, )(f) } diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 9736cb4f3eb3..1d87c8d8a0f2 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -4,6 +4,7 @@ import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ +import is.hail.backend.caching.NoCaching import is.hail.backend.service.ServiceBackend.MaxAvailableGcsConnections import is.hail.expr.Validate import is.hail.expr.ir.{ @@ -64,8 +65,6 @@ class ServiceBackend( private[this] var stageCount = 0 private[this] val executor = Executors.newFixedThreadPool(MaxAvailableGcsConnections) - override def shouldCacheQueryInfo: Boolean = false - def defaultParallelism: Int = 4 def broadcast[T: ClassTag](_value: T): BroadcastValue[T] = { @@ -322,9 +321,10 @@ class ServiceBackend( ), new IrMetadata(), references, - ImmutableMap.empty, - mutable.Map.empty, - ImmutableMap.empty, + NoCaching, + NoCaching, + NoCaching, + NoCaching, )(f) } } diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index 69fc39bf47c0..cdd57b45bad3 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -4,10 +4,11 @@ import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ -import is.hail.backend.caching.BlockMatrixCache +import is.hail.backend.caching.{BlockMatrixCache, NoCaching} import is.hail.backend.py4j.Py4JBackendExtensions import is.hail.expr.Validate import is.hail.expr.ir._ +import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.compile.Compile import is.hail.expr.ir.lowering._ @@ -343,6 +344,7 @@ class SparkBackend( private[this] val bmCache = new BlockMatrixCache() private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50) private[this] val persistedIr = mutable.Map.empty[Int, BaseIR] + private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32) def createExecuteContextForTests( timer: ExecutionTimer, @@ -365,8 +367,9 @@ class SparkBackend( new IrMetadata(), references, ImmutableMap.empty, - mutable.Map.empty, + NoCaching, ImmutableMap.empty, + NoCaching, ) override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = @@ -389,6 +392,7 @@ class SparkBackend( bmCache, codeCache, persistedIr, + coercerCache, )(f) } diff --git a/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala b/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala index 4758ac92cc03..37f7cc23b954 100644 --- a/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala +++ b/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala @@ -3,6 +3,7 @@ package is.hail.expr.ir import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext +import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer import is.hail.expr.ir.functions.UtilFunctions import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.expr.ir.streams.StreamProducer @@ -143,16 +144,6 @@ class PartitionIteratorLongReader( ) } -abstract class LoweredTableReaderCoercer { - def coerce( - ctx: ExecuteContext, - globals: IR, - contextType: Type, - contexts: IndexedSeq[Any], - body: IR => IR, - ): TableStage -} - class GenericTableValue( val fullTableType: TableType, val uidFieldName: String, @@ -168,12 +159,11 @@ class GenericTableValue( assert(contextType.hasField("partitionIndex")) assert(contextType.fieldType("partitionIndex") == TInt32) - private var ltrCoercer: LoweredTableReaderCoercer = _ - private def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any) - : LoweredTableReaderCoercer = { - if (ltrCoercer == null) { - ltrCoercer = LoweredTableReader.makeCoercer( + : LoweredTableReaderCoercer = + ctx.CoercerCache.getOrElseUpdate( + (1, contextType, fullTableType.key, cacheKey), + LoweredTableReader.makeCoercer( ctx, fullTableType.key, 1, @@ -184,11 +174,8 @@ class GenericTableValue( bodyPType, body, context, - cacheKey, - ) - } - ltrCoercer - } + ), + ) def toTableStage(ctx: ExecuteContext, requestedType: TableType, context: String, cacheKey: Any) : TableStage = { @@ -217,11 +204,13 @@ class GenericTableValue( val contextsIR = ToStream(Literal(TArray(contextType), contexts)) TableStage(globalsIR, p, TableStageDependency.none, contextsIR, requestedBody) } else { - getLTVCoercer(ctx, context, cacheKey).coerce( + getLTVCoercer(ctx, context, cacheKey)( ctx, globalsIR, - contextType, contexts, - requestedBody) + contextType, + contexts, + requestedBody, + ) } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index f50b07d597a6..e44e002560b3 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -144,7 +144,8 @@ object TableReader { object LoweredTableReader { - private[this] val coercerCache: Cache[Any, LoweredTableReaderCoercer] = new Cache(32) + type LoweredTableReaderCoercer = + (ExecuteContext, IR, Type, IndexedSeq[Any], IR => IR) => TableStage def makeCoercer( ctx: ExecuteContext, @@ -157,7 +158,6 @@ object LoweredTableReader { bodyPType: (TStruct) => PStruct, keys: (TStruct) => (Region, HailClassLoader, FS, Any) => Iterator[Long], context: String, - cacheKey: Any, ): LoweredTableReaderCoercer = { assert(key.nonEmpty) assert(contexts.nonEmpty) @@ -173,379 +173,363 @@ object LoweredTableReader { def selectPK(k: IR): IR = SelectFields(k, key.take(partitionKey)) - val cacheKeyWithInfo = (partitionKey, keyType, key, cacheKey) - coercerCache.get(cacheKeyWithInfo) match { - case Some(r) => r - case None => - info(s"scanning $context for sortedness...") - val prevkey = AggSignature(PrevNonnull(), FastSeq(), FastSeq(keyType)) - val count = AggSignature(Count(), FastSeq(), FastSeq()) - val samplekey = AggSignature(TakeBy(), FastSeq(TInt32), FastSeq(keyType, TFloat64)) - val sum = AggSignature(Sum(), FastSeq(), FastSeq(TInt64)) - val minkey = AggSignature(TakeBy(), FastSeq(TInt32), FastSeq(keyType, keyType)) - val maxkey = AggSignature(TakeBy(Descending), FastSeq(TInt32), FastSeq(keyType, keyType)) - - val xType = TStruct( - "key" -> keyType, - "token" -> TFloat64, - "prevkey" -> keyType, - ) + info(s"scanning $context for sortedness...") + val prevkey = AggSignature(PrevNonnull(), FastSeq(), FastSeq(keyType)) + val count = AggSignature(Count(), FastSeq(), FastSeq()) + val samplekey = AggSignature(TakeBy(), FastSeq(TInt32), FastSeq(keyType, TFloat64)) + val sum = AggSignature(Sum(), FastSeq(), FastSeq(TInt64)) + val minkey = AggSignature(TakeBy(), FastSeq(TInt32), FastSeq(keyType, keyType)) + val maxkey = AggSignature(TakeBy(Descending), FastSeq(TInt32), FastSeq(keyType, keyType)) + + val xType = TStruct( + "key" -> keyType, + "token" -> TFloat64, + "prevkey" -> keyType, + ) - val keyRef = Ref(freshName(), keyType) - val xRef = Ref(freshName(), xType) - val nRef = Ref(freshName(), TInt64) - - val scanBody = (ctx: IR) => - StreamAgg( - StreamAggScan( - ReadPartition( - ctx, - keyType, - new PartitionIteratorLongReader( - keyType, - uidFieldName, - contextType, - (requestedType: Type) => bodyPType(requestedType.asInstanceOf[TStruct]), - (requestedType: Type) => keys(requestedType.asInstanceOf[TStruct]), + val keyRef = Ref(freshName(), keyType) + val xRef = Ref(freshName(), xType) + val nRef = Ref(freshName(), TInt64) + + val scanBody = (ctx: IR) => + StreamAgg( + StreamAggScan( + ReadPartition( + ctx, + keyType, + new PartitionIteratorLongReader( + keyType, + uidFieldName, + contextType, + (requestedType: Type) => bodyPType(requestedType.asInstanceOf[TStruct]), + (requestedType: Type) => keys(requestedType.asInstanceOf[TStruct]), + ), + ), + keyRef.name, + MakeStruct(FastSeq( + "key" -> keyRef, + "token" -> invokeSeeded( + "rand_unif", + 1, + TFloat64, + RNGStateLiteral(), + F64(0.0), + F64(1.0), + ), + "prevkey" -> ApplyScanOp(FastSeq(), FastSeq(keyRef), prevkey), + )), + ), + xRef.name, + Let( + FastSeq(nRef.name -> ApplyAggOp(FastSeq(), FastSeq(), count)), + AggLet( + keyRef.name, + GetField(xRef, "key"), + MakeStruct(FastSeq( + "n" -> nRef, + "minkey" -> + ApplyAggOp( + FastSeq(I32(1)), + FastSeq(keyRef, keyRef), + minkey, ), - ), - keyRef.name, - MakeStruct(FastSeq( - "key" -> keyRef, - "token" -> invokeSeeded( - "rand_unif", - 1, - TFloat64, - RNGStateLiteral(), - F64(0.0), - F64(1.0), + "maxkey" -> + ApplyAggOp( + FastSeq(I32(1)), + FastSeq(keyRef, keyRef), + maxkey, ), - "prevkey" -> ApplyScanOp(FastSeq(), FastSeq(keyRef), prevkey), - )), - ), - xRef.name, - Let( - FastSeq(nRef.name -> ApplyAggOp(FastSeq(), FastSeq(), count)), - AggLet( - keyRef.name, - GetField(xRef, "key"), - MakeStruct(FastSeq( - "n" -> nRef, - "minkey" -> - ApplyAggOp( - FastSeq(I32(1)), - FastSeq(keyRef, keyRef), - minkey, - ), - "maxkey" -> - ApplyAggOp( - FastSeq(I32(1)), - FastSeq(keyRef, keyRef), - maxkey, - ), - "ksorted" -> - ApplyComparisonOp( - EQ(TInt64), - ApplyAggOp( - FastSeq(), - FastSeq( - invoke( - "toInt64", - TInt64, - invoke( - "lor", - TBoolean, - IsNA(GetField(xRef, "prevkey")), - ApplyComparisonOp( - LTEQ(keyType), - GetField(xRef, "prevkey"), - GetField(xRef, "key"), - ), - ), - ) + "ksorted" -> + ApplyComparisonOp( + EQ(TInt64), + ApplyAggOp( + FastSeq(), + FastSeq( + invoke( + "toInt64", + TInt64, + invoke( + "lor", + TBoolean, + IsNA(GetField(xRef, "prevkey")), + ApplyComparisonOp( + LTEQ(keyType), + GetField(xRef, "prevkey"), + GetField(xRef, "key"), + ), ), - sum, - ), - nRef, + ) ), - "pksorted" -> - ApplyComparisonOp( - EQ(TInt64), - ApplyAggOp( - FastSeq(), - FastSeq( - invoke( - "toInt64", - TInt64, - invoke( - "lor", - TBoolean, - IsNA(selectPK(GetField(xRef, "prevkey"))), - ApplyComparisonOp( - LTEQ(pkType), - selectPK(GetField(xRef, "prevkey")), - selectPK(GetField(xRef, "key")), - ), - ), - ) + sum, + ), + nRef, + ), + "pksorted" -> + ApplyComparisonOp( + EQ(TInt64), + ApplyAggOp( + FastSeq(), + FastSeq( + invoke( + "toInt64", + TInt64, + invoke( + "lor", + TBoolean, + IsNA(selectPK(GetField(xRef, "prevkey"))), + ApplyComparisonOp( + LTEQ(pkType), + selectPK(GetField(xRef, "prevkey")), + selectPK(GetField(xRef, "key")), + ), ), - sum, - ), - nRef, + ) ), - "sample" -> ApplyAggOp( - FastSeq(I32(samplesPerPartition)), - FastSeq(GetField(xRef, "key"), GetField(xRef, "token")), - samplekey, + sum, ), - )), - isScan = false, + nRef, + ), + "sample" -> ApplyAggOp( + FastSeq(I32(samplesPerPartition)), + FastSeq(GetField(xRef, "key"), GetField(xRef, "token")), + samplekey, ), - ), - ) + )), + isScan = false, + ), + ), + ) - val scanResult = cdaIR( - ToStream(Literal(TArray(contextType), contexts)), - MakeStruct(FastSeq()), - "table_coerce_sortedness", - NA(TString), - )((context, _) => scanBody(context)) + val scanResult = cdaIR( + ToStream(Literal(TArray(contextType), contexts)), + MakeStruct(FastSeq()), + "table_coerce_sortedness", + NA(TString), + )((context, _) => scanBody(context)) - val sortedPartDataIR = sortIR(bindIR(scanResult) { scanResult => + val sortedPartDataIR = sortIR(bindIR(scanResult) { scanResult => + mapIR( + filterIR( mapIR( - filterIR( - mapIR( - rangeIR(I32(0), ArrayLen(scanResult)) - ) { i => - InsertFields( - ArrayRef(scanResult, i), - FastSeq("i" -> i), - ) - } - )(row => ArrayLen(GetField(row, "minkey")) > 0) - ) { row => + rangeIR(I32(0), ArrayLen(scanResult)) + ) { i => InsertFields( - row, - FastSeq( - ("minkey", ArrayRef(GetField(row, "minkey"), I32(0))), - ("maxkey", ArrayRef(GetField(row, "maxkey"), I32(0))), - ), + ArrayRef(scanResult, i), + FastSeq("i" -> i), ) } - }) { (l, r) => - ApplyComparisonOp( - LT(TStruct("minkey" -> keyType, "maxkey" -> keyType)), - SelectFields(l, FastSeq("minkey", "maxkey")), - SelectFields(r, FastSeq("minkey", "maxkey")), - ) - } + )(row => ArrayLen(GetField(row, "minkey")) > 0) + ) { row => + InsertFields( + row, + FastSeq( + ("minkey", ArrayRef(GetField(row, "minkey"), I32(0))), + ("maxkey", ArrayRef(GetField(row, "maxkey"), I32(0))), + ), + ) + } + }) { (l, r) => + ApplyComparisonOp( + LT(TStruct("minkey" -> keyType, "maxkey" -> keyType)), + SelectFields(l, FastSeq("minkey", "maxkey")), + SelectFields(r, FastSeq("minkey", "maxkey")), + ) + } - val summary = bindIR(sortedPartDataIR) { sortedPartData => - MakeStruct(FastSeq( - "ksorted" -> - invoke( - "land", - TBoolean, - foldIR(ToStream(sortedPartData), True()) { (acc, partDataWithIndex) => - invoke("land", TBoolean, acc, GetField(partDataWithIndex, "ksorted")) - }, - foldIR(StreamRange(I32(0), ArrayLen(sortedPartData) - I32(1), I32(1)), True()) { - (acc, i) => - invoke( - "land", - TBoolean, - acc, - ApplyComparisonOp( - LTEQ(keyType), - GetField(ArrayRef(sortedPartData, i), "maxkey"), - GetField(ArrayRef(sortedPartData, i + I32(1)), "minkey"), - ), - ) - }, - ), - "pksorted" -> - invoke( - "land", - TBoolean, - foldIR(ToStream(sortedPartData), True()) { (acc, partDataWithIndex) => - invoke("land", TBoolean, acc, GetField(partDataWithIndex, "pksorted")) - }, - foldIR(StreamRange(I32(0), ArrayLen(sortedPartData) - I32(1), I32(1)), True()) { - (acc, i) => - invoke( - "land", - TBoolean, - acc, - ApplyComparisonOp( - LTEQ(pkType), - selectPK(GetField(ArrayRef(sortedPartData, i), "maxkey")), - selectPK(GetField(ArrayRef(sortedPartData, i + I32(1)), "minkey")), - ), - ) - }, - ), - "sortedPartData" -> sortedPartData, - )) - } + val summary = bindIR(sortedPartDataIR) { sortedPartData => + MakeStruct(FastSeq( + "ksorted" -> + invoke( + "land", + TBoolean, + foldIR(ToStream(sortedPartData), True()) { (acc, partDataWithIndex) => + invoke("land", TBoolean, acc, GetField(partDataWithIndex, "ksorted")) + }, + foldIR(StreamRange(I32(0), ArrayLen(sortedPartData) - I32(1), I32(1)), True()) { + (acc, i) => + invoke( + "land", + TBoolean, + acc, + ApplyComparisonOp( + LTEQ(keyType), + GetField(ArrayRef(sortedPartData, i), "maxkey"), + GetField(ArrayRef(sortedPartData, i + I32(1)), "minkey"), + ), + ) + }, + ), + "pksorted" -> + invoke( + "land", + TBoolean, + foldIR(ToStream(sortedPartData), True()) { (acc, partDataWithIndex) => + invoke("land", TBoolean, acc, GetField(partDataWithIndex, "pksorted")) + }, + foldIR(StreamRange(I32(0), ArrayLen(sortedPartData) - I32(1), I32(1)), True()) { + (acc, i) => + invoke( + "land", + TBoolean, + acc, + ApplyComparisonOp( + LTEQ(pkType), + selectPK(GetField(ArrayRef(sortedPartData, i), "maxkey")), + selectPK(GetField(ArrayRef(sortedPartData, i + I32(1)), "minkey")), + ), + ) + }, + ), + "sortedPartData" -> sortedPartData, + )) + } - val (Some(PTypeReferenceSingleCodeType(resultPType: PStruct)), f) = - Compile[AsmFunction1RegionLong]( - ctx, - FastSeq(), - FastSeq[TypeInfo[_]](classInfo[Region]), - LongInfo, - summary, - optimize = true, - ) + val (Some(PTypeReferenceSingleCodeType(resultPType: PStruct)), f) = + Compile[AsmFunction1RegionLong]( + ctx, + FastSeq(), + FastSeq[TypeInfo[_]](classInfo[Region]), + LongInfo, + summary, + optimize = true, + ) - val s = ctx.scopedExecution { (hcl, fs, htc, r) => - val a = f(hcl, fs, htc, r)(r) - SafeRow(resultPType, a) - } + val s = ctx.scopedExecution { (hcl, fs, htc, r) => + val a = f(hcl, fs, htc, r)(r) + SafeRow(resultPType, a) + } - val ksorted = s.getBoolean(0) - val pksorted = s.getBoolean(1) - val sortedPartData = s.getAs[IndexedSeq[Row]](2) - - val coercer = if (ksorted) { - info(s"Coerced sorted $context - no additional import work to do") - - new LoweredTableReaderCoercer { - def coerce( - ctx: ExecuteContext, - globals: IR, - contextType: Type, - contexts: IndexedSeq[Any], - body: IR => IR, - ): TableStage = { - val partOrigIndex = sortedPartData.map(_.getInt(6)) - - val partitioner = new RVDPartitioner( - ctx.stateManager, - keyType, - sortedPartData.map { partData => - Interval( - partData.get(1), - partData.get(2), - includesStart = true, - includesEnd = true, - ) - }, - key.length, - ) + val ksorted = s.getBoolean(0) + val pksorted = s.getBoolean(1) + val sortedPartData = s.getAs[IndexedSeq[Row]](2) + + if (ksorted) { + info(s"Coerced sorted $context - no additional import work to do") + ( + ctx: ExecuteContext, + globals: IR, + contextType: Type, + contexts: IndexedSeq[Any], + body: IR => IR, + ) => { + val partOrigIndex = sortedPartData.map(_.getInt(6)) + + val partitioner = new RVDPartitioner( + ctx.stateManager, + keyType, + sortedPartData.map { partData => + Interval( + partData.get(1), + partData.get(2), + includesStart = true, + includesEnd = true, + ) + }, + key.length, + ) - TableStage( - globals, - partitioner, - TableStageDependency.none, - ToStream(Literal(TArray(contextType), partOrigIndex.map(i => contexts(i)))), - body, - ) - } - } - } else if (pksorted) { - info( - s"Coerced prefix-sorted $context, requiring additional sorting within data partitions on each query." - ) + TableStage( + globals, + partitioner, + TableStageDependency.none, + ToStream(Literal(TArray(contextType), partOrigIndex.map(i => contexts(i)))), + body, + ) + } + } else if (pksorted) { + info( + s"Coerced prefix-sorted $context, requiring additional sorting within data partitions on each query." + ) - new LoweredTableReaderCoercer { - private[this] def selectPK(r: Row): Row = { - val a = new Array[Any](partitionKey) - var i = 0 - while (i < partitionKey) { - a(i) = r.get(i) - i += 1 - } - Row.fromSeq(a) - } + def selectPK(r: Row): Row = { + val a = new Array[Any](partitionKey) + var i = 0 + while (i < partitionKey) { + a(i) = r.get(i) + i += 1 + } + Row.fromSeq(a) + } - def coerce( - ctx: ExecuteContext, - globals: IR, - contextType: Type, - contexts: IndexedSeq[Any], - body: IR => IR, - ): TableStage = { - val partOrigIndex = sortedPartData.map(_.getInt(6)) - - val partitioner = new RVDPartitioner( - ctx.stateManager, - pkType, - sortedPartData.map { partData => - Interval( - selectPK(partData.getAs[Row](1)), - selectPK(partData.getAs[Row](2)), - includesStart = true, - includesEnd = true, - ) - }, - pkType.size, - ) + ( + ctx: ExecuteContext, + globals: IR, + contextType: Type, + contexts: IndexedSeq[Any], + body: IR => IR, + ) => { + val partOrigIndex = sortedPartData.map(_.getInt(6)) - val pkPartitioned = TableStage( - globals, - partitioner, - TableStageDependency.none, - ToStream(Literal(TArray(contextType), partOrigIndex.map(i => contexts(i)))), - body, - ) + val partitioner = new RVDPartitioner( + ctx.stateManager, + pkType, + sortedPartData.map { partData => + Interval( + selectPK(partData.getAs[Row](1)), + selectPK(partData.getAs[Row](2)), + includesStart = true, + includesEnd = true, + ) + }, + pkType.size, + ) - pkPartitioned - .extendKeyPreservesPartitioning(ctx, key) - .mapPartition(None) { part => - flatMapIR(StreamGroupByKey(part, pkType.fieldNames, missingEqual = true)) { - inner => - ToStream(sortIR(inner) { case (l, r) => ApplyComparisonOp(LT(l.typ), l, r) }) - } - } + val pkPartitioned = TableStage( + globals, + partitioner, + TableStageDependency.none, + ToStream(Literal(TArray(contextType), partOrigIndex.map(i => contexts(i)))), + body, + ) + + pkPartitioned + .extendKeyPreservesPartitioning(ctx, key) + .mapPartition(None) { part => + flatMapIR(StreamGroupByKey(part, pkType.fieldNames, missingEqual = true)) { + inner => ToStream(sortIR(inner) { case (l, r) => ApplyComparisonOp(LT(l.typ), l, r) }) } } - } else { - info( - s"$context is out of order..." + - s"\n Write the dataset to disk before running multiple queries to avoid multiple costly data shuffles." - ) + } + } else { + info( + s"$context is out of order..." + + s"\n Write the dataset to disk before running multiple queries to avoid multiple costly data shuffles." + ) - new LoweredTableReaderCoercer { - def coerce( - ctx: ExecuteContext, - globals: IR, - contextType: Type, - contexts: IndexedSeq[Any], - body: IR => IR, - ): TableStage = { - val partOrigIndex = sortedPartData.map(_.getInt(6)) - - val partitioner = RVDPartitioner.unkeyed(ctx.stateManager, sortedPartData.length) - - val tableStage = TableStage( - globals, - partitioner, - TableStageDependency.none, - ToStream(Literal(TArray(contextType), partOrigIndex.map(i => contexts(i)))), - body, - ) + ( + ctx: ExecuteContext, + globals: IR, + contextType: Type, + contexts: IndexedSeq[Any], + body: IR => IR, + ) => { + val partOrigIndex = sortedPartData.map(_.getInt(6)) + + val partitioner = RVDPartitioner.unkeyed(ctx.stateManager, sortedPartData.length) + + val tableStage = TableStage( + globals, + partitioner, + TableStageDependency.none, + ToStream(Literal(TArray(contextType), partOrigIndex.map(i => contexts(i)))), + body, + ) - val rowRType = - VirtualTypeWithReq(bodyPType(tableStage.rowType)).r.asInstanceOf[RStruct] - val globReq = Requiredness(globals, ctx) - val globRType = globReq.lookup(globals).asInstanceOf[RStruct] - - ctx.backend.lowerDistributedSort( - ctx, - tableStage, - keyType.fieldNames.map(f => SortField(f, Ascending)), - RTable(rowRType, globRType, FastSeq()), - ).lower( - ctx, - TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct]), - ) - } - } - } - if (ctx.backend.shouldCacheQueryInfo) - coercerCache += (cacheKeyWithInfo -> coercer) - coercer + val rowRType = + VirtualTypeWithReq(bodyPType(tableStage.rowType)).r.asInstanceOf[RStruct] + val globReq = Requiredness(globals, ctx) + val globRType = globReq.lookup(globals).asInstanceOf[RStruct] + + ctx.backend.lowerDistributedSort( + ctx, + tableStage, + keyType.fieldNames.map(f => SortField(f, Ascending)), + RTable(rowRType, globRType, FastSeq()), + ).lower( + ctx, + TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct]), + ) + } } } }