Skip to content

Commit 73eaa93

Browse files
committed
[query] unify backend rpc
1 parent b60c6de commit 73eaa93

35 files changed

+821
-856
lines changed

hail/python/hail/backend/backend.py

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import zipfile
44
from dataclasses import dataclass
55
from enum import Enum
6-
from typing import AbstractSet, Any, ClassVar, Dict, List, Mapping, Optional, Tuple, TypeVar, Union
6+
from typing import AbstractSet, Any, ClassVar, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union
77

88
import orjson
99

@@ -68,15 +68,45 @@ def local_jar_information() -> LocalJarInformation:
6868
raise ValueError(f'Hail requires either {hail_jar} or {hail_all_spark_jar}.')
6969

7070

71+
class IRFunction:
72+
def __init__(
73+
self,
74+
name: str,
75+
type_parameters: Union[Tuple[HailType, ...], List[HailType]],
76+
value_parameter_names: Union[Tuple[str, ...], List[str]],
77+
value_parameter_types: Union[Tuple[HailType, ...], List[HailType]],
78+
return_type: HailType,
79+
body: Expression,
80+
):
81+
assert len(value_parameter_names) == len(value_parameter_types)
82+
render = CSERenderer()
83+
self._name = name
84+
self._type_parameters = type_parameters
85+
self._value_parameter_names = value_parameter_names
86+
self._value_parameter_types = value_parameter_types
87+
self._return_type = return_type
88+
self._rendered_body = render(finalize_randomness(body._ir))
89+
90+
def to_dataclass(self):
91+
return SerializedIRFunction(
92+
name=self._name,
93+
type_parameters=[tp._parsable_string() for tp in self._type_parameters],
94+
value_parameter_names=list(self._value_parameter_names),
95+
value_parameter_types=[vpt._parsable_string() for vpt in self._value_parameter_types],
96+
return_type=self._return_type._parsable_string(),
97+
rendered_body=self._rendered_body,
98+
)
99+
100+
71101
class ActionTag(Enum):
72-
LOAD_REFERENCES_FROM_DATASET = 1
73-
VALUE_TYPE = 2
74-
TABLE_TYPE = 3
75-
MATRIX_TABLE_TYPE = 4
76-
BLOCK_MATRIX_TYPE = 5
77-
EXECUTE = 6
78-
PARSE_VCF_METADATA = 7
79-
IMPORT_FAM = 8
102+
VALUE_TYPE = 1
103+
TABLE_TYPE = 2
104+
MATRIX_TABLE_TYPE = 3
105+
BLOCK_MATRIX_TYPE = 4
106+
EXECUTE = 5
107+
PARSE_VCF_METADATA = 6
108+
IMPORT_FAM = 7
109+
LOAD_REFERENCES_FROM_DATASET = 8
80110
FROM_FASTA_FILE = 9
81111

82112

@@ -90,11 +120,21 @@ class IRTypePayload(ActionPayload):
90120
ir: str
91121

92122

123+
@dataclass
124+
class SerializedIRFunction:
125+
name: str
126+
type_parameters: List[str]
127+
value_parameter_names: List[str]
128+
value_parameter_types: List[str]
129+
return_type: str
130+
rendered_body: str
131+
132+
93133
@dataclass
94134
class ExecutePayload(ActionPayload):
95135
ir: str
136+
fns: List[SerializedIRFunction]
96137
stream_codec: str
97-
timed: bool
98138

99139

100140
@dataclass
@@ -164,17 +204,24 @@ def _valid_flags(self) -> AbstractSet[str]:
164204
def __init__(self):
165205
self._persisted_locations = dict()
166206
self._references = {}
207+
self.functions: List[IRFunction] = []
208+
self._registered_ir_function_names: Set[str] = set()
167209

168210
@abc.abstractmethod
169211
def validate_file(self, uri: str):
170212
raise NotImplementedError
171213

172214
@abc.abstractmethod
173215
def stop(self):
174-
pass
216+
self.functions = []
217+
self._registered_ir_function_names = set()
175218

