diff --git a/src/databricks/labs/blueprint/installation.py b/src/databricks/labs/blueprint/installation.py index 71db98f..5a05bf9 100644 --- a/src/databricks/labs/blueprint/installation.py +++ b/src/databricks/labs/blueprint/installation.py @@ -40,6 +40,8 @@ __all__ = ["Installation", "MockInstallation", "IllegalState", "NotInstalled", "SerdeError"] +FILE_SIZE_LIMIT: int = 1024 * 1024 * 10 + class IllegalState(ValueError): pass @@ -132,6 +134,10 @@ def check_folder(install_folder: str) -> Installation | None: tasks.append(functools.partial(check_folder, service_principal_folder)) return Threads.strict(f"finding {product} installations", tasks) + @staticmethod + def extension(filename): + return filename.split(".")[-1] + @classmethod def load_local(cls, type_ref: type[T], file: Path) -> T: """Loads a typed file from the local file system.""" @@ -348,17 +354,25 @@ def _overwrite_content(self, filename: str, as_dict: Json, type_ref: type): The `as_dict` argument is the dictionary representation of the object that is to be written to the file. The `type_ref` argument is the type of the object that is being saved.""" - converters: dict[str, Callable[[Any, type], bytes]] = { + converters: dict[str, Callable[[Any, type], list[bytes]]] = { "json": self._dump_json, "yml": self._dump_yaml, "csv": self._dump_csv, } - extension = filename.split(".")[-1] + extension = self.extension(filename) if extension not in converters: raise KeyError(f"Unknown extension: {extension}") logger.debug(f"Converting {type_ref.__name__} into {extension.upper()} format") - raw = converters[extension](as_dict, type_ref) - self.upload(filename, raw) + raws = converters[extension](as_dict, type_ref) + if len(raws) > 1: + for i, raw in enumerate(raws): + self.upload(f"{filename[0:-4]}.{i + 1}.csv", raw) + return + # Check if the file is more than 10MB + if len(raws[0]) > FILE_SIZE_LIMIT: + raise ValueError(f"File size too large: {len(raws[0])} bytes") + + self.upload(filename, raws[0]) @staticmethod def _global_installation(product): @@ -377,17 +391,39 @@ def _unmarshal_type(cls, as_dict, filename, type_ref): as_dict = cls._migrate_file_format(type_ref, expected_version, as_dict, filename) return cls._unmarshal(as_dict, [], type_ref) - def _load_content(self, filename: str) -> Json: + def _load_content(self, filename: str) -> Json | list[Json]: """The `_load_content` method is a private method that is used to load the contents of a file from WorkspaceFS as a dictionary. This method is called by the `load` method.""" with self._lock: # TODO: check how to make this fail fast during unit testing, otherwise # this currently hangs with the real installation class and mocked workspace client - with self._ws.workspace.download(f"{self.install_folder()}/{filename}") as f: - return self._convert_content(filename, f) + try: + with self._ws.workspace.download(f"{self.install_folder()}/{filename}") as f: + return self._convert_content(filename, f) + except NotFound: + # If the file is not found, check if it is a multi-part csv file + if self.extension(filename) != "csv": + raise + current_part = 1 + content: list[Json] = [] + try: + while True: + with self._ws.workspace.download( + f"{self.install_folder()}/{filename[0:-4]}.{current_part}.csv" + ) as f: + converted_content = self._convert_content(filename, f) + # check if converted_content is a list + if isinstance(converted_content, list): + content += converted_content + else: + content.append(converted_content) + except NotFound: + if current_part == 1: + raise + return content @classmethod - def _convert_content(cls, filename: str, raw: BinaryIO) -> Json: + def _convert_content(cls, filename: str, raw: BinaryIO) -> Json | list[Json]: """The `_convert_content` method is a private method that is used to convert the raw bytes of a file to a dictionary. This method is called by the `_load_content` method.""" converters: dict[str, Callable[[BinaryIO], Any]] = { @@ -395,7 +431,7 @@ def _convert_content(cls, filename: str, raw: BinaryIO) -> Json: "yml": cls._load_yaml, "csv": cls._load_csv, } - extension = filename.split(".")[-1] + extension = cls.extension(filename) if extension not in converters: raise KeyError(f"Unknown extension: {extension}") try: @@ -747,19 +783,19 @@ def _explain_why(type_ref: type, path: list[str], raw: Any) -> str: return f'{".".join(path)}: not a {type_ref.__name__}: {raw}' @staticmethod - def _dump_json(as_dict: Json, _: type) -> bytes: + def _dump_json(as_dict: Json, _: type) -> list[bytes]: """The `_dump_json` method is a private method that is used to serialize a dictionary to a JSON string. This method is called by the `save` method.""" - return json.dumps(as_dict, indent=2).encode("utf8") + return [json.dumps(as_dict, indent=2).encode("utf8")] @staticmethod - def _dump_yaml(raw: Json, _: type) -> bytes: + def _dump_yaml(raw: Json, _: type) -> list[bytes]: """The `_dump_yaml` method is a private method that is used to serialize a dictionary to a YAML string. This method is called by the `save` method.""" try: from yaml import dump # pylint: disable=import-outside-toplevel - return dump(raw).encode("utf8") + return [dump(raw).encode("utf8")] except ImportError as err: raise SyntaxError("PyYAML is not installed. Fix: pip install databricks-labs-blueprint[yaml]") from err @@ -781,9 +817,10 @@ def _load_yaml(raw: BinaryIO) -> Json: raise SyntaxError("PyYAML is not installed. Fix: pip install databricks-labs-blueprint[yaml]") from err @staticmethod - def _dump_csv(raw: list[Json], type_ref: type) -> bytes: + def _dump_csv(raw: list[Json], type_ref: type) -> list[bytes]: """The `_dump_csv` method is a private method that is used to serialize a list of dictionaries to a CSV string. This method is called by the `save` method.""" + raws = [] type_args = get_args(type_ref) if not type_args: raise SerdeError(f"Writing CSV is only supported for lists. Got {type_ref}") @@ -804,9 +841,21 @@ def _dump_csv(raw: list[Json], type_ref: type) -> bytes: writer = csv.DictWriter(buffer, field_names, dialect="excel") writer.writeheader() for as_dict in raw: + # Check if the buffer + the current row is over the file size limit + before_pos = buffer.tell() writer.writerow(as_dict) + if buffer.tell() > FILE_SIZE_LIMIT: + buffer.seek(before_pos) + buffer.truncate() + raws.append(buffer.getvalue().encode("utf8")) + buffer = io.StringIO() + writer = csv.DictWriter(buffer, field_names, dialect="excel") + writer.writeheader() + writer.writerow(as_dict) + buffer.seek(0) - return buffer.read().encode("utf8") + raws.append(buffer.getvalue().encode("utf8")) + return raws @staticmethod def _load_csv(raw: BinaryIO) -> list[Json]: diff --git a/tests/integration/test_installation.py b/tests/integration/test_installation.py index 5d9bede..891151c 100644 --- a/tests/integration/test_installation.py +++ b/tests/integration/test_installation.py @@ -2,6 +2,7 @@ import pytest from databricks.sdk.errors import PermissionDenied +from databricks.sdk.service.catalog import TableInfo from databricks.sdk.service.provisioning import Workspace from databricks.labs.blueprint.installation import Installation @@ -73,6 +74,19 @@ def test_saving_list_of_dataclasses_to_csv(new_installation): assert len(loaded) == 2 +def test_saving_list_of_dataclasses_to_multiple_csvs(new_installation): + tables: list[TableInfo] = [] + for i in range(500000): + tables.append(TableInfo(name=f"long_table_name_{i}", schema_name="very_long_schema_name")) + new_installation.save( + tables, + filename="many_tables_test.csv", + ) + + loaded = new_installation.load(list[Workspace], filename="many_tables_test.csv") + assert len(loaded) == 500000 + + @pytest.mark.parametrize( "ext,magic", [