diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/client.py similarity index 99% rename from src/databricks/sql/backend/sea/backend.py rename to src/databricks/sql/backend/sea/client.py index 3d23344b..9f7b552f 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/client.py @@ -349,10 +349,8 @@ def _results_message_to_execute_response( command_id=CommandId.from_sea_statement_id(response.statement_id), status=response.status.state, description=description, - has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=response.manifest.is_volume_operation, - arrow_schema_bytes=None, result_format=response.manifest.format, ) @@ -624,7 +622,6 @@ def get_execution_result( return SeaResultSet( connection=cursor.connection, execute_response=execute_response, - sea_client=self, result_data=response.result, manifest=response.manifest, buffer_size_bytes=cursor.buffer_size_bytes, diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 0644e4c0..3a1f6ef5 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -3,7 +3,7 @@ from abc import ABC from typing import List, Optional, Tuple -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest from databricks.sql.backend.sea.utils.constants import ResultFormat from databricks.sql.exc import ProgrammingError diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 302af5e3..6c7d2063 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -4,7 +4,7 @@ import logging -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter @@ -15,10 +15,10 @@ if TYPE_CHECKING: from databricks.sql.client import Connection -from databricks.sql.exc import ProgrammingError +from databricks.sql.exc import CursorAlreadyClosedError, ProgrammingError, RequestError from databricks.sql.types import Row from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory -from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.backend.types import CommandState, ExecuteResponse from databricks.sql.result_set import ResultSet logger = logging.getLogger(__name__) @@ -31,7 +31,6 @@ def __init__( self, connection: Connection, execute_response: ExecuteResponse, - sea_client: SeaDatabricksClient, result_data: ResultData, manifest: ResultManifest, buffer_size_bytes: int = 104857600, @@ -43,7 +42,6 @@ def __init__( Args: connection: The parent connection execute_response: Response from the execute command - sea_client: The SeaDatabricksClient instance for direct access buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch result_data: Result data from SEA response @@ -56,32 +54,36 @@ def __init__( if statement_id is None: raise ValueError("Command ID is not a SEA statement ID") - results_queue = SeaResultSetQueueFactory.build_queue( - result_data, - self.manifest, - statement_id, - description=execute_response.description, - max_download_threads=sea_client.max_download_threads, - sea_client=sea_client, - lz4_compressed=execute_response.lz4_compressed, - ) - # Call parent constructor with common attributes super().__init__( connection=connection, - backend=sea_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, command_id=execute_response.command_id, status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, - results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes, ) + # Assert that the backend is of the correct type + assert isinstance( + self.backend, SeaDatabricksClient + ), "Backend must be a SeaDatabricksClient" + + results_queue = SeaResultSetQueueFactory.build_queue( + result_data, + self.manifest, + statement_id, + description=execute_response.description, + max_download_threads=self.backend.max_download_threads, + sea_client=self.backend, + lz4_compressed=execute_response.lz4_compressed, + ) + + # Set the results queue + self.results = results_queue + def _convert_json_types(self, row: List[str]) -> List[Any]: """ Convert string values in the row to appropriate Python types based on column metadata. @@ -160,6 +162,9 @@ def fetchmany_json(self, size: int) -> List[List[str]]: if size < 0: raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.next_n_rows(size) self._next_row_index += len(results) @@ -173,6 +178,9 @@ def fetchall_json(self) -> List[List[str]]: Columnar table containing all remaining rows """ + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.remaining_rows() self._next_row_index += len(results) @@ -264,3 +272,21 @@ def fetchall(self) -> List[Row]: return self._create_json_table(self.fetchall_json()) else: raise NotImplementedError("fetchall only supported for JSON data") + + def close(self) -> None: + """ + Close the result set. + + If the connection has not been closed, and the result set has not already + been closed on the server for some other reason, issue a request to the server to close it. + """ + try: + if self.results is not None: + self.results.close() + if self.status != CommandState.CLOSED and self.connection.open: + self.backend.close_command(self.command_id) + except RequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + logger.info("Operation was canceled by a prior request") + finally: + self.status = CommandState.CLOSED diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index ef6c91d7..9e7a85c5 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -12,14 +12,13 @@ Optional, Any, Callable, - cast, TYPE_CHECKING, ) if TYPE_CHECKING: from databricks.sql.backend.sea.result_set import SeaResultSet -from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.backend.types import ExecuteResponse, CommandId, CommandState logger = logging.getLogger(__name__) @@ -45,6 +44,9 @@ def _filter_sea_result_set( """ # Get all remaining rows + if result_set.results is None: + raise RuntimeError("Results queue is not initialized") + all_rows = result_set.results.remaining_rows() # Filter rows @@ -58,9 +60,7 @@ def _filter_sea_result_set( command_id=command_id, status=result_set.status, description=result_set.description, - has_been_closed_server_side=result_set.has_been_closed_server_side, lz4_compressed=result_set.lz4_compressed, - arrow_schema_bytes=result_set._arrow_schema_bytes, is_staging_operation=False, ) @@ -69,7 +69,7 @@ def _filter_sea_result_set( result_data = ResultData(data=filtered_rows, external_links=None) - from databricks.sql.backend.sea.backend import SeaDatabricksClient + from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.sea.result_set import SeaResultSet # Create a new SeaResultSet with the filtered data @@ -79,7 +79,6 @@ def _filter_sea_result_set( filtered_result_set = SeaResultSet( connection=result_set.connection, execute_response=execute_response, - sea_client=cast(SeaDatabricksClient, result_set.backend), result_data=result_data, manifest=manifest, buffer_size_bytes=result_set.buffer_size_bytes, diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 32e024d4..48a7a1dd 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -821,14 +821,17 @@ def _results_message_to_execute_response(self, resp, operation_state): command_id=command_id, status=status, description=description, - has_been_closed_server_side=has_been_closed_server_side, lz4_compressed=lz4_compressed, is_staging_operation=t_result_set_metadata_resp.isStagingOperation, - arrow_schema_bytes=schema_bytes, result_format=t_result_set_metadata_resp.resultFormat, ) - return execute_response, is_direct_results + return ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) def get_execution_result( self, command_id: CommandId, cursor: "Cursor" @@ -881,17 +884,14 @@ def get_execution_result( command_id=command_id, status=status, description=description, - has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - arrow_schema_bytes=schema_bytes, result_format=t_result_set_metadata_resp.resultFormat, ) return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -899,6 +899,8 @@ def get_execution_result( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + arrow_schema_bytes=schema_bytes, + has_been_closed_server_side=False, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -1017,9 +1019,12 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1028,7 +1033,6 @@ def execute_command( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, @@ -1036,6 +1040,8 @@ def execute_command( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def get_catalogs( @@ -1057,9 +1063,12 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1068,7 +1077,6 @@ def get_catalogs( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -1076,6 +1084,8 @@ def get_catalogs( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def get_schemas( @@ -1103,9 +1113,12 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1114,7 +1127,6 @@ def get_schemas( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -1122,6 +1134,8 @@ def get_schemas( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def get_tables( @@ -1153,9 +1167,12 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1164,7 +1181,6 @@ def get_tables( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -1172,6 +1188,8 @@ def get_tables( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def get_columns( @@ -1203,9 +1221,12 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response, is_direct_results = self._handle_execute_response( - resp, cursor - ) + ( + execute_response, + is_direct_results, + has_been_closed_server_side, + schema_bytes, + ) = self._handle_execute_response(resp, cursor) t_row_set = None if resp.directResults and resp.directResults.resultSet: @@ -1214,7 +1235,6 @@ def get_columns( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -1222,6 +1242,8 @@ def get_columns( max_download_threads=self.max_download_threads, ssl_options=self._ssl_options, is_direct_results=is_direct_results, + has_been_closed_server_side=has_been_closed_server_side, + arrow_schema_bytes=schema_bytes, ) def _handle_execute_response(self, resp, cursor): diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index f6428a18..b188b7ba 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -419,8 +419,6 @@ class ExecuteResponse: command_id: CommandId status: CommandState description: List[Tuple] - has_been_closed_server_side: bool = False lz4_compressed: bool = True is_staging_operation: bool = False - arrow_schema_bytes: Optional[bytes] = None result_format: Optional[Any] = None diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index dc279cf9..51128da8 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -20,6 +20,7 @@ from databricks.sql.utils import ( ColumnTable, ColumnQueue, + ResultSetQueue, ) from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse @@ -36,50 +37,41 @@ class ResultSet(ABC): def __init__( self, connection: "Connection", - backend: "DatabricksClient", arraysize: int, buffer_size_bytes: int, command_id: CommandId, status: CommandState, - has_been_closed_server_side: bool = False, is_direct_results: bool = False, - results_queue=None, description: List[Tuple] = [], is_staging_operation: bool = False, lz4_compressed: bool = False, - arrow_schema_bytes: Optional[bytes] = None, ): """ A ResultSet manages the results of a single command. Parameters: :param connection: The parent connection - :param backend: The backend client :param arraysize: The max number of rows to fetch at a time (PEP-249) :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch :param command_id: The command ID :param status: The command status - :param has_been_closed_server_side: Whether the command has been closed on the server :param is_direct_results: Whether the command has more rows - :param results_queue: The results queue :param description: column description of the results :param is_staging_operation: Whether the command is a staging operation """ - self.connection = connection - self.backend = backend - self.arraysize = arraysize - self.buffer_size_bytes = buffer_size_bytes - self._next_row_index = 0 - self.description = description - self.command_id = command_id - self.status = status - self.has_been_closed_server_side = has_been_closed_server_side - self.is_direct_results = is_direct_results - self.results = results_queue - self._is_staging_operation = is_staging_operation - self.lz4_compressed = lz4_compressed - self._arrow_schema_bytes = arrow_schema_bytes + self.connection: "Connection" = connection + self.backend: DatabricksClient = connection.session.backend + self.arraysize: int = arraysize + self.buffer_size_bytes: int = buffer_size_bytes + self._next_row_index: int = 0 + self.description: List[Tuple] = description + self.command_id: CommandId = command_id + self.status: CommandState = status + self.is_direct_results: bool = is_direct_results + self.results: Optional[ResultSetQueue] = None + self._is_staging_operation: bool = is_staging_operation + self.lz4_compressed: bool = lz4_compressed def __iter__(self): while True: @@ -161,27 +153,12 @@ def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all remaining rows as an Arrow table.""" pass + @abstractmethod def close(self) -> None: """ Close the result set. - - If the connection has not been closed, and the result set has not already - been closed on the server for some other reason, issue a request to the server to close it. """ - try: - self.results.close() - if ( - self.status != CommandState.CLOSED - and not self.has_been_closed_server_side - and self.connection.open - ): - self.backend.close_command(self.command_id) - except RequestError as e: - if isinstance(e.args[1], CursorAlreadyClosedError): - logger.info("Operation was canceled by a prior request") - finally: - self.has_been_closed_server_side = True - self.status = CommandState.CLOSED + pass class ThriftResultSet(ResultSet): @@ -191,7 +168,6 @@ def __init__( self, connection: "Connection", execute_response: "ExecuteResponse", - thrift_client: "ThriftDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, @@ -199,6 +175,8 @@ def __init__( max_download_threads: int = 10, ssl_options=None, is_direct_results: bool = True, + has_been_closed_server_side: bool = False, + arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -206,7 +184,6 @@ def __init__( Parameters: :param connection: The parent connection :param execute_response: Response from the execute command - :param thrift_client: The ThriftDatabricksClient instance for direct access :param buffer_size_bytes: Buffer size for fetching results :param arraysize: Default number of rows to fetch :param use_cloud_fetch: Whether to use cloud fetch for retrieving results @@ -214,11 +191,15 @@ def __init__( :param max_download_threads: Maximum number of download threads for cloud fetch :param ssl_options: SSL options for cloud fetch :param is_direct_results: Whether there are more rows to fetch + :param has_been_closed_server_side: Whether the command has been closed on the server + :param arrow_schema_bytes: The schema of the result set """ # Initialize ThriftResultSet-specific attributes self._use_cloud_fetch = use_cloud_fetch self.is_direct_results = is_direct_results + self.has_been_closed_server_side = has_been_closed_server_side + self._arrow_schema_bytes = arrow_schema_bytes # Build the results queue if t_row_set is provided results_queue = None @@ -229,7 +210,7 @@ def __init__( results_queue = ThriftResultSetQueueFactory.build_queue( row_set_type=execute_response.result_format, t_row_set=t_row_set, - arrow_schema_bytes=execute_response.arrow_schema_bytes or b"", + arrow_schema_bytes=self._arrow_schema_bytes or b"", max_download_threads=max_download_threads, lz4_compressed=execute_response.lz4_compressed, description=execute_response.description, @@ -239,20 +220,26 @@ def __init__( # Call parent constructor with common attributes super().__init__( connection=connection, - backend=thrift_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, command_id=execute_response.command_id, status=execute_response.status, - has_been_closed_server_side=execute_response.has_been_closed_server_side, is_direct_results=is_direct_results, - results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, - arrow_schema_bytes=execute_response.arrow_schema_bytes, ) + # Assert that the backend is of the correct type + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + + assert isinstance( + self.backend, ThriftDatabricksClient + ), "Backend must be a ThriftDatabricksClient" + + # Set the results queue + self.results = results_queue + # Initialize results queue if not provided if not self.results: self._fill_results_buffer() @@ -308,6 +295,10 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows @@ -333,6 +324,9 @@ def fetchmany_columnar(self, size: int): if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows @@ -352,6 +346,9 @@ def fetchmany_columnar(self, size: int): def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.remaining_rows() self._next_row_index += results.num_rows @@ -378,6 +375,9 @@ def fetchall_arrow(self) -> "pyarrow.Table": def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.remaining_rows() self._next_row_index += results.num_rows @@ -394,6 +394,9 @@ def fetchone(self) -> Optional[Row]: Fetch the next row of a query result set, returning a single sequence, or None when no more data is available. """ + if self.results is None: + raise RuntimeError("Results queue is not initialized") + if isinstance(self.results, ColumnQueue): res = self._convert_columnar_table(self.fetchmany_columnar(1)) else: @@ -440,3 +443,26 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] + + def close(self) -> None: + """ + Close the result set. + + If the connection has not been closed, and the result set has not already + been closed on the server for some other reason, issue a request to the server to close it. + """ + try: + if self.results: + self.results.close() + if ( + self.status != CommandState.CLOSED + and not self.has_been_closed_server_side + and self.connection.open + ): + self.backend.close_command(self.command_id) + except RequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + logger.info("Operation was canceled by a prior request") + finally: + self.has_been_closed_server_side = True + self.status = CommandState.CLOSED diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 4f59857e..6f3f7387 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -8,7 +8,7 @@ from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME from databricks.sql.backend.thrift_backend import ThriftDatabricksClient -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 35764bf8..fa0bb1e6 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -13,7 +13,7 @@ import lz4.frame -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest try: diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 83e83fd4..51430e9e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -108,7 +108,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): mock_execute_response.status = ( CommandState.SUCCEEDED if not closed else CommandState.CLOSED ) - mock_execute_response.has_been_closed_server_side = closed mock_execute_response.is_staging_operation = False mock_execute_response.description = [] @@ -127,7 +126,7 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): real_result_set = ThriftResultSet( connection=connection, execute_response=mock_execute_response, - thrift_client=mock_backend, + has_been_closed_server_side=closed, ) # Mock execute_command to return our real result set @@ -187,22 +186,26 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() - mock_backend = Mock() + mock_backend = Mock(spec=ThriftDatabricksClient) mock_results = Mock() mock_backend.fetch_results.return_value = (Mock(), False) + # Ensure connection appears closed + type(mock_connection).open = PropertyMock(return_value=False) + # Ensure isinstance check passes if needed + mock_backend.__class__ = ThriftDatabricksClient + + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.backend = mock_backend + type(mock_connection).session = PropertyMock(return_value=mock_session) + result_set = ThriftResultSet( connection=mock_connection, execute_response=Mock(), - thrift_client=mock_backend, ) result_set.results = mock_results - # Setup session mock on the mock_connection - mock_session = Mock() - mock_session.open = False - type(mock_connection).session = PropertyMock(return_value=mock_session) - result_set.close() self.assertFalse(mock_backend.close_command.called) @@ -213,16 +216,18 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response = Mock() mock_results_response.has_been_closed_server_side = False mock_connection = Mock() - mock_thrift_backend = Mock() + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_results = Mock() # Setup session mock on the mock_connection mock_session = Mock() mock_session.open = True + mock_session.backend = mock_thrift_backend type(mock_connection).session = PropertyMock(return_value=mock_session) mock_thrift_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( - mock_connection, mock_results_response, mock_thrift_backend + mock_connection, + mock_results_response, ) result_set.results = mock_results @@ -267,10 +272,20 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - mock_backend = Mock() + mock_connection = Mock() + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + + mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.fetch_results.return_value = (Mock(), False) + # Ensure isinstance check passes + mock_backend.__class__ = ThriftDatabricksClient + + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.backend = mock_backend + type(mock_connection).session = PropertyMock(return_value=mock_session) - result_set = ThriftResultSet(Mock(), Mock(), mock_backend) + result_set = ThriftResultSet(mock_connection, Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index a649941e..8643404b 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -1,6 +1,6 @@ import unittest import pytest -from unittest.mock import Mock +from unittest.mock import Mock, PropertyMock try: import pyarrow as pa @@ -38,12 +38,19 @@ def make_arrow_queue(batch): @staticmethod def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more - schema, arrow_table = FetchTests.make_arrow_table(initial_results) - arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) + arrow_queue = FetchTests.make_arrow_queue(initial_results) + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient - # Create a mock backend that will return the queue when _fill_results_buffer is called mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) + # Ensure isinstance check passes + mock_thrift_backend.__class__ = ThriftDatabricksClient + + # Setup mock connection with session.backend + mock_connection = Mock() + mock_session = Mock() + mock_session.backend = mock_thrift_backend + type(mock_connection).session = PropertyMock(return_value=mock_session) num_cols = len(initial_results[0]) if initial_results else 0 description = [ @@ -52,17 +59,16 @@ def make_dummy_result_set_from_initial_results(initial_results): ] rs = ThriftResultSet( - connection=Mock(), + connection=mock_connection, execute_response=ExecuteResponse( command_id=None, status=None, - has_been_closed_server_side=True, description=description, lz4_compressed=True, is_staging_operation=False, ), - thrift_client=mock_thrift_backend, t_row_set=None, + has_been_closed_server_side=True, ) return rs @@ -86,8 +92,19 @@ def fetch_results( return results, batch_index < len(batch_list) + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results = fetch_results + # Ensure isinstance check passes + mock_thrift_backend.__class__ = ThriftDatabricksClient + + # Setup mock connection with session.backend + mock_connection = Mock() + mock_session = Mock() + mock_session.backend = mock_thrift_backend + type(mock_connection).session = PropertyMock(return_value=mock_session) + num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 description = [ @@ -96,16 +113,15 @@ def fetch_results( ] rs = ThriftResultSet( - connection=Mock(), + connection=mock_connection, execute_response=ExecuteResponse( command_id=None, status=None, - has_been_closed_server_side=False, description=description, lz4_compressed=True, is_staging_operation=False, ), - thrift_client=mock_thrift_backend, + has_been_closed_server_side=False, ) return rs diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index da45b429..e3dda181 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -8,7 +8,7 @@ import pytest from unittest.mock import patch, MagicMock, Mock -from databricks.sql.backend.sea.backend import ( +from databricks.sql.backend.sea.client import ( SeaDatabricksClient, _filter_session_configuration, ) @@ -33,7 +33,7 @@ class TestSeaBackend: def mock_http_client(self): """Create a mock HTTP client.""" with patch( - "databricks.sql.backend.sea.backend.SeaHttpClient" + "databricks.sql.backend.sea.client.SeaHttpClient" ) as mock_client_class: mock_client = mock_client_class.return_value yield mock_client @@ -70,12 +70,20 @@ def sea_command_id(self): return CommandId.from_sea_statement_id("test-statement-123") @pytest.fixture - def mock_cursor(self): + def mock_cursor(self, sea_client): """Create a mock cursor.""" cursor = Mock() cursor.active_command_id = None cursor.buffer_size_bytes = 1000 cursor.arraysize = 100 + + # Set up a mock connection with session.backend pointing to the sea_client + mock_connection = Mock() + mock_session = Mock() + mock_session.backend = sea_client + mock_connection.session = mock_session + cursor.connection = mock_connection + return cursor @pytest.fixture diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 544edaf9..25ac2313 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -23,12 +23,20 @@ def mock_connection(self): """Create a mock connection.""" connection = Mock() connection.open = True - return connection - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() + # Mock the session.backend to return a SeaDatabricksClient + mock_session = Mock() + from databricks.sql.backend.sea.client import SeaDatabricksClient + + mock_backend = Mock(spec=SeaDatabricksClient) + mock_backend.max_download_threads = 10 + mock_backend.close_command = Mock() + # Ensure isinstance check passes + mock_backend.__class__ = SeaDatabricksClient + mock_session.backend = mock_backend + connection.session = mock_session + + return connection @pytest.fixture def execute_response(self): @@ -71,9 +79,7 @@ def _create_empty_manifest(self, format: ResultFormat): ) @pytest.fixture - def result_set_with_data( - self, mock_connection, mock_sea_client, execute_response, sample_data - ): + def result_set_with_data(self, mock_connection, execute_response, sample_data): """Create a SeaResultSet with sample data.""" # Create ResultData with inline data result_data = ResultData( @@ -84,7 +90,6 @@ def result_set_with_data( result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=result_data, manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, @@ -99,14 +104,11 @@ def json_queue(self, sample_data): """Create a JsonQueue with sample data.""" return JsonQueue(sample_data) - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): + def test_init_with_execute_response(self, mock_connection, execute_response): """Test initializing SeaResultSet with an execute response.""" result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=[]), manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, @@ -117,17 +119,15 @@ def test_init_with_execute_response( assert result_set.command_id == execute_response.command_id assert result_set.status == CommandState.SUCCEEDED assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client assert result_set.buffer_size_bytes == 1000 assert result_set.arraysize == 100 assert result_set.description == execute_response.description - def test_close(self, mock_connection, mock_sea_client, execute_response): + def test_close(self, mock_connection, execute_response): """Test closing a result set.""" result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=[]), manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, @@ -138,42 +138,17 @@ def test_close(self, mock_connection, mock_sea_client, execute_response): result_set.close() # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) - assert result_set.has_been_closed_server_side is True - assert result_set.status == CommandState.CLOSED - - def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, execute_response - ): - """Test closing a result set that has already been closed server-side.""" - result_set = SeaResultSet( - connection=mock_connection, - execute_response=execute_response, - sea_client=mock_sea_client, - result_data=ResultData(data=[]), - manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), - buffer_size_bytes=1000, - arraysize=100, + mock_connection.session.backend.close_command.assert_called_once_with( + result_set.command_id ) - result_set.has_been_closed_server_side = True - - # Close the result set - result_set.close() - - # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, execute_response - ): + def test_close_when_connection_closed(self, mock_connection, execute_response): """Test closing a result set when the connection is closed.""" mock_connection.open = False result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=[]), manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, @@ -184,8 +159,7 @@ def test_close_when_connection_closed( result_set.close() # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() - assert result_set.has_been_closed_server_side is True + mock_connection.session.backend.close_command.assert_not_called() assert result_set.status == CommandState.CLOSED def test_init_with_result_data(self, result_set_with_data, sample_data): @@ -316,7 +290,7 @@ def test_iteration(self, result_set_with_data, sample_data): assert rows[0].col3 is True def test_fetchmany_arrow_not_implemented( - self, mock_connection, mock_sea_client, execute_response, sample_data + self, mock_connection, execute_response, sample_data ): """Test that fetchmany_arrow raises NotImplementedError for non-JSON data.""" @@ -329,7 +303,6 @@ def test_fetchmany_arrow_not_implemented( result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=None, external_links=[]), manifest=self._create_empty_manifest(ResultFormat.ARROW_STREAM), buffer_size_bytes=1000, @@ -337,7 +310,7 @@ def test_fetchmany_arrow_not_implemented( ) def test_fetchall_arrow_not_implemented( - self, mock_connection, mock_sea_client, execute_response, sample_data + self, mock_connection, execute_response, sample_data ): """Test that fetchall_arrow raises NotImplementedError for non-JSON data.""" # Test that NotImplementedError is raised @@ -349,16 +322,13 @@ def test_fetchall_arrow_not_implemented( result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=None, external_links=[]), manifest=self._create_empty_manifest(ResultFormat.ARROW_STREAM), buffer_size_bytes=1000, arraysize=100, ) - def test_is_staging_operation( - self, mock_connection, mock_sea_client, execute_response - ): + def test_is_staging_operation(self, mock_connection, execute_response): """Test the is_staging_operation property.""" # Set is_staging_operation to True execute_response.is_staging_operation = True @@ -367,7 +337,6 @@ def test_is_staging_operation( result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=[]), manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 37569f75..f66b356c 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -649,7 +649,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response, _ = thrift_backend._handle_execute_response( + execute_response, _, _, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -889,10 +889,9 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - ( - execute_response, - _, - ) = thrift_backend._handle_execute_response(execute_resp, Mock()) + (execute_response, _, _, _) = thrift_backend._handle_execute_response( + execute_resp, Mock() + ) self.assertEqual( execute_response.status, CommandState.SUCCEEDED, @@ -965,11 +964,14 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): ) ) thrift_backend = self._make_fake_thrift_backend() - execute_response, _ = thrift_backend._handle_execute_response( - t_execute_resp, Mock() - ) + ( + execute_response, + _, + _, + arrow_schema_bytes, + ) = thrift_backend._handle_execute_response(t_execute_resp, Mock()) - self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -997,7 +999,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): ) ) thrift_backend = self._make_fake_thrift_backend() - _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) + _, _, _, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( hive_schema_mock, @@ -1046,6 +1048,8 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ( execute_response, has_more_rows_result, + _, + _, ) = thrift_backend._handle_execute_response(execute_resp, Mock()) self.assertEqual(is_direct_results, has_more_rows_result) @@ -1179,7 +1183,12 @@ def test_execute_statement_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1215,7 +1224,12 @@ def test_get_catalogs_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1248,7 +1262,12 @@ def test_get_schemas_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1290,7 +1309,12 @@ def test_get_tables_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1336,7 +1360,12 @@ def test_get_columns_calls_client_and_handle_execute_response( ssl_options=SSLOptions(), ) thrift_backend._handle_execute_response = Mock() - thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) + thrift_backend._handle_execute_response.return_value = ( + Mock(), + Mock(), + Mock(), + Mock(), + ) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -2254,6 +2283,8 @@ def test_execute_command_sets_complex_type_fields_correctly( mock_handle_execute_response.return_value = ( mock_execute_response, mock_arrow_schema, + Mock(), + Mock(), ) # Iterate through each possible combination of native types (True, False and unset)