From cbe84676310b33b9591f06f9c6ce77633f7353a2 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Mon, 4 Nov 2024 12:23:57 -0500 Subject: [PATCH] only remove ir functions registered in that execute request --- .../scala/is/hail/backend/BackendRpc.scala | 11 ++-- .../is/hail/expr/ir/functions/Functions.scala | 58 ++++++++++++++----- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/hail/src/main/scala/is/hail/backend/BackendRpc.scala b/hail/src/main/scala/is/hail/backend/BackendRpc.scala index 33cecbbb0fed..fc8e411276df 100644 --- a/hail/src/main/scala/is/hail/backend/BackendRpc.scala +++ b/hail/src/main/scala/is/hail/backend/BackendRpc.scala @@ -2,13 +2,14 @@ package is.hail.backend import is.hail.expr.ir.IRParser import is.hail.expr.ir.functions.IRFunctionRegistry +import is.hail.expr.ir.functions.IRFunctionRegistry.UserDefinedFnKey import is.hail.io.BufferSpec import is.hail.io.plink.LoadPlink import is.hail.io.vcf.LoadVCF import is.hail.services.retryTransientErrors import is.hail.types.virtual.{Kind, TFloat64, VType} import is.hail.types.virtual.Kinds._ -import is.hail.utils.{using, ExecutionTimer} +import is.hail.utils.{using, BoxedArrayBuilder, ExecutionTimer} import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome @@ -177,9 +178,10 @@ trait BackendRpc { )( body: => A ): A = { + val fns = new BoxedArrayBuilder[UserDefinedFnKey](serializedFunctions.length) try { - serializedFunctions.foreach { func => - IRFunctionRegistry.registerIR( + for (func <- serializedFunctions) { + fns += IRFunctionRegistry.registerIR( ctx, func.name, func.type_parameters, @@ -192,7 +194,8 @@ trait BackendRpc { body } finally - IRFunctionRegistry.clearUserFunctions() + for (i <- 0 until fns.length) + IRFunctionRegistry.unregisterIr(fns(i)) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala index be5f0e3e3b2d..b81fe4051759 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala @@ -21,7 +21,9 @@ import scala.reflect._ import org.apache.spark.sql.Row object IRFunctionRegistry { - private val userAddedFunctions: mutable.Set[(String, (Type, Seq[Type], Seq[Type]))] = + type UserDefinedFnKey = (String, (Type, Seq[Type], Seq[Type])) + + private[this] val userAddedFunctions: mutable.Set[UserDefinedFnKey] = mutable.HashSet.empty def clearUserFunctions(): Unit = { @@ -69,25 +71,41 @@ object IRFunctionRegistry { typeParamStrs: Array[String], argNameStrs: Array[String], argTypeStrs: Array[String], - returnType: String, + returnTypeStr: String, bodyStr: String, - ): Unit = { + ): UserDefinedFnKey = { requireJavaIdentifier(name) - val argNames = argNameStrs.map(Name) val typeParameters = typeParamStrs.map(IRParser.parseType).toFastSeq val valueParameterTypes = argTypeStrs.map(IRParser.parseType).toFastSeq - val refMap = BindingEnv.eval(argNames.zip(valueParameterTypes): _*) - val body = IRParser.parse_value_ir(ctx, bodyStr, refMap) + val argNames = argNameStrs.map(Name) + + val body = + IRParser.parse_value_ir(ctx, bodyStr, BindingEnv.eval(argNames.zip(valueParameterTypes): _*)) + val returnType = IRParser.parseType(returnTypeStr) + assert(body.typ == returnType) - userAddedFunctions += ((name, (body.typ, typeParameters, valueParameterTypes))) + val key: UserDefinedFnKey = (name, (returnType, typeParameters, valueParameterTypes)) + userAddedFunctions += key addIR( name, typeParameters, valueParameterTypes, - IRParser.parseType(returnType), + returnType, false, (_, args, _) => Subst(body, BindingEnv.eval(argNames.zip(args): _*)), ) + key + } + + def unregisterIr(key: UserDefinedFnKey): Unit = { + val (name, (returnType, typeParameterTypes, valueParameterTypes)) = key + if (userAddedFunctions.remove(key)) + removeIRFunction(name, returnType, typeParameterTypes, valueParameterTypes) + else { + throw new NoSuchElementException( + s"No user defined function registered matching: ${prettyFunctionSignature(name, returnType, typeParameterTypes, valueParameterTypes)}" + ) + } } def removeIRFunction( @@ -112,7 +130,9 @@ object IRFunctionRegistry { case Seq() => None case Seq(f) => Some(f) case _ => - fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).") + fatal( + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." + ) } def lookupFunctionOrFail( @@ -124,28 +144,34 @@ object IRFunctionRegistry { jvmRegistry.lift(name) match { case None => fatal( - s"no functions found with the signature $name(${valueParameterTypes.mkString(", ")}): $returnType" + s"no functions found with the signature ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." ) case Some(functions) => functions.filter(t => t.unify(typeParameters, valueParameterTypes, returnType) ).toSeq match { case Seq() => - val prettyFunctionSignature = - s"$name[${typeParameters.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType" val prettyMismatchedFunctionSignatures = functions.map(x => s" $x").mkString("\n") fatal( - s"No function found with the signature $prettyFunctionSignature.\n" + + s"No function found with the signature ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}.\n" + s"However, there are other functions with that name:\n$prettyMismatchedFunctionSignatures" ) case Seq(f) => f case _ => fatal( - s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(", ")})." + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)})." ) } } } + private[this] def prettyFunctionSignature( + name: String, + returnType: Type, + typeParameterTypes: Seq[Type], + valueParameterTypes: Seq[Type], + ): String = + s"$name[${typeParameterTypes.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType" + def lookupIR( name: String, returnType: Type, @@ -165,7 +191,9 @@ object IRFunctionRegistry { case Seq() => None case Seq(kv) => Some(kv) case _ => - fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).") + fatal( + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." + ) } }