Skip to content

Commit

Permalink
model refactoring, wip, now move safe_snapshot_downlaod to single fil…
Browse files Browse the repository at this point in the history
…e, and move patchers to patcher.py
  • Loading branch information
tastelikefeet committed Sep 11, 2024
1 parent 4092b71 commit 4bf3acf
Show file tree
Hide file tree
Showing 13 changed files with 820 additions and 773 deletions.
5 changes: 2 additions & 3 deletions swift/hub/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
from huggingface_hub.hf_api import CommitInfo, future_compatible
from requests.exceptions import HTTPError
from transformers.utils import logging, strtobool
from swift.utils.env import use_hf_hub

logger = logging.get_logger(__name__)

_use_hf_hub = strtobool(os.environ.get('USE_HF', 'False'))


class HubOperation:

Expand Down Expand Up @@ -136,7 +135,7 @@ def upload_folder(
class MSHub(HubOperation):
ms_token = None

if not _use_hf_hub:
if not use_hf_hub():
import huggingface_hub
from transformers import trainer
huggingface_hub.create_repo = create_repo
Expand Down
Empty file added swift/llm/model/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions swift/llm/model/checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@


def check_awq_ext() -> None:
try:
from awq.utils.packing_utils import dequantize_gemm
import awq_ext # with CUDA kernels (AutoAWQ_kernels)
except ImportError as e:
raise ImportError('You are training awq models, remember installing awq_ext by '
'`git clone https://github.com/casper-hansen/AutoAWQ_kernels '
'&& cd AutoAWQ_kernels && pip install -e .`') from e
86 changes: 86 additions & 0 deletions swift/llm/model/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
from contextlib import nullcontext
from typing import Optional, Dict, Any

from transformers.utils.versions import require_version

from swift import get_logger
from swift.hub import hub
from swift.llm.utils.utils import is_unsloth_available
from swift.utils import safe_ddp_context
from swift.utils.env import use_hf_hub

logger = get_logger()

# Model Home: 'https://modelscope.cn/models/{model_id_or_path}'
MODEL_MAPPING: Dict[str, Dict[str, Any]] = {}


def safe_snapshot_download(model_type: str,
model_id_or_path: Optional[str] = None,
revision: Optional[str] = None,
download_model: bool = True,
**kwargs) -> str:
"""Download model protected by DDP context
Args:
model_type: The model type, can be None
model_id_or_path: The model id or model path
revision: The model revision
download_model: Download model bin files or not
**kwargs:
Returns:
The model dir
"""
# Perform snapshot_download (ms or hf) based on model_type and model_id_or_path.
model_info = MODEL_MAPPING.get(model_type, {})

model_dir = None
if model_id_or_path is None:
model_dir = kwargs.pop('model_dir', None) # compat with swift<1.7
if model_dir is not None:
model_id_or_path = model_dir
else:
model_id_or_path = model_info['hf_model_id' if use_hf_hub() else 'model_id_or_path']

with safe_ddp_context():
if model_id_or_path is not None and not os.path.exists(model_id_or_path):
if model_id_or_path.startswith('/'):
raise ValueError(f"path: '{model_id_or_path}' not found")
ignore_file_pattern = model_info.get('ignore_file_pattern')
model_dir = hub.download_model(model_id_or_path, revision, download_model, ignore_file_pattern, **kwargs)
else:
model_dir = model_id_or_path
logger.info(f'Loading the model using model_dir: {model_dir}')

model_dir = os.path.expanduser(model_dir)
assert os.path.isdir(model_dir), f'model_dir: {model_dir}'
return model_dir


def load_by_unsloth(model_dir, torch_dtype, **kwargs):
assert is_unsloth_available(), 'please install unsloth if using `use_unsloth=True`'
from unsloth import FastLanguageModel
return FastLanguageModel.from_pretrained(
model_name=model_dir,
max_seq_length=kwargs.get('max_length', None),
dtype=torch_dtype,
load_in_4bit=kwargs.get('load_in_4bit', True),
trust_remote_code=True,
)


def load_by_transformers(automodel_class, model_dir, model_config, torch_dtype,
is_aqlm, is_training, model_kwargs, **kwargs):
context = kwargs.get('context', None)
if is_aqlm and is_training:
require_version('transformers>=4.39')
import aqlm
context = aqlm.optimize_for_training()
if context is None:
context = nullcontext()
with context:
model = automodel_class.from_pretrained(
model_dir, config=model_config, torch_dtype=torch_dtype, trust_remote_code=True, **model_kwargs)
return model
Loading

0 comments on commit 4bf3acf

Please sign in to comment.