From fc4afd0a0b75849846188441cb46170b4152376f Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Mon, 4 Nov 2024 12:23:57 -0500 Subject: [PATCH] only remove ir functions registered in that execute request --- hail/python/hail/backend/service_backend.py | 10 ++-- .../scala/is/hail/backend/BackendRpc.scala | 11 ++-- .../hail/backend/service/ServiceBackend.scala | 7 +-- .../is/hail/expr/ir/functions/Functions.scala | 58 ++++++++++++++----- 4 files changed, 57 insertions(+), 29 deletions(-) diff --git a/hail/python/hail/backend/service_backend.py b/hail/python/hail/backend/service_backend.py index d753dab03ed1..b56bde184970 100644 --- a/hail/python/hail/backend/service_backend.py +++ b/hail/python/hail/backend/service_backend.py @@ -78,7 +78,6 @@ class SequenceConfig: @dataclass class ServiceBackendRPCConfig: tmp_dir: str - remote_tmpdir: str flags: Dict[str, str] custom_references: List[str] liftovers: Dict[str, Dict[str, str]] @@ -328,8 +327,8 @@ async def _run_on_batch( elif self.driver_memory is not None: resources['memory'] = str(self.driver_memory) - if service_backend_config.storage != '0Gi': - resources['storage'] = service_backend_config.storage + if job_config.storage != '0Gi': + resources['storage'] = job_config.storage j = self._batch.create_jvm_job( jar_spec=self.jar_spec, @@ -343,7 +342,7 @@ async def _run_on_batch( resources=resources, attributes={'name': name + '_driver'}, regions=self.regions, - cloudfuse=[(c.bucket, c.mount_path, c.read_only) for c in service_backend_config.cloudfuse_configs], + cloudfuse=[(c.bucket, c.mount_path, c.read_only) for c in job_config.cloudfuse_configs], profile=self.flags['profile'] is not None, ) await self._batch.submit(disable_progress_bar=True) @@ -441,8 +440,7 @@ async def _async_rpc(self, action: ActionTag, payload: ActionPayload): return await self._run_on_batch( name=f'{action.name.lower()}(...)', service_backend_config=ServiceBackendRPCConfig( - tmp_dir=tmp_dir(), - remote_tmpdir=self.remote_tmpdir, + tmp_dir=self.remote_tmpdir, flags=self.flags, custom_references=[ orjson.dumps(rg._config).decode('utf-8') diff --git a/hail/src/main/scala/is/hail/backend/BackendRpc.scala b/hail/src/main/scala/is/hail/backend/BackendRpc.scala index 33cecbbb0fed..fc8e411276df 100644 --- a/hail/src/main/scala/is/hail/backend/BackendRpc.scala +++ b/hail/src/main/scala/is/hail/backend/BackendRpc.scala @@ -2,13 +2,14 @@ package is.hail.backend import is.hail.expr.ir.IRParser import is.hail.expr.ir.functions.IRFunctionRegistry +import is.hail.expr.ir.functions.IRFunctionRegistry.UserDefinedFnKey import is.hail.io.BufferSpec import is.hail.io.plink.LoadPlink import is.hail.io.vcf.LoadVCF import is.hail.services.retryTransientErrors import is.hail.types.virtual.{Kind, TFloat64, VType} import is.hail.types.virtual.Kinds._ -import is.hail.utils.{using, ExecutionTimer} +import is.hail.utils.{using, BoxedArrayBuilder, ExecutionTimer} import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome @@ -177,9 +178,10 @@ trait BackendRpc { )( body: => A ): A = { + val fns = new BoxedArrayBuilder[UserDefinedFnKey](serializedFunctions.length) try { - serializedFunctions.foreach { func => - IRFunctionRegistry.registerIR( + for (func <- serializedFunctions) { + fns += IRFunctionRegistry.registerIR( ctx, func.name, func.type_parameters, @@ -192,7 +194,8 @@ trait BackendRpc { body } finally - IRFunctionRegistry.clearUserFunctions() + for (i <- 0 until fns.length) + IRFunctionRegistry.unregisterIr(fns(i)) } } diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 752fd0b43f34..70571573199e 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -308,7 +308,7 @@ class ServiceBackend( ExecutionTimer.time { timer => ExecuteContext.scoped( rpcConfig.tmp_dir, - rpcConfig.remote_tmpdir, + rpcConfig.tmp_dir, this, fs, timer, @@ -316,9 +316,9 @@ class ServiceBackend( theHailClassLoader, flags, ServiceBackendContext( - rpcConfig.remote_tmpdir, + rpcConfig.tmp_dir, jobConfig, - ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), + ExecutionCache.fromFlags(flags, fs, rpcConfig.tmp_dir), ), new IrMetadata(), references, @@ -534,7 +534,6 @@ case class SequenceConfig(fasta: String, index: String) case class ServiceBackendRPCPayload( tmp_dir: String, - remote_tmpdir: String, flags: Map[String, String], custom_references: Array[String], liftovers: Map[String, Map[String, String]], diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala index be5f0e3e3b2d..b81fe4051759 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala @@ -21,7 +21,9 @@ import scala.reflect._ import org.apache.spark.sql.Row object IRFunctionRegistry { - private val userAddedFunctions: mutable.Set[(String, (Type, Seq[Type], Seq[Type]))] = + type UserDefinedFnKey = (String, (Type, Seq[Type], Seq[Type])) + + private[this] val userAddedFunctions: mutable.Set[UserDefinedFnKey] = mutable.HashSet.empty def clearUserFunctions(): Unit = { @@ -69,25 +71,41 @@ object IRFunctionRegistry { typeParamStrs: Array[String], argNameStrs: Array[String], argTypeStrs: Array[String], - returnType: String, + returnTypeStr: String, bodyStr: String, - ): Unit = { + ): UserDefinedFnKey = { requireJavaIdentifier(name) - val argNames = argNameStrs.map(Name) val typeParameters = typeParamStrs.map(IRParser.parseType).toFastSeq val valueParameterTypes = argTypeStrs.map(IRParser.parseType).toFastSeq - val refMap = BindingEnv.eval(argNames.zip(valueParameterTypes): _*) - val body = IRParser.parse_value_ir(ctx, bodyStr, refMap) + val argNames = argNameStrs.map(Name) + + val body = + IRParser.parse_value_ir(ctx, bodyStr, BindingEnv.eval(argNames.zip(valueParameterTypes): _*)) + val returnType = IRParser.parseType(returnTypeStr) + assert(body.typ == returnType) - userAddedFunctions += ((name, (body.typ, typeParameters, valueParameterTypes))) + val key: UserDefinedFnKey = (name, (returnType, typeParameters, valueParameterTypes)) + userAddedFunctions += key addIR( name, typeParameters, valueParameterTypes, - IRParser.parseType(returnType), + returnType, false, (_, args, _) => Subst(body, BindingEnv.eval(argNames.zip(args): _*)), ) + key + } + + def unregisterIr(key: UserDefinedFnKey): Unit = { + val (name, (returnType, typeParameterTypes, valueParameterTypes)) = key + if (userAddedFunctions.remove(key)) + removeIRFunction(name, returnType, typeParameterTypes, valueParameterTypes) + else { + throw new NoSuchElementException( + s"No user defined function registered matching: ${prettyFunctionSignature(name, returnType, typeParameterTypes, valueParameterTypes)}" + ) + } } def removeIRFunction( @@ -112,7 +130,9 @@ object IRFunctionRegistry { case Seq() => None case Seq(f) => Some(f) case _ => - fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).") + fatal( + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." + ) } def lookupFunctionOrFail( @@ -124,28 +144,34 @@ object IRFunctionRegistry { jvmRegistry.lift(name) match { case None => fatal( - s"no functions found with the signature $name(${valueParameterTypes.mkString(", ")}): $returnType" + s"no functions found with the signature ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." ) case Some(functions) => functions.filter(t => t.unify(typeParameters, valueParameterTypes, returnType) ).toSeq match { case Seq() => - val prettyFunctionSignature = - s"$name[${typeParameters.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType" val prettyMismatchedFunctionSignatures = functions.map(x => s" $x").mkString("\n") fatal( - s"No function found with the signature $prettyFunctionSignature.\n" + + s"No function found with the signature ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}.\n" + s"However, there are other functions with that name:\n$prettyMismatchedFunctionSignatures" ) case Seq(f) => f case _ => fatal( - s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(", ")})." + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)})." ) } } } + private[this] def prettyFunctionSignature( + name: String, + returnType: Type, + typeParameterTypes: Seq[Type], + valueParameterTypes: Seq[Type], + ): String = + s"$name[${typeParameterTypes.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType" + def lookupIR( name: String, returnType: Type, @@ -165,7 +191,9 @@ object IRFunctionRegistry { case Seq() => None case Seq(kv) => Some(kv) case _ => - fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).") + fatal( + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." + ) } }