|
7 | 7 | from databricks.sql.utils import ExecuteResponse, ArrowQueue |
8 | 8 |
|
9 | 9 |
|
| 10 | +class _StubArrowQueue: |
| 11 | + """Minimal queue that hands back a pre-built pyarrow.Table once. |
| 12 | +
|
| 13 | + Used to inject a placeholder whose schema differs from the real chunks — |
| 14 | + what ``CloudFetchQueue._create_empty_table`` can produce when its |
| 15 | + ``schema_bytes`` are stale. |
| 16 | + """ |
| 17 | + |
| 18 | + def __init__(self, table): |
| 19 | + self._table = table |
| 20 | + self._consumed = False |
| 21 | + |
| 22 | + def _take(self): |
| 23 | + if self._consumed: |
| 24 | + return self._table.slice(0, 0) |
| 25 | + self._consumed = True |
| 26 | + return self._table |
| 27 | + |
| 28 | + def next_n_rows(self, num_rows): |
| 29 | + return self._take() |
| 30 | + |
| 31 | + def remaining_rows(self): |
| 32 | + return self._take() |
| 33 | + |
| 34 | + def close(self): |
| 35 | + pass |
| 36 | + |
| 37 | + |
10 | 38 | class FetchTests(unittest.TestCase): |
11 | 39 | """ |
12 | 40 | Unit tests for checking the fetch logic. |
@@ -98,6 +126,42 @@ def fetch_results( |
98 | 126 | ) |
99 | 127 | return rs |
100 | 128 |
|
| 129 | + @staticmethod |
| 130 | + def make_dummy_result_set_from_queue_list(queue_list, description=None): |
| 131 | + """Like make_dummy_result_set_from_batch_list but yields pre-built queues. |
| 132 | +
|
| 133 | + Lets tests inject queues whose returned tables have an arbitrary schema |
| 134 | + — needed to reproduce the CloudFetch placeholder case that ``ArrowQueue`` |
| 135 | + would never produce on its own. |
| 136 | + """ |
| 137 | + queue_index = 0 |
| 138 | + |
| 139 | + def fetch_results(**_): |
| 140 | + nonlocal queue_index |
| 141 | + q = queue_list[queue_index] |
| 142 | + queue_index += 1 |
| 143 | + return q, queue_index < len(queue_list) |
| 144 | + |
| 145 | + mock_thrift_backend = Mock() |
| 146 | + mock_thrift_backend.fetch_results = fetch_results |
| 147 | + |
| 148 | + rs = client.ResultSet( |
| 149 | + connection=Mock(), |
| 150 | + thrift_backend=mock_thrift_backend, |
| 151 | + execute_response=ExecuteResponse( |
| 152 | + status=None, |
| 153 | + has_been_closed_server_side=False, |
| 154 | + has_more_rows=True, |
| 155 | + description=description or [], |
| 156 | + lz4_compressed=Mock(), |
| 157 | + command_handle=None, |
| 158 | + arrow_queue=None, |
| 159 | + arrow_schema_bytes=None, |
| 160 | + is_staging_operation=False, |
| 161 | + ), |
| 162 | + ) |
| 163 | + return rs |
| 164 | + |
101 | 165 | def assertEqualRowValues(self, actual, expected): |
102 | 166 | self.assertEqual(len(actual) if actual else 0, len(expected) if expected else 0) |
103 | 167 | for act, exp in zip(actual, expected): |
@@ -255,6 +319,68 @@ def test_fetchone_without_initial_results(self): |
255 | 319 | dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_2) |
256 | 320 | self.assertEqual(dummy_result_set.fetchone(), None) |
257 | 321 |
|
| 322 | + # Regression tests for fetchmany_arrow / fetchall_arrow handling of |
| 323 | + # the CloudFetch empty placeholder |
| 324 | + def test_fetchall_arrow_drops_mismatched_empty_placeholder(self): |
| 325 | + # First fetch_results call hands back a 0-row placeholder whose |
| 326 | + # schema does not match the real chunks . The second |
| 327 | + # call hands back the real data. |
| 328 | + placeholder = pa.Table.from_pydict( |
| 329 | + {"stale_col": []}, schema=pa.schema({"stale_col": pa.string()}) |
| 330 | + ) |
| 331 | + _, real_table = self.make_arrow_table([[1], [2], [3]]) |
| 332 | + rs = self.make_dummy_result_set_from_queue_list( |
| 333 | + [_StubArrowQueue(placeholder), _StubArrowQueue(real_table)], |
| 334 | + description=[("col0", "integer", None, None, None, None, None)], |
| 335 | + ) |
| 336 | + |
| 337 | + result = rs.fetchall_arrow() |
| 338 | + |
| 339 | + self.assertEqual(result.num_rows, 3) |
| 340 | + self.assertEqual(result.schema.names, ["col0"]) |
| 341 | + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) |
| 342 | + |
| 343 | + def test_fetchall_arrow_all_empty_returns_zero_row_table(self): |
| 344 | + # Every queue call returns the placeholder — the call site should |
| 345 | + # fall back to ``zero_row_table`` and return a real pa.Table. |
| 346 | + placeholder = pa.Table.from_pydict({}) |
| 347 | + rs = self.make_dummy_result_set_from_queue_list( |
| 348 | + [_StubArrowQueue(placeholder)], |
| 349 | + ) |
| 350 | + |
| 351 | + result = rs.fetchall_arrow() |
| 352 | + |
| 353 | + self.assertIsInstance(result, pa.Table) |
| 354 | + self.assertEqual(result.num_rows, 0) |
| 355 | + |
| 356 | + def test_fetchmany_arrow_drops_mismatched_empty_placeholder(self): |
| 357 | + # See ``test_fetchall_arrow_drops_mismatched_empty_placeholder``. |
| 358 | + placeholder = pa.Table.from_pydict( |
| 359 | + {"stale_col": []}, schema=pa.schema({"stale_col": pa.string()}) |
| 360 | + ) |
| 361 | + _, real_table = self.make_arrow_table([[1], [2], [3]]) |
| 362 | + rs = self.make_dummy_result_set_from_queue_list( |
| 363 | + [_StubArrowQueue(placeholder), _StubArrowQueue(real_table)], |
| 364 | + description=[("col0", "integer", None, None, None, None, None)], |
| 365 | + ) |
| 366 | + |
| 367 | + result = rs.fetchmany_arrow(3) |
| 368 | + |
| 369 | + self.assertEqual(result.num_rows, 3) |
| 370 | + self.assertEqual(result.schema.names, ["col0"]) |
| 371 | + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) |
| 372 | + |
| 373 | + def test_fetchmany_arrow_all_empty_returns_zero_row_table(self): |
| 374 | + placeholder = pa.Table.from_pydict({}) |
| 375 | + rs = self.make_dummy_result_set_from_queue_list( |
| 376 | + [_StubArrowQueue(placeholder)], |
| 377 | + ) |
| 378 | + |
| 379 | + result = rs.fetchmany_arrow(10) |
| 380 | + |
| 381 | + self.assertIsInstance(result, pa.Table) |
| 382 | + self.assertEqual(result.num_rows, 0) |
| 383 | + |
258 | 384 |
|
259 | 385 | if __name__ == "__main__": |
260 | 386 | unittest.main() |
0 commit comments