Skip to content

Commit

Permalink
Files API client: recover on download failures (databricks#844)
Browse files Browse the repository at this point in the history
  • Loading branch information
ksafonov-db committed Jan 8, 2025
1 parent 6d6923e commit 189067d
Show file tree
Hide file tree
Showing 5 changed files with 582 additions and 6 deletions.
8 changes: 7 additions & 1 deletion databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 13 additions & 3 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import logging
from abc import ABC, abstractmethod
import urllib.parse
from datetime import timedelta
from types import TracebackType
Expand Down Expand Up @@ -284,9 +285,18 @@ def _record_request_log(self, response: requests.Response, raw: bool = False) ->
return
logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate())

class _RawResponse(ABC):
@abstractmethod
# follows Response signature: https://github.com/psf/requests/blob/main/src/requests/models.py#L799
def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False):
pass

@abstractmethod
def close(self):
pass

class _StreamingResponse(BinaryIO):
_response: requests.Response
_response: _RawResponse
_buffer: bytes
_content: Union[Iterator[bytes], None]
_chunk_size: Union[int, None]
Expand All @@ -298,7 +308,7 @@ def fileno(self) -> int:
def flush(self) -> int:
pass

def __init__(self, response: requests.Response, chunk_size: Union[int, None] = None):
def __init__(self, response: _RawResponse, chunk_size: Union[int, None] = None):
self._response = response
self._buffer = b''
self._content = None
Expand All @@ -308,7 +318,7 @@ def _open(self) -> None:
if self._closed:
raise ValueError("I/O operation on closed file")
if not self._content:
self._content = self._response.iter_content(chunk_size=self._chunk_size)
self._content = self._response.iter_content(chunk_size=self._chunk_size, decode_unicode=False)

def __enter__(self) -> BinaryIO:
self._open()
Expand Down
4 changes: 4 additions & 0 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ class Config:
max_connections_per_pool: int = ConfigAttribute()
databricks_environment: Optional[DatabricksEnvironment] = None

enable_experimental_files_api_client: bool = ConfigAttribute(env='DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT')
files_api_client_download_max_total_recovers = None
files_api_client_download_max_total_recovers_without_progressing = 1

def __init__(
self,
*,
Expand Down
183 changes: 181 additions & 2 deletions databricks/sdk/mixins/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,25 @@
import sys
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterator
from io import BytesIO
from types import TracebackType
from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Generator, Iterable,
Iterator, Type, Union)
from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Generator, Iterable, Optional, Type, Union)
from urllib import parse
from requests import RequestException

import logging
from .._property import _cached_property
from ..errors import NotFound
from ..service import files
from ..service._internal import _escape_multi_segment_path_parameter
from ..service.files import DownloadResponse
from .._base_client import _RawResponse, _StreamingResponse

if TYPE_CHECKING:
from _typeshed import Self

_LOG = logging.getLogger(__name__)

class _DbfsIO(BinaryIO):
MAX_CHUNK_SIZE = 1024 * 1024
Expand Down Expand Up @@ -636,3 +642,176 @@ def delete(self, path: str, *, recursive=False):
if p.is_dir and not recursive:
raise IOError('deleting directories requires recursive flag')
p.delete(recursive=recursive)


class FilesExt(files.FilesAPI):
__doc__ = files.FilesAPI.__doc__

def __init__(self, api_client, config: Config):
super().__init__(api_client)
self._config = config.copy()

def download(self, file_path: str) -> DownloadResponse:
"""Download a file.
Downloads a file of any size. The file contents are the response body.
This is a standard HTTP file download, not a JSON RPC.
It is strongly recommended, for fault tolerance reasons,
to iteratively consume from the stream with a maximum read(size)
defined instead of using indefinite-size reads.
:param file_path: str
The remote path of the file, e.g. /Volumes/path/to/your/file
:returns: :class:`DownloadResponse`
"""

initial_response: DownloadResponse = self._download_raw_stream(file_path=file_path,
start_byte_offset=0,
if_unmodified_since_timestamp=None)

wrapped_response = self._wrap_stream(file_path, initial_response)
initial_response.contents._response = wrapped_response
return initial_response

def _download_raw_stream(self,
file_path: str,
start_byte_offset: int,
if_unmodified_since_timestamp: Optional[str] = None) -> DownloadResponse:
headers = {'Accept': 'application/octet-stream', }

