diff --git a/src/databricks/labs/blueprint/installation.py b/src/databricks/labs/blueprint/installation.py
index 71db98f..5a05bf9 100644
--- a/src/databricks/labs/blueprint/installation.py
+++ b/src/databricks/labs/blueprint/installation.py
@@ -40,6 +40,8 @@
__all__ = ["Installation", "MockInstallation", "IllegalState", "NotInstalled", "SerdeError"]
+FILE_SIZE_LIMIT: int = 1024 * 1024 * 10
+
class IllegalState(ValueError):
pass
@@ -132,6 +134,10 @@ def check_folder(install_folder: str) -> Installation | None:
tasks.append(functools.partial(check_folder, service_principal_folder))
return Threads.strict(f"finding {product} installations", tasks)
+ @staticmethod
+ def extension(filename):
+ return filename.split(".")[-1]
+
@classmethod
def load_local(cls, type_ref: type[T], file: Path) -> T:
"""Loads a typed file from the local file system."""
@@ -348,17 +354,25 @@ def _overwrite_content(self, filename: str, as_dict: Json, type_ref: type):
The `as_dict` argument is the dictionary representation of the object that is to be written to the file.
The `type_ref` argument is the type of the object that is being saved."""
- converters: dict[str, Callable[[Any, type], bytes]] = {
+ converters: dict[str, Callable[[Any, type], list[bytes]]] = {
"json": self._dump_json,
"yml": self._dump_yaml,
"csv": self._dump_csv,
}
- extension = filename.split(".")[-1]
+ extension = self.extension(filename)
if extension not in converters:
raise KeyError(f"Unknown extension: {extension}")
logger.debug(f"Converting {type_ref.__name__} into {extension.upper()} format")
- raw = converters[extension](as_dict, type_ref)
- self.upload(filename, raw)
+ raws = converters[extension](as_dict, type_ref)
+ if len(raws) > 1:
+ for i, raw in enumerate(raws):
+ self.upload(f"{filename[0:-4]}.{i + 1}.csv", raw)
+ return
+ # Check if the file is more than 10MB
+ if len(raws[0]) > FILE_SIZE_LIMIT:
+ raise ValueError(f"File size too large: {len(raws[0])} bytes")
+
+ self.upload(filename, raws[0])
@staticmethod
def _global_installation(product):
@@ -377,17 +391,39 @@ def _unmarshal_type(cls, as_dict, filename, type_ref):
as_dict = cls._migrate_file_format(type_ref, expected_version, as_dict, filename)
return cls._unmarshal(as_dict, [], type_ref)
- def _load_content(self, filename: str) -> Json:
+ def _load_content(self, filename: str) -> Json | list[Json]:
"""The `_load_content` method is a private method that is used to load the contents of a file from
WorkspaceFS as a dictionary. This method is called by the `load` method."""
with self._lock:
# TODO: check how to make this fail fast during unit testing, otherwise
# this currently hangs with the real installation class and mocked workspace client
- with self._ws.workspace.download(f"{self.install_folder()}/{filename}") as f:
- return self._convert_content(filename, f)
+ try:
+ with self._ws.workspace.download(f"{self.install_folder()}/{filename}") as f:
+ return self._convert_content(filename, f)
+ except NotFound:
+ # If the file is not found, check if it is a multi-part csv file
+ if self.extension(filename) != "csv":
+ raise
+ current_part = 1
+ content: list[Json] = []
+ try:
+ while True:
+ with self._ws.workspace.download(
+ f"{self.install_folder()}/{filename[0:-4]}.{current_part}.csv"
+ ) as f:
+ converted_content = self._convert_content(filename, f)
+ # check if converted_content is a list
+ if isinstance(converted_content, list):
+ content += converted_content
+ else:
+ content.append(converted_content)
+ except NotFound:
+ if current_part == 1:
+ raise
+ return content
@classmethod
- def _convert_content(cls, filename: str, raw: BinaryIO) -> Json:
+ def _convert_content(cls, filename: str, raw: BinaryIO) -> Json | list[Json]:
"""The `_convert_content` method is a private method that is used to convert the raw bytes of a file to a
dictionary. This method is called by the `_load_content` method."""
converters: dict[str, Callable[[BinaryIO], Any]] = {
@@ -395,7 +431,7 @@ def _convert_content(cls, filename: str, raw: BinaryIO) -> Json:
"yml": cls._load_yaml,
"csv": cls._load_csv,
}
- extension = filename.split(".")[-1]
+ extension = cls.extension(filename)
if extension not in converters:
raise KeyError(f"Unknown extension: {extension}")
try:
@@ -747,19 +783,19 @@ def _explain_why(type_ref: type, path: list[str], raw: Any) -> str:
return f'{".".join(path)}: not a {type_ref.__name__}: {raw}'
@staticmethod
- def _dump_json(as_dict: Json, _: type) -> bytes:
+ def _dump_json(as_dict: Json, _: type) -> list[bytes]:
"""The `_dump_json` method is a private method that is used to serialize a dictionary to a JSON string. This
method is called by the `save` method."""
- return json.dumps(as_dict, indent=2).encode("utf8")
+ return [json.dumps(as_dict, indent=2).encode("utf8")]
@staticmethod
- def _dump_yaml(raw: Json, _: type) -> bytes:
+ def _dump_yaml(raw: Json, _: type) -> list[bytes]:
"""The `_dump_yaml` method is a private method that is used to serialize a dictionary to a YAML string. This
method is called by the `save` method."""
try:
from yaml import dump # pylint: disable=import-outside-toplevel
- return dump(raw).encode("utf8")
+ return [dump(raw).encode("utf8")]
except ImportError as err:
raise SyntaxError("PyYAML is not installed. Fix: pip install databricks-labs-blueprint[yaml]") from err
@@ -781,9 +817,10 @@ def _load_yaml(raw: BinaryIO) -> Json:
raise SyntaxError("PyYAML is not installed. Fix: pip install databricks-labs-blueprint[yaml]") from err
@staticmethod
- def _dump_csv(raw: list[Json], type_ref: type) -> bytes:
+ def _dump_csv(raw: list[Json], type_ref: type) -> list[bytes]:
"""The `_dump_csv` method is a private method that is used to serialize a list of dictionaries to a CSV string.
This method is called by the `save` method."""
+ raws = []
type_args = get_args(type_ref)
if not type_args:
raise SerdeError(f"Writing CSV is only supported for lists. Got {type_ref}")
@@ -804,9 +841,21 @@ def _dump_csv(raw: list[Json], type_ref: type) -> bytes:
writer = csv.DictWriter(buffer, field_names, dialect="excel")
writer.writeheader()
for as_dict in raw:
+ # Check if the buffer + the current row is over the file size limit
+ before_pos = buffer.tell()
writer.writerow(as_dict)
+ if buffer.tell() > FILE_SIZE_LIMIT:
+ buffer.seek(before_pos)
+ buffer.truncate()
+ raws.append(buffer.getvalue().encode("utf8"))
+ buffer = io.StringIO()
+ writer = csv.DictWriter(buffer, field_names, dialect="excel")
+ writer.writeheader()
+ writer.writerow(as_dict)
+
buffer.seek(0)
- return buffer.read().encode("utf8")
+ raws.append(buffer.getvalue().encode("utf8"))
+ return raws
@staticmethod
def _load_csv(raw: BinaryIO) -> list[Json]:
diff --git a/tests/integration/test_installation.py b/tests/integration/test_installation.py
index 5d9bede..891151c 100644
--- a/tests/integration/test_installation.py
+++ b/tests/integration/test_installation.py
@@ -2,6 +2,7 @@
import pytest
from databricks.sdk.errors import PermissionDenied
+from databricks.sdk.service.catalog import TableInfo
from databricks.sdk.service.provisioning import Workspace
from databricks.labs.blueprint.installation import Installation
@@ -73,6 +74,19 @@ def test_saving_list_of_dataclasses_to_csv(new_installation):
assert len(loaded) == 2
+def test_saving_list_of_dataclasses_to_multiple_csvs(new_installation):
+ tables: list[TableInfo] = []
+ for i in range(500000):
+ tables.append(TableInfo(name=f"long_table_name_{i}", schema_name="very_long_schema_name"))
+ new_installation.save(
+ tables,
+ filename="many_tables_test.csv",
+ )
+
+ loaded = new_installation.load(list[Workspace], filename="many_tables_test.csv")
+ assert len(loaded) == 500000
+
+
@pytest.mark.parametrize(
"ext,magic",
[