diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 995d0b668c42..e7e14bbe90a8 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -3,7 +3,7 @@ import zipfile from dataclasses import dataclass from enum import Enum -from typing import AbstractSet, Any, ClassVar, Dict, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import AbstractSet, Any, ClassVar, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union import orjson @@ -68,15 +68,45 @@ def local_jar_information() -> LocalJarInformation: raise ValueError(f'Hail requires either {hail_jar} or {hail_all_spark_jar}.') +class IRFunction: + def __init__( + self, + name: str, + type_parameters: Union[Tuple[HailType, ...], List[HailType]], + value_parameter_names: Union[Tuple[str, ...], List[str]], + value_parameter_types: Union[Tuple[HailType, ...], List[HailType]], + return_type: HailType, + body: Expression, + ): + assert len(value_parameter_names) == len(value_parameter_types) + render = CSERenderer() + self._name = name + self._type_parameters = type_parameters + self._value_parameter_names = value_parameter_names + self._value_parameter_types = value_parameter_types + self._return_type = return_type + self._rendered_body = render(finalize_randomness(body._ir)) + + def to_dataclass(self): + return SerializedIRFunction( + name=self._name, + type_parameters=[tp._parsable_string() for tp in self._type_parameters], + value_parameter_names=list(self._value_parameter_names), + value_parameter_types=[vpt._parsable_string() for vpt in self._value_parameter_types], + return_type=self._return_type._parsable_string(), + rendered_body=self._rendered_body, + ) + + class ActionTag(Enum): - LOAD_REFERENCES_FROM_DATASET = 1 - VALUE_TYPE = 2 - TABLE_TYPE = 3 - MATRIX_TABLE_TYPE = 4 - BLOCK_MATRIX_TYPE = 5 - EXECUTE = 6 - PARSE_VCF_METADATA = 7 - IMPORT_FAM = 8 + VALUE_TYPE = 1 + TABLE_TYPE = 2 + MATRIX_TABLE_TYPE = 3 + BLOCK_MATRIX_TYPE = 4 + EXECUTE = 5 + PARSE_VCF_METADATA = 6 + IMPORT_FAM = 7 + LOAD_REFERENCES_FROM_DATASET = 8 FROM_FASTA_FILE = 9 @@ -90,11 +120,21 @@ class IRTypePayload(ActionPayload): ir: str +@dataclass +class SerializedIRFunction: + name: str + type_parameters: List[str] + value_parameter_names: List[str] + value_parameter_types: List[str] + return_type: str + rendered_body: str + + @dataclass class ExecutePayload(ActionPayload): ir: str + fns: List[SerializedIRFunction] stream_codec: str - timed: bool @dataclass @@ -164,6 +204,8 @@ def _valid_flags(self) -> AbstractSet[str]: def __init__(self): self._persisted_locations = dict() self._references = {} + self.functions: List[IRFunction] = [] + self._registered_ir_function_names: Set[str] = set() @abc.abstractmethod def validate_file(self, uri: str): @@ -171,10 +213,15 @@ def validate_file(self, uri: str): @abc.abstractmethod def stop(self): - pass + self.functions = [] + self._registered_ir_function_names = set() def execute(self, ir: BaseIR, timed: bool = False) -> Any: - payload = ExecutePayload(self._render_ir(ir), '{"name":"StreamBufferSpec"}', timed) + payload = ExecutePayload( + self._render_ir(ir), + fns=[fn.to_dataclass() for fn in self.functions], + stream_codec='{"name":"StreamBufferSpec"}', + ) try: result, timings = self._rpc(ActionTag.EXECUTE, payload) except FatalError as e: @@ -300,7 +347,6 @@ def unpersist(self, dataset: Dataset) -> Dataset: tempfile.__exit__(None, None, None) return unpersisted - @abc.abstractmethod def register_ir_function( self, name: str, @@ -310,11 +356,13 @@ def register_ir_function( return_type: HailType, body: Expression, ): - pass + self._registered_ir_function_names.add(name) + self.functions.append( + IRFunction(name, type_parameters, value_parameter_names, value_parameter_types, return_type, body) + ) - @abc.abstractmethod def _is_registered_ir_function_name(self, name: str) -> bool: - pass + return name in self._registered_ir_function_names @abc.abstractmethod def persist_expression(self, expr: Expression) -> Expression: diff --git a/hail/python/hail/backend/local_backend.py b/hail/python/hail/backend/local_backend.py index 708eacaef2d7..7bcb1145259e 100644 --- a/hail/python/hail/backend/local_backend.py +++ b/hail/python/hail/backend/local_backend.py @@ -119,7 +119,7 @@ def register_ir_function( ) def stop(self): - super().stop() + super(Py4JBackend, self).stop() self._exit_stack.close() uninstall_exception_handler() diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index b9f986e5834a..697269b152be 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -4,7 +4,7 @@ import socketserver import sys from threading import Thread -from typing import Mapping, Optional, Set, Tuple +from typing import Mapping, Optional, Tuple import orjson import py4j @@ -192,8 +192,6 @@ def decode_bytearray(encoded): self._backend_server.start() self._requests_session = requests.Session() - self._registered_ir_function_names: Set[str] = set() - # This has to go after creating the SparkSession. Unclear why. # Maybe it does its own patch? install_exception_handler() @@ -239,9 +237,6 @@ def persist_expression(self, expr): t = expr.dtype return construct_expr(JavaIR(t, self._jbackend.pyExecuteLiteral(self._render_ir(expr._ir))), t) - def _is_registered_ir_function_name(self, name: str) -> bool: - return name in self._registered_ir_function_names - def set_flags(self, **flags: Mapping[str, str]): available = self._jbackend.pyAvailableFlags() invalid = [] @@ -276,12 +271,6 @@ def add_liftover(self, name, chain_file, dest_reference_genome): def remove_liftover(self, name, dest_reference_genome): self._jbackend.pyRemoveLiftover(name, dest_reference_genome) - def _parse_value_ir(self, code, ref_map={}): - return self._jbackend.parse_value_ir( - code, - {k: t._parsable_string() for k, t in ref_map.items()}, - ) - def _register_ir_function(self, name, type_parameters, argument_names, argument_types, return_type, code): self._registered_ir_function_names.add(name) self._jbackend.pyRegisterIR( @@ -293,12 +282,6 @@ def _register_ir_function(self, name, type_parameters, argument_names, argument_ code, ) - def _parse_table_ir(self, code): - return self._jbackend.parse_table_ir(code) - - def _parse_matrix_ir(self, code): - return self._jbackend.parse_matrix_ir(code) - def _parse_blockmatrix_ir(self, code): return self._jbackend.parse_blockmatrix_ir(code) @@ -310,5 +293,5 @@ def stop(self): self._jbackend.close() self._jhc.stop() self._jhc = None - self._registered_ir_function_names = set() uninstall_exception_handler() + super().stop() diff --git a/hail/python/hail/backend/service_backend.py b/hail/python/hail/backend/service_backend.py index 5a5a24571736..8d57f1e0211f 100644 --- a/hail/python/hail/backend/service_backend.py +++ b/hail/python/hail/backend/service_backend.py @@ -12,10 +12,6 @@ import hailtop.aiotools.fs as afs from hail.context import TemporaryDirectory, TemporaryFilename, revision, tmp_dir, version from hail.experimental import read_expression, write_expression -from hail.expr.expressions.base_expression import Expression -from hail.expr.types import HailType -from hail.ir import finalize_randomness -from hail.ir.renderer import CSERenderer from hail.utils import FatalError from hailtop import yamlx from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration, get_gcs_requester_pays_configuration @@ -32,7 +28,7 @@ from ..builtin_references import BUILTIN_REFERENCES from ..utils import ANY_REGION -from .backend import ActionPayload, ActionTag, Backend, ExecutePayload, fatal_error_from_java_error_triplet +from .backend import ActionPayload, ActionTag, Backend, fatal_error_from_java_error_triplet ReferenceGenomeConfig = Dict[str, Any] @@ -66,53 +62,6 @@ async def read_str(strm: afs.ReadableStream) -> str: return b.decode('utf-8') -@dataclass -class SerializedIRFunction: - name: str - type_parameters: List[str] - value_parameter_names: List[str] - value_parameter_types: List[str] - return_type: str - rendered_body: str - - -class IRFunction: - def __init__( - self, - name: str, - type_parameters: Union[Tuple[HailType, ...], List[HailType]], - value_parameter_names: Union[Tuple[str, ...], List[str]], - value_parameter_types: Union[Tuple[HailType, ...], List[HailType]], - return_type: HailType, - body: Expression, - ): - assert len(value_parameter_names) == len(value_parameter_types) - render = CSERenderer() - self._name = name - self._type_parameters = type_parameters - self._value_parameter_names = value_parameter_names - self._value_parameter_types = value_parameter_types - self._return_type = return_type - self._rendered_body = render(finalize_randomness(body._ir)) - - def to_dataclass(self): - return SerializedIRFunction( - name=self._name, - type_parameters=[tp._parsable_string() for tp in self._type_parameters], - value_parameter_names=list(self._value_parameter_names), - value_parameter_types=[vpt._parsable_string() for vpt in self._value_parameter_types], - return_type=self._return_type._parsable_string(), - rendered_body=self._rendered_body, - ) - - -@dataclass -class ServiceBackendExecutePayload(ActionPayload): - functions: List[SerializedIRFunction] - idempotency_token: str - payload: ExecutePayload - - @dataclass class CloudfuseConfig: bucket: str @@ -131,15 +80,21 @@ class ServiceBackendRPCConfig: tmp_dir: str remote_tmpdir: str billing_project: str + flags: Dict[str, str] + custom_references: List[str] + liftovers: Dict[str, Dict[str, str]] + sequences: Dict[str, SequenceConfig] + + +@dataclass +class BatchJobConfig: + token: str + billing_project: str worker_cores: str worker_memory: str storage: str cloudfuse_configs: List[CloudfuseConfig] regions: List[str] - flags: Dict[str, str] - custom_references: List[str] - liftovers: Dict[str, Dict[str, str]] - sequences: Dict[str, SequenceConfig] class ServiceBackend(Backend): @@ -148,14 +103,14 @@ class ServiceBackend(Backend): DRIVER = "driver" # is.hail.backend.service.ServiceBackendSocketAPI2 protocol - LOAD_REFERENCES_FROM_DATASET = 1 - VALUE_TYPE = 2 - TABLE_TYPE = 3 - MATRIX_TABLE_TYPE = 4 - BLOCK_MATRIX_TYPE = 5 - EXECUTE = 6 - PARSE_VCF_METADATA = 7 - IMPORT_FAM = 8 + VALUE_TYPE = 1 + TABLE_TYPE = 2 + MATRIX_TABLE_TYPE = 3 + BLOCK_MATRIX_TYPE = 4 + EXECUTE = 5 + PARSE_VCF_METADATA = 6 + IMPORT_FAM = 7 + LOAD_REFERENCES_FROM_DATASET = 8 FROM_FASTA_FILE = 9 @staticmethod @@ -288,7 +243,6 @@ def __init__( self.batch_attributes = batch_attributes self.remote_tmpdir = remote_tmpdir self.flags: Dict[str, str] = {} - self.functions: List[IRFunction] = [] self._registered_ir_function_names: Set[str] = set() self.driver_cores = driver_cores self.driver_memory = driver_memory @@ -333,16 +287,16 @@ def logger(self): def stop(self): hail_event_loop().run_until_complete(self._stop()) + super().stop() async def _stop(self): await self._async_exit_stack.aclose() - self.functions = [] - self._registered_ir_function_names = set() async def _run_on_batch( self, name: str, service_backend_config: ServiceBackendRPCConfig, + job_config: BatchJobConfig, action: ActionTag, payload: ActionPayload, *, @@ -356,7 +310,8 @@ async def _run_on_batch( async with await self._async_fs.create(iodir + '/in') as infile: await infile.write( orjson.dumps({ - 'config': service_backend_config, + 'rpc_config': service_backend_config, + 'job_config': job_config, 'action': action.value, 'payload': payload, }) @@ -466,11 +421,6 @@ def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, Option return self._cancel_on_ctrl_c(self._async_rpc(action, payload)) async def _async_rpc(self, action: ActionTag, payload: ActionPayload): - if isinstance(payload, ExecutePayload): - payload = ServiceBackendExecutePayload( - [f.to_dataclass() for f in self.functions], self._batch.token, payload - ) - storage_requirement_bytes = 0 readonly_fuse_buckets: Set[str] = set() @@ -485,31 +435,38 @@ async def _async_rpc(self, action: ActionTag, payload: ActionPayload): readonly_fuse_buckets.add(bucket) storage_requirement_bytes += await (await self._async_fs.statfile(blob)).size() sequence_file_mounts[rg_name] = SequenceConfig( - f'/cloudfuse/{fasta_bucket}/{fasta_path}', f'/cloudfuse/{index_bucket}/{index_path}' + f'/cloudfuse/{fasta_bucket}/{fasta_path}', + f'/cloudfuse/{index_bucket}/{index_path}', ) - storage_gib_str = f'{math.ceil(storage_requirement_bytes / 1024 / 1024 / 1024)}Gi' - qob_config = ServiceBackendRPCConfig( - tmp_dir=tmp_dir(), - remote_tmpdir=self.remote_tmpdir, - billing_project=self.billing_project, - worker_cores=str(self.worker_cores), - worker_memory=str(self.worker_memory), - storage=storage_gib_str, - cloudfuse_configs=[ - CloudfuseConfig(bucket, f'/cloudfuse/{bucket}', True) for bucket in readonly_fuse_buckets - ], - regions=self.regions, - flags=self.flags, - custom_references=[ - orjson.dumps(rg._config).decode('utf-8') - for rg in self._references.values() - if rg.name not in BUILTIN_REFERENCES - ], - liftovers={rg.name: rg._liftovers for rg in self._references.values() if len(rg._liftovers) > 0}, - sequences=sequence_file_mounts, + return await self._run_on_batch( + name=f'{action.name.lower()}(...)', + service_backend_config=ServiceBackendRPCConfig( + tmp_dir=tmp_dir(), + remote_tmpdir=self.remote_tmpdir, + flags=self.flags, + custom_references=[ + orjson.dumps(rg._config).decode('utf-8') + for rg in self._references.values() + if rg.name not in BUILTIN_REFERENCES + ], + liftovers={rg.name: rg._liftovers for rg in self._references.values() if len(rg._liftovers) > 0}, + sequences=sequence_file_mounts, + ), + job_config=BatchJobConfig( + token=self._batch.token, + billing_project=self.billing_project, + worker_cores=str(self.worker_cores), + worker_memory=str(self.worker_memory), + storage=f'{math.ceil(storage_requirement_bytes / 1024 / 1024 / 1024)}Gi', + cloudfuse_configs=[ + CloudfuseConfig(bucket, f'/cloudfuse/{bucket}', True) for bucket in readonly_fuse_buckets + ], + regions=self.regions, + ), + action=action, + payload=payload, ) - return await self._run_on_batch(f'{action.name.lower()}(...)', qob_config, action, payload) # Sequence and liftover information is stored on the ReferenceGenome # and there is no persistent backend to keep in sync. @@ -532,23 +489,6 @@ def add_liftover(self, name: str, chain_file: str, dest_reference_genome: str): def remove_liftover(self, name, dest_reference_genome): # pylint: disable=unused-argument pass - def register_ir_function( - self, - name: str, - type_parameters: Union[Tuple[HailType, ...], List[HailType]], - value_parameter_names: Union[Tuple[str, ...], List[str]], - value_parameter_types: Union[Tuple[HailType, ...], List[HailType]], - return_type: HailType, - body: Expression, - ): - self._registered_ir_function_names.add(name) - self.functions.append( - IRFunction(name, type_parameters, value_parameter_names, value_parameter_types, return_type, body) - ) - - def _is_registered_ir_function_name(self, name: str) -> bool: - return name in self._registered_ir_function_names - def persist_expression(self, expr): # FIXME: should use context manager to clean up persisted resources fname = TemporaryFilename(prefix='persist_expression').name diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index a4de4a62dd95..24ad808d3ee5 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -1,7 +1,7 @@ import functools import re import unittest -from test.hail.helpers import resource, skip_unless_spark_backend, skip_when_service_backend +from test.hail.helpers import resource, skip_unless_spark_backend import numpy as np import pytest @@ -188,12 +188,6 @@ def value_ir(value_irs, request): return value_irs[request.param] -@skip_when_service_backend() -def test_ir_parses(value_ir): - env = value_irs_env() - Env.backend()._parse_value_ir(str(value_ir), env) - - def test_ir_value_type(value_ir): env = value_irs_env() typ = Env.backend().value_type( @@ -312,11 +306,6 @@ def table_ir(table_irs, request): return table_irs[request.param] -@skip_when_service_backend() -def test_table_ir_parses(table_ir): - Env.backend()._parse_table_ir(str(table_ir)) - - def test_table_ir_table_type(table_ir): typ = Env.backend().table_type(table_ir) assert table_ir.typ == typ @@ -419,11 +408,6 @@ def matrix_ir(matrix_irs, request): return matrix_irs[request.param] -@skip_when_service_backend() -def test_matrix_ir_parses(matrix_ir): - Env.backend()._parse_matrix_ir(str(matrix_ir)) - - def test_matrix_ir_matrix_type(matrix_ir): typ = Env.backend().matrix_type(matrix_ir) assert typ == matrix_ir.typ diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index bf0f9425c0ca..76281713a511 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -15,11 +15,8 @@ import is.hail.utils.fatal import scala.reflect.ClassTag import java.io.{Closeable, OutputStream} -import java.nio.charset.StandardCharsets import com.fasterxml.jackson.core.StreamReadConstraints -import org.json4s.JValue -import org.json4s.jackson.JsonMethods import sourcecode.Enclosing object Backend { @@ -48,9 +45,6 @@ object Backend { assert(t.isFieldDefined(off, 0)) codec.encode(ctx, elementType, t.loadField(off, 0), os) } - - def jsonToBytes(f: => JValue): Array[Byte] = - JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8) } abstract class BroadcastValue[T] { def value: T } diff --git a/hail/src/main/scala/is/hail/backend/BackendRpc.scala b/hail/src/main/scala/is/hail/backend/BackendRpc.scala index 64f15af7be9d..ea9bff7d5656 100644 --- a/hail/src/main/scala/is/hail/backend/BackendRpc.scala +++ b/hail/src/main/scala/is/hail/backend/BackendRpc.scala @@ -5,21 +5,21 @@ import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.io.BufferSpec import is.hail.io.plink.LoadPlink import is.hail.io.vcf.LoadVCF -import is.hail.linalg.RowMatrix import is.hail.services.retryTransientErrors import is.hail.types.virtual.{Kind, TFloat64, VType} import is.hail.types.virtual.Kinds._ -import is.hail.utils.{toRichIterable, using, ExecutionTimer} +import is.hail.utils.{using, ExecutionTimer} import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome +import scala.language.existentials import scala.util.control.NonFatal import java.io.ByteArrayOutputStream import java.nio.charset.StandardCharsets -import org.json4s.{DefaultFormats, Extraction, Formats, JValue} -import org.json4s.jackson.{JsonMethods, Serialization} +import org.json4s.{DefaultFormats, Extraction, Formats, JArray, JValue} +import org.json4s.jackson.JsonMethods case class IRTypePayload(ir: String) case class LoadReferencesFromDatasetPayload(path: String) @@ -39,7 +39,7 @@ case class ImportFamPayload(path: String, quant_pheno: Boolean, delimiter: Strin case class ExecutePayload( ir: String, - fs: Array[SerializedIRFunction], + fns: Array[SerializedIRFunction], stream_codec: String, ) @@ -75,17 +75,6 @@ trait BackendRpc { mt_contigs: Array[String], par: Array[String], ) extends Command - - case class ExportBlockMatrix( - pathIn: String, - pathOut: String, - delimiter: String, - header: String, - addIndex: Boolean, - exportType: String, - partitionSize: Int, - entries: String, - ) extends Command } trait Ask[Env] { @@ -144,14 +133,16 @@ trait BackendRpc { } case ImportFam(path, isQuantPheno, delimiter, missing) => - LoadPlink - .importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missing) - .getBytes(StandardCharsets.UTF_8) + jsonToBytes { + LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missing) + } case LoadReferencesFromDataset(path) => - val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) - ctx.References ++= rgs.map(rg => rg.name -> rg) - Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) + jsonToBytes { + val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) + ctx.References ++= rgs.map(rg => rg.name -> rg) + JArray(rgs.map(_.toJSON).toList) + } case LoadReferencesFromFASTA(name, fasta, index, xContigs, yContigs, mtContigs, par) => jsonToBytes { @@ -168,51 +159,6 @@ trait BackendRpc { ctx.References += rg.name -> rg rg.toJSON } - - case ExportBlockMatrix(pathIn, pathOut, delimiter, header, addIndex, exportType, - partitionSize, entries) => - val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize) - entries match { - case "full" => - rm.export(ctx, pathOut, delimiter, Option(header), addIndex, exportType) - case "lower" => - rm.exportLowerTriangle( - ctx, - pathOut, - delimiter, - Option(header), - addIndex, - exportType, - ) - case "strict_lower" => - rm.exportStrictLowerTriangle( - ctx, - pathOut, - delimiter, - Option(header), - addIndex, - exportType, - ) - case "upper" => - rm.exportUpperTriangle( - ctx, - pathOut, - delimiter, - Option(header), - addIndex, - exportType, - ) - case "strict_upper" => - rm.exportStrictUpperTriangle( - ctx, - pathOut, - delimiter, - Option(header), - addIndex, - exportType, - ) - } - Array() } } } @@ -276,8 +222,8 @@ trait HttpLikeBackendRpc[A] extends BackendRpc { case Routes.TypeOf(k) => TypeOf(k, payload(a).extract[IRTypePayload].ir) case Routes.Execute => - val ExecutePayload(ir, fs, codec) = payload(a).extract[ExecutePayload] - Execute(ir, fs, codec) + val ExecutePayload(ir, fns, codec) = payload(a).extract[ExecutePayload] + Execute(ir, fns, codec) case Routes.ParseVcfMetadata => ParseVcfMetadata(payload(a).extract[ParseVCFMetadataPayload].path) case Routes.ImportFam => diff --git a/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala index 88881bc435bd..13b89888876c 100644 --- a/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala +++ b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala @@ -3,11 +3,7 @@ package is.hail.backend.py4j import is.hail.HailFeatureFlags import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager} import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} -import is.hail.expr.ir.{ - BaseIR, BindingEnv, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IR, IRParser, Interpret, - MatrixIR, MatrixNativeReader, MatrixRead, Name, NativeReaderOptions, TableIR, TableLiteral, - TableValue, -} +import is.hail.expr.ir.{BaseIR, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IRParser, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue} import is.hail.expr.ir.IRParser.parseType import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} @@ -18,9 +14,7 @@ import is.hail.utils.{fatal, log, toRichIterable, HailException, Interval} import is.hail.variant.ReferenceGenome import scala.collection.mutable -import scala.jdk.CollectionConverters.{ - asScalaBufferConverter, mapAsScalaMapConverter, seqAsJavaListConverter, -} +import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter} import java.util @@ -224,23 +218,6 @@ trait Py4JBackendExtensions { private[this] def removeReference(name: String): Unit = references -= name - def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = - backend.withExecuteContext { ctx => - IRParser.parse_value_ir( - ctx, - s, - BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) => - Name(n) -> IRParser.parseType(t) - }.toSeq: _*), - ) - }._1 - - def parse_table_ir(s: String): TableIR = - withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_table_ir(ctx, s)) - - def parse_matrix_ir(s: String): MatrixIR = - withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_matrix_ir(ctx, s)) - def parse_blockmatrix_ir(s: String): BlockMatrixIR = withExecuteContext(selfContainedExecution = false) { ctx => IRParser.parse_blockmatrix_ir(ctx, s) diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 6c885a5c4197..031114a97474 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -24,6 +24,7 @@ import is.hail.utils._ import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome +import scala.annotation.switch import scala.collection.mutable import scala.reflect.ClassTag @@ -362,7 +363,7 @@ object ServiceBackendAPI extends HttpLikeBackendRpc[Request] with Logging { using(fs.openNoCompression(inputURL)) { is => val input = JsonMethods.parse(is) ( - (input \ "config").extract[ServiceBackendRPCPayload], + (input \ "rpc_config").extract[ServiceBackendRPCPayload], (input \ "job_config").extract[BatchJobConfig], (input \ "action").extract[Int], input \ "payload", @@ -424,15 +425,15 @@ object ServiceBackendAPI extends HttpLikeBackendRpc[Request] with Logging { import Routes._ override def route(a: Request): Route = - a.action match { - case 2 => TypeOf(Kinds.Value) - case 3 => TypeOf(Kinds.Table) - case 4 => TypeOf(Kinds.Matrix) - case 5 => TypeOf(Kinds.BlockMatrix) - case 6 => Execute - case 7 => ParseVcfMetadata - case 8 => ImportFam - case 1 => LoadReferencesFromDataset + (a.action: @switch) match { + case 1 => TypeOf(Kinds.Value) + case 2 => TypeOf(Kinds.Table) + case 3 => TypeOf(Kinds.Matrix) + case 4 => TypeOf(Kinds.BlockMatrix) + case 5 => Execute + case 6 => ParseVcfMetadata + case 7 => ImportFam + case 8 => LoadReferencesFromDataset case 9 => LoadReferencesFromFASTA } @@ -535,6 +536,7 @@ case class ServiceBackendRPCPayload( ) case class BatchJobConfig( + token: String, billing_project: String, worker_cores: String, worker_memory: String, diff --git a/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala b/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala index 9c4cd0d5f43e..a4cf8a5fa436 100644 --- a/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala +++ b/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala @@ -19,7 +19,6 @@ import is.hail.variant._ import org.apache.spark.TaskContext import org.apache.spark.sql.Row import org.json4s.{DefaultFormats, Formats, JValue} -import org.json4s.jackson.JsonMethods case class FamFileConfig( isQuantPheno: Boolean = false, @@ -82,14 +81,13 @@ object LoadPlink { isQuantPheno: Boolean, delimiter: String, missingValue: String, - ): String = { + ): JValue = { val ffConfig = FamFileConfig(isQuantPheno, delimiter, missingValue) val (data, ptyp) = LoadPlink.parseFam(fs, path, ffConfig) - val jv = JSONAnnotationImpex.exportAnnotation( + JSONAnnotationImpex.exportAnnotation( Row(ptyp.virtualType.toString, data), TStruct("type" -> TString, "data" -> TArray(ptyp.virtualType)), ) - JsonMethods.compact(jv) } def parseFam(fs: FS, filename: String, ffConfig: FamFileConfig) diff --git a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala index 631a95c07159..694280806cbb 100644 --- a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala @@ -155,6 +155,7 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV sequences = Map(), ), BatchJobConfig( + token = tokenUrlSafe, billing_project = "fancy", worker_cores = "128", worker_memory = "a lot.",