Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions src/databricks/sql/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,21 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
"""
if size < 0:
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)

# Hold 0-row chunks aside instead of appending them to ``partial_result_chunks``.
# CloudFetchQueue may return a placeholder empty table whose schema does not
# match the real downloaded chunks; concatenating it would corrupt the result.
partial_result_chunks: List["pyarrow.Table"] = []
zero_row_table: Optional["pyarrow.Table"] = None
n_remaining_rows = size

results = self.results.next_n_rows(size)
partial_result_chunks = [results]
n_remaining_rows = size - results.num_rows
self._next_row_index += results.num_rows
if results.num_rows == 0:
zero_row_table = results
else:
partial_result_chunks.append(results)
n_remaining_rows -= results.num_rows
self._next_row_index += results.num_rows

while (
n_remaining_rows > 0
Expand All @@ -318,10 +329,14 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
):
self._fill_results_buffer()
partial_results = self.results.next_n_rows(n_remaining_rows)
if partial_results.num_rows == 0:
continue
partial_result_chunks.append(partial_results)
n_remaining_rows -= partial_results.num_rows
self._next_row_index += partial_results.num_rows

if not partial_result_chunks:
partial_result_chunks.append(zero_row_table)
return concat_table_chunks(partial_result_chunks)

def fetchmany_columnar(self, size: int):
Expand Down Expand Up @@ -351,15 +366,30 @@ def fetchmany_columnar(self, size: int):

def fetchall_arrow(self) -> "pyarrow.Table":
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
# Hold 0-row chunks aside instead of appending them to ``partial_result_chunks``.
# CloudFetchQueue may return a placeholder empty table whose schema does not
# match the real downloaded chunks; concatenating it would corrupt the result.
partial_result_chunks: List = []
zero_row_table: Optional["pyarrow.Table"] = None

results = self.results.remaining_rows()
self._next_row_index += results.num_rows
partial_result_chunks = [results]
if results.num_rows == 0:
zero_row_table = results
else:
partial_result_chunks.append(results)
self._next_row_index += results.num_rows

while not self.has_been_closed_server_side and self.has_more_rows:
self._fill_results_buffer()
partial_results = self.results.remaining_rows()
if partial_results.num_rows == 0:
continue
partial_result_chunks.append(partial_results)
self._next_row_index += partial_results.num_rows

if not partial_result_chunks:
partial_result_chunks.append(zero_row_table)

result_table = concat_table_chunks(partial_result_chunks)
# If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table
# Valid only for metadata commands result set
Expand Down
123 changes: 123 additions & 0 deletions tests/unit/test_fetches.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,34 @@
from databricks.sql.result_set import ThriftResultSet


class _StubArrowQueue:
"""Minimal queue that hands back a pre-built pyarrow.Table once.

Used to inject a schemaless / wrong-schema placeholder that the real
ArrowQueue would never produce — this is what CloudFetchQueue emits
when ``self.table is None`` and ``schema_bytes`` is missing.
"""

def __init__(self, table):
self._table = table
self._consumed = False

def _take(self):
if self._consumed:
return self._table.slice(0, 0)
self._consumed = True
return self._table

def next_n_rows(self, num_rows):
return self._take()

def remaining_rows(self):
return self._take()

def close(self):
pass


@pytest.mark.skipif(pa is None, reason="PyArrow is not installed")
class FetchTests(unittest.TestCase):
"""
Expand Down Expand Up @@ -110,6 +138,39 @@ def fetch_results(
)
return rs

@staticmethod
def make_dummy_result_set_from_queue_list(queue_list, description=None):
"""Like make_dummy_result_set_from_batch_list but yields pre-built queues.

Lets tests inject queues whose returned tables have an arbitrary schema
(or no schema at all) — needed to reproduce the CloudFetch placeholder
case that ``ArrowQueue`` would never produce.
"""
queue_index = 0

def fetch_results(**_):
nonlocal queue_index
q = queue_list[queue_index]
queue_index += 1
return q, queue_index < len(queue_list), 0

mock_thrift_backend = Mock(spec=ThriftDatabricksClient)
mock_thrift_backend.fetch_results = fetch_results

rs = ThriftResultSet(
connection=Mock(),
execute_response=ExecuteResponse(
command_id=None,
status=None,
has_been_closed_server_side=False,
description=description or [],
lz4_compressed=True,
is_staging_operation=False,
),
thrift_client=mock_thrift_backend,
)
return rs

def assertEqualRowValues(self, actual, expected):
self.assertEqual(len(actual) if actual else 0, len(expected) if expected else 0)
for act, exp in zip(actual, expected):
Expand Down Expand Up @@ -267,6 +328,68 @@ def test_fetchone_without_initial_results(self):
dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_2)
self.assertEqual(dummy_result_set.fetchone(), None)

# Regression tests for fetchmany_arrow / fetchall_arrow handling of
# the schemaless CloudFetch placeholder.
def test_fetchall_arrow_drops_mismatched_empty_placeholder(self):
# First fetch_results() call hands back a 0-row placeholder whose
# schema does not match the real chunks. The second call
# hands back real data.
placeholder = pa.Table.from_pydict(
{"stale_col": []}, schema=pa.schema({"stale_col": pa.string()})
)
_, real_table = self.make_arrow_table([[1], [2], [3]])
rs = self.make_dummy_result_set_from_queue_list(
[_StubArrowQueue(placeholder), _StubArrowQueue(real_table)],
description=[("col0", "integer", None, None, None, None, None)],
)

result = rs.fetchall_arrow()

self.assertEqual(result.num_rows, 3)
self.assertEqual(result.schema.names, ["col0"])
self.assertEqual(result.column(0).to_pylist(), [1, 2, 3])

def test_fetchall_arrow_all_empty_returns_zero_row_table(self):
# Every queue call returns the schemaless placeholder — the
# call site should fall back to zero_row_table without crashing.
placeholder = pa.Table.from_pydict({})
rs = self.make_dummy_result_set_from_queue_list(
[_StubArrowQueue(placeholder)],
)

result = rs.fetchall_arrow()

self.assertIsInstance(result, pa.Table)
self.assertEqual(result.num_rows, 0)

def test_fetchmany_arrow_drops_mismatched_empty_placeholder(self):
# See ``test_fetchall_arrow_drops_mismatched_empty_placeholder``.
placeholder = pa.Table.from_pydict(
{"stale_col": []}, schema=pa.schema({"stale_col": pa.string()})
)
_, real_table = self.make_arrow_table([[1], [2], [3]])
rs = self.make_dummy_result_set_from_queue_list(
[_StubArrowQueue(placeholder), _StubArrowQueue(real_table)],
description=[("col0", "integer", None, None, None, None, None)],
)

result = rs.fetchmany_arrow(3)

self.assertEqual(result.num_rows, 3)
self.assertEqual(result.schema.names, ["col0"])
self.assertEqual(result.column(0).to_pylist(), [1, 2, 3])

def test_fetchmany_arrow_all_empty_returns_zero_row_table(self):
placeholder = pa.Table.from_pydict({})
rs = self.make_dummy_result_set_from_queue_list(
[_StubArrowQueue(placeholder)],
)

result = rs.fetchmany_arrow(10)

self.assertIsInstance(result, pa.Table)
self.assertEqual(result.num_rows, 0)


if __name__ == "__main__":
unittest.main()
Loading