@@ -538,6 +538,8 @@ class LazyDataFrameCollection(QueryBuilder):
538538 def __init__ (
539539 self ,
540540 lazy_dataframes : List [LazyDataFrame ],
541+ arrow_string_format_default : Optional [Union [ArrowOutputStringFormat , "pa.DataType" ]] = None ,
542+ arrow_string_format_per_column : Optional [Dict [str , Union [ArrowOutputStringFormat , "pa.DataType" ]]] = None ,
541543 ):
542544 """
543545 Gather a list of `LazyDataFrame`s into a single object that can be collected together.
@@ -563,6 +565,8 @@ def __init__(
563565 )
564566 super ().__init__ ()
565567 self ._lazy_dataframes = lazy_dataframes
568+ self ._arrow_string_format_default = arrow_string_format_default
569+ self ._arrow_string_format_per_column = arrow_string_format_per_column
566570 if len (self ._lazy_dataframes ):
567571 self ._lib = self ._lazy_dataframes [0 ].lib
568572 self ._output_format = self ._lazy_dataframes [0 ].read_request .output_format
@@ -588,7 +592,12 @@ def collect(self) -> List[Union[VersionedItem, DataError]]:
588592 """
589593 if not len (self ._lazy_dataframes ):
590594 return []
591- return self ._lib .read_batch (self ._read_requests (), output_format = self ._output_format )
595+ return self ._lib .read_batch (
596+ self ._read_requests (),
597+ output_format = self ._output_format ,
598+ arrow_string_format_default = self ._arrow_string_format_default ,
599+ arrow_string_format_per_column = self ._arrow_string_format_per_column ,
600+ )
592601
593602 def _read_requests (self ) -> List [ReadRequest ]:
594603 # Combines queries for individual LazyDataFrames with the global query associated with this
@@ -647,33 +656,6 @@ def __init__(
647656 super ().__init__ ()
648657 self ._lazy_dataframes = lazy_dataframes
649658 self .then (join )
650- self .arrow_string_format_default = None
651- self .arrow_string_format_per_column = {}
652- for lf in self ._lazy_dataframes ._lazy_dataframes :
653- self .arrow_string_format_default = (
654- self .arrow_string_format_default or lf .read_request .arrow_string_format_default
655- )
656- check (
657- lf .read_request .arrow_string_format_default is None
658- or self .arrow_string_format_default == lf .read_request .arrow_string_format_default ,
659- "Lazy frames from collection cannot be combined for join because they have incompatible arrow_string_format_default values {} and {}" ,
660- self .arrow_string_format_default ,
661- lf .read_request .arrow_string_format_default ,
662- )
663- if lf .read_request .arrow_string_format_per_column is not None :
664- common_cols = (
665- self .arrow_string_format_per_column .keys () & lf .read_request .arrow_string_format_per_column .keys ()
666- )
667- for common_col in common_cols :
668- check (
669- self .arrow_string_format_per_column [common_col ]
670- == lf .read_request .arrow_string_format_per_column [common_col ],
671- "Lazy frames from collection cannot be combined for join because they have incompatible arrow_string_format_per_column values {} and {} for column {}" ,
672- self .arrow_string_format_per_column [common_col ],
673- lf .read_request .arrow_string_format_per_column [common_col ],
674- common_col ,
675- )
676- self .arrow_string_format_per_column .update (lf .read_request .arrow_string_format_per_column )
677659
678660 def collect (self ) -> VersionedItemWithJoin :
679661 """
@@ -693,8 +675,8 @@ def collect(self) -> VersionedItemWithJoin:
693675 self ._lazy_dataframes ._read_requests (),
694676 self ,
695677 output_format = self ._lazy_dataframes ._output_format ,
696- arrow_string_format_default = self .arrow_string_format_default ,
697- arrow_string_format_per_column = self .arrow_string_format_per_column ,
678+ arrow_string_format_default = self ._lazy_dataframes . _arrow_string_format_default ,
679+ arrow_string_format_per_column = self ._lazy_dataframes . _arrow_string_format_per_column ,
698680 )
699681
700682 def __str__ (self ) -> str :
@@ -2079,6 +2061,7 @@ def read_batch(
20792061 lazy : bool = False ,
20802062 output_format : Optional [Union [OutputFormat , str ]] = None ,
20812063 arrow_string_format_default : Optional [Union [ArrowOutputStringFormat , "pa.DataType" ]] = None ,
2064+ arrow_string_format_per_column : Optional [Dict [str , Union [ArrowOutputStringFormat , "pa.DataType" ]]] = None ,
20822065 ) -> Union [List [Union [VersionedItem , DataError ]], LazyDataFrameCollection ]:
20832066 """
20842067 Reads multiple symbols.
@@ -2107,6 +2090,10 @@ def read_batch(
21072090 It serves as the default for the entire batch. The string format settings inside the `ReadRequest`s will
21082091 override this batch level setting.
21092092
2093+ arrow_string_format_per_column: Optional[Dict[str, Union[ArrowOutputStringFormat, "pa.DataType"]]], default=None,
2094+ Provides per column name overrides for `arrow_string_format_default`. It is only applied to symbols which
2095+ don't have a `arrow_string_format_per_column` set in their `ReadRequest`.
2096+
21102097 Returns
21112098 -------
21122099 Union[List[Union[VersionedItem, DataError]], LazyDataFrameCollection]
@@ -2221,14 +2208,18 @@ def handle_symbol(s_):
22212208 columns = columns [idx ],
22222209 query_builder = q ,
22232210 output_format = output_format ,
2224- arrow_string_format_default = (
2225- per_symbol_arrow_string_format_default [ idx ] or arrow_string_format_default
2226- ),
2227- arrow_string_format_per_column = per_symbol_arrow_string_format_per_column [ idx ] ,
2211+ arrow_string_format_default = per_symbol_arrow_string_format_default [ idx ]
2212+ or arrow_string_format_default ,
2213+ arrow_string_format_per_column = per_symbol_arrow_string_format_per_column [ idx ]
2214+ or arrow_string_format_per_column ,
22282215 ),
22292216 )
22302217 )
2231- return LazyDataFrameCollection (lazy_dataframes )
2218+ return LazyDataFrameCollection (
2219+ lazy_dataframes ,
2220+ arrow_string_format_default = arrow_string_format_default ,
2221+ arrow_string_format_per_column = arrow_string_format_per_column ,
2222+ )
22322223 else :
22332224 return self ._nvs ._batch_read_to_versioned_items (
22342225 symbol_strings ,
@@ -2242,6 +2233,7 @@ def handle_symbol(s_):
22422233 iterate_snapshots_if_tombstoned = False ,
22432234 output_format = output_format ,
22442235 arrow_string_format_default = arrow_string_format_default ,
2236+ arrow_string_format_per_column = arrow_string_format_per_column ,
22452237 per_symbol_arrow_string_format_default = per_symbol_arrow_string_format_default ,
22462238 per_symbol_arrow_string_format_per_column = per_symbol_arrow_string_format_per_column ,
22472239 )
0 commit comments