Skip to content

Commit

Permalink
extract references from Backend implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Oct 29, 2024
1 parent 2479255 commit a3e9279
Show file tree
Hide file tree
Showing 17 changed files with 586 additions and 657 deletions.
100 changes: 16 additions & 84 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package is.hail.backend

import is.hail.asm4s._
import is.hail.backend.Backend.jsonToBytes
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.{
BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses,
SortField, TableIR, TableReader,
}
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.fs._
Expand All @@ -20,7 +20,6 @@ import is.hail.types.virtual.{BlockMatrixType, TFloat64}
import is.hail.utils._
import is.hail.variant.ReferenceGenome

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag

Expand All @@ -29,7 +28,7 @@ import java.nio.charset.StandardCharsets

import com.fasterxml.jackson.core.StreamReadConstraints
import org.json4s._
import org.json4s.jackson.{JsonMethods, Serialization}
import org.json4s.jackson.JsonMethods
import sourcecode.Enclosing

object Backend {
Expand All @@ -41,13 +40,6 @@ object Backend {
s"hail_query_$id"
}

private var irID: Int = 0

def nextIRID(): Int = {
irID += 1
irID
}

def encodeToOutputStream(
ctx: ExecuteContext,
t: PTuple,
Expand All @@ -66,6 +58,9 @@ object Backend {
assert(t.isFieldDefined(off, 0))
codec.encode(ctx, elementType, t.loadField(off, 0), os)
}

def jsonToBytes(f: => JValue): Array[Byte] =
JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8)
}

abstract class BroadcastValue[T] { def value: T }
Expand All @@ -89,14 +84,6 @@ abstract class Backend extends Closeable {

val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()

protected[this] def addJavaIR(ir: BaseIR): Int = {
val id = Backend.nextIRID()
persistedIR += (id -> ir)
id
}

def removeJavaIR(id: Int): Unit = persistedIR.remove(id)

def defaultParallelism: Int

def canExecuteParallelTasksOnDriver: Boolean = true
Expand Down Expand Up @@ -131,31 +118,6 @@ abstract class Backend extends Closeable {
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T]

var references: Map[String, ReferenceGenome] = Map.empty

def addDefaultReferences(): Unit =
references = ReferenceGenome.builtinReferences()

def addReference(rg: ReferenceGenome): Unit = {
references.get(rg.name) match {
case Some(rg2) =>
if (rg != rg2) {
fatal(
s"Cannot add reference genome '${rg.name}', a different reference with that name already exists. Choose a reference name NOT in the following list:\n " +
s"@1",
references.keys.truncatable("\n "),
)
}
case None =>
references += (rg.name -> rg)
}
}

def hasReference(name: String) = references.contains(name)

def removeReference(name: String): Unit =
references -= name

def lowerDistributedSort(
ctx: ExecuteContext,
stage: TableStage,
Expand Down Expand Up @@ -189,9 +151,6 @@ abstract class Backend extends Closeable {

def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T

private[this] def jsonToBytes(f: => JValue): Array[Byte] =
JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8)

final def valueType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext { ctx =>
Expand Down Expand Up @@ -220,15 +179,7 @@ abstract class Backend extends Closeable {
}
}

def loadReferencesFromDataset(path: String): Array[Byte] = {
withExecuteContext { ctx =>
val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path)
rgs.foreach(addReference)

implicit val formats: Formats = defaultJSONFormats
Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8)
}
}
def loadReferencesFromDataset(path: String): Array[Byte]

def fromFASTAFile(
name: String,
Expand All @@ -240,18 +191,20 @@ abstract class Backend extends Closeable {
parInput: Array[String],
): Array[Byte] =
withExecuteContext { ctx =>
val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput)
rg.toJSONString.getBytes(StandardCharsets.UTF_8)
jsonToBytes {
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput).toJSON
}
}

def parseVCFMetadata(path: String): Array[Byte] = jsonToBytes {
def parseVCFMetadata(path: String): Array[Byte] =
withExecuteContext { ctx =>
val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path)
implicit val formats = defaultJSONFormats
Extraction.decompose(metadata)
jsonToBytes {
val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path)
implicit val formats = defaultJSONFormats
Extraction.decompose(metadata)
}
}
}

def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String)
: Array[Byte] =
Expand All @@ -261,27 +214,6 @@ abstract class Backend extends Closeable {
)
}

