Skip to content

Commit 1413acb

Browse files
authored
Merge pull request #69 from SkywardAI/fix/load_model
fix download model issue
2 parents 37c1f01 + d8d2ba2 commit 1413acb

File tree

4 files changed

+171
-21
lines changed

4 files changed

+171
-21
lines changed

examples/download_examples.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

src/kimchima/pkg/model_factory.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,20 @@ def __init__(self):
3636
@classmethod
3737
def auto_model(cls, *args, **kwargs)-> AutoModel:
3838
r"""
39-
It is used to get the model from the Hugging Face Transformers AutoModel.
40-
41-
Args:
42-
pretrained_model_name_or_path: pretrained model name or path
39+
Here we will use AutoModel from Huggingface to load the model form local.
40+
It support a wider range of models beyond causal language models,
41+
like BERT, RoBERTa, BART, T5 and more.
42+
43+
It returns the base model without a specific head, it does not directly
44+
perform tasks like text generation or translation.
4345
4446
"""
4547
pretrained_model_name_or_path=kwargs.pop("pretrained_model_name_or_path", None)
4648
if pretrained_model_name_or_path is None:
4749
raise ValueError("pretrained_model_name_or_path cannot be None")
4850

49-
quantization_config=kwargs.pop("quantization_config", None)
5051
model = AutoModel.from_pretrained(
5152
pretrained_model_name_or_path,
52-
quantization_config,
5353
**kwargs
5454
)
5555
logger.debug(f"Loaded model: {pretrained_model_name_or_path}")
@@ -58,21 +58,17 @@ def auto_model(cls, *args, **kwargs)-> AutoModel:
5858
@classmethod
5959
def auto_model_for_causal_lm(cls, *args, **kwargs)-> AutoModelForCausalLM:
6060
r"""
61-
It is used to get the model from the Hugging Face Transformers AutoModelForCausalLM.
62-
63-
Args:
64-
pretrained_model_name_or_path: pretrained model name or path
65-
61+
Here we will use AutoModelForCausalLM to load the model from local,
62+
Like GPT-2 XLNet etc.
63+
It return a language modeling head which can be used to generate text,
64+
translate text, write content, answer questions in a informative way.
6665
"""
6766
pretrained_model_name_or_path=kwargs.pop("pretrained_model_name_or_path", None)
6867
if pretrained_model_name_or_path is None:
6968
raise ValueError("pretrained_model_name_or_path cannot be None")
7069

71-
quantization_config=kwargs.pop("quantization_config", None)
7270
model = AutoModelForCausalLM.from_pretrained(
73-
pretrained_model_name_or_path,
74-
quantization_config=quantization_config,
75-
device_map='auto',
71+
pretrained_model_name_or_path,
7672
**kwargs
7773
)
7874
logger.debug(f"Loaded model: {pretrained_model_name_or_path}")
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# coding=utf-8
2+
# Copyright [2024] [SkywardAI]
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from kimchima.utils import Downloader
18+
from kimchima.pipelines import PipelinesFactory
19+
from kimchima.pkg import ModelFactory
20+
from kimchima.pkg import TokenizerFactory
21+
22+
23+
class TestDownloader(unittest.TestCase):
24+
25+
# prepare test data
26+
def setUp(self):
27+
self.model_name="gpt2"
28+
self.folder_name="gpt2"
29+
self.model_name_auto="sentence-transformers/all-MiniLM-L6-v2"
30+
self.folder_name_auto="encoder"
31+
32+
@unittest.skip("skip test_model_downloader")
33+
def test_model_downloader(self):
34+
"""
35+
Test model_downloader method
36+
"""
37+
Downloader.model_downloader(model_name=self.model_name, folder_name=self.folder_name)
38+
39+
# load it from the folder
40+
pipe=PipelinesFactory.customized_pipe(model=self.folder_name, device_map='auto')
41+
42+
# pipe is not None
43+
self.assertIsNotNone(pipe)
44+
self.assertEqual(pipe.model.name_or_path, self.folder_name)
45+
46+
47+
# @unittest.skip("skip test_auto_downloader")
48+
def test_auto_downloader(self):
49+
"""
50+
Test auto_downloader method
51+
"""
52+
Downloader.auto_downloader(model_name=self.model_name_auto, folder_name=self.folder_name_auto)
53+
Downloader.auto_token_downloader(model_name=self.model_name_auto, folder_name=self.folder_name_auto)
54+
55+
# load it from the folder
56+
model=ModelFactory.auto_model(pretrained_model_name_or_path=self.folder_name_auto)
57+
58+
59+
# load it from the local dolder
60+
tokenizer=TokenizerFactory.auto_tokenizer(pretrained_model_name_or_path=self.folder_name_auto)
61+
62+
self.assertIsNotNone(model)
63+
self.assertEqual(model.name_or_path, self.folder_name_auto)
64+
65+
self.assertIsNotNone(tokenizer)
66+
67+
promt="test"
68+
input=tokenizer(promt, return_tensors="pt")
69+
output=model(**input)
70+
71+
self.assertIsNotNone(output[0])
72+
73+
@unittest.skip("skip test_casual_downloader")
74+
def test_casual_downloader(self):
75+
"""
76+
Test casual_downloader method
77+
"""
78+
Downloader.casual_downloader(model_name=self.model_name_auto, folder_name=self.folder_name_auto)
79+
Downloader.auto_token_downloader(model_name=self.model_name_auto, folder_name=self.folder_name_auto)
80+
81+
model=ModelFactory.auto_model_for_causal_lm(pretrained_model_name_or_path=self.folder_name_auto)
82+
tokenizer=TokenizerFactory.auto_tokenizer(pretrained_model_name_or_path=self.folder_name_auto)
83+
84+
self.assertIsNotNone(model)
85+
self.assertEqual(model.name_or_path, self.folder_name_auto)
86+
87+
self.assertIsNotNone(tokenizer)
88+
89+
prompt="test"
90+
input=tokenizer(prompt, return_tensors="pt")
91+
output=model(**input)
92+
self.assertIsNotNone(output[0])
93+
94+

