Skip to content

Commit d03db7a

Browse files
committed
Made code changes
1 parent c666489 commit d03db7a

2 files changed

Lines changed: 168 additions & 11 deletions

File tree

src/databricks/sql/client.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,9 +1324,23 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
13241324
"""
13251325
if size < 0:
13261326
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)
1327+
1328+
# Hold 0-row chunks aside instead of concatenating them with real chunks.
1329+
# CloudFetchQueue may emit a placeholder empty table whose schema does
1330+
# not match the real downloaded chunks; pyarrow.concat_tables with
1331+
# promote_options="default" would silently merge it in as phantom
1332+
# columns.
1333+
partial_result_chunks: List["pyarrow.Table"] = []
1334+
zero_row_table: Optional["pyarrow.Table"] = None
1335+
n_remaining_rows = size
1336+
13271337
results = self.results.next_n_rows(size)
1328-
n_remaining_rows = size - results.num_rows
1329-
self._next_row_index += results.num_rows
1338+
if results.num_rows == 0:
1339+
zero_row_table = results
1340+
else:
1341+
partial_result_chunks.append(results)
1342+
n_remaining_rows -= results.num_rows
1343+
self._next_row_index += results.num_rows
13301344

13311345
while (
13321346
n_remaining_rows > 0
@@ -1335,13 +1349,17 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
13351349
):
13361350
self._fill_results_buffer()
13371351
partial_results = self.results.next_n_rows(n_remaining_rows)
1338-
results = pyarrow.concat_tables(
1339-
[results, partial_results], promote_options="default"
1340-
)
1352+
if partial_results.num_rows == 0:
1353+
continue
1354+
partial_result_chunks.append(partial_results)
13411355
n_remaining_rows -= partial_results.num_rows
13421356
self._next_row_index += partial_results.num_rows
13431357

1344-
return results
1358+
if not partial_result_chunks:
1359+
return zero_row_table
1360+
return pyarrow.concat_tables(
1361+
partial_result_chunks, promote_options="default"
1362+
)
13451363

13461364
def merge_columnar(self, result1, result2):
13471365
"""
@@ -1387,18 +1405,31 @@ def fetchmany_columnar(self, size: int):
13871405

13881406
def fetchall_arrow(self) -> "pyarrow.Table":
13891407
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
1408+
# See ``fetchmany_arrow`` for why 0-row chunks are held aside rather than
1409+
# concatenated with the real chunks.
1410+
partial_result_chunks: List["pyarrow.Table"] = []
1411+
zero_row_table: Optional["pyarrow.Table"] = None
1412+
13901413
results = self.results.remaining_rows()
1391-
self._next_row_index += results.num_rows
1414+
if results.num_rows == 0:
1415+
zero_row_table = results
1416+
else:
1417+
partial_result_chunks.append(results)
1418+
self._next_row_index += results.num_rows
13921419

13931420
while not self.has_been_closed_server_side and self.has_more_rows:
13941421
self._fill_results_buffer()
13951422
partial_results = self.results.remaining_rows()
1396-
results = pyarrow.concat_tables(
1397-
[results, partial_results], promote_options="default"
1398-
)
1423+
if partial_results.num_rows == 0:
1424+
continue
1425+
partial_result_chunks.append(partial_results)
13991426
self._next_row_index += partial_results.num_rows
14001427

1401-
return results
1428+
if not partial_result_chunks:
1429+
return zero_row_table
1430+
return pyarrow.concat_tables(
1431+
partial_result_chunks, promote_options="default"
1432+
)
14021433

14031434
def fetchall_columnar(self):
14041435
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""

tests/unit/test_fetches.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,34 @@
77
from databricks.sql.utils import ExecuteResponse, ArrowQueue
88

99

