Skip to content

Commit 0eb79a8

Browse files
authored
#44 Make Split and Predicate Serializable (#45)
1 parent 75d00d7 commit 0eb79a8

File tree

2 files changed

+49
-21
lines changed

2 files changed

+49
-21
lines changed

pypaimon/py4j/java_implementation.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from pypaimon.py4j.java_gateway import get_gateway
2525
from pypaimon.py4j.util import java_utils, constants
26+
from pypaimon.py4j.util.java_utils import serialize_java_object, deserialize_java_object
2627
from pypaimon.api import \
2728
(catalog, table, read_builder, table_scan, split, row_type,
2829
table_read, write_builder, table_write, commit_message,
@@ -145,33 +146,41 @@ def __init__(self, j_splits):
145146
self._j_splits = j_splits
146147

147148
def splits(self) -> List['Split']:
148-
return list(map(lambda s: Split(s), self._j_splits))
149+
return list(map(lambda s: self._build_single_split(s), self._j_splits))
150+
151+
def _build_single_split(self, j_split) -> 'Split':
152+
j_split_bytes = serialize_java_object(j_split)
153+
row_count = j_split.rowCount()
154+
files_optional = j_split.convertToRawFiles()
155+
if not files_optional.isPresent():
156+
file_size = 0
157+
file_paths = []
158+
else:
159+
files = files_optional.get()
160+
file_size = sum(file.length() for file in files)
161+
file_paths = [file.path() for file in files]
162+
return Split(j_split_bytes, row_count, file_size, file_paths)
149163

150164

151165
class Split(split.Split):
152166

153-
def __init__(self, j_split):
154-
self._j_split = j_split
167+
def __init__(self, j_split_bytes, row_count: int, file_size: int, file_paths: List[str]):
168+
self._j_split_bytes = j_split_bytes
169+
self._row_count = row_count
170+
self._file_size = file_size
171+
self._file_paths = file_paths
155172

156173
def to_j_split(self):
157-
return self._j_split
174+
return deserialize_java_object(self._j_split_bytes)
158175

159176
def row_count(self) -> int:
160-
return self._j_split.rowCount()
177+
return self._row_count
161178

162179
def file_size(self) -> int:
163-
files_optional = self._j_split.convertToRawFiles()
164-
if not files_optional.isPresent():
165-
return 0
166-
files = files_optional.get()
167-
return sum(file.length() for file in files)
180+
return self._file_size
168181

169182
def file_paths(self) -> List[str]:
170-
files_optional = self._j_split.convertToRawFiles()
171-
if not files_optional.isPresent():
172-
return []
173-
files = files_optional.get()
174-
return [file.path() for file in files]
183+
return self._file_paths
175184

176185

177186
class TableRead(table_read.TableRead):
@@ -317,11 +326,11 @@ def close(self):
317326

318327
class Predicate(predicate.Predicate):
319328

320-
def __init__(self, j_predicate):
321-
self._j_predicate = j_predicate
329+
def __init__(self, j_predicate_bytes):
330+
self._j_predicate_bytes = j_predicate_bytes
322331

323332
def to_j_predicate(self):
324-
return self._j_predicate
333+
return deserialize_java_object(self._j_predicate_bytes)
325334

326335

327336
class PredicateBuilder(predicate.PredicateBuilder):
@@ -350,7 +359,7 @@ def _build(self, method: str, field: str, literals: Optional[List[Any]] = None):
350359
index,
351360
literals
352361
)
353-
return Predicate(j_predicate)
362+
return Predicate(serialize_java_object(j_predicate))
354363

355364
def equal(self, field: str, literal: Any) -> Predicate:
356365
return self._build('equal', field, [literal])
@@ -397,8 +406,10 @@ def between(self, field: str, included_lower_bound: Any, included_upper_bound: A
397406

398407
def and_predicates(self, predicates: List[Predicate]) -> Predicate:
399408
predicates = list(map(lambda p: p.to_j_predicate(), predicates))
400-
return Predicate(get_gateway().jvm.PredicationUtil.buildAnd(predicates))
409+
j_predicate = get_gateway().jvm.PredicationUtil.buildAnd(predicates)
410+
return Predicate(serialize_java_object(j_predicate))
401411

402412
def or_predicates(self, predicates: List[Predicate]) -> Predicate:
403413
predicates = list(map(lambda p: p.to_j_predicate(), predicates))
404-
return Predicate(get_gateway().jvm.PredicationUtil.buildOr(predicates))
414+
j_predicate = get_gateway().jvm.PredicationUtil.buildOr(predicates)
415+
return Predicate(serialize_java_object(j_predicate))

pypaimon/py4j/util/java_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,20 @@ def to_arrow_schema(j_row_type):
100100
arrow_schema = schema_reader.schema
101101
schema_reader.close()
102102
return arrow_schema
103+
104+
105+
def serialize_java_object(java_obj) -> bytes:
106+
gateway = get_gateway()
107+
util = gateway.jvm.org.apache.paimon.utils.InstantiationUtil
108+
try:
109+
java_bytes = util.serializeObject(java_obj)
110+
return bytes(java_bytes)
111+
except Exception as e:
112+
raise RuntimeError(f"Java serialization failed: {e}")
113+
114+
115+
def deserialize_java_object(bytes_data):
116+
gateway = get_gateway()
117+
cl = get_gateway().jvm.Thread.currentThread().getContextClassLoader()
118+
util = gateway.jvm.org.apache.paimon.utils.InstantiationUtil
119+
return util.deserializeObject(bytes_data, cl)

0 commit comments

Comments
 (0)