Skip to content

Commit

Permalink
prefix methods called in python with py
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Oct 29, 2024
1 parent a3e9279 commit e24bbad
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 36 deletions.
8 changes: 4 additions & 4 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,17 @@ def _rpc(self, action, payload) -> Tuple[bytes, Optional[dict]]:

def persist_expression(self, expr):
t = expr.dtype
return construct_expr(JavaIR(t, self._jbackend.executeLiteral(self._render_ir(expr._ir))), t)
return construct_expr(JavaIR(t, self._jbackend.pyExecuteLiteral(self._render_ir(expr._ir))), t)

def _is_registered_ir_function_name(self, name: str) -> bool:
return name in self._registered_ir_function_names

def set_flags(self, **flags: Mapping[str, str]):
available = self._jbackend.availableFlags()
available = self._jbackend.pyAvailableFlags()
invalid = []
for flag, value in flags.items():
if flag in available:
self._jbackend.setFlag(flag, value)
self._jbackend.pySetFlag(flag, value)
else:
invalid.append(flag)
if len(invalid) != 0:
Expand All @@ -256,7 +256,7 @@ def set_flags(self, **flags: Mapping[str, str]):
)

def get_flags(self, *flags) -> Mapping[str, str]:
return {flag: self._jbackend.getFlag(flag) for flag in flags}
return {flag: self._jbackend.pyGetFlag(flag) for flag in flags}

def _add_reference_to_scala_backend(self, rg):
self._jbackend.pyAddReference(orjson.dumps(rg._config).decode('utf-8'))
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3880,7 +3880,7 @@ def __del__(self):
if Env._hc:
backend = Env.backend()
assert isinstance(backend, Py4JBackend)
backend._jbackend.removeJavaIR(self._id)
backend._jbackend.pyRemoveJavaIR(self._id)


class JavaIR(IR):
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/ir/table_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,4 +1215,4 @@ def __del__(self):
if Env._hc:
backend = Env.backend()
assert isinstance(backend, Py4JBackend)
backend._jbackend.removeJavaIR(self._id)
backend._jbackend.pyRemoveJavaIR(self._id)
3 changes: 3 additions & 0 deletions hail/src/main/scala/is/hail/HailFeatureFlags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class HailFeatureFlags private (
flags.update(flag, value)
}

def +(feature: (String, String)): HailFeatureFlags =
new HailFeatureFlags(flags + (feature._1 -> feature._2))

def get(flag: String): String = flags(flag)

