3
3
import zipfile
4
4
from dataclasses import dataclass
5
5
from enum import Enum
6
- from typing import AbstractSet , Any , ClassVar , Dict , List , Mapping , Optional , Tuple , TypeVar , Union
6
+ from typing import AbstractSet , Any , ClassVar , Dict , List , Mapping , Optional , Set , Tuple , TypeVar , Union
7
7
8
8
import orjson
9
9
@@ -68,15 +68,45 @@ def local_jar_information() -> LocalJarInformation:
68
68
raise ValueError (f'Hail requires either { hail_jar } or { hail_all_spark_jar } .' )
69
69
70
70
71
+ class IRFunction :
72
+ def __init__ (
73
+ self ,
74
+ name : str ,
75
+ type_parameters : Union [Tuple [HailType , ...], List [HailType ]],
76
+ value_parameter_names : Union [Tuple [str , ...], List [str ]],
77
+ value_parameter_types : Union [Tuple [HailType , ...], List [HailType ]],
78
+ return_type : HailType ,
79
+ body : Expression ,
80
+ ):
81
+ assert len (value_parameter_names ) == len (value_parameter_types )
82
+ render = CSERenderer ()
83
+ self ._name = name
84
+ self ._type_parameters = type_parameters
85
+ self ._value_parameter_names = value_parameter_names
86
+ self ._value_parameter_types = value_parameter_types
87
+ self ._return_type = return_type
88
+ self ._rendered_body = render (finalize_randomness (body ._ir ))
89
+
90
+ def to_dataclass (self ):
91
+ return SerializedIRFunction (
92
+ name = self ._name ,
93
+ type_parameters = [tp ._parsable_string () for tp in self ._type_parameters ],
94
+ value_parameter_names = list (self ._value_parameter_names ),
95
+ value_parameter_types = [vpt ._parsable_string () for vpt in self ._value_parameter_types ],
96
+ return_type = self ._return_type ._parsable_string (),
97
+ rendered_body = self ._rendered_body ,
98
+ )
99
+
100
+
71
101
class ActionTag (Enum ):
72
- LOAD_REFERENCES_FROM_DATASET = 1
73
- VALUE_TYPE = 2
74
- TABLE_TYPE = 3
75
- MATRIX_TABLE_TYPE = 4
76
- BLOCK_MATRIX_TYPE = 5
77
- EXECUTE = 6
78
- PARSE_VCF_METADATA = 7
79
- IMPORT_FAM = 8
102
+ VALUE_TYPE = 1
103
+ TABLE_TYPE = 2
104
+ MATRIX_TABLE_TYPE = 3
105
+ BLOCK_MATRIX_TYPE = 4
106
+ EXECUTE = 5
107
+ PARSE_VCF_METADATA = 6
108
+ IMPORT_FAM = 7
109
+ LOAD_REFERENCES_FROM_DATASET = 8
80
110
FROM_FASTA_FILE = 9
81
111
82
112
@@ -90,11 +120,21 @@ class IRTypePayload(ActionPayload):
90
120
ir : str
91
121
92
122
123
+ @dataclass
124
+ class SerializedIRFunction :
125
+ name : str
126
+ type_parameters : List [str ]
127
+ value_parameter_names : List [str ]
128
+ value_parameter_types : List [str ]
129
+ return_type : str
130
+ rendered_body : str
131
+
132
+
93
133
@dataclass
94
134
class ExecutePayload (ActionPayload ):
95
135
ir : str
136
+ fns : List [SerializedIRFunction ]
96
137
stream_codec : str
97
- timed : bool
98
138
99
139
100
140
@dataclass
@@ -164,17 +204,24 @@ def _valid_flags(self) -> AbstractSet[str]:
164
204
def __init__ (self ):
165
205
self ._persisted_locations = dict ()
166
206
self ._references = {}
207
+ self .functions : List [IRFunction ] = []
208
+ self ._registered_ir_function_names : Set [str ] = set ()
167
209
168
210
@abc .abstractmethod
169
211
def validate_file (self , uri : str ):
170
212
raise NotImplementedError
171
213
172
214
@abc .abstractmethod
173
215
def stop (self ):
174
- pass
216
+ self .functions = []
217
+ self ._registered_ir_function_names = set ()
175
218
176
219
def execute (self , ir : BaseIR , timed : bool = False ) -> Any :
177
- payload = ExecutePayload (self ._render_ir (ir ), '{"name":"StreamBufferSpec"}' , timed )
220
+ payload = ExecutePayload (
221
+ self ._render_ir (ir ),
222
+ fns = [fn .to_dataclass () for fn in self .functions ],
223
+ stream_codec = '{"name":"StreamBufferSpec"}' ,
224
+ )
178
225
try :
179
226
result , timings = self ._rpc (ActionTag .EXECUTE , payload )
180
227
except FatalError as e :
@@ -300,7 +347,6 @@ def unpersist(self, dataset: Dataset) -> Dataset:
300
347
tempfile .__exit__ (None , None , None )
301
348
return unpersisted
302
349
303
- @abc .abstractmethod
304
350
def register_ir_function (
305
351
self ,
306
352
name : str ,
@@ -310,11 +356,13 @@ def register_ir_function(
310
356
return_type : HailType ,
311
357
body : Expression ,
312
358
):
313
- pass
359
+ self ._registered_ir_function_names .add (name )
360
+ self .functions .append (
361
+ IRFunction (name , type_parameters , value_parameter_names , value_parameter_types , return_type , body )
362
+ )
314
363
315
- @abc .abstractmethod
316
364
def _is_registered_ir_function_name (self , name : str ) -> bool :
317
- pass
365
+ return name in self . _registered_ir_function_names
318
366
319
367
@abc .abstractmethod
320
368
def persist_expression (self , expr : Expression ) -> Expression :
0 commit comments