def pyRegisterIR(
name: String,
typeParamStrs: java.util.ArrayList[String],
argNameStrs: java.util.ArrayList[String],
argTypeStrs: java.util.ArrayList[String],
returnType: String,
bodyStr: String,
): Unit = {
withExecuteContext { ctx =>
IRFunctionRegistry.registerIR(
ctx,
name,
typeParamStrs.asScala.toArray,
argNameStrs.asScala.toArray,
argTypeStrs.asScala.toArray,
returnType,
bodyStr,
)
}
}

def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
}

Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/backend/BackendServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler {
}
return
}

val response: Array[Byte] = exchange.getRequestURI.getPath match {
case "/value/type" => backend.valueType(body.extract[IRTypePayload].ir)
case "/table/type" => backend.tableType(body.extract[IRTypePayload].ir)
Expand Down
107 changes: 17 additions & 90 deletions hail/src/main/scala/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags}
import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.backend._
import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir.{IRParser, _}
import is.hail.expr.ir._
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
Expand All @@ -17,7 +18,7 @@ import is.hail.types.virtual.{BlockMatrixType, TVoid}
import is.hail.utils._
import is.hail.variant.ReferenceGenome

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag

import java.io.PrintWriter
Expand Down Expand Up @@ -47,8 +48,11 @@ object LocalBackend {
if (!skipLoggingConfiguration)
HailContext.configureLogging(logFile, quiet, append)

theLocalBackend = new LocalBackend(tmpdir)
theLocalBackend.addDefaultReferences()
theLocalBackend = new LocalBackend(
tmpdir,
mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*),
)

theLocalBackend
}

Expand All @@ -64,18 +68,16 @@ object LocalBackend {
}
}

class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache {

private[this] val flags = HailFeatureFlags.fromEnv()
private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader())
class LocalBackend(
val tmpdir: String,
override val references: mutable.Map[String, ReferenceGenome],
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {

def getFlag(name: String): String = flags.get(name)
override def backend: Backend = this
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
override def longLifeTempFileManager: TempFileManager = null

def setFlag(name: String, value: String) = flags.set(name, value)

// called from python
val availableFlags: java.util.ArrayList[String] =
flags.available
private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader())

// flags can be set after construction from python
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
Expand All @@ -87,7 +89,7 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache
tmpdir,
tmpdir,
this,
references,
references.toMap,
fs,
timer,
null,
Expand Down Expand Up @@ -191,81 +193,6 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache
res
}

def executeLiteral(irStr: String): Int =
withExecuteContext { ctx =>
val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap))
assert(ir.typ.isRealizable)
execute(ctx, ir) match {
case Left(_) => throw new HailException("Can't create literal")
case Right((pt, addr)) =>
val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0)
addJavaIR(field)
}
}

def pyAddReference(jsonConfig: String): Unit = addReference(ReferenceGenome.fromJSON(jsonConfig))
def pyRemoveReference(name: String): Unit = removeReference(name)

def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit =
withExecuteContext(ctx => references(name).addLiftover(ctx, chainFile, destRGName))

def pyRemoveLiftover(name: String, destRGName: String) =
references(name).removeLiftover(destRGName)

def pyFromFASTAFile(
name: String,
fastaFile: String,
indexFile: String,
xContigs: java.util.List[String],
yContigs: java.util.List[String],
mtContigs: java.util.List[String],
parInput: java.util.List[String],
): String =
withExecuteContext { ctx =>
val rg = ReferenceGenome.fromFASTAFile(
ctx,
name,
fastaFile,
indexFile,
xContigs.asScala.toArray,
yContigs.asScala.toArray,
mtContigs.asScala.toArray,
parInput.asScala.toArray,
)
rg.toJSONString
}

def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
withExecuteContext(ctx => references(name).addSequence(ctx, fastaFile, indexFile))

def pyRemoveSequence(name: String) = references(name).removeSequence()

def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR =
withExecuteContext { ctx =>
IRParser.parse_value_ir(
s,
IRParserEnvironment(ctx, persistedIR.toMap),
BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) =>
Name(n) -> IRParser.parseType(t)
}.toSeq: _*),
)
}

def parse_table_ir(s: String): TableIR =
withExecuteContext { ctx =>
IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
}

def parse_matrix_ir(s: String): MatrixIR =
withExecuteContext { ctx =>
IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
}

def parse_blockmatrix_ir(s: String): BlockMatrixIR =
withExecuteContext { ctx =>
IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
}

override def lowerDistributedSort(
ctx: ExecuteContext,
stage: TableStage,
Expand Down
Loading

0 comments on commit a3e9279

Please sign in to comment.