def lookup(flag: String): Option[String] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,29 @@ trait Py4JBackendExtensions {
def flags: HailFeatureFlags
def longLifeTempFileManager: TempFileManager

def getFlag(name: String): String =
def pyGetFlag(name: String): String =
flags.get(name)

def setFlag(name: String, value: String): Unit =
def pySetFlag(name: String, value: String): Unit =
flags.set(name, value)

val availableFlags: java.util.ArrayList[String] =
def pyAvailableFlags: java.util.ArrayList[String] =
flags.available

private[this] var irID: Int = 0

def nextIRID(): Int =
synchronized {
irID += 1
irID
}
private[this] def nextIRID(): Int = {
irID += 1
irID
}

protected[this] def addJavaIR(ir: BaseIR): Int = {
private[this] def addJavaIR(ir: BaseIR): Int = {
val id = nextIRID()
persistedIR += (id -> ir)
id
}

def removeJavaIR(id: Int): Unit =
def pyRemoveJavaIR(id: Int): Unit =
persistedIR.remove(id)

def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
Expand Down Expand Up @@ -133,7 +132,7 @@ trait Py4JBackendExtensions {
}
}

def executeLiteral(irStr: String): Int =
def pyExecuteLiteral(irStr: String): Int =
backend.withExecuteContext { ctx =>
val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap))
assert(ir.typ.isRealizable)
Expand Down Expand Up @@ -212,7 +211,7 @@ trait Py4JBackendExtensions {
def pyRemoveLiftover(name: String, destRGName: String) =
references(name).removeLiftover(destRGName)

def addReference(rg: ReferenceGenome): Unit = {
private[this] def addReference(rg: ReferenceGenome): Unit = {
references.get(rg.name) match {
case Some(rg2) =>
if (rg != rg2) {
Expand All @@ -227,7 +226,7 @@ trait Py4JBackendExtensions {
}
}

def removeReference(name: String): Unit =
private[this] def removeReference(name: String): Unit =
references -= name

def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR =
Expand Down Expand Up @@ -272,7 +271,8 @@ trait Py4JBackendExtensions {
)(implicit E: Enclosing
): T =
backend.withExecuteContext { ctx =>
if (selfContainedExecution && longLifeTempFileManager != null) f(ctx)
else ctx.local(tempFileManager = NonOwningTempFileManager(longLifeTempFileManager))(f)
val tempFileManager = longLifeTempFileManager
if (selfContainedExecution && tempFileManager != null) f(ctx)
else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import is.hail.asm4s._
import is.hail.backend._
import is.hail.expr.Validate
import is.hail.expr.ir.{
Compile, IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField,
TableIR, TableReader, TypeCheck,
Compile, IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader,
TypeCheck,
}
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.functions.IRFunctionRegistry
Expand Down
8 changes: 4 additions & 4 deletions hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,11 @@ class SparkBackend(
Validate(ir)
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
try {
val lowerTable = getFlag("lower") != null
val lowerBM = getFlag("lower_bm") != null
val lowerTable = flags.get("lower") != null
val lowerBM = flags.get("lower_bm") != null
_jvmLowerAndExecute(ctx, ir, optimize = true, lowerTable, lowerBM)
} catch {
case e: LowererUnsupportedOperation if getFlag("lower_only") != null => throw e
case e: LowererUnsupportedOperation if flags.get("lower_only") != null => throw e
case _: LowererUnsupportedOperation =>
CompileAndEvaluate._apply(ctx, ir, optimize = true)
}
Expand All @@ -548,7 +548,7 @@ class SparkBackend(
rt: RTable,
nPartitions: Option[Int],
): TableReader = {
if (getFlag("use_new_shuffle") != null)
if (flags.get("use_new_shuffle") != null)
return LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt)

val (globals, rvd) = TableStageToRVD(ctx, stage)
Expand Down
4 changes: 3 additions & 1 deletion hail/src/main/scala/is/hail/variant/ReferenceGenome.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import is.hail.expr.{
}
import is.hail.expr.ir.RelationalSpec
import is.hail.io.fs.FS
import is.hail.io.reference.{FASTAReader, FASTAReaderConfig, FastaSequenceIndex, IndexedFastaSequenceFile, LiftOver}
import is.hail.io.reference.{
FASTAReader, FASTAReaderConfig, FastaSequenceIndex, IndexedFastaSequenceFile, LiftOver,
}
import is.hail.types._
import is.hail.types.virtual.{TLocus, Type}
import is.hail.utils._
Expand Down
2 changes: 1 addition & 1 deletion hail/src/test/scala/is/hail/HailSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ object HailSuite {

lazy val hc: HailContext = {
val hc = withSparkBackend()
hc.sparkBackend("HailSuite.hc").setFlag("lower", "1")
hc.sparkBackend("HailSuite.hc").flags.set("lower", "1")
hc.checkRVDKeys = true
hc
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,8 @@ class LowerDistributedSortSuite extends HailSuite {
}

// Only does ascending for now
def testDistributedSortHelper(myTable: TableIR, sortFields: IndexedSeq[SortField]): Unit = {
val originalShuffleCutoff = backend.getFlag("shuffle_cutoff_to_local_sort")
try {
backend.setFlag("shuffle_cutoff_to_local_sort", "40")
def testDistributedSortHelper(myTable: TableIR, sortFields: IndexedSeq[SortField]): Unit =
ctx.local(flags = ctx.flags + ("shuffle_cutoff_to_local_sort" -> "40")) { ctx =>
val analyses: LoweringAnalyses = LoweringAnalyses.apply(myTable, ctx)
val rt = analyses.requirednessAnalysis.lookup(myTable).asInstanceOf[RTable]
val stage = LowerTableIR.applyTable(myTable, DArrayLowering.All, ctx, analyses)
Expand Down Expand Up @@ -103,9 +101,7 @@ class LowerDistributedSortSuite extends HailSuite {
ans
}
assert(res == scalaSorted)
} finally
backend.setFlag("shuffle_cutoff_to_local_sort", originalShuffleCutoff)
}
}

@Test def testDistributedSort(): Unit = {
val tableRange = TableRange(100, 10)
Expand Down

0 comments on commit e24bbad

Please sign in to comment.