176219
def execute(self, ir: BaseIR, timed: bool = False) -> Any:
177-
payload = ExecutePayload(self._render_ir(ir), '{"name":"StreamBufferSpec"}', timed)
220+
payload = ExecutePayload(
221+
self._render_ir(ir),
222+
fns=[fn.to_dataclass() for fn in self.functions],
223+
stream_codec='{"name":"StreamBufferSpec"}',
224+
)
178225
try:
179226
result, timings = self._rpc(ActionTag.EXECUTE, payload)
180227
except FatalError as e:
@@ -300,7 +347,6 @@ def unpersist(self, dataset: Dataset) -> Dataset:
300347
tempfile.__exit__(None, None, None)
301348
return unpersisted
302349

303-
@abc.abstractmethod
304350
def register_ir_function(
305351
self,
306352
name: str,
@@ -310,11 +356,13 @@ def register_ir_function(
310356
return_type: HailType,
311357
body: Expression,
312358
):
313-
pass
359+
self._registered_ir_function_names.add(name)
360+
self.functions.append(
361+
IRFunction(name, type_parameters, value_parameter_names, value_parameter_types, return_type, body)
362+
)
314363

315-
@abc.abstractmethod
316364
def _is_registered_ir_function_name(self, name: str) -> bool:
317-
pass
365+
return name in self._registered_ir_function_names
318366

319367
@abc.abstractmethod
320368
def persist_expression(self, expr: Expression) -> Expression:

hail/python/hail/backend/local_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def register_ir_function(
119119
)
120120

121121
def stop(self):
122-
super().stop()
122+
super(Py4JBackend, self).stop()
123123
self._exit_stack.close()
124124
uninstall_exception_handler()
125125

hail/python/hail/backend/py4j_backend.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import socketserver
55
import sys
66
from threading import Thread
7-
from typing import Mapping, Optional, Set, Tuple
7+
from typing import Mapping, Optional, Tuple
88

99
import orjson
1010
import py4j
@@ -192,8 +192,6 @@ def decode_bytearray(encoded):
192192
self._backend_server.start()
193193
self._requests_session = requests.Session()
194194

195-
self._registered_ir_function_names: Set[str] = set()
196-
197195
# This has to go after creating the SparkSession. Unclear why.
198196
# Maybe it does its own patch?
199197
install_exception_handler()
@@ -239,9 +237,6 @@ def persist_expression(self, expr):
239237
t = expr.dtype
240238
return construct_expr(JavaIR(t, self._jbackend.pyExecuteLiteral(self._render_ir(expr._ir))), t)
241239

242-
def _is_registered_ir_function_name(self, name: str) -> bool:
243-
return name in self._registered_ir_function_names
244-
245240
def set_flags(self, **flags: Mapping[str, str]):
246241
available = self._jbackend.pyAvailableFlags()
247242
invalid = []
@@ -276,12 +271,6 @@ def add_liftover(self, name, chain_file, dest_reference_genome):
276271
def remove_liftover(self, name, dest_reference_genome):
277272
self._jbackend.pyRemoveLiftover(name, dest_reference_genome)
278273

279-
def _parse_value_ir(self, code, ref_map={}):
280-
return self._jbackend.parse_value_ir(
281-
code,
282-
{k: t._parsable_string() for k, t in ref_map.items()},
283-
)
284-
285274
def _register_ir_function(self, name, type_parameters, argument_names, argument_types, return_type, code):
286275
self._registered_ir_function_names.add(name)
287276
self._jbackend.pyRegisterIR(
@@ -293,12 +282,6 @@ def _register_ir_function(self, name, type_parameters, argument_names, argument_
293282
code,
294283
)
295284

296-
def _parse_table_ir(self, code):
297-
return self._jbackend.parse_table_ir(code)
298-
299-
def _parse_matrix_ir(self, code):
300-
return self._jbackend.parse_matrix_ir(code)
301-
302285
def _parse_blockmatrix_ir(self, code):
303286
return self._jbackend.parse_blockmatrix_ir(code)
304287

@@ -310,5 +293,5 @@ def stop(self):
310293
self._jbackend.close()
311294
self._jhc.stop()
312295
self._jhc = None
313-
self._registered_ir_function_names = set()
314296
uninstall_exception_handler()
297+
super().stop()

0 commit comments

Comments
 (0)