Skip to content

Commit

Permalink
update python rpc methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Oct 1, 2024
1 parent d25cc7f commit 708053a
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 279 deletions.
80 changes: 64 additions & 16 deletions hail/python/hail/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import zipfile
from dataclasses import dataclass
from enum import Enum
from typing import AbstractSet, Any, ClassVar, Dict, List, Mapping, Optional, Tuple, TypeVar, Union
from typing import AbstractSet, Any, ClassVar, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union

import orjson

Expand Down Expand Up @@ -68,15 +68,45 @@ def local_jar_information() -> LocalJarInformation:
raise ValueError(f'Hail requires either {hail_jar} or {hail_all_spark_jar}.')


class IRFunction:
def __init__(
self,
name: str,
type_parameters: Union[Tuple[HailType, ...], List[HailType]],
value_parameter_names: Union[Tuple[str, ...], List[str]],
value_parameter_types: Union[Tuple[HailType, ...], List[HailType]],
return_type: HailType,
body: Expression,
):
assert len(value_parameter_names) == len(value_parameter_types)
render = CSERenderer()
self._name = name
self._type_parameters = type_parameters
self._value_parameter_names = value_parameter_names
self._value_parameter_types = value_parameter_types
self._return_type = return_type
self._rendered_body = render(finalize_randomness(body._ir))

def to_dataclass(self):
return SerializedIRFunction(
name=self._name,
type_parameters=[tp._parsable_string() for tp in self._type_parameters],
value_parameter_names=list(self._value_parameter_names),
value_parameter_types=[vpt._parsable_string() for vpt in self._value_parameter_types],
return_type=self._return_type._parsable_string(),
rendered_body=self._rendered_body,
)


class ActionTag(Enum):
LOAD_REFERENCES_FROM_DATASET = 1
VALUE_TYPE = 2
TABLE_TYPE = 3
MATRIX_TABLE_TYPE = 4
BLOCK_MATRIX_TYPE = 5
EXECUTE = 6
PARSE_VCF_METADATA = 7
IMPORT_FAM = 8
VALUE_TYPE = 1
TABLE_TYPE = 2
MATRIX_TABLE_TYPE = 3
BLOCK_MATRIX_TYPE = 4
EXECUTE = 5
PARSE_VCF_METADATA = 6
IMPORT_FAM = 7
LOAD_REFERENCES_FROM_DATASET = 8
FROM_FASTA_FILE = 9


Expand All @@ -90,11 +120,21 @@ class IRTypePayload(ActionPayload):
ir: str


@dataclass
class SerializedIRFunction:
name: str
type_parameters: List[str]
value_parameter_names: List[str]
value_parameter_types: List[str]
return_type: str
rendered_body: str


@dataclass
class ExecutePayload(ActionPayload):
ir: str
fns: List[SerializedIRFunction]
stream_codec: str
timed: bool


@dataclass
Expand Down Expand Up @@ -164,17 +204,24 @@ def _valid_flags(self) -> AbstractSet[str]:
def __init__(self):
self._persisted_locations = dict()
self._references = {}
self.functions: List[IRFunction] = []
self._registered_ir_function_names: Set[str] = set()

@abc.abstractmethod
def validate_file(self, uri: str):
raise NotImplementedError

@abc.abstractmethod
def stop(self):
pass
self.functions = []
self._registered_ir_function_names = set()

def execute(self, ir: BaseIR, timed: bool = False) -> Any:
payload = ExecutePayload(self._render_ir(ir), '{"name":"StreamBufferSpec"}', timed)
payload = ExecutePayload(
self._render_ir(ir),
fns=[fn.to_dataclass() for fn in self.functions],
stream_codec='{"name":"StreamBufferSpec"}',
)
try:
result, timings = self._rpc(ActionTag.EXECUTE, payload)
except FatalError as e:
Expand Down Expand Up @@ -300,7 +347,6 @@ def unpersist(self, dataset: Dataset) -> Dataset:
tempfile.__exit__(None, None, None)
return unpersisted

@abc.abstractmethod
def register_ir_function(
self,
name: str,
Expand All @@ -310,11 +356,13 @@ def register_ir_function(
return_type: HailType,
body: Expression,
):
pass
self._registered_ir_function_names.add(name)
self.functions.append(
IRFunction(name, type_parameters, value_parameter_names, value_parameter_types, return_type, body)
)

@abc.abstractmethod
def _is_registered_ir_function_name(self, name: str) -> bool:
pass
return name in self._registered_ir_function_names

@abc.abstractmethod
def persist_expression(self, expr: Expression) -> Expression:
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/backend/local_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def register_ir_function(
)

def stop(self):
super().stop()
super(Py4JBackend, self).stop()
self._exit_stack.close()
uninstall_exception_handler()

Expand Down
21 changes: 2 additions & 19 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socketserver
import sys
from threading import Thread
from typing import Mapping, Optional, Set, Tuple
from typing import Mapping, Optional, Tuple

import orjson
import py4j
Expand Down Expand Up @@ -192,8 +192,6 @@ def decode_bytearray(encoded):
self._backend_server.start()
self._requests_session = requests.Session()

self._registered_ir_function_names: Set[str] = set()

# This has to go after creating the SparkSession. Unclear why.
# Maybe it does its own patch?
install_exception_handler()
Expand Down Expand Up @@ -239,9 +237,6 @@ def persist_expression(self, expr):
t = expr.dtype
return construct_expr(JavaIR(t, self._jbackend.pyExecuteLiteral(self._render_ir(expr._ir))), t)

def _is_registered_ir_function_name(self, name: str) -> bool:
return name in self._registered_ir_function_names

def set_flags(self, **flags: Mapping[str, str]):
available = self._jbackend.pyAvailableFlags()
invalid = []
Expand Down Expand Up @@ -276,12 +271,6 @@ def add_liftover(self, name, chain_file, dest_reference_genome):
def remove_liftover(self, name, dest_reference_genome):
self._jbackend.pyRemoveLiftover(name, dest_reference_genome)

def _parse_value_ir(self, code, ref_map={}):
return self._jbackend.parse_value_ir(
code,
{k: t._parsable_string() for k, t in ref_map.items()},
)

def _register_ir_function(self, name, type_parameters, argument_names, argument_types, return_type, code):
self._registered_ir_function_names.add(name)
self._jbackend.pyRegisterIR(
Expand All @@ -293,12 +282,6 @@ def _register_ir_function(self, name, type_parameters, argument_names, argument_
code,
)

def _parse_table_ir(self, code):
return self._jbackend.parse_table_ir(code)

def _parse_matrix_ir(self, code):
return self._jbackend.parse_matrix_ir(code)

def _parse_blockmatrix_ir(self, code):
return self._jbackend.parse_blockmatrix_ir(code)

Expand All @@ -310,5 +293,5 @@ def stop(self):
self._jbackend.close()
self._jhc.stop()
self._jhc = None
self._registered_ir_function_names = set()
uninstall_exception_handler()
super().stop()
Loading

0 comments on commit 708053a

Please sign in to comment.