Skip to content

Commit

Permalink
[query] make local and remote tmp settable on backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Jan 22, 2025
1 parent 19bf763 commit aa8d67f
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 21 deletions.
20 changes: 20 additions & 0 deletions hail/python/hail/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,23 @@ def get_flags(self, *flags) -> Mapping[str, str]:
@abc.abstractmethod
def requires_lowering(self):
pass

@property
@abc.abstractmethod
def local_tmpdir(self) -> str:
pass

@local_tmpdir.setter
@abc.abstractmethod
def local_tmpdir(self, dir: str) -> None:
pass

@property
@abc.abstractmethod
def remote_tmpdir(self) -> str:
pass

@remote_tmpdir.setter
@abc.abstractmethod
def remote_tmpdir(self, dir: str) -> None:
pass
22 changes: 20 additions & 2 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def decode_bytearray(encoded):
self._jhc = jhc

self._jbackend = self._hail_package.backend.api.Py4JBackendApi(jbackend)
self._jbackend.pySetLocalTmp(tmpdir)
self._jbackend.pySetRemoteTmp(remote_tmpdir)
self.local_tmpdir = tmpdir
self.remote_tmpdir = tmpdir

self._jhttp_server = self._jbackend.pyHttpServer()
self._backend_server_port: int = self._jhttp_server.port()
Expand Down Expand Up @@ -326,3 +326,21 @@ def stop(self):
self._jhc = None
uninstall_exception_handler()
super().stop()

@property
def local_tmpdir(self) -> str:
return self._local_tmpdir

@local_tmpdir.setter
def local_tmpdir(self, tmpdir: str) -> None:
self._local_tmpdir = tmpdir
self._jbackend.pySetLocalTmp(tmpdir)

@property
def remote_tmpdir(self) -> str:
return self._remote_tmpdir

@remote_tmpdir.setter
def remote_tmpdir(self, tmpdir: str) -> None:
self._remote_tmpdir = tmpdir
self._jbackend.pySetRemoteTmp(tmpdir)
22 changes: 19 additions & 3 deletions hail/python/hail/backend/service_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import warnings
from contextlib import AsyncExitStack
from dataclasses import dataclass
from typing import Any, Awaitable, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union
from typing import Any, Awaitable, Dict, List, Mapping, NoReturn, Optional, Set, Tuple, TypeVar, Union

import orjson

import hailtop.aiotools.fs as afs
from hail.context import TemporaryDirectory, TemporaryFilename, revision, tmp_dir, version
from hail.context import TemporaryDirectory, TemporaryFilename, revision, version
from hail.experimental import read_expression, write_expression
from hail.utils import FatalError
from hailtop import yamlx
Expand Down Expand Up @@ -239,7 +239,7 @@ def __init__(
self._batch_was_submitted: bool = False
self.disable_progress_bar = disable_progress_bar
self.batch_attributes = batch_attributes
self.remote_tmpdir = remote_tmpdir
self._remote_tmpdir = remote_tmpdir
self.flags: Dict[str, str] = {}
self._registered_ir_function_names: Set[str] = set()
self.driver_cores = driver_cores
Expand Down Expand Up @@ -519,3 +519,19 @@ def get_flags(self, *flags: str) -> Mapping[str, str]:
@property
def requires_lowering(self):
return True

@property
def local_tmpdir(self) -> NoReturn:
raise AttributeError('local tmp folders are not supported on the batch backend')

@local_tmpdir.setter
def local_tmpdir(self, tmpdir: str) -> NoReturn:
raise AttributeError('local tmp folders are not supported on the batch backend')

@property
def remote_tmpdir(self) -> str:
return self._remote_tmpdir

@remote_tmpdir.setter
def remote_tmpdir(self, tmpdir: str) -> None:
self._remote_tmpdir = tmpdir
28 changes: 13 additions & 15 deletions hail/python/hail/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ def create(
log: str,
quiet: bool,
append: bool,
tmpdir: str,
local_tmpdir: str,
default_reference: str,
global_seed: Optional[int],
backend: Backend,
Expand All @@ -77,25 +75,17 @@ def create(
log=log,
quiet=quiet,
append=append,
tmpdir=tmpdir,
local_tmpdir=local_tmpdir,
global_seed=global_seed,
backend=backend,
)
hc.initialize_references(default_reference)
return hc

@typecheck_method(
log=str, quiet=bool, append=bool, tmpdir=str, local_tmpdir=str, global_seed=nullable(int), backend=Backend
)
def __init__(self, log, quiet, append, tmpdir, local_tmpdir, global_seed, backend):
@typecheck_method(log=str, quiet=bool, append=bool, global_seed=nullable(int), backend=Backend)
def __init__(self, log, quiet, append, global_seed, backend: Backend):
assert not Env._hc

self._log = log

self._tmpdir = tmpdir
self._local_tmpdir = local_tmpdir

self._backend = backend

self._warn_cols_order = True
Expand Down Expand Up @@ -138,6 +128,14 @@ def initialize_references(self, default_reference):
else:
self._default_ref = ReferenceGenome.read(default_reference)

@property
def _tmpdir(self) -> str:
return self._backend.remote_tmpdir

@property
def _local_tmpdir(self) -> str:
return self._backend.local_tmpdir

@property
def default_reference(self) -> ReferenceGenome:
assert self._default_ref is not None, '_default_ref should have been initialized in HailContext.create'
Expand Down Expand Up @@ -500,7 +498,7 @@ def init_spark(
if not backend.fs.exists(tmpdir):
backend.fs.mkdir(tmpdir)

HailContext.create(log, quiet, append, tmpdir, local_tmpdir, default_reference, global_seed, backend)
HailContext.create(log, quiet, append, default_reference, global_seed, backend)
if not quiet:
connect_logger(backend._utils_package_object, 'localhost', 12888)

Expand Down Expand Up @@ -571,7 +569,7 @@ async def init_batch(
tmpdir = os.path.join(backend.remote_tmpdir, 'tmp/hail', secret_alnum_string())
local_tmpdir = _get_local_tmpdir(local_tmpdir)

HailContext.create(log, quiet, append, tmpdir, local_tmpdir, default_reference, global_seed, backend)
HailContext.create(log, quiet, append, default_reference, global_seed, backend)


@typecheck(
Expand Down Expand Up @@ -623,7 +621,7 @@ def init_local(
if not backend.fs.exists(tmpdir):
backend.fs.mkdir(tmpdir)

HailContext.create(log, quiet, append, tmpdir, tmpdir, default_reference, global_seed, backend)
HailContext.create(log, quiet, append, default_reference, global_seed, backend)
if not quiet:
connect_logger(backend._utils_package_object, 'localhost', 12888)

Expand Down
10 changes: 9 additions & 1 deletion hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,15 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
synchronized { tmpdir = tmp }

def pySetLocalTmp(tmp: String): Unit =
synchronized { localTmpdir = tmp }
synchronized {
localTmpdir = tmp
backend match {
case s: SparkBackend =>
s.sc.getConf.set("spark.local.dir", tmp)
case _ =>
()
}
}

def pySetGcsRequesterPaysConfig(project: String, buckets: util.List[String]): Unit =
synchronized {
Expand Down

0 comments on commit aa8d67f

Please sign in to comment.