10+
class _StubArrowQueue:
11+
"""Minimal queue that hands back a pre-built pyarrow.Table once.
12+
13+
Used to inject a placeholder whose schema differs from the real chunks —
14+
what ``CloudFetchQueue._create_empty_table`` can produce when its
15+
``schema_bytes`` are stale.
16+
"""
17+
18+
def __init__(self, table):
19+
self._table = table
20+
self._consumed = False
21+
22+
def _take(self):
23+
if self._consumed:
24+
return self._table.slice(0, 0)
25+
self._consumed = True
26+
return self._table
27+
28+
def next_n_rows(self, num_rows):
29+
return self._take()
30+
31+
def remaining_rows(self):
32+
return self._take()
33+
34+
def close(self):
35+
pass
36+
37+
1038
class FetchTests(unittest.TestCase):
1139
"""
1240
Unit tests for checking the fetch logic.
@@ -98,6 +126,42 @@ def fetch_results(
98126
)
99127
return rs
100128

129+
@staticmethod
130+
def make_dummy_result_set_from_queue_list(queue_list, description=None):
131+
"""Like make_dummy_result_set_from_batch_list but yields pre-built queues.
132+
133+
Lets tests inject queues whose returned tables have an arbitrary schema
134+
— needed to reproduce the CloudFetch placeholder case that ``ArrowQueue``
135+
would never produce on its own.
136+
"""
137+
queue_index = 0
138+
139+
def fetch_results(**_):
140+
nonlocal queue_index
141+
q = queue_list[queue_index]
142+
queue_index += 1
143+
return q, queue_index < len(queue_list)
144+
145+
mock_thrift_backend = Mock()
146+
mock_thrift_backend.fetch_results = fetch_results
147+
148+
rs = client.ResultSet(
149+
connection=Mock(),
150+
thrift_backend=mock_thrift_backend,
151+
execute_response=ExecuteResponse(
152+
status=None,
153+
has_been_closed_server_side=False,
154+
has_more_rows=True,
155+
description=description or [],
156+
lz4_compressed=Mock(),
157+
command_handle=None,
158+
arrow_queue=None,
159+
arrow_schema_bytes=None,
160+
is_staging_operation=False,
161+
),
162+
)
163+
return rs
164+
101165
def assertEqualRowValues(self, actual, expected):
102166
self.assertEqual(len(actual) if actual else 0, len(expected) if expected else 0)
103167
for act, exp in zip(actual, expected):
@@ -255,6 +319,68 @@ def test_fetchone_without_initial_results(self):
255319
dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_2)
256320
self.assertEqual(dummy_result_set.fetchone(), None)
257321

322+
# Regression tests for fetchmany_arrow / fetchall_arrow handling of
323+
# the CloudFetch empty placeholder
324+
def test_fetchall_arrow_drops_mismatched_empty_placeholder(self):
325+
# First fetch_results call hands back a 0-row placeholder whose
326+
# schema does not match the real chunks . The second
327+
# call hands back the real data.
328+
placeholder = pa.Table.from_pydict(
329+
{"stale_col": []}, schema=pa.schema({"stale_col": pa.string()})
330+
)
331+
_, real_table = self.make_arrow_table([[1], [2], [3]])
332+
rs = self.make_dummy_result_set_from_queue_list(
333+
[_StubArrowQueue(placeholder), _StubArrowQueue(real_table)],
334+
description=[("col0", "integer", None, None, None, None, None)],
335+
)
336+
337+
result = rs.fetchall_arrow()
338+
339+
self.assertEqual(result.num_rows, 3)
340+
self.assertEqual(result.schema.names, ["col0"])
341+
self.assertEqual(result.column(0).to_pylist(), [1, 2, 3])
342+
343+
def test_fetchall_arrow_all_empty_returns_zero_row_table(self):
344+
# Every queue call returns the placeholder — the call site should
345+
# fall back to ``zero_row_table`` and return a real pa.Table.
346+
placeholder = pa.Table.from_pydict({})
347+
rs = self.make_dummy_result_set_from_queue_list(
348+
[_StubArrowQueue(placeholder)],
349+
)
350+
351+
result = rs.fetchall_arrow()
352+
353+
self.assertIsInstance(result, pa.Table)
354+
self.assertEqual(result.num_rows, 0)
355+
356+
def test_fetchmany_arrow_drops_mismatched_empty_placeholder(self):
357+
# See ``test_fetchall_arrow_drops_mismatched_empty_placeholder``.
358+
placeholder = pa.Table.from_pydict(
359+
{"stale_col": []}, schema=pa.schema({"stale_col": pa.string()})
360+
)
361+
_, real_table = self.make_arrow_table([[1], [2], [3]])
362+
rs = self.make_dummy_result_set_from_queue_list(
363+
[_StubArrowQueue(placeholder), _StubArrowQueue(real_table)],
364+
description=[("col0", "integer", None, None, None, None, None)],
365+
)
366+
367+
result = rs.fetchmany_arrow(3)
368+
369+
self.assertEqual(result.num_rows, 3)
370+
self.assertEqual(result.schema.names, ["col0"])
371+
self.assertEqual(result.column(0).to_pylist(), [1, 2, 3])
372+
373+
def test_fetchmany_arrow_all_empty_returns_zero_row_table(self):
374+
placeholder = pa.Table.from_pydict({})
375+
rs = self.make_dummy_result_set_from_queue_list(
376+
[_StubArrowQueue(placeholder)],
377+
)
378+
379+
result = rs.fetchmany_arrow(10)
380+
381+
self.assertIsInstance(result, pa.Table)
382+
self.assertEqual(result.num_rows, 0)
383+
258384

259385
if __name__ == "__main__":
260386
unittest.main()

0 commit comments

Comments
 (0)