src/kimchima/utils/downloader.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
from __future__ import annotations
1717

1818
from kimchima.pkg import logging
19-
from transformers import pipeline
19+
from transformers import (
20+
pipeline,
21+
AutoModel,
22+
AutoTokenizer,
23+
AutoModelForCausalLM,
24+
)
2025

2126
logger=logging.get_logger(__name__)
2227

@@ -31,8 +36,10 @@ def __init__(self):
3136
)
3237

3338
@classmethod
34-
def model_downloader(cls, *args, **kwargs)->str:
39+
def model_downloader(cls, *args, **kwargs):
3540
r"""
41+
Here we will use pipeline from Huggingface to download the model.
42+
And save the model to the specified folder.
3643
"""
3744
model_name=kwargs.pop("model_name", None)
3845
if model_name is None:
@@ -41,3 +48,60 @@ def model_downloader(cls, *args, **kwargs)->str:
4148
folder_name=kwargs.pop("folder_name", None)
4249
pipe=pipeline(model=model_name)
4350
pipe.save_pretrained(folder_name if folder_name is not None else model_name)
51+
logger.info(f"Model {model_name} has been downloaded successfully")
52+
53+
54+
@classmethod
55+
def auto_downloader(cls, *args, **kwargs):
56+
r"""
57+
Here we will use AutoModel from Huggingface to download the model.
58+
It support a wider range of models beyond causal language models,
59+
like BERT, RoBERTa, BART, T5 and more.
60+
61+
It returns the base model without a specific head, it does not directly
62+
perform tasks like text generation or translation.
63+
"""
64+
65+
model_name=kwargs.pop("model_name", None)
66+
if model_name is None:
67+
raise ValueError("model_name is required")
68+
folder_name=kwargs.pop("folder_name", None)
69+
70+
model=AutoModel.from_pretrained(model_name)
71+
model.save_pretrained(folder_name if folder_name is not None else model_name)
72+
logger.info(f"Model {model_name} has been downloaded successfully")
73+
74+
75+
@classmethod
76+
def casual_downloader(cls, *args, **kwargs):
77+
r"""
78+
Here we will use AutoModelForCausalLM from Huggingface to download the model
79+
Like GPT-2 XLNet etc.
80+
It return a language modeling head which can be used to generate text,
81+
translate text, write content, answer questions in a informative way.
82+
"""
83+
model_name=kwargs.pop("model_name", None)
84+
if model_name is None:
85+
raise ValueError("model_name is required")
86+
87+
folder_name=kwargs.pop("folder_name", None)
88+
# https://github.com/huggingface/transformers/issues/25296
89+
# https://github.com/huggingface/accelerate/issues/661
90+
model=AutoModelForCausalLM.from_pretrained(model_name)
91+
model.save_pretrained(folder_name if folder_name is not None else model_name)
92+
logger.info(f"Model {model_name} has been downloaded successfully")
93+
94+
@classmethod
95+
def auto_token_downloader(cls, *args, **kwargs):
96+
r"""
97+
Here we will use AutoTokenizer from Huggingface to download the tokenizer congifuration.
98+
"""
99+
model_name=kwargs.pop("model_name", None)
100+
if model_name is None:
101+
raise ValueError("model_name is required")
102+
103+
folder_name=kwargs.pop("folder_name", None)
104+
105+
tokenizer=AutoTokenizer.from_pretrained(model_name)
106+
tokenizer.save_pretrained(folder_name if folder_name is not None else model_name)
107+
logger.info(f"Tokenizer {model_name} has been downloaded successfully")

0 commit comments

Comments
 (0)