Skip to content

Commit

Permalink
[query] Expose references via ExecuteContext
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Oct 29, 2024
1 parent 4cf050e commit 2479255
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 60 deletions.
9 changes: 7 additions & 2 deletions hail/src/main/scala/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ object ExecuteContext {
tmpdir: String,
localTmpdir: String,
backend: Backend,
references: Map[String, ReferenceGenome],
fs: FS,
timer: ExecutionTimer,
tempFileManager: TempFileManager,
Expand All @@ -79,6 +80,7 @@ object ExecuteContext {
tmpdir,
localTmpdir,
backend,
references,
fs,
region,
timer,
Expand Down Expand Up @@ -107,6 +109,7 @@ class ExecuteContext(
val tmpdir: String,
val localTmpdir: String,
val backend: Backend,
val references: Map[String, ReferenceGenome],
val fs: FS,
val r: Region,
val timer: ExecutionTimer,
Expand All @@ -128,7 +131,7 @@ class ExecuteContext(
)
}

val stateManager = HailStateManager(backend.references)
val stateManager = HailStateManager(references)

val tempFileManager: TempFileManager =
if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs)
Expand All @@ -154,7 +157,7 @@ class ExecuteContext(

def getFlag(name: String): String = flags.get(name)

def getReference(name: String): ReferenceGenome = backend.references(name)
def getReference(name: String): ReferenceGenome = references(name)

def shouldWriteIRFiles(): Boolean = getFlag("write_ir_files") != null

Expand All @@ -174,6 +177,7 @@ class ExecuteContext(
tmpdir: String = this.tmpdir,
localTmpdir: String = this.localTmpdir,
backend: Backend = this.backend,
references: Map[String, ReferenceGenome] = this.references,
fs: FS = this.fs,
r: Region = this.r,
timer: ExecutionTimer = this.timer,
Expand All @@ -189,6 +193,7 @@ class ExecuteContext(
tmpdir,
localTmpdir,
backend,
references,
fs,
r,
timer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache
tmpdir,
tmpdir,
this,
references,
fs,
timer,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ class ServiceBackend(
tmpdir,
"file:///tmp",
this,
references,
fs,
timer,
null,
Expand Down
2 changes: 2 additions & 0 deletions hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ class SparkBackend(
tmpdir,
localTmpdir,
this,
references,
fs,
region,
timer,
Expand Down Expand Up @@ -399,6 +400,7 @@ class SparkBackend(
tmpdir,
localTmpdir,
this,
references,
fs,
timer,
tmpFileManager,
Expand Down
97 changes: 39 additions & 58 deletions hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package is.hail.expr.ir
import is.hail.HailSuite
import is.hail.annotations.{Region, SafeIndexedSeq}
import is.hail.asm4s._
import is.hail.backend.ExecuteContext
import is.hail.check.Gen
import is.hail.check.Prop.forAll
import is.hail.expr.ir.functions.LocusFunctions
Expand All @@ -17,46 +18,19 @@ import is.hail.variant.{Locus, ReferenceGenome}
import org.scalatest.matchers.should.Matchers.{be, convertToAnyShouldWrapper}
import org.testng.annotations.Test

sealed trait StagedCoercions[A] {
def ti: TypeInfo[A]
def sType: SType
def fromType(cb: EmitCodeBuilder, region: Value[Region], a: Value[A]): SValue
def toType(cb: EmitCodeBuilder, sa: SValue): Value[A]
}
class StagedMinHeapSuite extends HailSuite {

sealed trait StagedCoercionInstances {
implicit object StagedIntCoercions extends StagedCoercions[Int] {
override def ti: TypeInfo[Int] =
implicitly

override def sType: SType =
SInt32
override def ti: TypeInfo[Int] = implicitly
override def sType: SType = SInt32

override def fromType(cb: EmitCodeBuilder, region: Value[Region], a: Value[Int]): SValue =
override def fromValue(cb: EmitCodeBuilder, region: Value[Region], a: Value[Int]): SValue =
new SInt32Value(a)

override def toType(cb: EmitCodeBuilder, sa: SValue): Value[Int] =
override def toValue(cb: EmitCodeBuilder, sa: SValue): Value[Int] =
sa.asInt.value
}

def stagedLocusCoercions(rg: ReferenceGenome): StagedCoercions[Locus] =
new StagedCoercions[Locus] {
override def ti: TypeInfo[Locus] =
implicitly

override def sType: SType =
PCanonicalLocus(rg.name, required = true).sType

override def fromType(cb: EmitCodeBuilder, region: Value[Region], a: Value[Locus]): SValue =
LocusFunctions.emitLocus(cb, region, a, sType.storageType().asInstanceOf[PCanonicalLocus])

override def toType(cb: EmitCodeBuilder, sa: SValue): Value[Locus] =
sa.asLocus.getLocusObj(cb)
}
}

class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances {

@Test def testSorting(): Unit =
forAll((xs: IndexedSeq[Int]) => sort(xs) == xs.sorted).check()

Expand All @@ -70,7 +44,7 @@ class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances {
}.check()

@Test def testNonEmpty(): Unit =
gen("NonEmpty") { (heap: IntHeap) =>
gen(ctx, "NonEmpty") { (heap: IntHeap) =>
heap.nonEmpty should be(false)
for (i <- 0 to 10) heap.push(i)
heap.nonEmpty should be(true)
Expand All @@ -86,10 +60,12 @@ class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances {

@Test def testLocus(): Unit =
forAll(loci) { case (rg: ReferenceGenome, loci: IndexedSeq[Locus]) =>
withReferenceGenome(rg) {
ctx.local(references = Map(rg.name -> rg)) { ctx =>
implicit val coercions: StagedCoercions[Locus] =
stagedLocusCoercions(rg)

val sortedLoci =
gen("Locus", stagedLocusCoercions(rg)) { (heap: LocusHeap) =>
gen[Locus, LocusHeap, IndexedSeq[Locus]](ctx, "Locus") { (heap: LocusHeap) =>
loci.foreach(heap.push)
IndexedSeq.fill(loci.size)(heap.pop())
}
Expand All @@ -98,38 +74,24 @@ class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances {
}
}.check()

def withReferenceGenome[A](rg: ReferenceGenome)(f: => A): A = {
ctx.backend.addReference(rg)
try f
finally ctx.backend.removeReference(rg.name)
}

def sort(xs: IndexedSeq[Int]): IndexedSeq[Int] =
gen("Sort") { (heap: IntHeap) =>
gen(ctx, "Sort") { (heap: IntHeap) =>
xs.foreach(heap.push)
IndexedSeq.fill(xs.size)(heap.pop())
}

def heapify(xs: IndexedSeq[Int]): IndexedSeq[Int] =
gen("Heapify") { (heap: IntHeap) =>
gen(ctx, "Heapify") { (heap: IntHeap) =>
pool.scopedRegion { r =>
xs.foreach(heap.push)
val ptr = heap.toArray(r)
SafeIndexedSeq(PCanonicalArray(PInt32Required), ptr).asInstanceOf[IndexedSeq[Int]]
}
}

def gen[H <: Heap[A], A, B](
def gen[A, H <: Heap[A], B](
ctx: ExecuteContext,
name: String,
A: StagedCoercions[A],
)(
f: H => B
)(implicit H: TypeInfo[H]
): B =
gen[H, A, B](name)(f)(H, A)

def gen[H <: Heap[A], A, B](
name: String
)(
f: H => B
)(implicit
Expand All @@ -146,13 +108,13 @@ class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances {

Main.defineEmitMethod("push", FastSeq(A.ti), UnitInfo) { mb =>
mb.voidWithBuilder { cb =>
MinHeap.push(cb, A.fromType(cb, Main.partitionRegion, mb.getCodeParam[A](1)(A.ti)))
MinHeap.push(cb, A.fromValue(cb, Main.partitionRegion, mb.getCodeParam[A](1)(A.ti)))
}
}

Main.defineEmitMethod("pop", FastSeq(), A.ti) { mb =>
mb.emitWithBuilder[A] { cb =>
val res = A.toType(cb, MinHeap.peek(cb))
val res = A.toValue(cb, MinHeap.peek(cb))
MinHeap.pop(cb)
MinHeap.realloc(cb)
res
Expand Down Expand Up @@ -182,11 +144,11 @@ class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances {
MinHeap.init(cb, Main.pool())
}
}
Main.defineEmitMethod("close", FastSeq(), UnitInfo)(mb => mb.voidWithBuilder(MinHeap.close))
Main.defineEmitMethod("close", FastSeq(), UnitInfo)(_.voidWithBuilder(MinHeap.close))

pool.scopedRegion { r =>
ctx.scopedExecution { (cl, fs, tc, r) =>
val heap = Main
.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r)
.resultWithIndex()(cl, fs, tc, r)
.asInstanceOf[H with Resource]

heap.init()
Expand All @@ -208,4 +170,23 @@ class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances {
def nonEmpty: Boolean
def toArray(r: Region): Long
}

sealed trait StagedCoercions[A] {
def ti: TypeInfo[A]
def sType: SType
def fromValue(cb: EmitCodeBuilder, region: Value[Region], a: Value[A]): SValue
def toValue(cb: EmitCodeBuilder, sa: SValue): Value[A]
}

def stagedLocusCoercions(rg: ReferenceGenome): StagedCoercions[Locus] =
new StagedCoercions[Locus] {
override def ti: TypeInfo[Locus] = implicitly
override def sType: SType = PCanonicalLocus(rg.name, required = true).sType

override def fromValue(cb: EmitCodeBuilder, region: Value[Region], a: Value[Locus]): SValue =
LocusFunctions.emitLocus(cb, region, a, sType.storageType().asInstanceOf[PCanonicalLocus])

override def toValue(cb: EmitCodeBuilder, sa: SValue): Value[Locus] =
sa.asLocus.getLocusObj(cb)
}
}

0 comments on commit 2479255

Please sign in to comment.