Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable modelscope for itrex #1655

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]:
pretrained_model_name_or_path,
*model_args,
config=config,
model_hub=model_hub,
**kwargs,
)
logger.info(
Expand Down Expand Up @@ -1451,9 +1452,13 @@ def train_func(model):
model.quantization_config = None
return model
else:
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
if model_hub=="modelscope":
from modelscope import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path,trust_remote_code=True)
else:
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
if (
not torch.cuda.is_available()
or device_map == "cpu"
Expand Down Expand Up @@ -1519,7 +1524,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
device_map = kwargs.pop("device_map", "auto")
use_safetensors = kwargs.pop("use_safetensors", None)
kwarg_attn_imp = kwargs.pop("attn_implementation", None)

model_hub = kwargs.pop("model_hub", None)
# lm-eval device map is dictionary
device_map = device_map[""] if isinstance(device_map, dict) and "" in device_map else device_map

Expand Down Expand Up @@ -1708,7 +1713,19 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
if model_hub=="modelscope":
from modelscope import snapshot_download
model_dir = snapshot_download(pretrained_model_name_or_path)
if os.path.exists(model_dir+"/model.safetensors"):
resolved_archive_file = model_dir+"/model.safetensors"
elif os.path.exists(model_dir+"/pytorch_model.bin"):
resolved_archive_file = model_dir+"/pytorch_model.bin"
else:
assert (
resolved_archive_file is not None
), "Don't detect this model checkpoint"
else:
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)

# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not.
Expand Down
Loading