Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try s3 crt #519

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
push:
branches:
- 'main'
- 's3-crt'
tags:
- 'v*'

Expand Down
13 changes: 10 additions & 3 deletions server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

import torch
import os
from loguru import logger
from transformers.configuration_utils import PretrainedConfig

Expand Down Expand Up @@ -57,13 +58,19 @@ def get_model(
# change the model id to be the local path to the folder so
# we can load the config_dict locally
logger.info("Using the local files since we are coming from s3")
model_path = get_s3_model_local_dir(model_id)
logger.info(f"model_path: {model_path}")
model_path = get_s3_model_local_dir(model_id) / "snapshots"

files = os.listdir(model_path.absolute().as_posix())
logger.info(files)
if len(files) == 1:
# Do something. Why we do have multiple snapshots?
model_path = model_path / files[0]

config_dict, _ = PretrainedConfig.get_config_dict(
model_path, revision=revision, trust_remote_code=trust_remote_code
)
logger.info(f"config_dict: {config_dict}")
model_id = str(model_path)
model_id = model_path.absolute().as_posix()
elif source == "hub":
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
Expand Down
96 changes: 61 additions & 35 deletions server/lorax_server/utils/sources/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Optional, Tuple

from s3transfer.crt import CRTTransferManager, create_s3_crt_client, BotocoreCRTRequestSerializer, BotocoreCRTCredentialsWrapper
import botocore
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -73,9 +74,11 @@ def _get_bucket_resource(bucket_name: str) -> "Bucket":


def get_s3_model_local_dir(model_id: str):
_, model_id = _get_bucket_and_model_id(model_id)
object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}" / "snapshots"
return repo_cache
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"

return repo_cache


def weight_s3_files(bucket: Any, model_id: str, extension: str = ".safetensors") -> List[str]:
Expand All @@ -97,42 +100,65 @@ def download_files_from_s3(
revision: str = "",
) -> List[Path]:
"""Download the safetensors files from the s3"""
import threading

def download_file(filename):
repo_cache = get_s3_model_local_dir(model_id)
local_file = try_to_load_from_cache(repo_cache, revision, filename)
if local_file is not None:
logger.info(f"File {filename} already present in cache.")
return Path(local_file)
logger.info(f"Download file: {filename}")
start_time = time.time()
local_file_path = get_s3_model_local_dir(model_id) / filename
# ensure cache dir exists and create it if needed
local_file_path.parent.mkdir(parents=True, exist_ok=True)
model_id_path = Path(model_id)
bucket_file_name = model_id_path / filename
logger.info(f"Downloading file {bucket_file_name} to {local_file_path}")
bucket.download_file(str(bucket_file_name), str(local_file_path))
# TODO: add support for revision
logger.info(f"Downloaded {local_file_path} in {timedelta(seconds=int(time.time() - start_time))}.")
if not local_file_path.is_file():
raise FileNotFoundError(f"File {local_file_path} not found")
return local_file_path
def download_file(filename, files):
try:
repo_cache = get_s3_model_local_dir(model_id)
local_file = try_to_load_from_cache(repo_cache, revision, filename)
if local_file is not None:
logger.info(f"File {filename} already present in cache.")
return Path(local_file)
logger.info(f"Download file: {filename}")
start_time = time.time()
local_file_path = get_s3_model_local_dir(model_id) / filename
# ensure cache dir exists and create it if needed
local_file_path.parent.mkdir(parents=True, exist_ok=True)
model_id_path = Path(model_id)
bucket_file_name = model_id_path / filename
session = botocore.session.get_session()
request_serializer = BotocoreCRTRequestSerializer(session)
logger.info(f"Downloading file {bucket_file_name} to {local_file_path}")
with CRTTransferManager(create_s3_crt_client(
"us-west-2",
BotocoreCRTCredentialsWrapper(session.get_credentials()).to_crt_credentials_provider(), target_throughput=50),
request_serializer
) as transfer:
future = transfer.download(bucket.name, str(bucket_file_name), str(local_file_path))
future.result()
# bucket.download_file(str(bucket_file_name), str(local_file_path))
# TODO: add support for revision
logger.info(f"Downloaded {local_file_path} in {timedelta(seconds=int(time.time() - start_time))}.")
if not local_file_path.is_file():
raise FileNotFoundError(f"File {local_file_path} not found")
files.append(local_file_path)
except Exception as e:
logger.info(e)
raise e

start_time = time.time()
files = []
for i, filename in enumerate(filenames):
# TODO: clean up name creation logic
if not filename:
continue
file = download_file(filename)

elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(filenames) - (i + 1)
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0

logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}")
files.append(file)
# for i, filename in enumerate(filenames):
# # TODO: clean up name creation logic
# if not filename:
# continue
# file = download_file(filename)

# elapsed = timedelta(seconds=int(time.time() - start_time))
# remaining = len(filenames) - (i + 1)
# eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0

# logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}")
# files.append(file)
threads = []
for filename in filenames:
t = threading.Thread(target=download_file, args=(filename,files,))
threads.append(t)
t.start()

# Wait for all threads to complete and collect results
for t in threads:
t.join()

return files

Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/utils/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def try_to_load_from_cache(repo_cache: Path, revision: Optional[str], filename: str) -> Optional[Path]:
"""Try to load a file from the Hugging Face cache"""
if revision is None:
if revision is None or revision == "":
revision = "main"

if not repo_cache.is_dir():
Expand Down
87 changes: 71 additions & 16 deletions server/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true }
torch = { version = "2.3.0", optional = true }
peft = { version = "0.4.0", optional = true }
boto3 = "^1.28.34"
boto3 = {extras = ["crt"], version = "^1.34.129"}
urllib3 = "<=1.26.18"
hqq = { version = "^0.1.7", optional = true }
stanford-stk = { version = "^0.7.0", markers = "sys_platform == 'linux'" }
outlines = { version = "^0.0.40", optional = true }
prometheus-client = "^0.20.0"
py-cpuinfo = "^9.0.0"
s3transfer = "0.10.1"

[tool.poetry.extras]
torch = ["torch"]
Expand Down
Loading