Skip to content

Commit 13c2615

Browse files
authored
GH-47029: [Archery][Integration] Fix generation of run-end-encoded data (#47653)
### Rationale for this change The size passed when generating a run-end-encoded field must be interpreted as the logical length of the run-end-encoded column, not the physical number of runs. ### Are these changes tested? Yes, by the CI integration test. I also checked the generated JSON manually. ### Are there any user-facing changes? No. * GitHub Issue: #47029 Authored-by: Antoine Pitrou <[email protected]> Signed-off-by: Antoine Pitrou <[email protected]>
1 parent dc9753c commit 13c2615

File tree

1 file changed

+44
-29
lines changed

1 file changed

+44
-29
lines changed

dev/archery/archery/integration/datagen.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import binascii
2020
import gzip
2121
import json
22+
import math
2223
import os
2324
import random
2425
import tempfile
@@ -177,6 +178,9 @@ def _get_type(self):
177178
('bitWidth', self.bit_width)
178179
])
179180

181+
def _encode_values(self, values):
182+
return list(map(int if self.bit_width < 64 else str, values))
183+
180184
def generate_column(self, size, name=None):
181185
lower_bound, upper_bound = self._get_generated_data_bounds()
182186
return self.generate_range(size, lower_bound, upper_bound,
@@ -187,42 +191,29 @@ def generate_range(self, size, lower, upper, name=None,
187191
values = np.random.randint(lower, upper, size=size, dtype=np.int64)
188192
if include_extremes and size >= 2:
189193
values[:2] = [lower, upper]
190-
values = list(map(int if self.bit_width < 64 else str, values))
194+
values = self._encode_values(values)
191195

192196
is_valid = self._make_is_valid(size)
193197

194198
if name is None:
195199
name = self.name
196200
return PrimitiveColumn(name, size, is_valid, values)
197201

202+
@property
203+
def column_class(self):
204+
return PrimitiveColumn
205+
198206

199207
# Integer field that fulfils the requirements for the run ends field of REE.
200208
# The integers are positive and in a strictly increasing sequence
201209
class RunEndsField(IntegerField):
202-
# bit_width should only be one of 16/32/64
203210
def __init__(self, name, bit_width, *, metadata=None):
211+
assert bit_width in (16, 32, 64)
204212
super().__init__(name, is_signed=True, bit_width=bit_width,
205-
nullable=False, metadata=metadata, min_value=1)
213+
nullable=False, metadata=metadata)
206214

207-
def generate_range(self, size, lower, upper, name=None,
208-
include_extremes=False):
209-
rng = np.random.default_rng()
210-
# generate values that are strictly increasing with a min-value of
211-
# 1, but don't go higher than the max signed value for the given
212-
# bit width. We sort the values to ensure they are strictly increasing
213-
# and set replace to False to avoid duplicates, ensuring a valid
214-
# run-ends array.
215-
values = rng.choice(2 ** (self.bit_width - 1) - 1, size=size, replace=False)
216-
values += 1
217-
values = sorted(values)
218-
values = list(map(int if self.bit_width < 64 else str, values))
219-
# RunEnds cannot be null, as such self.nullable == False and this
220-
# will generate a validity map of all ones.
221-
is_valid = self._make_is_valid(size)
222-
223-
if name is None:
224-
name = self.name
225-
return PrimitiveColumn(name, size, is_valid, values)
215+
def generate_column(self, size, name=None):
216+
raise NotImplementedError("cannot be generated directly")
226217

227218

228219
class DateField(IntegerField):
@@ -1159,11 +1150,32 @@ def _get_children(self):
11591150
]
11601151

11611152
def generate_column(self, size, name=None):
1162-
values = self.values_field.generate_column(size)
1163-
run_ends = self.run_ends_field.generate_column(size)
1153+
# The `size` of a RunEndEncodedField is the logical length of the
1154+
# run-end-encoded column, so we choose a number of physical runs
1155+
# that's smaller.
1156+
if size > 0:
1157+
num_runs = np.random.randint(1, math.ceil(size * 0.75))
1158+
# Generate run ends
1159+
run_ends = np.random.choice(size - 1, num_runs - 1, replace=False) + 1
1160+
run_ends.sort()
1161+
run_ends = np.concat((run_ends, [size]))
1162+
assert len(run_ends) == num_runs
1163+
assert len(set(run_ends)) == num_runs
1164+
assert (run_ends > 0).all()
1165+
assert (run_ends <= size).all()
1166+
else:
1167+
num_runs = 0
1168+
run_ends = []
1169+
run_ends_is_valid = self._make_is_valid(num_runs, null_probability=0)
1170+
run_ends = self.run_ends_field._encode_values(run_ends)
1171+
1172+
run_end_column = self.run_ends_field.column_class(
1173+
self.run_ends_field.name, num_runs, run_ends_is_valid, run_ends)
1174+
values = self.values_field.generate_column(num_runs)
1175+
11641176
if name is None:
11651177
name = self.name
1166-
return RunEndEncodedColumn(name, size, run_ends, values)
1178+
return RunEndEncodedColumn(name, size, run_end_column, values)
11671179

11681180

11691181
class _BaseUnionField(Field):
@@ -1746,11 +1758,14 @@ def generate_recursive_nested_case():
17461758

17471759
def generate_run_end_encoded_case():
17481760
fields = [
1749-
RunEndEncodedField('ree16', 16, get_field('values', 'int32')),
1750-
RunEndEncodedField('ree32', 32, get_field('values', 'utf8')),
1751-
RunEndEncodedField('ree64', 64, get_field('values', 'float32')),
1761+
RunEndEncodedField('ree16_int32', 16, get_field('values', 'int32')),
1762+
RunEndEncodedField('ree32_utf8', 32, get_field('values', 'utf8')),
1763+
RunEndEncodedField('ree64_float32', 64, get_field('values', 'float32')),
1764+
RunEndEncodedField('ree16_bool', 64, get_field('values', 'bool')),
1765+
# Add a non-REE-encoded field to check column size correctness
1766+
BooleanField('bool'),
17521767
]
1753-
batch_sizes = [0, 7, 10]
1768+
batch_sizes = [0, 7, 20]
17541769
return _generate_file("run_end_encoded", fields, batch_sizes)
17551770

17561771

0 commit comments

Comments
 (0)