Skip to content

Commit 6d1e128

Browse files
committed
Make read_batch take a global arrow_string_format_per_column
to be used as an override if the read request doesn't specify it. It is also used to set up per column formatting for `read_batch_and_join`
1 parent 81c1f25 commit 6d1e128

File tree

4 files changed

+73
-80
lines changed

4 files changed

+73
-80
lines changed

python/arcticdb/version_store/_store.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,7 @@ def batch_read(
12181218
query_builder: Optional[Union[QueryBuilder, List[QueryBuilder]]] = None,
12191219
columns: Optional[List[List[str]]] = None,
12201220
arrow_string_format_default: Optional[Union[ArrowOutputStringFormat, "pa.DataType"]] = None,
1221+
arrow_string_format_per_column: Optional[Dict[str, Union[ArrowOutputStringFormat, "pa.DataType"]]] = None,
12211222
per_symbol_arrow_string_format_default: Optional[
12221223
List[Optional[Union[ArrowOutputStringFormat, "pa.DataType"]]]
12231224
] = None,
@@ -1257,6 +1258,11 @@ def batch_read(
12571258
If using `output_format=EXPERIMENTAL_ARROW` it sets the output format of string columns for arrow.
12581259
See documentation of `ArrowOutputStringFormat` for more information on the different options.
12591260
It serves as the default for the entire batch.
1261+
arrow_string_format_per_column: Optional[Dict[str, Union[ArrowOutputStringFormat, "pa.DataType"]]], default=None,
1262+
If using `output_format=EXPERIMENTAL_ARROW` it sets the output format of string columns for arrow.
1263+
See documentation of `ArrowOutputStringFormat` for more information on the different options.
1264+
It defines the setting per column. It is applied to all symbols which don't have a
1265+
`per_symbol_arrow_string_format_per_column` set.
12601266
per_symbol_arrow_string_format_default: Optional[List[Optional[Union[ArrowOutputStringFormat, "pa.DataType"]]]], default=None,
12611267
If using `output_format=EXPERIMENTAL_ARROW` it sets the output format of string columns for arrow.
12621268
See documentation of `ArrowOutputStringFormat` for more information on the different options.
@@ -1292,6 +1298,7 @@ def batch_read(
12921298
query_builder=query_builder,
12931299
throw_on_error=throw_on_error,
12941300
arrow_string_format_default=arrow_string_format_default,
1301+
arrow_string_format_per_column=arrow_string_format_per_column,
12951302
per_symbol_arrow_string_format_default=per_symbol_arrow_string_format_default,
12961303
per_symbol_arrow_string_format_per_column=per_symbol_arrow_string_format_per_column,
12971304
**kwargs,
@@ -1312,6 +1319,7 @@ def _batch_read_to_versioned_items(
13121319
query_builder,
13131320
throw_on_error,
13141321
arrow_string_format_default,
1322+
arrow_string_format_per_column,
13151323
per_symbol_arrow_string_format_default,
13161324
per_symbol_arrow_string_format_per_column,
13171325
**kwargs,
@@ -1327,6 +1335,7 @@ def _batch_read_to_versioned_items(
13271335
len(symbols),
13281336
throw_on_error,
13291337
arrow_string_format_default,
1338+
arrow_string_format_per_column,
13301339
per_symbol_arrow_string_format_default,
13311340
per_symbol_arrow_string_format_per_column,
13321341
**kwargs,
@@ -2136,6 +2145,7 @@ def _get_batch_read_options(
21362145
num_symbols,
21372146
batch_throw_on_error,
21382147
global_arrow_string_format_default=None,
2148+
global_arrow_string_format_per_column=None,
21392149
per_symbol_arrow_string_format_default=None,
21402150
per_symbol_arrow_string_format_per_column=None,
21412151
**kwargs,
@@ -2159,15 +2169,17 @@ def _get_batch_read_options(
21592169
)
21602170
for idx in range(num_symbols):
21612171
arrow_string_format_default = global_arrow_string_format_default
2162-
arrow_string_format_per_column = None
2172+
arrow_string_format_per_column = global_arrow_string_format_per_column
21632173

21642174
if per_symbol_arrow_string_format_default is not None:
21652175
arrow_string_format_default = (
21662176
per_symbol_arrow_string_format_default[idx] or global_arrow_string_format_default
21672177
)
21682178

21692179
if per_symbol_arrow_string_format_per_column is not None:
2170-
arrow_string_format_per_column = per_symbol_arrow_string_format_per_column[idx]
2180+
arrow_string_format_per_column = (
2181+
per_symbol_arrow_string_format_per_column[idx] or global_arrow_string_format_per_column
2182+
)
21712183

21722184
read_options_per_symbol.append(
21732185
self._get_read_options(

python/arcticdb/version_store/library.py

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,8 @@ class LazyDataFrameCollection(QueryBuilder):
538538
def __init__(
539539
self,
540540
lazy_dataframes: List[LazyDataFrame],
541+
arrow_string_format_default: Optional[Union[ArrowOutputStringFormat, "pa.DataType"]] = None,
542+
arrow_string_format_per_column: Optional[Dict[str, Union[ArrowOutputStringFormat, "pa.DataType"]]] = None,
541543
):
542544
"""
543545
Gather a list of `LazyDataFrame`s into a single object that can be collected together.
@@ -563,6 +565,8 @@ def __init__(
563565
)
564566
super().__init__()
565567
self._lazy_dataframes = lazy_dataframes
568+
self._arrow_string_format_default = arrow_string_format_default
569+
self._arrow_string_format_per_column = arrow_string_format_per_column
566570
if len(self._lazy_dataframes):
567571
self._lib = self._lazy_dataframes[0].lib
568572
self._output_format = self._lazy_dataframes[0].read_request.output_format
@@ -588,7 +592,12 @@ def collect(self) -> List[Union[VersionedItem, DataError]]:
588592
"""
589593
if not len(self._lazy_dataframes):
590594
return []
591-
return self._lib.read_batch(self._read_requests(), output_format=self._output_format)
595+
return self._lib.read_batch(
596+
self._read_requests(),
597+
output_format=self._output_format,
598+
arrow_string_format_default=self._arrow_string_format_default,
599+
arrow_string_format_per_column=self._arrow_string_format_per_column,
600+
)
592601

593602
def _read_requests(self) -> List[ReadRequest]:
594603
# Combines queries for individual LazyDataFrames with the global query associated with this
@@ -647,33 +656,6 @@ def __init__(
647656
super().__init__()
648657
self._lazy_dataframes = lazy_dataframes
649658
self.then(join)
650-
self.arrow_string_format_default = None
651-
self.arrow_string_format_per_column = {}
652-
for lf in self._lazy_dataframes._lazy_dataframes:
653-
self.arrow_string_format_default = (
654-
self.arrow_string_format_default or lf.read_request.arrow_string_format_default
655-
)
656-
check(
657-
lf.read_request.arrow_string_format_default is None
658-
or self.arrow_string_format_default == lf.read_request.arrow_string_format_default,
659-
"Lazy frames from collection cannot be combined for join because they have incompatible arrow_string_format_default values {} and {}",
660-
self.arrow_string_format_default,
661-
lf.read_request.arrow_string_format_default,
662-
)
663-
if lf.read_request.arrow_string_format_per_column is not None:
664-
common_cols = (
665-
self.arrow_string_format_per_column.keys() & lf.read_request.arrow_string_format_per_column.keys()
666-
)
667-
for common_col in common_cols:
668-
check(
669-
self.arrow_string_format_per_column[common_col]
670-
== lf.read_request.arrow_string_format_per_column[common_col],
671-
"Lazy frames from collection cannot be combined for join because they have incompatible arrow_string_format_per_column values {} and {} for column {}",
672-
self.arrow_string_format_per_column[common_col],
673-
lf.read_request.arrow_string_format_per_column[common_col],
674-
common_col,
675-
)
676-
self.arrow_string_format_per_column.update(lf.read_request.arrow_string_format_per_column)
677659

678660
def collect(self) -> VersionedItemWithJoin:
679661
"""
@@ -693,8 +675,8 @@ def collect(self) -> VersionedItemWithJoin:
693675
self._lazy_dataframes._read_requests(),
694676
self,
695677
output_format=self._lazy_dataframes._output_format,
696-
arrow_string_format_default=self.arrow_string_format_default,
697-
arrow_string_format_per_column=self.arrow_string_format_per_column,
678+
arrow_string_format_default=self._lazy_dataframes._arrow_string_format_default,
679+
arrow_string_format_per_column=self._lazy_dataframes._arrow_string_format_per_column,
698680
)
699681

700682
def __str__(self) -> str:
@@ -2079,6 +2061,7 @@ def read_batch(
20792061
lazy: bool = False,
20802062
output_format: Optional[Union[OutputFormat, str]] = None,
20812063
arrow_string_format_default: Optional[Union[ArrowOutputStringFormat, "pa.DataType"]] = None,
2064+
arrow_string_format_per_column: Optional[Dict[str, Union[ArrowOutputStringFormat, "pa.DataType"]]] = None,
20822065
) -> Union[List[Union[VersionedItem, DataError]], LazyDataFrameCollection]:
20832066
"""
20842067
Reads multiple symbols.
@@ -2107,6 +2090,10 @@ def read_batch(
21072090
It serves as the default for the entire batch. The string format settings inside the `ReadRequest`s will
21082091
override this batch level setting.
21092092
2093+
arrow_string_format_per_column: Optional[Dict[str, Union[ArrowOutputStringFormat, "pa.DataType"]]], default=None,
2094+
Provides per column name overrides for `arrow_string_format_default`. It is only applied to symbols which
2095+
don't have a `arrow_string_format_per_column` set in their `ReadRequest`.
2096+
21102097
Returns
21112098
-------
21122099
Union[List[Union[VersionedItem, DataError]], LazyDataFrameCollection]
@@ -2221,14 +2208,18 @@ def handle_symbol(s_):
22212208
columns=columns[idx],
22222209
query_builder=q,
22232210
output_format=output_format,
2224-
arrow_string_format_default=(
2225-
per_symbol_arrow_string_format_default[idx] or arrow_string_format_default
2226-
),
2227-
arrow_string_format_per_column=per_symbol_arrow_string_format_per_column[idx],
2211+
arrow_string_format_default=per_symbol_arrow_string_format_default[idx]
2212+
or arrow_string_format_default,
2213+
arrow_string_format_per_column=per_symbol_arrow_string_format_per_column[idx]
2214+
or arrow_string_format_per_column,
22282215
),
22292216
)
22302217
)
2231-
return LazyDataFrameCollection(lazy_dataframes)
2218+
return LazyDataFrameCollection(
2219+
lazy_dataframes,
2220+
arrow_string_format_default=arrow_string_format_default,
2221+
arrow_string_format_per_column=arrow_string_format_per_column,
2222+
)
22322223
else:
22332224
return self._nvs._batch_read_to_versioned_items(
22342225
symbol_strings,
@@ -2242,6 +2233,7 @@ def handle_symbol(s_):
22422233
iterate_snapshots_if_tombstoned=False,
22432234
output_format=output_format,
22442235
arrow_string_format_default=arrow_string_format_default,
2236+
arrow_string_format_per_column=arrow_string_format_per_column,
22452237
per_symbol_arrow_string_format_default=per_symbol_arrow_string_format_default,
22462238
per_symbol_arrow_string_format_per_column=per_symbol_arrow_string_format_per_column,
22472239
)

python/tests/unit/arcticdb/test_arrow_api.py

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -342,10 +342,14 @@ def test_read_batch_strings(lmdb_storage, lib_name, lazy, batch_default):
342342
ReadRequest(
343343
symbol=sym_2,
344344
arrow_string_format_default=ArrowOutputStringFormat.LARGE_STRING,
345-
arrow_string_format_per_column={"col_2": ArrowOutputStringFormat.CATEGORICAL},
346345
),
347346
]
348-
batch_result = lib.read_batch(read_requests, arrow_string_format_default=batch_default, lazy=lazy)
347+
batch_result = lib.read_batch(
348+
read_requests,
349+
arrow_string_format_default=batch_default,
350+
arrow_string_format_per_column={"col_2": ArrowOutputStringFormat.CATEGORICAL},
351+
lazy=lazy,
352+
)
349353
if lazy:
350354
batch_result = batch_result.collect()
351355
table_1 = batch_result[0].data
@@ -354,49 +358,32 @@ def test_read_batch_strings(lmdb_storage, lib_name, lazy, batch_default):
354358
assert_frame_equal_with_arrow(table_1, df_1)
355359
table_2 = batch_result[1].data
356360
assert table_2.schema.field(0).type == pa.large_string() # per symbol default
357-
assert table_2.schema.field(1).type == pa.dictionary(pa.int32(), pa.large_string()) # per_column override
361+
assert table_2.schema.field(1).type == pa.dictionary(pa.int32(), pa.large_string()) # global per_column
358362
assert_frame_equal_with_arrow(table_2, df_2)
359363

360364

361-
@pytest.mark.parametrize("default_1", [None, ArrowOutputStringFormat.SMALL_STRING])
362-
@pytest.mark.parametrize(
363-
"default_2", [None, ArrowOutputStringFormat.LARGE_STRING, ArrowOutputStringFormat.SMALL_STRING]
364-
)
365-
@pytest.mark.parametrize("per_column_1", [ArrowOutputStringFormat.CATEGORICAL, ArrowOutputStringFormat.LARGE_STRING])
366-
def test_read_batch_and_join_strings(lmdb_storage, lib_name, default_1, default_2, per_column_1):
365+
@pytest.mark.parametrize("default", [None, ArrowOutputStringFormat.SMALL_STRING])
366+
@pytest.mark.parametrize("per_column", [None, ArrowOutputStringFormat.CATEGORICAL])
367+
def test_read_batch_and_join_strings(lmdb_storage, lib_name, default, per_column):
367368
ac = lmdb_storage.create_arctic(output_format=OutputFormat.EXPERIMENTAL_ARROW)
368369
lib = ac.create_library(lib_name, library_options=LibraryOptions(dynamic_schema=True))
369370
sym_1, sym_2 = "sym_1", "sym_2"
370371
df_1 = pd.DataFrame({"col_1": ["a", "a", "bb"], "col_2": ["x", "y", "z"]})
371372
df_2 = pd.DataFrame({"col_2": ["a", "aa", "aaa"], "col_3": ["a", "a", "a"]})
372373
lib.write_batch([WritePayload(sym_1, df_1), WritePayload(sym_2, df_2)])
373374

374-
read_requests = [
375-
ReadRequest(
376-
symbol=sym_1, arrow_string_format_default=default_1, arrow_string_format_per_column={"col_2": per_column_1}
377-
),
378-
ReadRequest(
379-
symbol=sym_2,
380-
arrow_string_format_default=default_2,
381-
arrow_string_format_per_column={
382-
"col_2": ArrowOutputStringFormat.CATEGORICAL,
383-
"col_3": ArrowOutputStringFormat.LARGE_STRING,
384-
},
385-
),
386-
]
387-
lazy_dfs = lib.read_batch(read_requests, lazy=True)
388-
389-
has_mismatch_default = default_1 is not None and default_2 is not None and default_1 != default_2
390-
has_mismatch_per_column = per_column_1 != ArrowOutputStringFormat.CATEGORICAL
391-
should_raise = has_mismatch_default or has_mismatch_per_column
392-
if should_raise:
393-
with pytest.raises(ArcticNativeException):
394-
lazy_with_join = concat(lazy_dfs)
395-
else:
396-
lazy_with_join = concat(lazy_dfs)
397-
result = lazy_with_join.collect().data
398-
assert result.schema.field(0).type == default_1 or default_2 or pa.large_string()
399-
assert result.schema.field(1).type == pa.dictionary(pa.int32(), pa.large_string())
400-
assert result.schema.field(2).type == pa.large_string()
401-
expected_df = pd.concat([df_1, df_2]).reset_index(drop=True)
402-
assert_frame_equal_with_arrow(expected_df, result)
375+
arrow_string_format_per_column = {"col_1": per_column, "col_3": per_column} if per_column is not None else None
376+
lazy_dfs = lib.read_batch(
377+
[sym_1, sym_2],
378+
arrow_string_format_default=default,
379+
arrow_string_format_per_column=arrow_string_format_per_column,
380+
lazy=True,
381+
)
382+
383+
lazy_with_join = concat(lazy_dfs)
384+
result = lazy_with_join.collect().data
385+
assert result.schema.field(0).type == per_column or default or pa.large_string()
386+
assert result.schema.field(1).type == default or pa.large_string()
387+
assert result.schema.field(2).type == per_column or default or pa.large_string()
388+
expected_df = pd.concat([df_1, df_2]).reset_index(drop=True)
389+
assert_frame_equal_with_arrow(expected_df, result)

python/tests/unit/arcticdb/version_store/test_arrow_read.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,22 +1135,24 @@ def test_arrow_read_batch_with_strings(lmdb_version_store_arrow):
11351135
lib.batch_write([sym_1, sym_2], [df_1, df_2])
11361136

11371137
arrow_string_format_default = ArrowOutputStringFormat.SMALL_STRING
1138-
per_symbol_arrow_string_format_default = [None, ArrowOutputStringFormat.LARGE_STRING]
1138+
arrow_string_format_per_column = {"col_1": ArrowOutputStringFormat.CATEGORICAL}
1139+
per_symbol_arrow_string_format_default = [ArrowOutputStringFormat.LARGE_STRING, None]
11391140
per_symbol_arrow_string_format_per_column = [
1140-
{"col_1": ArrowOutputStringFormat.CATEGORICAL},
1141+
None, # First item will use the global arrow_string_format_per_column
11411142
{"col_2": ArrowOutputStringFormat.CATEGORICAL},
11421143
]
11431144
batch_result = lib.batch_read(
11441145
[sym_1, sym_2],
11451146
arrow_string_format_default=arrow_string_format_default,
1147+
arrow_string_format_per_column=arrow_string_format_per_column,
11461148
per_symbol_arrow_string_format_default=per_symbol_arrow_string_format_default,
11471149
per_symbol_arrow_string_format_per_column=per_symbol_arrow_string_format_per_column,
11481150
)
11491151
table_1 = batch_result[sym_1].data
1150-
assert table_1.schema.field(0).type == pa.dictionary(pa.int32(), pa.large_string()) # per_column override
1151-
assert table_1.schema.field(1).type == pa.string() # global default for all symbols
1152+
assert table_1.schema.field(0).type == pa.dictionary(pa.int32(), pa.large_string()) # global per_column
1153+
assert table_1.schema.field(1).type == pa.large_string() # per symbol default
11521154
assert_frame_equal_with_arrow(table_1, df_1)
11531155
table_2 = batch_result[sym_2].data
1154-
assert table_2.schema.field(0).type == pa.large_string() # per symbol default
1156+
assert table_2.schema.field(0).type == pa.string() # global default for all symbols
11551157
assert table_2.schema.field(1).type == pa.dictionary(pa.int32(), pa.large_string()) # per_column override
11561158
assert_frame_equal_with_arrow(table_2, df_2)

0 commit comments

Comments
 (0)