91
91
92
92
93
93
def get_model (
94
- model_id : str , revision : Optional [str ], sharded : bool , quantize : Optional [str ]
94
+ model_id : str ,
95
+ revision : Optional [str ],
96
+ sharded : bool ,
97
+ quantize : Optional [str ],
98
+ trust_remote_code : bool ,
95
99
) -> Model :
96
100
if "facebook/galactica" in model_id :
97
101
if sharded :
98
- return GalacticaSharded (model_id , revision , quantize = quantize )
102
+ return GalacticaSharded (
103
+ model_id ,
104
+ revision ,
105
+ quantize = quantize ,
106
+ trust_remote_code = trust_remote_code ,
107
+ )
99
108
else :
100
- return Galactica (model_id , revision , quantize = quantize )
109
+ return Galactica (
110
+ model_id ,
111
+ revision ,
112
+ quantize = quantize ,
113
+ trust_remote_code = trust_remote_code ,
114
+ )
101
115
102
116
if model_id .startswith ("bigcode/" ):
103
117
if sharded :
104
118
if not FLASH_ATTENTION :
105
119
raise NotImplementedError (
106
120
FLASH_ATT_ERROR_MESSAGE .format (f"Sharded Santacoder" )
107
121
)
108
- return FlashSantacoderSharded (model_id , revision , quantize = quantize )
122
+ return FlashSantacoderSharded (
123
+ model_id ,
124
+ revision ,
125
+ quantize = quantize ,
126
+ trust_remote_code = trust_remote_code ,
127
+ )
109
128
else :
110
129
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
111
- return santacoder_cls (model_id , revision , quantize = quantize )
130
+ return santacoder_cls (
131
+ model_id ,
132
+ revision ,
133
+ quantize = quantize ,
134
+ trust_remote_code = trust_remote_code ,
135
+ )
112
136
113
- config = AutoConfig .from_pretrained (model_id , revision = revision )
137
+ config = AutoConfig .from_pretrained (
138
+ model_id , revision = revision , trust_remote_code = trust_remote_code
139
+ )
114
140
model_type = config .model_type
115
141
116
142
if model_type == "gpt_bigcode" :
@@ -119,52 +145,133 @@ def get_model(
119
145
raise NotImplementedError (
120
146
FLASH_ATT_ERROR_MESSAGE .format (f"Sharded Santacoder" )
121
147
)
122
- return FlashSantacoderSharded (model_id , revision , quantize = quantize )
148
+ return FlashSantacoderSharded (
149
+ model_id ,
150
+ revision ,
151
+ quantize = quantize ,
152
+ trust_remote_code = trust_remote_code ,
153
+ )
123
154
else :
124
155
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
125
- return santacoder_cls (model_id , revision , quantize = quantize )
156
+ return santacoder_cls (
157
+ model_id ,
158
+ revision ,
159
+ quantize = quantize ,
160
+ trust_remote_code = trust_remote_code ,
161
+ )
126
162
127
163
if model_type == "bloom" :
128
164
if sharded :
129
- return BLOOMSharded (model_id , revision , quantize = quantize )
165
+ return BLOOMSharded (
166
+ model_id ,
167
+ revision ,
168
+ quantize = quantize ,
169
+ trust_remote_code = trust_remote_code ,
170
+ )
130
171
else :
131
- return BLOOM (model_id , revision , quantize = quantize )
172
+ return BLOOM (
173
+ model_id ,
174
+ revision ,
175
+ quantize = quantize ,
176
+ trust_remote_code = trust_remote_code ,
177
+ )
132
178
133
179
if model_type == "gpt_neox" :
134
180
if sharded :
135
181
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
136
- return neox_cls (model_id , revision , quantize = quantize )
182
+ return neox_cls (
183
+ model_id ,
184
+ revision ,
185
+ quantize = quantize ,
186
+ trust_remote_code = trust_remote_code ,
187
+ )
137
188
else :
138
189
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
139
- return neox_cls (model_id , revision , quantize = quantize )
190
+ return neox_cls (
191
+ model_id ,
192
+ revision ,
193
+ quantize = quantize ,
194
+ trust_remote_code = trust_remote_code ,
195
+ )
140
196
141
197
if model_type == "llama" :
142
198
if sharded :
143
199
if FLASH_ATTENTION :
144
- return FlashLlamaSharded (model_id , revision , quantize = quantize )
200
+ return FlashLlamaSharded (
201
+ model_id ,
202
+ revision ,
203
+ quantize = quantize ,
204
+ trust_remote_code = trust_remote_code ,
205
+ )
145
206
raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format (f"Sharded Llama" ))
146
207
else :
147
208
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
148
- return llama_cls (model_id , revision , quantize = quantize )
209
+ return llama_cls (
210
+ model_id ,
211
+ revision ,
212
+ quantize = quantize ,
213
+ trust_remote_code = trust_remote_code ,
214
+ )
149
215
150
216
if config .model_type == "opt" :
151
217
if sharded :
152
- return OPTSharded (model_id , revision , quantize = quantize )
218
+ return OPTSharded (
219
+ model_id ,
220
+ revision ,
221
+ quantize = quantize ,
222
+ trust_remote_code = trust_remote_code ,
223
+ )
153
224
else :
154
- return OPT (model_id , revision , quantize = quantize )
225
+ return OPT (
226
+ model_id ,
227
+ revision ,
228
+ quantize = quantize ,
229
+ trust_remote_code = trust_remote_code ,
230
+ )
155
231
156
232
if model_type == "t5" :
157
233
if sharded :
158
- return T5Sharded (model_id , revision , quantize = quantize )
234
+ return T5Sharded (
235
+ model_id ,
236
+ revision ,
237
+ quantize = quantize ,
238
+ trust_remote_code = trust_remote_code ,
239
+ )
159
240
else :
160
- return Seq2SeqLM (model_id , revision , quantize = quantize )
241
+ return Seq2SeqLM (
242
+ model_id ,
243
+ revision ,
244
+ quantize = quantize ,
245
+ trust_remote_code = trust_remote_code ,
246
+ )
161
247
162
248
if sharded :
163
249
raise ValueError ("sharded is not supported for AutoModel" )
164
250
165
251
if model_type in modeling_auto .MODEL_FOR_CAUSAL_LM_MAPPING_NAMES :
166
- return CausalLM (model_id , revision , quantize = quantize )
252
+ return CausalLM (
253
+ model_id , revision , quantize = quantize , trust_remote_code = trust_remote_code
254
+ )
167
255
if model_type in modeling_auto .MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES :
168
- return Seq2SeqLM (model_id , revision , quantize = quantize )
256
+ return Seq2SeqLM (
257
+ model_id , revision , quantize = quantize , trust_remote_code = trust_remote_code
258
+ )
259
+
260
+ auto_map = getattr (config , "auto_map" , None )
261
+ if trust_remote_code and auto_map is not None :
262
+ if "AutoModelForCausalLM" in auto_map .keys ():
263
+ return CausalLM (
264
+ model_id ,
265
+ revision ,
266
+ quantize = quantize ,
267
+ trust_remote_code = trust_remote_code ,
268
+ )
269
+ if "AutoModelForSeq2SeqLM" in auto_map .keys :
270
+ return Seq2SeqLM (
271
+ model_id ,
272
+ revision ,
273
+ quantize = quantize ,
274
+ trust_remote_code = trust_remote_code ,
275
+ )
169
276
170
277
raise ValueError (f"Unsupported model type { model_type } " )
0 commit comments