diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index e7e14bbe90a8..dd35b670e660 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -392,3 +392,13 @@ def get_flags(self, *flags) -> Mapping[str, str]: @abc.abstractmethod def requires_lowering(self): pass + + @property + @abc.abstractmethod + def local_tmpdir(self) -> str: + pass + + @property + @abc.abstractmethod + def remote_tmpdir(self) -> str: + pass diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index d59ae456a937..c34ca1d6b986 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -196,8 +196,8 @@ def decode_bytearray(encoded): self._jhc = jhc self._jbackend = self._hail_package.backend.api.Py4JBackendApi(jbackend) - self._jbackend.pySetLocalTmp(tmpdir) - self._jbackend.pySetRemoteTmp(remote_tmpdir) + self.local_tmpdir = tmpdir + self.remote_tmpdir = tmpdir self._jhttp_server = self._jbackend.pyHttpServer() self._backend_server_port: int = self._jhttp_server.port() @@ -326,3 +326,21 @@ def stop(self): self._jhc = None uninstall_exception_handler() super().stop() + + @property + def local_tmpdir(self) -> str: + return self._local_tmpdir + + @local_tmpdir.setter + def local_tmpdir(self, tmpdir) -> str: + self._local_tmpdir = tmpdir + self._jbackend.pySetLocalTmp(tmpdir) + + @property + def remote_tmpdir(self) -> str: + return self._remote_tmpdir + + @remote_tmpdir.setter + def remote_tmpdir(self, tmpdir) -> str: + self._remote_tmpdir = tmpdir + self._jbackend.pySetRemoteTmp(tmpdir) diff --git a/hail/python/hail/backend/service_backend.py b/hail/python/hail/backend/service_backend.py index d753dab03ed1..e9fd13537c9f 100644 --- a/hail/python/hail/backend/service_backend.py +++ b/hail/python/hail/backend/service_backend.py @@ -240,6 +240,7 @@ def __init__( self._batch_was_submitted: bool = False self.disable_progress_bar = disable_progress_bar self.batch_attributes = batch_attributes + self.local_tmpdir = tmp_dir() self.remote_tmpdir = remote_tmpdir self.flags: Dict[str, str] = {} self._registered_ir_function_names: Set[str] = set() @@ -441,7 +442,7 @@ async def _async_rpc(self, action: ActionTag, payload: ActionPayload): return await self._run_on_batch( name=f'{action.name.lower()}(...)', service_backend_config=ServiceBackendRPCConfig( - tmp_dir=tmp_dir(), + tmp_dir=self.local_tmpdir, remote_tmpdir=self.remote_tmpdir, flags=self.flags, custom_references=[