Skip to content

Commit af06c70

Browse files
authored
#46 Improve Readability of TableRead Impletation (#47)
1 parent 08d0bb3 commit af06c70

File tree

1 file changed

+9
-18
lines changed

1 file changed

+9
-18
lines changed

pypaimon/py4j/java_implementation.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -186,18 +186,15 @@ def file_paths(self) -> List[str]:
186186
class TableRead(table_read.TableRead):
187187

188188
def __init__(self, j_table_read, j_read_type, catalog_options):
189-
self._j_table_read = j_table_read
190-
self._j_read_type = j_read_type
191-
self._catalog_options = catalog_options
192-
self._j_bytes_reader = None
193189
self._arrow_schema = java_utils.to_arrow_schema(j_read_type)
190+
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader(
191+
j_table_read, j_read_type, TableRead._get_max_workers(catalog_options))
194192

195193
def to_arrow(self, splits):
196194
record_batch_reader = self.to_arrow_batch_reader(splits)
197195
return pa.Table.from_batches(record_batch_reader, schema=self._arrow_schema)
198196

199197
def to_arrow_batch_reader(self, splits):
200-
self._init()
201198
j_splits = list(map(lambda s: s.to_j_split(), splits))
202199
self._j_bytes_reader.setSplits(j_splits)
203200
batch_iterator = self._batch_generator()
@@ -222,19 +219,13 @@ def to_ray(self, splits: List[Split]) -> "ray.data.dataset.Dataset":
222219

223220
return ray.data.from_arrow(self.to_arrow(splits))
224221

225-
def _init(self):
226-
if self._j_bytes_reader is None:
227-
# get thread num
228-
max_workers = self._catalog_options.get(constants.MAX_WORKERS)
229-
if max_workers is None:
230-
# default is sequential
231-
max_workers = 1
232-
else:
233-
max_workers = int(max_workers)
234-
if max_workers <= 0:
235-
raise ValueError("max_workers must be greater than 0")
236-
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader(
237-
self._j_table_read, self._j_read_type, max_workers)
222+
@staticmethod
223+
def _get_max_workers(catalog_options):
224+
# default is sequential
225+
max_workers = int(catalog_options.get(constants.MAX_WORKERS, 1))
226+
if max_workers <= 0:
227+
raise ValueError("max_workers must be greater than 0")
228+
return max_workers
238229

239230
def _batch_generator(self) -> Iterator[pa.RecordBatch]:
240231
while True:

0 commit comments

Comments
 (0)