if start_byte_offset and not if_unmodified_since_timestamp:
raise Exception("if_unmodified_since_timestamp is required if start_byte_offset is specified")

if start_byte_offset:
headers['Range'] = f'bytes={start_byte_offset}-'

if if_unmodified_since_timestamp:
headers['If-Unmodified-Since'] = if_unmodified_since_timestamp

response_headers = ['content-length', 'content-type', 'last-modified', ]
res = self._api.do('GET',
f'/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}',
headers=headers,
response_headers=response_headers,
raw=True)

result = DownloadResponse.from_dict(res)
if not isinstance(result.contents, _StreamingResponse):
raise Exception("Internal error: response contents is of unexpected type: " + type(result.contents).__name__)

return result

def _wrap_stream(self, file_path: str, downloadResponse: DownloadResponse):
underlying_response = _ResilientIterator._extract_raw_response(downloadResponse)
return _ResilientResponse(self, file_path, downloadResponse.last_modified, offset=0,
underlying_response=underlying_response)


class _ResilientResponse(_RawResponse):
# _StreamingResponse uses two methods of the underlying response:
# - _response.iter_content(chunk_size=self._chunk_size)
# - _response.close
# we need to provide them and nothing else

def __init__(self, api: FilesExt, file_path: str, file_last_modified: str, offset: int,
underlying_response: _RawResponse):
self.api = api
self.file_path = file_path
self.underlying_response = underlying_response
self.offset = offset
self.file_last_modified = file_last_modified

def iter_content(self, chunk_size=1, decode_unicode=False):
if decode_unicode:
raise ValueError('Decode unicode is not supported')

iterator = self.underlying_response.iter_content(chunk_size=chunk_size, decode_unicode=False)
self.iterator = _ResilientIterator(iterator, self.file_path, self.file_last_modified,
self.offset, self.api, chunk_size)
return self.iterator

def close(self):
self.iterator.close()


class _ResilientIterator(Iterator):
# This class tracks current offset (returned to the client code)
# and recovers from failures by requesting download from the current offset.

@staticmethod
def _extract_raw_response(download_response: DownloadResponse) -> _RawResponse:
streaming_response: _StreamingResponse = download_response.contents # this is an instance of _StreamingResponse
return streaming_response._response

def __init__(self, underlying_iterator, file_path: str, file_last_modified: str, offset: int,
api: FilesExt, chunk_size: int):
self._underlying_iterator = underlying_iterator
self._api = api
self._file_path = file_path

# Absolute current offset (0-based), i.e. number of bytes from the beginning of the file
# that were so far returned to the caller code.
self._offset = offset
self._file_last_modified = file_last_modified
self._chunk_size = chunk_size

self._total_recovers_count: int = 0
self._recovers_without_progressing_count: int = 0
self._closed: bool = False


def _should_recover(self) -> bool:
if self._total_recovers_count == self._api._config.files_api_client_download_max_total_recovers:
_LOG.debug("Total recovers limit exceeded")
return False
if self._api._config.files_api_client_download_max_total_recovers_without_progressing is not None and self._recovers_without_progressing_count >= self._api._config.files_api_client_download_max_total_recovers_without_progressing:
_LOG.debug("No progression recovers limit exceeded")
return False
return True

def _recover(self) -> bool:
if not self._should_recover():
return False # recover suppressed, rethrow original exception

self._total_recovers_count += 1
self._recovers_without_progressing_count += 1

try:
self._underlying_iterator.close()

_LOG.debug("Trying to recover from offset " + str(self._offset))

# following call includes all the required network retries
downloadResponse = self._api._download_raw_stream(self._file_path, self._offset, self._file_last_modified)
underlying_response = _ResilientIterator._extract_raw_response(downloadResponse)
self._underlying_iterator = underlying_response.iter_content(chunk_size=self._chunk_size, decode_unicode=False)
_LOG.debug("Recover succeeded")
return True
except:
return False # recover failed, rethrow original exception

def __next__(self):
if self._closed:
# following _BaseClient
raise ValueError("I/O operation on closed file")

while True:
try:
returned_bytes = next(self._underlying_iterator)
self._offset += len(returned_bytes)
self._recovers_without_progressing_count = 0
return returned_bytes

except StopIteration:
raise

# https://requests.readthedocs.io/en/latest/user/quickstart/#errors-and-exceptions
except RequestException:
if not self._recover():
raise

def close(self):
self._underlying_iterator.close()
self._closed = True
Loading

0 comments on commit 189067d

Please sign in to comment.