Skip to content

Commit dd84033

Browse files
committed
[query] make local and remote tmp settable on backend
1 parent 6f3042d commit dd84033

File tree

5 files changed

+81
-21
lines changed

5 files changed

+81
-21
lines changed

hail/hail/src/is/hail/backend/api/Py4JBackendApi.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,15 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
7070
synchronized { tmpdir = tmp }
7171

7272
def pySetLocalTmp(tmp: String): Unit =
73-
synchronized { localTmpdir = tmp }
73+
synchronized {
74+
localTmpdir = tmp
75+
backend match {
76+
case s: SparkBackend =>
77+
s.sc.getConf.set("spark.local.dir", tmp)
78+
case _ =>
79+
()
80+
}
81+
}
7482

7583
def pySetGcsRequesterPaysConfig(project: String, buckets: util.List[String]): Unit =
7684
synchronized {

hail/python/hail/backend/backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,23 @@ def get_flags(self, *flags) -> Mapping[str, str]:
392392
@abc.abstractmethod
393393
def requires_lowering(self):
394394
pass
395+
396+
@property
397+
@abc.abstractmethod
398+
def local_tmpdir(self) -> str:
399+
pass
400+
401+
@local_tmpdir.setter
402+
@abc.abstractmethod
403+
def local_tmpdir(self, dir: str) -> None:
404+
pass
405+
406+
@property
407+
@abc.abstractmethod
408+
def remote_tmpdir(self) -> str:
409+
pass
410+
411+
@remote_tmpdir.setter
412+
@abc.abstractmethod
413+
def remote_tmpdir(self, dir: str) -> None:
414+
pass

hail/python/hail/backend/py4j_backend.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,8 @@ def decode_bytearray(encoded):
197197
self._jhc = jhc
198198

199199
self._jbackend = self._hail_package.backend.api.Py4JBackendApi(jbackend)
200-
self._jbackend.pySetLocalTmp(tmpdir)
201-
self._jbackend.pySetRemoteTmp(remote_tmpdir)
200+
self.local_tmpdir = tmpdir
201+
self.remote_tmpdir = tmpdir
202202

203203
self._jhttp_server = self._jbackend.pyHttpServer()
204204
self._backend_server_port: int = self._jhttp_server.port()
@@ -325,3 +325,21 @@ def stop(self):
325325
self._jhc = None
326326
uninstall_exception_handler()
327327
super().stop()
328+
329+
@property
330+
def local_tmpdir(self) -> str:
331+
return self._local_tmpdir
332+
333+
@local_tmpdir.setter
334+
def local_tmpdir(self, tmpdir: str) -> None:
335+
self._local_tmpdir = tmpdir
336+
self._jbackend.pySetLocalTmp(tmpdir)
337+
338+
@property
339+
def remote_tmpdir(self) -> str:
340+
return self._remote_tmpdir
341+
342+
@remote_tmpdir.setter
343+
def remote_tmpdir(self, tmpdir: str) -> None:
344+
self._remote_tmpdir = tmpdir
345+
self._jbackend.pySetRemoteTmp(tmpdir)

hail/python/hail/backend/service_backend.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import warnings
66
from contextlib import AsyncExitStack
77
from dataclasses import dataclass
8-
from typing import Any, Awaitable, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union
8+
from typing import Any, Awaitable, Dict, List, Mapping, NoReturn, Optional, Set, Tuple, TypeVar, Union
99

1010
import orjson
1111

1212
import hailtop.aiotools.fs as afs
13-
from hail.context import TemporaryDirectory, TemporaryFilename, tmp_dir
13+
from hail.context import TemporaryDirectory, TemporaryFilename
1414
from hail.experimental import read_expression, write_expression
1515
from hail.utils import FatalError
1616
from hail.version import __revision__, __version__
@@ -240,7 +240,7 @@ def __init__(
240240
self._batch_was_submitted: bool = False
241241
self.disable_progress_bar = disable_progress_bar
242242
self.batch_attributes = batch_attributes
243-
self.remote_tmpdir = remote_tmpdir
243+
self._remote_tmpdir = remote_tmpdir
244244
self.flags: Dict[str, str] = {}
245245
self._registered_ir_function_names: Set[str] = set()
246246
self.driver_cores = driver_cores
@@ -520,3 +520,19 @@ def get_flags(self, *flags: str) -> Mapping[str, str]:
520520
@property
521521
def requires_lowering(self):
522522
return True
523+
524+
@property
525+
def local_tmpdir(self) -> NoReturn:
526+
raise AttributeError('local tmp folders are not supported on the batch backend')
527+
528+
@local_tmpdir.setter
529+
def local_tmpdir(self, tmpdir: str) -> NoReturn:
530+
raise AttributeError('local tmp folders are not supported on the batch backend')
531+
532+
@property
533+
def remote_tmpdir(self) -> str:
534+
return self._remote_tmpdir
535+
536+
@remote_tmpdir.setter
537+
def remote_tmpdir(self, tmpdir: str) -> None:
538+
self._remote_tmpdir = tmpdir

hail/python/hail/context.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ def create(
6666
log: str,
6767
quiet: bool,
6868
append: bool,
69-
tmpdir: str,
70-
local_tmpdir: str,
7169
default_reference: str,
7270
global_seed: Optional[int],
7371
backend: Backend,
@@ -76,25 +74,17 @@ def create(
7674
log=log,
7775
quiet=quiet,
7876
append=append,
79-
tmpdir=tmpdir,
80-
local_tmpdir=local_tmpdir,
8177
global_seed=global_seed,
8278
backend=backend,
8379
)
8480
hc.initialize_references(default_reference)
8581
return hc
8682

87-
@typecheck_method(
88-
log=str, quiet=bool, append=bool, tmpdir=str, local_tmpdir=str, global_seed=nullable(int), backend=Backend
89-
)
90-
def __init__(self, log, quiet, append, tmpdir, local_tmpdir, global_seed, backend):
83+
@typecheck_method(log=str, quiet=bool, append=bool, global_seed=nullable(int), backend=Backend)
84+
def __init__(self, log, quiet, append, global_seed, backend: Backend):
9185
assert not Env._hc
9286

9387
self._log = log
94-
95-
self._tmpdir = tmpdir
96-
self._local_tmpdir = local_tmpdir
97-
9888
self._backend = backend
9989

10090
self._warn_cols_order = True
@@ -136,6 +126,14 @@ def initialize_references(self, default_reference):
136126
else:
137127
self._default_ref = ReferenceGenome.read(default_reference)
138128

129+
@property
130+
def _tmpdir(self) -> str:
131+
return self._backend.remote_tmpdir
132+
133+
@property
134+
def _local_tmpdir(self) -> str:
135+
return self._backend.local_tmpdir
136+
139137
@property
140138
def default_reference(self) -> ReferenceGenome:
141139
assert self._default_ref is not None, '_default_ref should have been initialized in HailContext.create'
@@ -498,7 +496,7 @@ def init_spark(
498496
if not backend.fs.exists(tmpdir):
499497
backend.fs.mkdir(tmpdir)
500498

501-
HailContext.create(log, quiet, append, tmpdir, local_tmpdir, default_reference, global_seed, backend)
499+
HailContext.create(log, quiet, append, default_reference, global_seed, backend)
502500
if not quiet:
503501
connect_logger(backend._utils_package_object, 'localhost', 12888)
504502

@@ -569,7 +567,7 @@ async def init_batch(
569567
tmpdir = os.path.join(backend.remote_tmpdir, 'tmp/hail', secret_alnum_string())
570568
local_tmpdir = _get_local_tmpdir(local_tmpdir)
571569

572-
HailContext.create(log, quiet, append, tmpdir, local_tmpdir, default_reference, global_seed, backend)
570+
HailContext.create(log, quiet, append, default_reference, global_seed, backend)
573571

574572

575573
@typecheck(
@@ -621,7 +619,7 @@ def init_local(
621619
if not backend.fs.exists(tmpdir):
622620
backend.fs.mkdir(tmpdir)
623621

624-
HailContext.create(log, quiet, append, tmpdir, tmpdir, default_reference, global_seed, backend)
622+
HailContext.create(log, quiet, append, default_reference, global_seed, backend)
625623
if not quiet:
626624
connect_logger(backend._utils_package_object, 'localhost', 12888)
627625

0 commit comments

Comments
 (0)