From 89f86804f4a89c88c2fc110ac1dfa0adcb999fc5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 7 Jul 2025 09:13:33 +0530 Subject: [PATCH 1/7] stop passing client to ResultSet, infer from connection Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 1 - src/databricks/sql/backend/sea/result_set.py | 38 ++++++----- .../sql/backend/sea/utils/filters.py | 7 ++- src/databricks/sql/backend/thrift_backend.py | 6 -- src/databricks/sql/result_set.py | 63 ++++++++++++------- tests/unit/test_client.py | 42 +++++++++---- tests/unit/test_fetches.py | 32 +++++++--- tests/unit/test_sea_backend.py | 10 ++- tests/unit/test_sea_result_set.py | 59 ++++++++--------- 9 files changed, 159 insertions(+), 99 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 814859a31..353252c42 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -620,7 +620,6 @@ def get_execution_result( return SeaResultSet( connection=cursor.connection, execute_response=execute_response, - sea_client=self, result_data=response.result, manifest=response.manifest, buffer_size_bytes=cursor.buffer_size_bytes, diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 302af5e3a..14ed61575 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -31,7 +31,6 @@ def __init__( self, connection: Connection, execute_response: ExecuteResponse, - sea_client: SeaDatabricksClient, result_data: ResultData, manifest: ResultManifest, buffer_size_bytes: int = 104857600, @@ -43,7 +42,6 @@ def __init__( Args: connection: The parent connection execute_response: Response from the execute command - sea_client: The SeaDatabricksClient instance for direct access buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch result_data: Result data from SEA response @@ -56,32 +54,38 @@ def __init__( if statement_id is None: raise ValueError("Command ID is not a SEA statement ID") - results_queue = SeaResultSetQueueFactory.build_queue( - result_data, - self.manifest, - statement_id, - description=execute_response.description, - max_download_threads=sea_client.max_download_threads, - sea_client=sea_client, - lz4_compressed=execute_response.lz4_compressed, - ) - # Call parent constructor with common attributes super().__init__( connection=connection, - backend=sea_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, arrow_schema_bytes=execute_response.arrow_schema_bytes, ) + # Assert that the backend is of the correct type + assert isinstance( + self.backend, SeaDatabricksClient + ), "Backend must be a SeaDatabricksClient" + + results_queue = SeaResultSetQueueFactory.build_queue( + result_data, + self.manifest, + statement_id, + description=execute_response.description, + max_download_threads=self.backend.max_download_threads, + sea_client=self.backend, + lz4_compressed=execute_response.lz4_compressed, + ) + + # Set the results queue + self.results = results_queue + def _convert_json_types(self, row: List[str]) -> List[Any]: """ Convert string values in the row to appropriate Python types based on column metadata. @@ -160,6 +164,9 @@ def fetchmany_json(self, size: int) -> List[List[str]]: if size < 0: raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.next_n_rows(size) self._next_row_index += len(results) @@ -173,6 +180,9 @@ def fetchall_json(self) -> List[List[str]]: Columnar table containing all remaining rows """ + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.remaining_rows() self._next_row_index += len(results) diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index ef6c91d7d..cd27778fb 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -12,14 +12,13 @@ Optional, Any, Callable, - cast, TYPE_CHECKING, ) if TYPE_CHECKING: from databricks.sql.backend.sea.result_set import SeaResultSet -from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.backend.types import ExecuteResponse, CommandId, CommandState logger = logging.getLogger(__name__) @@ -45,6 +44,9 @@ def _filter_sea_result_set( """ # Get all remaining rows + if result_set.results is None: + raise RuntimeError("Results queue is not initialized") + all_rows = result_set.results.remaining_rows() # Filter rows @@ -79,7 +81,6 @@ def _filter_sea_result_set( filtered_result_set = SeaResultSet( connection=result_set.connection, execute_response=execute_response, - sea_client=cast(SeaDatabricksClient, result_set.backend), result_data=result_data, manifest=manifest, buffer_size_bytes=result_set.buffer_size_bytes, diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 02d335aa4..12b727120 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -856,7 +856,6 @@ def get_execution_result( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -987,7 +986,6 @@ def execute_command( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, @@ -1027,7 +1025,6 @@ def get_catalogs( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -1071,7 +1068,6 @@ def get_schemas( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -1119,7 +1115,6 @@ def get_tables( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -1167,7 +1162,6 @@ def get_columns( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 8934d0d56..5151988ad 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -20,6 +20,7 @@ from databricks.sql.utils import ( ColumnTable, ColumnQueue, + ResultSetQueue, ) from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse @@ -36,14 +37,12 @@ class ResultSet(ABC): def __init__( self, connection: "Connection", - backend: "DatabricksClient", arraysize: int, buffer_size_bytes: int, command_id: CommandId, status: CommandState, has_been_closed_server_side: bool = False, is_direct_results: bool = False, - results_queue=None, description: List[Tuple] = [], is_staging_operation: bool = False, lz4_compressed: bool = False, @@ -54,32 +53,30 @@ def __init__( Parameters: :param connection: The parent connection - :param backend: The backend client :param arraysize: The max number of rows to fetch at a time (PEP-249) :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch :param command_id: The command ID :param status: The command status :param has_been_closed_server_side: Whether the command has been closed on the server :param is_direct_results: Whether the command has more rows - :param results_queue: The results queue :param description: column description of the results :param is_staging_operation: Whether the command is a staging operation """ - self.connection = connection - self.backend = backend - self.arraysize = arraysize - self.buffer_size_bytes = buffer_size_bytes - self._next_row_index = 0 - self.description = description - self.command_id = command_id - self.status = status - self.has_been_closed_server_side = has_been_closed_server_side - self.is_direct_results = is_direct_results - self.results = results_queue - self._is_staging_operation = is_staging_operation - self.lz4_compressed = lz4_compressed - self._arrow_schema_bytes = arrow_schema_bytes + self.connection: "Connection" = connection + self.backend: DatabricksClient = connection.session.backend + self.arraysize: int = arraysize + self.buffer_size_bytes: int = buffer_size_bytes + self._next_row_index: int = 0 + self.description: List[Tuple] = description + self.command_id: CommandId = command_id + self.status: CommandState = status + self.has_been_closed_server_side: bool = has_been_closed_server_side + self.is_direct_results: bool = is_direct_results + self.results: Optional[ResultSetQueue] = None # Children will set this + self._is_staging_operation: bool = is_staging_operation + self.lz4_compressed: bool = lz4_compressed + self._arrow_schema_bytes: Optional[bytes] = arrow_schema_bytes def __iter__(self): while True: @@ -190,7 +187,6 @@ def __init__( self, connection: "Connection", execute_response: "ExecuteResponse", - thrift_client: "ThriftDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, @@ -205,7 +201,6 @@ def __init__( Parameters: :param connection: The parent connection :param execute_response: Response from the execute command - :param thrift_client: The ThriftDatabricksClient instance for direct access :param buffer_size_bytes: Buffer size for fetching results :param arraysize: Default number of rows to fetch :param use_cloud_fetch: Whether to use cloud fetch for retrieving results @@ -238,20 +233,28 @@ def __init__( # Call parent constructor with common attributes super().__init__( connection=connection, - backend=thrift_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, is_direct_results=is_direct_results, - results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, arrow_schema_bytes=execute_response.arrow_schema_bytes, ) + # Assert that the backend is of the correct type + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + + assert isinstance( + self.backend, ThriftDatabricksClient + ), "Backend must be a ThriftDatabricksClient" + + # Set the results queue + self.results = results_queue + # Initialize results queue if not provided if not self.results: self._fill_results_buffer() @@ -307,6 +310,10 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows @@ -332,6 +339,9 @@ def fetchmany_columnar(self, size: int): if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows @@ -351,6 +361,9 @@ def fetchmany_columnar(self, size: int): def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.remaining_rows() self._next_row_index += results.num_rows @@ -377,6 +390,9 @@ def fetchall_arrow(self) -> "pyarrow.Table": def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.remaining_rows() self._next_row_index += results.num_rows @@ -393,6 +409,9 @@ def fetchone(self) -> Optional[Row]: Fetch the next row of a query result set, returning a single sequence, or None when no more data is available. """ + if self.results is None: + raise RuntimeError("Results queue is not initialized") + if isinstance(self.results, ColumnQueue): res = self._convert_columnar_table(self.fetchmany_columnar(1)) else: diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 5ffdea9f0..c14a74038 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -118,7 +118,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): real_result_set = ThriftResultSet( connection=connection, execute_response=mock_execute_response, - thrift_client=mock_backend, ) # Verify initial state @@ -185,19 +184,24 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() - mock_backend = Mock() + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + + mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.fetch_results.return_value = (Mock(), False) + # Ensure isinstance check passes + mock_backend.__class__ = ThriftDatabricksClient - result_set = ThriftResultSet( - connection=mock_connection, - execute_response=Mock(), - thrift_client=mock_backend, - ) # Setup session mock on the mock_connection mock_session = Mock() mock_session.open = False + mock_session.backend = mock_backend type(mock_connection).session = PropertyMock(return_value=mock_session) + result_set = ThriftResultSet( + connection=mock_connection, + execute_response=Mock(), + ) + result_set.close() self.assertFalse(mock_backend.close_command.called) @@ -207,15 +211,21 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response = Mock() mock_results_response.has_been_closed_server_side = False mock_connection = Mock() - mock_thrift_backend = Mock() + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) + # Ensure isinstance check passes + mock_thrift_backend.__class__ = ThriftDatabricksClient # Setup session mock on the mock_connection mock_session = Mock() mock_session.open = True + mock_session.backend = mock_thrift_backend type(mock_connection).session = PropertyMock(return_value=mock_session) mock_thrift_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( - mock_connection, mock_results_response, mock_thrift_backend + mock_connection, + mock_results_response, ) result_set.close() @@ -258,10 +268,20 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - mock_backend = Mock() + mock_connection = Mock() + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + + mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.fetch_results.return_value = (Mock(), False) + # Ensure isinstance check passes + mock_backend.__class__ = ThriftDatabricksClient + + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.backend = mock_backend + type(mock_connection).session = PropertyMock(return_value=mock_session) - result_set = ThriftResultSet(Mock(), Mock(), mock_backend) + result_set = ThriftResultSet(mock_connection, Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index a649941e1..e6ad33aae 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -1,6 +1,6 @@ import unittest import pytest -from unittest.mock import Mock +from unittest.mock import Mock, PropertyMock try: import pyarrow as pa @@ -38,12 +38,19 @@ def make_arrow_queue(batch): @staticmethod def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more - schema, arrow_table = FetchTests.make_arrow_table(initial_results) - arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) + arrow_queue = FetchTests.make_arrow_queue(initial_results) + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient - # Create a mock backend that will return the queue when _fill_results_buffer is called mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) + # Ensure isinstance check passes + mock_thrift_backend.__class__ = ThriftDatabricksClient + + # Setup mock connection with session.backend + mock_connection = Mock() + mock_session = Mock() + mock_session.backend = mock_thrift_backend + type(mock_connection).session = PropertyMock(return_value=mock_session) num_cols = len(initial_results[0]) if initial_results else 0 description = [ @@ -52,7 +59,7 @@ def make_dummy_result_set_from_initial_results(initial_results): ] rs = ThriftResultSet( - connection=Mock(), + connection=mock_connection, execute_response=ExecuteResponse( command_id=None, status=None, @@ -61,7 +68,6 @@ def make_dummy_result_set_from_initial_results(initial_results): lz4_compressed=True, is_staging_operation=False, ), - thrift_client=mock_thrift_backend, t_row_set=None, ) return rs @@ -86,8 +92,19 @@ def fetch_results( return results, batch_index < len(batch_list) + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results = fetch_results + # Ensure isinstance check passes + mock_thrift_backend.__class__ = ThriftDatabricksClient + + # Setup mock connection with session.backend + mock_connection = Mock() + mock_session = Mock() + mock_session.backend = mock_thrift_backend + type(mock_connection).session = PropertyMock(return_value=mock_session) + num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 description = [ @@ -96,7 +113,7 @@ def fetch_results( ] rs = ThriftResultSet( - connection=Mock(), + connection=mock_connection, execute_response=ExecuteResponse( command_id=None, status=None, @@ -105,7 +122,6 @@ def fetch_results( lz4_compressed=True, is_staging_operation=False, ), - thrift_client=mock_thrift_backend, ) return rs diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 7eae8e5a8..3185589f6 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -68,12 +68,20 @@ def sea_command_id(self): return CommandId.from_sea_statement_id("test-statement-123") @pytest.fixture - def mock_cursor(self): + def mock_cursor(self, sea_client): """Create a mock cursor.""" cursor = Mock() cursor.active_command_id = None cursor.buffer_size_bytes = 1000 cursor.arraysize = 100 + + # Set up a mock connection with session.backend pointing to the sea_client + mock_connection = Mock() + mock_session = Mock() + mock_session.backend = sea_client + mock_connection.session = mock_session + cursor.connection = mock_connection + return cursor @pytest.fixture diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 544edaf96..49b2564c4 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -23,12 +23,20 @@ def mock_connection(self): """Create a mock connection.""" connection = Mock() connection.open = True - return connection - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() + # Mock the session.backend to return a SeaDatabricksClient + mock_session = Mock() + from databricks.sql.backend.sea.backend import SeaDatabricksClient + + mock_backend = Mock(spec=SeaDatabricksClient) + mock_backend.max_download_threads = 10 + mock_backend.close_command = Mock() + # Ensure isinstance check passes + mock_backend.__class__ = SeaDatabricksClient + mock_session.backend = mock_backend + connection.session = mock_session + + return connection @pytest.fixture def execute_response(self): @@ -71,9 +79,7 @@ def _create_empty_manifest(self, format: ResultFormat): ) @pytest.fixture - def result_set_with_data( - self, mock_connection, mock_sea_client, execute_response, sample_data - ): + def result_set_with_data(self, mock_connection, execute_response, sample_data): """Create a SeaResultSet with sample data.""" # Create ResultData with inline data result_data = ResultData( @@ -84,7 +90,6 @@ def result_set_with_data( result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=result_data, manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, @@ -99,14 +104,11 @@ def json_queue(self, sample_data): """Create a JsonQueue with sample data.""" return JsonQueue(sample_data) - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): + def test_init_with_execute_response(self, mock_connection, execute_response): """Test initializing SeaResultSet with an execute response.""" result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=[]), manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, @@ -117,17 +119,15 @@ def test_init_with_execute_response( assert result_set.command_id == execute_response.command_id assert result_set.status == CommandState.SUCCEEDED assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client assert result_set.buffer_size_bytes == 1000 assert result_set.arraysize == 100 assert result_set.description == execute_response.description - def test_close(self, mock_connection, mock_sea_client, execute_response): + def test_close(self, mock_connection, execute_response): """Test closing a result set.""" result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=[]), manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, @@ -138,18 +138,19 @@ def test_close(self, mock_connection, mock_sea_client, execute_response): result_set.close() # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + mock_connection.session.backend.close_command.assert_called_once_with( + result_set.command_id + ) assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, execute_response + self, mock_connection, execute_response ): """Test closing a result set that has already been closed server-side.""" result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=[]), manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, @@ -161,19 +162,16 @@ def test_close_when_already_closed_server_side( result_set.close() # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() + mock_connection.session.backend.close_command.assert_not_called() assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, execute_response - ): + def test_close_when_connection_closed(self, mock_connection, execute_response): """Test closing a result set when the connection is closed.""" mock_connection.open = False result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=[]), manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, @@ -184,7 +182,7 @@ def test_close_when_connection_closed( result_set.close() # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() + mock_connection.session.backend.close_command.assert_not_called() assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED @@ -316,7 +314,7 @@ def test_iteration(self, result_set_with_data, sample_data): assert rows[0].col3 is True def test_fetchmany_arrow_not_implemented( - self, mock_connection, mock_sea_client, execute_response, sample_data + self, mock_connection, execute_response, sample_data ): """Test that fetchmany_arrow raises NotImplementedError for non-JSON data.""" @@ -329,7 +327,6 @@ def test_fetchmany_arrow_not_implemented( result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=None, external_links=[]), manifest=self._create_empty_manifest(ResultFormat.ARROW_STREAM), buffer_size_bytes=1000, @@ -337,7 +334,7 @@ def test_fetchmany_arrow_not_implemented( ) def test_fetchall_arrow_not_implemented( - self, mock_connection, mock_sea_client, execute_response, sample_data + self, mock_connection, execute_response, sample_data ): """Test that fetchall_arrow raises NotImplementedError for non-JSON data.""" # Test that NotImplementedError is raised @@ -349,16 +346,13 @@ def test_fetchall_arrow_not_implemented( result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=None, external_links=[]), manifest=self._create_empty_manifest(ResultFormat.ARROW_STREAM), buffer_size_bytes=1000, arraysize=100, ) - def test_is_staging_operation( - self, mock_connection, mock_sea_client, execute_response - ): + def test_is_staging_operation(self, mock_connection, execute_response): """Test the is_staging_operation property.""" # Set is_staging_operation to True execute_response.is_staging_operation = True @@ -367,7 +361,6 @@ def test_is_staging_operation( result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=[]), manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, From 7f9b35d54685880ba9e7f36c7322eb5bedd46421 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 7 Jul 2025 09:23:19 +0530 Subject: [PATCH 2/7] rename sea.backend to sea.client for clarity Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/{backend.py => client.py} | 0 src/databricks/sql/backend/sea/queue.py | 2 +- src/databricks/sql/backend/sea/result_set.py | 2 +- src/databricks/sql/backend/sea/utils/filters.py | 2 +- src/databricks/sql/session.py | 2 +- src/databricks/sql/utils.py | 2 +- tests/unit/test_sea_backend.py | 4 ++-- tests/unit/test_sea_result_set.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) rename src/databricks/sql/backend/sea/{backend.py => client.py} (100%) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/client.py similarity index 100% rename from src/databricks/sql/backend/sea/backend.py rename to src/databricks/sql/backend/sea/client.py diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 73f47ea96..3aeee41c4 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -3,7 +3,7 @@ from abc import ABC from typing import List, Optional, Tuple -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest from databricks.sql.backend.sea.utils.constants import ResultFormat from databricks.sql.exc import ProgrammingError diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 14ed61575..c6ed63900 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -4,7 +4,7 @@ import logging -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index cd27778fb..639b6495f 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -71,7 +71,7 @@ def _filter_sea_result_set( result_data = ResultData(data=filtered_rows, external_links=None) - from databricks.sql.backend.sea.backend import SeaDatabricksClient + from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.sea.result_set import SeaResultSet # Create a new SeaResultSet with the filtered data diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 76aec4675..4c8b882f4 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -8,7 +8,7 @@ from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME from databricks.sql.backend.thrift_backend import ThriftDatabricksClient -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 35c7bce4d..1a2a8e693 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -13,7 +13,7 @@ import lz4.frame -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest try: diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 3185589f6..96485d235 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -8,7 +8,7 @@ import pytest from unittest.mock import patch, MagicMock, Mock -from databricks.sql.backend.sea.backend import ( +from databricks.sql.backend.sea.client import ( SeaDatabricksClient, _filter_session_configuration, ) @@ -31,7 +31,7 @@ class TestSeaBackend: def mock_http_client(self): """Create a mock HTTP client.""" with patch( - "databricks.sql.backend.sea.backend.SeaHttpClient" + "databricks.sql.backend.sea.client.SeaHttpClient" ) as mock_client_class: mock_client = mock_client_class.return_value yield mock_client diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 49b2564c4..8884e812a 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -26,7 +26,7 @@ def mock_connection(self): # Mock the session.backend to return a SeaDatabricksClient mock_session = Mock() - from databricks.sql.backend.sea.backend import SeaDatabricksClient + from databricks.sql.backend.sea.client import SeaDatabricksClient mock_backend = Mock(spec=SeaDatabricksClient) mock_backend.max_download_threads = 10 From 88371ea8371ae80ea1ff7a2bf4aef09f5f0838d8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 15 Jul 2025 10:32:42 +0530 Subject: [PATCH 3/7] type issues Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 5 +++-- tests/unit/test_client.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 196a2a313..233750194 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -73,7 +73,7 @@ def __init__( self.status: CommandState = status self.has_been_closed_server_side: bool = has_been_closed_server_side self.is_direct_results: bool = is_direct_results - self.results: Optional[ResultSetQueue] = None # Children will set this + self.results: Optional[ResultSetQueue] = None self._is_staging_operation: bool = is_staging_operation self.lz4_compressed: bool = lz4_compressed self._arrow_schema_bytes: Optional[bytes] = arrow_schema_bytes @@ -166,7 +166,8 @@ def close(self) -> None: been closed on the server for some other reason, issue a request to the server to close it. """ try: - self.results.close() + if self.results: + self.results.close() if ( self.status != CommandState.CLOSED and not self.has_been_closed_server_side diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 0d29e2c49..e4d7646d1 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -219,7 +219,7 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response = Mock() mock_results_response.has_been_closed_server_side = False mock_connection = Mock() - mock_thrift_backend = Mock() + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_results = Mock() # Setup session mock on the mock_connection mock_session = Mock() From 120dfc0bc692262513e255e7336010ba4408a6ac Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 15 Jul 2025 11:05:04 +0530 Subject: [PATCH 4/7] remove service specific state from ExecuteResponse Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/client.py | 2 - src/databricks/sql/backend/sea/result_set.py | 24 +++++-- .../sql/backend/sea/utils/filters.py | 2 - src/databricks/sql/backend/thrift_backend.py | 68 +++++++++++++------ src/databricks/sql/backend/types.py | 2 - src/databricks/sql/result_set.py | 61 ++++++++++------- 6 files changed, 103 insertions(+), 56 deletions(-) diff --git a/src/databricks/sql/backend/sea/client.py b/src/databricks/sql/backend/sea/client.py index 418d6b51a..9f7b552f8 100644 --- a/src/databricks/sql/backend/sea/client.py +++ b/src/databricks/sql/backend/sea/client.py @@ -349,10 +349,8 @@ def _results_message_to_execute_response( command_id=CommandId.from_sea_statement_id(response.statement_id), status=response.status.state, description=description, - has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=response.manifest.is_volume_operation, - arrow_schema_bytes=None, result_format=response.manifest.format, ) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index c6ed63900..6c7d20636 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -15,10 +15,10 @@ if TYPE_CHECKING: from databricks.sql.client import Connection -from databricks.sql.exc import ProgrammingError +from databricks.sql.exc import CursorAlreadyClosedError, ProgrammingError, RequestError from databricks.sql.types import Row from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory -from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.backend.types import CommandState, ExecuteResponse from databricks.sql.result_set import ResultSet logger = logging.getLogger(__name__) @@ -61,11 +61,9 @@ def __init__( buffer_size_bytes=buffer_size_bytes, command_id=execute_response.command_id, status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes, ) # Assert that the backend is of the correct type @@ -274,3 +272,21 @@ def fetchall(self) -> List[Row]: return self._create_json_table(self.fetchall_json()) else: raise NotImplementedError("fetchall only supported for JSON data") + + def close(self) -> None: + """ + Close the result set. + + If the connection has not been closed, and the result set has not already + been closed on the server for some other reason, issue a request to the server to close it. + """ + try: + if self.results is not None: + self.results.close() + if self.status != CommandState.CLOSED and self.connection.open: + self.backend.close_command(self.command_id) + except RequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + logger.info("Operation was canceled by a prior request") + finally: + self.status = CommandState.CLOSED diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index 639b6495f..9e7a85c56 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -60,9 +60,7 @@ def _filter_sea_result_set( command_id=command_id, status=result_set.status, description=result_set.description, - has_been_closed_server_side=result_set.has_been_closed_server_side, lz4_compressed=result_set.lz4_compressed, - arrow_schema_bytes=result_set._arrow_schema_bytes, is_staging_operation=False, ) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 9b3105171..48a7a1ddb 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -821,14 +821,17 @@ def _results_message_to_execute_response(self, resp, operation_state): command_id=command_id, status=status, description=description, - has_been_closed_server_side=has_been_closed_server_side, lz4_compressed=lz4_compressed, is_staging_operation=t_result_set_metadata_resp.isStagingOperation, - arrow_schema_bytes=schema_bytes, result_format=t_result_set_metadata_resp.resultFormat, ) - return execute_response, is_direct_results + return ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) def get_execution_result( self, command_id: CommandId, cursor: "Cursor" @@ -881,10 +884,8 @@ def get_execution_result( command_id=command_id, status=status, description=description, - has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - arrow_schema_bytes=schema_bytes, result_format=t_result_set_metadata_resp.resultFormat, ) @@ -898,6 +899,8 @@ def get_execution_result( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + arrow_schema_bytes=schema_bytes, + has_been_closed_server_side=False, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -1016,9 +1019,12 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1034,6 +1040,8 @@ def execute_command( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def get_catalogs( @@ -1055,9 +1063,12 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1073,6 +1084,8 @@ def get_catalogs( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def get_schemas( @@ -1100,9 +1113,12 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1118,6 +1134,8 @@ def get_schemas( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def get_tables( @@ -1149,9 +1167,12 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1167,6 +1188,8 @@ def get_tables( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def get_columns( @@ -1198,9 +1221,12 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1216,6 +1242,8 @@ def get_columns( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def _handle_execute_response(self, resp, cursor): diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index f6428a187..b188b7ba1 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -419,8 +419,6 @@ class ExecuteResponse: command_id: CommandId status: CommandState description: List[Tuple] - has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False - arrow_schema_bytes: Optional[bytes] = None result_format: Optional[Any] = None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 233750194..c947ac739 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -41,12 +41,10 @@ def __init__( buffer_size_bytes: int, command_id: CommandId, status: CommandState, - has_been_closed_server_side: bool = False, is_direct_results: bool = False, description: List[Tuple] = [], is_staging_operation: bool = False, lz4_compressed: bool = False, - arrow_schema_bytes: Optional[bytes] = None, ): """ A ResultSet manages the results of a single command. @@ -57,7 +55,6 @@ def __init__( :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch :param command_id: The command ID :param status: The command status - :param has_been_closed_server_side: Whether the command has been closed on the server :param is_direct_results: Whether the command has more rows :param description: column description of the results :param is_staging_operation: Whether the command is a staging operation @@ -71,12 +68,10 @@ def __init__( self.description: List[Tuple] = description self.command_id: CommandId = command_id self.status: CommandState = status - self.has_been_closed_server_side: bool = has_been_closed_server_side self.is_direct_results: bool = is_direct_results self.results: Optional[ResultSetQueue] = None self._is_staging_operation: bool = is_staging_operation self.lz4_compressed: bool = lz4_compressed - self._arrow_schema_bytes: Optional[bytes] = arrow_schema_bytes def __iter__(self): while True: @@ -158,28 +153,12 @@ def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all remaining rows as an Arrow table.""" pass + @abstractmethod def close(self) -> None: """ Close the result set. - - If the connection has not been closed, and the result set has not already - been closed on the server for some other reason, issue a request to the server to close it. """ - try: - if self.results: - self.results.close() - if ( - self.status != CommandState.CLOSED - and not self.has_been_closed_server_side - and self.connection.open - ): - self.backend.close_command(self.command_id) - except RequestError as e: - if isinstance(e.args[1], CursorAlreadyClosedError): - logger.info("Operation was canceled by a prior request") - finally: - self.has_been_closed_server_side = True - self.status = CommandState.CLOSED + pass class ThriftResultSet(ResultSet): @@ -196,6 +175,8 @@ def __init__( max_download_threads: int = 10, ssl_options=None, is_direct_results: bool = True, + has_been_closed_server_side: bool = False, + arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -210,11 +191,15 @@ def __init__( :param max_download_threads: Maximum number of download threads for cloud fetch :param ssl_options: SSL options for cloud fetch :param is_direct_results: Whether there are more rows to fetch + :param has_been_closed_server_side: Whether the command has been closed on the server + :param arrow_schema_bytes: The schema of the result set """ # Initialize ThriftResultSet-specific attributes self._use_cloud_fetch = use_cloud_fetch self.is_direct_results = is_direct_results + self.has_been_closed_server_side = has_been_closed_server_side + self._arrow_schema_bytes = arrow_schema_bytes # Build the results queue if t_row_set is provided results_queue = None @@ -225,7 +210,7 @@ def __init__( results_queue = ThriftResultSetQueueFactory.build_queue( row_set_type=execute_response.result_format, t_row_set=t_row_set, - arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", + arrow_schema_bytes=self._arrow_schema_bytes or b"", max_download_threads=max_download_threads, lz4_compressed=execute_response.lz4_compressed, description=execute_response.description, @@ -239,12 +224,10 @@ def __init__( buffer_size_bytes=buffer_size_bytes, command_id=execute_response.command_id, status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, is_direct_results=is_direct_results, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes, ) # Assert that the backend is of the correct type @@ -460,3 +443,29 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] + + def close(self) -> None: + """ + Close the result set. + + If the connection has not been closed, and the result set has not already + been closed on the server for some other reason, issue a request to the server to close it. + """ + try: + if self.results: + self.results.close() + print(f"has_been_closed_server_side: {self.has_been_closed_server_side}") + print(f"status: {self.status}") + print(f"connection.open: {self.connection.open}") + if ( + self.status != CommandState.CLOSED + and not self.has_been_closed_server_side + and self.connection.open + ): + self.backend.close_command(self.command_id) + except RequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + logger.info("Operation was canceled by a prior request") + finally: + self.has_been_closed_server_side = True + self.status = CommandState.CLOSED From 962e4a6e33b49aff8f98da62ccb5dfe2e6370862 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 15 Jul 2025 15:26:42 +0530 Subject: [PATCH 5/7] fix some tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 2 +- tests/unit/test_fetches.py | 4 ++-- tests/unit/test_sea_result_set.py | 24 ------------------------ tests/unit/test_thrift_backend.py | 24 +++++++++++++++--------- 4 files changed, 18 insertions(+), 36 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index e4d7646d1..06f0378a1 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -108,7 +108,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): mock_execute_response.status = ( CommandState.SUCCEEDED if not closed else CommandState.CLOSED ) - mock_execute_response.has_been_closed_server_side = closed mock_execute_response.is_staging_operation = False mock_execute_response.description = [] @@ -127,6 +126,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): real_result_set = ThriftResultSet( connection=connection, execute_response=mock_execute_response, + has_been_closed_server_side=closed, ) # Mock execute_command to return our real result set diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index e6ad33aae..8643404ba 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -63,12 +63,12 @@ def make_dummy_result_set_from_initial_results(initial_results): execute_response=ExecuteResponse( command_id=None, status=None, - has_been_closed_server_side=True, description=description, lz4_compressed=True, is_staging_operation=False, ), t_row_set=None, + has_been_closed_server_side=True, ) return rs @@ -117,11 +117,11 @@ def fetch_results( execute_response=ExecuteResponse( command_id=None, status=None, - has_been_closed_server_side=False, description=description, lz4_compressed=True, is_staging_operation=False, ), + has_been_closed_server_side=False, ) return rs diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 8884e812a..25ac23133 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -141,29 +141,6 @@ def test_close(self, mock_connection, execute_response): mock_connection.session.backend.close_command.assert_called_once_with( result_set.command_id ) - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_already_closed_server_side( - self, mock_connection, execute_response - ): - """Test closing a result set that has already been closed server-side.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - result_data=ResultData(data=[]), - manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), - buffer_size_bytes=1000, - arraysize=100, - ) - result_set.has_been_closed_server_side = True - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_connection.session.backend.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED def test_close_when_connection_closed(self, mock_connection, execute_response): @@ -183,7 +160,6 @@ def test_close_when_connection_closed(self, mock_connection, execute_response): # Verify the backend's close_command was NOT called mock_connection.session.backend.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED def test_init_with_result_data(self, result_set_with_data, sample_data): diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 37569f755..4636077a4 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -649,7 +649,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response, _ = thrift_backend._handle_execute_response( + execute_response, _, _, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -892,6 +892,8 @@ def test_handle_execute_response_can_handle_without_direct_results( ( execute_response, _, + _, + _ ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual( execute_response.status, @@ -965,11 +967,11 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): ) ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _ = thrift_backend._handle_execute_response( + execute_response, _, _, arrow_schema_bytes = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) - self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -997,7 +999,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) ) thrift_backend = self._make_fake_thrift_backend() - _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) + _, _, _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, @@ -1046,6 +1048,8 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ( execute_response, has_more_rows_result, + _, + _ ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual(is_direct_results, has_more_rows_result) @@ -1179,7 +1183,7 @@ def test_execute_statement_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1215,7 +1219,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1248,7 +1252,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1290,7 +1294,7 @@ def test_get_tables_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1336,7 +1340,7 @@ def test_get_columns_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock()) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -2254,6 +2258,8 @@ def test_execute_command_sets_complex_type_fields_correctly( mock_handle_execute_response.return_value = ( mock_execute_response, mock_arrow_schema, + Mock(), + Mock() ) # Iterate through each possible combination of native types (True, False and unset) From 53975da1a057e7abfc79d00ff30824d6c72165af Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 15 Jul 2025 15:27:07 +0530 Subject: [PATCH 6/7] remove redundant logs Signed-off-by: varun-edachali-dbx --- src/databricks/sql/result_set.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index c947ac739..51128da8c 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -454,9 +454,6 @@ def close(self) -> None: try: if self.results: self.results.close() - print(f"has_been_closed_server_side: {self.has_been_closed_server_side}") - print(f"status: {self.status}") - print(f"connection.open: {self.connection.open}") if ( self.status != CommandState.CLOSED and not self.has_been_closed_server_side From 17cbb99cb2a422feec41786903495a49b945df0b Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 15 Jul 2025 15:40:54 +0530 Subject: [PATCH 7/7] fix more tests Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 13 +++---- tests/unit/test_thrift_backend.py | 61 ++++++++++++++++++++++--------- 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 06f0378a1..51430e9e0 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -186,15 +186,17 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() - mock_backend = Mock() + mock_backend = Mock(spec=ThriftDatabricksClient) mock_results = Mock() mock_backend.fetch_results.return_value = (Mock(), False) - # Ensure isinstance check passes + + # Ensure connection appears closed + type(mock_connection).open = PropertyMock(return_value=False) + # Ensure isinstance check passes if needed mock_backend.__class__ = ThriftDatabricksClient # Setup session mock on the mock_connection mock_session = Mock() - mock_session.open = False mock_session.backend = mock_backend type(mock_connection).session = PropertyMock(return_value=mock_session) @@ -204,11 +206,6 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self): ) result_set.results = mock_results - # Setup session mock on the mock_connection - mock_session = Mock() - mock_session.open = False - type(mock_connection).session = PropertyMock(return_value=mock_session) - result_set.close() self.assertFalse(mock_backend.close_command.called) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 4636077a4..f66b356ca 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -889,12 +889,9 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - ( - execute_response, - _, - _, - _ - ) = thrift_backend._handle_execute_response(execute_resp, Mock()) + (execute_response, _, _, _) = thrift_backend._handle_execute_response( + execute_resp, Mock() + ) self.assertEqual( execute_response.status, CommandState.SUCCEEDED, @@ -967,9 +964,12 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): ) ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _, _, arrow_schema_bytes = thrift_backend._handle_execute_response( - t_execute_resp, Mock() - ) + ( + execute_response, + _, + _, + arrow_schema_bytes, + ) = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual(arrow_schema_bytes, arrow_schema_mock) @@ -1048,8 +1048,8 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ( execute_response, has_more_rows_result, - _, - _ + _, + _, ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual(is_direct_results, has_more_rows_result) @@ -1183,7 +1183,12 @@ def test_execute_statement_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1219,7 +1224,12 @@ def test_get_catalogs_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1252,7 +1262,12 @@ def test_get_schemas_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1294,7 +1309,12 @@ def test_get_tables_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1340,7 +1360,12 @@ def test_get_columns_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock(), Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -2258,8 +2283,8 @@ def test_execute_command_sets_complex_type_fields_correctly( mock_handle_execute_response.return_value = ( mock_execute_response, mock_arrow_schema, - Mock(), - Mock() + Mock(), + Mock(), ) # Iterate through each possible combination of native types (True, False and unset)