|
23 | 23 |
|
24 | 24 | from pypaimon.py4j.java_gateway import get_gateway |
25 | 25 | from pypaimon.py4j.util import java_utils, constants |
| 26 | +from pypaimon.py4j.util.java_utils import serialize_java_object, deserialize_java_object |
26 | 27 | from pypaimon.api import \ |
27 | 28 | (catalog, table, read_builder, table_scan, split, row_type, |
28 | 29 | table_read, write_builder, table_write, commit_message, |
@@ -145,33 +146,41 @@ def __init__(self, j_splits): |
145 | 146 | self._j_splits = j_splits |
146 | 147 |
|
147 | 148 | def splits(self) -> List['Split']: |
148 | | - return list(map(lambda s: Split(s), self._j_splits)) |
| 149 | + return list(map(lambda s: self._build_single_split(s), self._j_splits)) |
| 150 | + |
| 151 | + def _build_single_split(self, j_split) -> 'Split': |
| 152 | + j_split_bytes = serialize_java_object(j_split) |
| 153 | + row_count = j_split.rowCount() |
| 154 | + files_optional = j_split.convertToRawFiles() |
| 155 | + if not files_optional.isPresent(): |
| 156 | + file_size = 0 |
| 157 | + file_paths = [] |
| 158 | + else: |
| 159 | + files = files_optional.get() |
| 160 | + file_size = sum(file.length() for file in files) |
| 161 | + file_paths = [file.path() for file in files] |
| 162 | + return Split(j_split_bytes, row_count, file_size, file_paths) |
149 | 163 |
|
150 | 164 |
|
151 | 165 | class Split(split.Split): |
152 | 166 |
|
153 | | - def __init__(self, j_split): |
154 | | - self._j_split = j_split |
| 167 | + def __init__(self, j_split_bytes, row_count: int, file_size: int, file_paths: List[str]): |
| 168 | + self._j_split_bytes = j_split_bytes |
| 169 | + self._row_count = row_count |
| 170 | + self._file_size = file_size |
| 171 | + self._file_paths = file_paths |
155 | 172 |
|
156 | 173 | def to_j_split(self): |
157 | | - return self._j_split |
| 174 | + return deserialize_java_object(self._j_split_bytes) |
158 | 175 |
|
159 | 176 | def row_count(self) -> int: |
160 | | - return self._j_split.rowCount() |
| 177 | + return self._row_count |
161 | 178 |
|
162 | 179 | def file_size(self) -> int: |
163 | | - files_optional = self._j_split.convertToRawFiles() |
164 | | - if not files_optional.isPresent(): |
165 | | - return 0 |
166 | | - files = files_optional.get() |
167 | | - return sum(file.length() for file in files) |
| 180 | + return self._file_size |
168 | 181 |
|
169 | 182 | def file_paths(self) -> List[str]: |
170 | | - files_optional = self._j_split.convertToRawFiles() |
171 | | - if not files_optional.isPresent(): |
172 | | - return [] |
173 | | - files = files_optional.get() |
174 | | - return [file.path() for file in files] |
| 183 | + return self._file_paths |
175 | 184 |
|
176 | 185 |
|
177 | 186 | class TableRead(table_read.TableRead): |
@@ -317,11 +326,11 @@ def close(self): |
317 | 326 |
|
318 | 327 | class Predicate(predicate.Predicate): |
319 | 328 |
|
320 | | - def __init__(self, j_predicate): |
321 | | - self._j_predicate = j_predicate |
| 329 | + def __init__(self, j_predicate_bytes): |
| 330 | + self._j_predicate_bytes = j_predicate_bytes |
322 | 331 |
|
323 | 332 | def to_j_predicate(self): |
324 | | - return self._j_predicate |
| 333 | + return deserialize_java_object(self._j_predicate_bytes) |
325 | 334 |
|
326 | 335 |
|
327 | 336 | class PredicateBuilder(predicate.PredicateBuilder): |
@@ -350,7 +359,7 @@ def _build(self, method: str, field: str, literals: Optional[List[Any]] = None): |
350 | 359 | index, |
351 | 360 | literals |
352 | 361 | ) |
353 | | - return Predicate(j_predicate) |
| 362 | + return Predicate(serialize_java_object(j_predicate)) |
354 | 363 |
|
355 | 364 | def equal(self, field: str, literal: Any) -> Predicate: |
356 | 365 | return self._build('equal', field, [literal]) |
@@ -397,8 +406,10 @@ def between(self, field: str, included_lower_bound: Any, included_upper_bound: A |
397 | 406 |
|
398 | 407 | def and_predicates(self, predicates: List[Predicate]) -> Predicate: |
399 | 408 | predicates = list(map(lambda p: p.to_j_predicate(), predicates)) |
400 | | - return Predicate(get_gateway().jvm.PredicationUtil.buildAnd(predicates)) |
| 409 | + j_predicate = get_gateway().jvm.PredicationUtil.buildAnd(predicates) |
| 410 | + return Predicate(serialize_java_object(j_predicate)) |
401 | 411 |
|
402 | 412 | def or_predicates(self, predicates: List[Predicate]) -> Predicate: |
403 | 413 | predicates = list(map(lambda p: p.to_j_predicate(), predicates)) |
404 | | - return Predicate(get_gateway().jvm.PredicationUtil.buildOr(predicates)) |
| 414 | + j_predicate = get_gateway().jvm.PredicationUtil.buildOr(predicates) |
| 415 | + return Predicate(serialize_java_object(j_predicate)) |
0 commit comments