19
19
import binascii
20
20
import gzip
21
21
import json
22
+ import math
22
23
import os
23
24
import random
24
25
import tempfile
@@ -177,6 +178,9 @@ def _get_type(self):
177
178
('bitWidth' , self .bit_width )
178
179
])
179
180
181
+ def _encode_values (self , values ):
182
+ return list (map (int if self .bit_width < 64 else str , values ))
183
+
180
184
def generate_column (self , size , name = None ):
181
185
lower_bound , upper_bound = self ._get_generated_data_bounds ()
182
186
return self .generate_range (size , lower_bound , upper_bound ,
@@ -187,42 +191,29 @@ def generate_range(self, size, lower, upper, name=None,
187
191
values = np .random .randint (lower , upper , size = size , dtype = np .int64 )
188
192
if include_extremes and size >= 2 :
189
193
values [:2 ] = [lower , upper ]
190
- values = list ( map ( int if self .bit_width < 64 else str , values ) )
194
+ values = self ._encode_values ( values )
191
195
192
196
is_valid = self ._make_is_valid (size )
193
197
194
198
if name is None :
195
199
name = self .name
196
200
return PrimitiveColumn (name , size , is_valid , values )
197
201
202
+ @property
203
+ def column_class (self ):
204
+ return PrimitiveColumn
205
+
198
206
199
207
# Integer field that fulfils the requirements for the run ends field of REE.
200
208
# The integers are positive and in a strictly increasing sequence
201
209
class RunEndsField (IntegerField ):
202
- # bit_width should only be one of 16/32/64
203
210
def __init__ (self , name , bit_width , * , metadata = None ):
211
+ assert bit_width in (16 , 32 , 64 )
204
212
super ().__init__ (name , is_signed = True , bit_width = bit_width ,
205
- nullable = False , metadata = metadata , min_value = 1 )
213
+ nullable = False , metadata = metadata )
206
214
207
- def generate_range (self , size , lower , upper , name = None ,
208
- include_extremes = False ):
209
- rng = np .random .default_rng ()
210
- # generate values that are strictly increasing with a min-value of
211
- # 1, but don't go higher than the max signed value for the given
212
- # bit width. We sort the values to ensure they are strictly increasing
213
- # and set replace to False to avoid duplicates, ensuring a valid
214
- # run-ends array.
215
- values = rng .choice (2 ** (self .bit_width - 1 ) - 1 , size = size , replace = False )
216
- values += 1
217
- values = sorted (values )
218
- values = list (map (int if self .bit_width < 64 else str , values ))
219
- # RunEnds cannot be null, as such self.nullable == False and this
220
- # will generate a validity map of all ones.
221
- is_valid = self ._make_is_valid (size )
222
-
223
- if name is None :
224
- name = self .name
225
- return PrimitiveColumn (name , size , is_valid , values )
215
+ def generate_column (self , size , name = None ):
216
+ raise NotImplementedError ("cannot be generated directly" )
226
217
227
218
228
219
class DateField (IntegerField ):
@@ -1159,11 +1150,32 @@ def _get_children(self):
1159
1150
]
1160
1151
1161
1152
def generate_column (self , size , name = None ):
1162
- values = self .values_field .generate_column (size )
1163
- run_ends = self .run_ends_field .generate_column (size )
1153
+ # The `size` of a RunEndEncodedField is the logical length of the
1154
+ # run-end-encoded column, so we choose a number of physical runs
1155
+ # that's smaller.
1156
+ if size > 0 :
1157
+ num_runs = np .random .randint (1 , math .ceil (size * 0.75 ))
1158
+ # Generate run ends
1159
+ run_ends = np .random .choice (size - 1 , num_runs - 1 , replace = False ) + 1
1160
+ run_ends .sort ()
1161
+ run_ends = np .concat ((run_ends , [size ]))
1162
+ assert len (run_ends ) == num_runs
1163
+ assert len (set (run_ends )) == num_runs
1164
+ assert (run_ends > 0 ).all ()
1165
+ assert (run_ends <= size ).all ()
1166
+ else :
1167
+ num_runs = 0
1168
+ run_ends = []
1169
+ run_ends_is_valid = self ._make_is_valid (num_runs , null_probability = 0 )
1170
+ run_ends = self .run_ends_field ._encode_values (run_ends )
1171
+
1172
+ run_end_column = self .run_ends_field .column_class (
1173
+ self .run_ends_field .name , num_runs , run_ends_is_valid , run_ends )
1174
+ values = self .values_field .generate_column (num_runs )
1175
+
1164
1176
if name is None :
1165
1177
name = self .name
1166
- return RunEndEncodedColumn (name , size , run_ends , values )
1178
+ return RunEndEncodedColumn (name , size , run_end_column , values )
1167
1179
1168
1180
1169
1181
class _BaseUnionField (Field ):
@@ -1746,11 +1758,14 @@ def generate_recursive_nested_case():
1746
1758
1747
1759
def generate_run_end_encoded_case ():
1748
1760
fields = [
1749
- RunEndEncodedField ('ree16' , 16 , get_field ('values' , 'int32' )),
1750
- RunEndEncodedField ('ree32' , 32 , get_field ('values' , 'utf8' )),
1751
- RunEndEncodedField ('ree64' , 64 , get_field ('values' , 'float32' )),
1761
+ RunEndEncodedField ('ree16_int32' , 16 , get_field ('values' , 'int32' )),
1762
+ RunEndEncodedField ('ree32_utf8' , 32 , get_field ('values' , 'utf8' )),
1763
+ RunEndEncodedField ('ree64_float32' , 64 , get_field ('values' , 'float32' )),
1764
+ RunEndEncodedField ('ree16_bool' , 64 , get_field ('values' , 'bool' )),
1765
+ # Add a non-REE-encoded field to check column size correctness
1766
+ BooleanField ('bool' ),
1752
1767
]
1753
- batch_sizes = [0 , 7 , 10 ]
1768
+ batch_sizes = [0 , 7 , 20 ]
1754
1769
return _generate_file ("run_end_encoded" , fields , batch_sizes )
1755
1770
1756
1771
0 commit comments