From e24bbad200eeb3410de68eae9141024f9cc19925 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Tue, 17 Sep 2024 12:29:49 -0400 Subject: [PATCH] prefix methods called in python with `py` --- hail/python/hail/backend/py4j_backend.py | 8 ++--- hail/python/hail/ir/ir.py | 2 +- hail/python/hail/ir/table_ir.py | 2 +- .../main/scala/is/hail/HailFeatureFlags.scala | 3 ++ .../backend/py4j/Py4JBackendExtensions.scala | 30 +++++++++---------- .../hail/backend/service/ServiceBackend.scala | 4 +-- .../is/hail/backend/spark/SparkBackend.scala | 8 ++--- .../is/hail/variant/ReferenceGenome.scala | 4 ++- hail/src/test/scala/is/hail/HailSuite.scala | 2 +- .../lowering/LowerDistributedSortSuite.scala | 10 ++----- 10 files changed, 37 insertions(+), 36 deletions(-) diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index 2c1610b56145..b9f986e5834a 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -237,17 +237,17 @@ def _rpc(self, action, payload) -> Tuple[bytes, Optional[dict]]: def persist_expression(self, expr): t = expr.dtype - return construct_expr(JavaIR(t, self._jbackend.executeLiteral(self._render_ir(expr._ir))), t) + return construct_expr(JavaIR(t, self._jbackend.pyExecuteLiteral(self._render_ir(expr._ir))), t) def _is_registered_ir_function_name(self, name: str) -> bool: return name in self._registered_ir_function_names def set_flags(self, **flags: Mapping[str, str]): - available = self._jbackend.availableFlags() + available = self._jbackend.pyAvailableFlags() invalid = [] for flag, value in flags.items(): if flag in available: - self._jbackend.setFlag(flag, value) + self._jbackend.pySetFlag(flag, value) else: invalid.append(flag) if len(invalid) != 0: @@ -256,7 +256,7 @@ def set_flags(self, **flags: Mapping[str, str]): ) def get_flags(self, *flags) -> Mapping[str, str]: - return {flag: self._jbackend.getFlag(flag) for flag in flags} + return {flag: self._jbackend.pyGetFlag(flag) for flag in flags} def _add_reference_to_scala_backend(self, rg): self._jbackend.pyAddReference(orjson.dumps(rg._config).decode('utf-8')) diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index b024a36304bf..2bef587fc1da 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -3880,7 +3880,7 @@ def __del__(self): if Env._hc: backend = Env.backend() assert isinstance(backend, Py4JBackend) - backend._jbackend.removeJavaIR(self._id) + backend._jbackend.pyRemoveJavaIR(self._id) class JavaIR(IR): diff --git a/hail/python/hail/ir/table_ir.py b/hail/python/hail/ir/table_ir.py index 8184401c1269..eb96deee8639 100644 --- a/hail/python/hail/ir/table_ir.py +++ b/hail/python/hail/ir/table_ir.py @@ -1215,4 +1215,4 @@ def __del__(self): if Env._hc: backend = Env.backend() assert isinstance(backend, Py4JBackend) - backend._jbackend.removeJavaIR(self._id) + backend._jbackend.pyRemoveJavaIR(self._id) diff --git a/hail/src/main/scala/is/hail/HailFeatureFlags.scala b/hail/src/main/scala/is/hail/HailFeatureFlags.scala index 48bb22bb3907..49eff3139ec1 100644 --- a/hail/src/main/scala/is/hail/HailFeatureFlags.scala +++ b/hail/src/main/scala/is/hail/HailFeatureFlags.scala @@ -68,6 +68,9 @@ class HailFeatureFlags private ( flags.update(flag, value) } + def +(feature: (String, String)): HailFeatureFlags = + new HailFeatureFlags(flags + (feature._1 -> feature._2)) + def get(flag: String): String = flags(flag) def lookup(flag: String): Option[String] = diff --git a/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala index 49d075c73d65..2fa2b37404bd 100644 --- a/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala +++ b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala @@ -38,30 +38,29 @@ trait Py4JBackendExtensions { def flags: HailFeatureFlags def longLifeTempFileManager: TempFileManager - def getFlag(name: String): String = + def pyGetFlag(name: String): String = flags.get(name) - def setFlag(name: String, value: String): Unit = + def pySetFlag(name: String, value: String): Unit = flags.set(name, value) - val availableFlags: java.util.ArrayList[String] = + def pyAvailableFlags: java.util.ArrayList[String] = flags.available private[this] var irID: Int = 0 - def nextIRID(): Int = - synchronized { - irID += 1 - irID - } + private[this] def nextIRID(): Int = { + irID += 1 + irID + } - protected[this] def addJavaIR(ir: BaseIR): Int = { + private[this] def addJavaIR(ir: BaseIR): Int = { val id = nextIRID() persistedIR += (id -> ir) id } - def removeJavaIR(id: Int): Unit = + def pyRemoveJavaIR(id: Int): Unit = persistedIR.remove(id) def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = @@ -133,7 +132,7 @@ trait Py4JBackendExtensions { } } - def executeLiteral(irStr: String): Int = + def pyExecuteLiteral(irStr: String): Int = backend.withExecuteContext { ctx => val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap)) assert(ir.typ.isRealizable) @@ -212,7 +211,7 @@ trait Py4JBackendExtensions { def pyRemoveLiftover(name: String, destRGName: String) = references(name).removeLiftover(destRGName) - def addReference(rg: ReferenceGenome): Unit = { + private[this] def addReference(rg: ReferenceGenome): Unit = { references.get(rg.name) match { case Some(rg2) => if (rg != rg2) { @@ -227,7 +226,7 @@ trait Py4JBackendExtensions { } } - def removeReference(name: String): Unit = + private[this] def removeReference(name: String): Unit = references -= name def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = @@ -272,7 +271,8 @@ trait Py4JBackendExtensions { )(implicit E: Enclosing ): T = backend.withExecuteContext { ctx => - if (selfContainedExecution && longLifeTempFileManager != null) f(ctx) - else ctx.local(tempFileManager = NonOwningTempFileManager(longLifeTempFileManager))(f) + val tempFileManager = longLifeTempFileManager + if (selfContainedExecution && tempFileManager != null) f(ctx) + else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f) } } diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 66f4035e47dc..f1192f637a9e 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -6,8 +6,8 @@ import is.hail.asm4s._ import is.hail.backend._ import is.hail.expr.Validate import is.hail.expr.ir.{ - Compile, IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, - TableIR, TableReader, TypeCheck, + Compile, IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, + TypeCheck, } import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.functions.IRFunctionRegistry diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index 786117fb4f8b..94a1020734a1 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -531,11 +531,11 @@ class SparkBackend( Validate(ir) ctx.irMetadata.semhash = SemanticHash(ctx)(ir) try { - val lowerTable = getFlag("lower") != null - val lowerBM = getFlag("lower_bm") != null + val lowerTable = flags.get("lower") != null + val lowerBM = flags.get("lower_bm") != null _jvmLowerAndExecute(ctx, ir, optimize = true, lowerTable, lowerBM) } catch { - case e: LowererUnsupportedOperation if getFlag("lower_only") != null => throw e + case e: LowererUnsupportedOperation if flags.get("lower_only") != null => throw e case _: LowererUnsupportedOperation => CompileAndEvaluate._apply(ctx, ir, optimize = true) } @@ -548,7 +548,7 @@ class SparkBackend( rt: RTable, nPartitions: Option[Int], ): TableReader = { - if (getFlag("use_new_shuffle") != null) + if (flags.get("use_new_shuffle") != null) return LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt) val (globals, rvd) = TableStageToRVD(ctx, stage) diff --git a/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala b/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala index 594631a9a2be..7d7c3317e25a 100644 --- a/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala +++ b/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala @@ -8,7 +8,9 @@ import is.hail.expr.{ } import is.hail.expr.ir.RelationalSpec import is.hail.io.fs.FS -import is.hail.io.reference.{FASTAReader, FASTAReaderConfig, FastaSequenceIndex, IndexedFastaSequenceFile, LiftOver} +import is.hail.io.reference.{ + FASTAReader, FASTAReaderConfig, FastaSequenceIndex, IndexedFastaSequenceFile, LiftOver, +} import is.hail.types._ import is.hail.types.virtual.{TLocus, Type} import is.hail.utils._ diff --git a/hail/src/test/scala/is/hail/HailSuite.scala b/hail/src/test/scala/is/hail/HailSuite.scala index 50b25f5b4041..c4994aa4be84 100644 --- a/hail/src/test/scala/is/hail/HailSuite.scala +++ b/hail/src/test/scala/is/hail/HailSuite.scala @@ -43,7 +43,7 @@ object HailSuite { lazy val hc: HailContext = { val hc = withSparkBackend() - hc.sparkBackend("HailSuite.hc").setFlag("lower", "1") + hc.sparkBackend("HailSuite.hc").flags.set("lower", "1") hc.checkRVDKeys = true hc } diff --git a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala index fb296ad605b3..827b439e288c 100644 --- a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala @@ -60,10 +60,8 @@ class LowerDistributedSortSuite extends HailSuite { } // Only does ascending for now - def testDistributedSortHelper(myTable: TableIR, sortFields: IndexedSeq[SortField]): Unit = { - val originalShuffleCutoff = backend.getFlag("shuffle_cutoff_to_local_sort") - try { - backend.setFlag("shuffle_cutoff_to_local_sort", "40") + def testDistributedSortHelper(myTable: TableIR, sortFields: IndexedSeq[SortField]): Unit = + ctx.local(flags = ctx.flags + ("shuffle_cutoff_to_local_sort" -> "40")) { ctx => val analyses: LoweringAnalyses = LoweringAnalyses.apply(myTable, ctx) val rt = analyses.requirednessAnalysis.lookup(myTable).asInstanceOf[RTable] val stage = LowerTableIR.applyTable(myTable, DArrayLowering.All, ctx, analyses) @@ -103,9 +101,7 @@ class LowerDistributedSortSuite extends HailSuite { ans } assert(res == scalaSorted) - } finally - backend.setFlag("shuffle_cutoff_to_local_sort", originalShuffleCutoff) - } + } @Test def testDistributedSort(): Unit = { val tableRange = TableRange(100, 10)