Skip to content

Commit 2bbfd39

Browse files
authored
remove zipping (#105)
This pr removes zipping before upload and instead passes the structure of a directory to the API directory
1 parent 9c25bbf commit 2bbfd39

File tree

5 files changed

+195
-68
lines changed

5 files changed

+195
-68
lines changed

integration_tests/test_model_upload.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,22 @@ def test_model_upload_directory(self) -> None:
102102
# Create Version
103103
model_upload(self.handle, temp_dir, LICENSE_NAME)
104104

105+
def test_model_upload_directory_structure(self) -> None:
106+
nested_dir = Path(self.temp_dir) / "nested"
107+
nested_dir.mkdir()
108+
109+
with open(Path(self.temp_dir) / "file1.txt", "w") as f:
110+
f.write("dummy content in nested file")
111+
112+
# Create dummy files in the nested directory
113+
nested_dummy_files = ["nested_model.h5", "nested_config.json", "nested_metadata.json"]
114+
for file in nested_dummy_files:
115+
with open(nested_dir / file, "w") as f:
116+
f.write("dummy content in nested file")
117+
118+
# Call the model upload function with the base directory
119+
model_upload(self.handle, self.temp_dir, LICENSE_NAME)
120+
105121
def test_model_upload_nested_dir(self) -> None:
106122
# Create a nested directory within self.temp_dir
107123
nested_dir = Path(self.temp_dir) / "nested"

src/kagglehub/gcs_upload.py

Lines changed: 125 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
import logging
22
import os
3-
import shutil
43
import time
54
import zipfile
65
from datetime import datetime
7-
from multiprocessing import Pool
8-
from pathlib import Path
96
from tempfile import TemporaryDirectory
10-
from typing import List, Tuple, Union
7+
from typing import Dict, List, Optional, Union
118

129
import requests
1310
from requests.exceptions import ConnectionError, Timeout
@@ -25,6 +22,25 @@
2522
REQUEST_TIMEOUT = 600
2623

2724

25+
class UploadDirectoryInfo:
26+
def __init__(
27+
self,
28+
name: str,
29+
files: Optional[List[str]] = None,
30+
directories: Optional[List["UploadDirectoryInfo"]] = None,
31+
):
32+
self.name = name
33+
self.files = files if files is not None else []
34+
self.directories = directories if directories is not None else []
35+
36+
def serialize(self) -> Dict:
37+
return {
38+
"name": self.name,
39+
"files": [{"token": file} for file in self.files],
40+
"directories": [directory.serialize() for directory in self.directories],
41+
}
42+
43+
2844
def parse_datetime_string(string: str) -> Union[datetime, str]:
2945
time_formats = ["%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S.%fZ"]
3046
for t in time_formats:
@@ -138,51 +154,108 @@ def _upload_blob(file_path: str, model_type: str) -> str:
138154
return response["token"]
139155

140156

141-
def zip_file(args: Tuple[Path, Path, Path]) -> int:
142-
file_path, zip_path, source_path_obj = args
143-
arcname = file_path.relative_to(source_path_obj)
144-
size = file_path.stat().st_size
145-
with zipfile.ZipFile(zip_path, "a", zipfile.ZIP_STORED, allowZip64=True) as zipf:
146-
zipf.write(file_path, arcname)
147-
return size
148-
149-
150-
def zip_files(source_path_obj: Path, zip_path: Path) -> List[int]:
151-
files = [file for file in source_path_obj.rglob("*") if file.is_file()]
152-
args = [(file, zip_path, source_path_obj) for file in files]
153-
154-
with Pool() as pool:
155-
sizes = pool.map(zip_file, args)
156-
return sizes
157-
158-
159-
def upload_files(source_path: str, model_type: str) -> List[str]:
160-
source_path_obj = Path(source_path)
161-
with TemporaryDirectory() as temp_dir:
162-
temp_dir_path = Path(temp_dir)
163-
total_size = 0
164-
165-
if source_path_obj.is_dir():
166-
for file_path in source_path_obj.rglob("*"):
167-
if file_path.is_file():
168-
total_size += file_path.stat().st_size
169-
elif source_path_obj.is_file():
170-
total_size = source_path_obj.stat().st_size
171-
else:
172-
path_error_message = "The source path does not point to a valid file or directory."
173-
raise ValueError(path_error_message)
174-
175-
with tqdm(total=total_size, desc="Zipping", unit="B", unit_scale=True, unit_divisor=1024) as pbar:
176-
if source_path_obj.is_dir():
177-
zip_path = temp_dir_path / "archive.zip"
178-
sizes = zip_files(source_path_obj, zip_path)
179-
for size in sizes:
180-
pbar.update(size)
181-
upload_path = str(zip_path)
182-
elif source_path_obj.is_file():
183-
temp_file_path = temp_dir_path / source_path_obj.name
184-
shutil.copy(source_path_obj, temp_file_path)
185-
pbar.update(temp_file_path.stat().st_size)
186-
upload_path = str(temp_file_path)
187-
188-
return [token for token in [_upload_blob(upload_path, model_type)] if token]
157+
def upload_files_and_directories(
158+
folder: str, model_type: str, quiet: bool = False # noqa: FBT002, FBT001
159+
) -> UploadDirectoryInfo:
160+
# Count the total number of files
161+
file_count = 0
162+
for _, _, files in os.walk(folder):
163+
file_count += len(files)
164+
165+
if file_count > MAX_FILES_TO_UPLOAD:
166+
if not quiet:
167+
logger.info(f"More than {MAX_FILES_TO_UPLOAD} files detected, creating a zip archive...")
168+
169+
with TemporaryDirectory() as temp_dir:
170+
zip_path = os.path.join(temp_dir, TEMP_ARCHIVE_FILE)
171+
with zipfile.ZipFile(zip_path, "w") as zipf:
172+
for root, _, files in os.walk(folder):
173+
for file in files:
174+
file_path = os.path.join(root, file)
175+
zipf.write(file_path, os.path.relpath(file_path, folder))
176+
177+
tokens = [
178+
token
179+
for token in [_upload_file_or_folder(temp_dir, TEMP_ARCHIVE_FILE, model_type, quiet)]
180+
if token is not None
181+
]
182+
return UploadDirectoryInfo(name="archive", files=tokens)
183+
184+
root_dict = UploadDirectoryInfo(name="root")
185+
if os.path.isfile(folder):
186+
# Directly upload the file if the path is a file
187+
file_name = os.path.basename(folder)
188+
token = _upload_file_or_folder(os.path.dirname(folder), file_name, model_type, quiet)
189+
if token:
190+
root_dict.files.append(token)
191+
else:
192+
for root, _, files in os.walk(folder):
193+
# Path of the current folder relative to the base folder
194+
path = os.path.relpath(root, folder)
195+
196+
# Navigate or create the dictionary path to the current folder
197+
current_dict = root_dict
198+
if path != ".":
199+
for part in path.split(os.sep):
200+
# Find or create the subdirectory in the current dictionary
201+
for subdir in current_dict.directories:
202+
if subdir.name == part:
203+
current_dict = subdir
204+
break
205+
else:
206+
# If the directory is not found, create a new one
207+
new_dir = UploadDirectoryInfo(name=part)
208+
current_dict.directories.append(new_dir)
209+
current_dict = new_dir
210+
211+
# Add file tokens to the current directory in the dictionary
212+
for file in files:
213+
token = _upload_file_or_folder(root, file, model_type, quiet)
214+
if token:
215+
current_dict.files.append(token)
216+
217+
return root_dict
218+
219+
220+
def _upload_file_or_folder(
221+
parent_path: str,
222+
file_or_folder_name: str,
223+
model_type: str,
224+
quiet: bool = False, # noqa: FBT002, FBT001
225+
) -> Optional[str]:
226+
"""
227+
Uploads a file or each file inside a folder individually from a specified path to a remote service.
228+
Parameters
229+
==========
230+
parent_path: The parent directory path from where the file or folder is to be uploaded.
231+
file_or_folder_name: The name of the file or folder to be uploaded.
232+
dir_mode: The mode to handle directories. Accepts 'zip', 'tar', or other values for skipping.
233+
model_type: Type of the model that is being uploaded.
234+
quiet: suppress verbose output (default is False)
235+
:return: A token if the upload is successful, or None if the file is skipped or the upload fails.
236+
"""
237+
full_path = os.path.join(parent_path, file_or_folder_name)
238+
if os.path.isfile(full_path):
239+
return _upload_file(file_or_folder_name, full_path, quiet, model_type)
240+
return None
241+
242+
243+
def _upload_file(file_name: str, full_path: str, quiet: bool, model_type: str) -> Optional[str]: # noqa: FBT001
244+
"""Helper function to upload a single file
245+
Parameters
246+
==========
247+
file_name: name of the file to upload
248+
full_path: path to the file to upload
249+
quiet: suppress verbose output
250+
model_type: Type of the model that is being uploaded.
251+
:return: None - upload unsuccessful; instance of UploadFile - upload successful
252+
"""
253+
254+
if not quiet:
255+
logger.info("Starting upload for file " + file_name)
256+
257+
content_length = os.path.getsize(full_path)
258+
token = _upload_blob(full_path, model_type)
259+
if not quiet:
260+
logger.info("Upload successful: " + file_name + " (" + File.get_size(content_length) + ")")
261+
return token

src/kagglehub/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Optional
33

44
from kagglehub import registry
5-
from kagglehub.gcs_upload import upload_files
5+
from kagglehub.gcs_upload import upload_files_and_directories
66
from kagglehub.handle import parse_model_handle
77
from kagglehub.models_helpers import create_model_if_missing, create_model_instance_or_version
88

@@ -47,7 +47,7 @@ def model_upload(
4747
create_model_if_missing(h.owner, h.model)
4848

4949
# Upload the model files to GCS
50-
tokens = upload_files(local_model_dir, "model")
50+
tokens = upload_files_and_directories(local_model_dir, "model")
5151

5252
# Create a model instance if it doesn't exist, and create a new instance version if an instance exists
5353
create_model_instance_or_version(h, tokens, license_name, version_notes)

src/kagglehub/models_helpers.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
22
from http import HTTPStatus
3-
from typing import List, Optional
3+
from typing import Optional
44

5-
from kagglehub.clients import KaggleApiV1Client
6-
from kagglehub.exceptions import BackendError, KaggleApiHTTPError
5+
from kagglehub.clients import BackendError, KaggleApiV1Client
6+
from kagglehub.exceptions import KaggleApiHTTPError
7+
from kagglehub.gcs_upload import UploadDirectoryInfo
78
from kagglehub.handle import ModelHandle
89

910
logger = logging.getLogger(__name__)
@@ -16,11 +17,15 @@ def _create_model(owner_slug: str, model_slug: str) -> None:
1617
logger.info(f"Model '{model_slug}' Created.")
1718

1819

19-
def _create_model_instance(model_handle: ModelHandle, files: List[str], license_name: Optional[str] = None) -> None:
20+
def _create_model_instance(
21+
model_handle: ModelHandle, files_and_directories: UploadDirectoryInfo, license_name: Optional[str] = None
22+
) -> None:
23+
serialized_data = files_and_directories.serialize()
2024
data = {
2125
"instanceSlug": model_handle.variation,
2226
"framework": model_handle.framework,
23-
"files": [{"token": file_token} for file_token in files],
27+
"files": [{"token": file_token} for file_token in files_and_directories.files],
28+
"directories": serialized_data["directories"],
2429
}
2530
if license_name is not None:
2631
data["licenseName"] = license_name
@@ -30,8 +35,15 @@ def _create_model_instance(model_handle: ModelHandle, files: List[str], license_
3035
logger.info(f"Your model instance has been created.\nFiles are being processed...\nSee at: {model_handle.to_url()}")
3136

3237

33-
def _create_model_instance_version(model_handle: ModelHandle, files: List[str], version_notes: str = "") -> None:
34-
data = {"versionNotes": version_notes, "files": [{"token": file_token} for file_token in files]}
38+
def _create_model_instance_version(
39+
model_handle: ModelHandle, files_and_directories: UploadDirectoryInfo, version_notes: str = ""
40+
) -> None:
41+
serialized_data = files_and_directories.serialize()
42+
data = {
43+
"versionNotes": version_notes,
44+
"files": [{"token": file_token} for file_token in files_and_directories.files],
45+
"directories": serialized_data["directories"],
46+
}
3547
api_client = KaggleApiV1Client()
3648
api_client.post(
3749
f"/models/{model_handle.owner}/{model_handle.model}/{model_handle.framework}/{model_handle.variation}/create/version",
@@ -43,7 +55,7 @@ def _create_model_instance_version(model_handle: ModelHandle, files: List[str],
4355

4456

4557
def create_model_instance_or_version(
46-
model_handle: ModelHandle, files: List[str], license_name: Optional[str], version_notes: str = ""
58+
model_handle: ModelHandle, files: UploadDirectoryInfo, license_name: Optional[str], version_notes: str = ""
4759
) -> None:
4860
try:
4961
_create_model_instance(model_handle, files, license_name)

tests/test_model_upload.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_model_upload_instance_with_valid_handle(self) -> None:
140140
test_filepath.touch() # Create a temporary file in the temporary directory
141141
model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type")
142142
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
143-
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
143+
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
144144

145145
def test_model_upload_instance_with_nested_directories(self) -> None:
146146
# execution path: get_model -> create_model -> get_instance -> create_version
@@ -156,7 +156,7 @@ def test_model_upload_instance_with_nested_directories(self) -> None:
156156
test_filepath.touch()
157157
model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type")
158158
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
159-
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
159+
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
160160

161161
def test_model_upload_version_with_valid_handle(self) -> None:
162162
# execution path: get_model -> get_instance -> create_instance
@@ -168,7 +168,7 @@ def test_model_upload_version_with_valid_handle(self) -> None:
168168
test_filepath.touch() # Create a temporary file in the temporary directory
169169
model_upload("metaresearch/llama-2/pyTorch/7b", temp_dir, APACHE_LICENSE, "model_type")
170170
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
171-
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
171+
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
172172

173173
def test_model_upload_with_too_many_files(self) -> None:
174174
with create_test_http_server(KaggleAPIHandler):
@@ -199,7 +199,7 @@ def test_model_upload_resumable(self) -> None:
199199
# Check that GcsAPIHandler received two PUT requests
200200
self.assertEqual(GcsAPIHandler.put_requests_count, 2)
201201
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
202-
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
202+
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
203203

204204
def test_model_upload_with_none_license(self) -> None:
205205
with create_test_http_server(KaggleAPIHandler):
@@ -209,7 +209,7 @@ def test_model_upload_with_none_license(self) -> None:
209209
test_filepath.touch() # Create a temporary file in the temporary directory
210210
model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, None, "model_type")
211211
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
212-
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
212+
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
213213

214214
def test_model_upload_without_license(self) -> None:
215215
with create_test_http_server(KaggleAPIHandler):
@@ -219,7 +219,7 @@ def test_model_upload_without_license(self) -> None:
219219
test_filepath.touch() # Create a temporary file in the temporary directory
220220
model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, version_notes="model_type")
221221
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
222-
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
222+
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
223223

224224
def test_model_upload_with_invalid_license_fails(self) -> None:
225225
with create_test_http_server(KaggleAPIHandler):
@@ -244,3 +244,29 @@ def test_single_file_upload(self) -> None:
244244

245245
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
246246
self.assertIn("single_dummy_file.txt", KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
247+
248+
def test_model_upload_with_directory_structure(self) -> None:
249+
with create_test_http_server(KaggleAPIHandler):
250+
with create_test_http_server(GcsAPIHandler, "http://localhost:7778"):
251+
with TemporaryDirectory() as temp_dir:
252+
base_path = Path(temp_dir)
253+
(base_path / "dir1").mkdir()
254+
(base_path / "dir2").mkdir()
255+
256+
(base_path / "file1.txt").touch()
257+
258+
(base_path / "dir1" / "file2.txt").touch()
259+
(base_path / "dir1" / "file3.txt").touch()
260+
261+
(base_path / "dir1" / "subdir1").mkdir()
262+
(base_path / "dir1" / "subdir1" / "file4.txt").touch()
263+
264+
model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type")
265+
266+
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 4)
267+
expected_files = {"file1.txt", "file2.txt", "file3.txt", "file4.txt"}
268+
self.assertTrue(set(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES).issubset(expected_files))
269+
270+
# TODO: Add assertions on CreateModelInstanceRequest.Directories and
271+
# CreateModelInstanceRequest.Files to verify the expected structure
272+
# is sent.

0 commit comments

Comments
 (0)