Skip to content

Commit e3e487d

Browse files
feat(server): support trust_remote_code (#363)
1 parent e9669a4 commit e3e487d

File tree

17 files changed

+321
-72
lines changed

17 files changed

+321
-72
lines changed

.github/workflows/build.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,12 @@ jobs:
213213
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
214214
- name: Install
215215
run: |
216-
pip install pytest-xdist
217216
make install-integration-tests
218217
- name: Run tests
219218
run: |
220219
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
221220
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
222-
pytest -s -vv -n 2 --dist loadfile integration-tests
221+
pytest -s -vv integration-tests
223222
224223
stop-runner:
225224
name: Stop self-hosted EC2 runner

launcher/src/main.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ struct Args {
5353
#[clap(long, env)]
5454
revision: Option<String>,
5555

56-
/// Wether to shard or not the model across multiple GPUs
56+
/// Whether to shard the model across multiple GPUs
5757
/// By default text-generation-inference will use all available GPUs to run
5858
/// the model. Setting it to `false` deactivates `num_shard`.
5959
#[clap(long, env)]
@@ -66,11 +66,17 @@ struct Args {
6666
#[clap(long, env)]
6767
num_shard: Option<usize>,
6868

69-
/// Wether you want the model to be quantized or not. This will use `bitsandbytes` for
69+
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
7070
/// quantization on the fly, or `gptq`.
7171
#[clap(long, env, value_enum)]
7272
quantize: Option<Quantization>,
7373

74+
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
75+
/// encouraged when loading a model with custom code to ensure no malicious code has been
76+
/// contributed in a newer revision.
77+
#[clap(long, env, value_enum)]
78+
trust_remote_code: bool,
79+
7480
/// The maximum amount of concurrent requests for this particular deployment.
7581
/// Having a low limit will refuse clients requests instead of having them
7682
/// wait for too long and is usually good to handle backpressure correctly.
@@ -239,6 +245,7 @@ fn shard_manager(
239245
model_id: String,
240246
revision: Option<String>,
241247
quantize: Option<Quantization>,
248+
trust_remote_code: bool,
242249
uds_path: String,
243250
rank: usize,
244251
world_size: usize,
@@ -272,6 +279,11 @@ fn shard_manager(
272279
"--json-output".to_string(),
273280
];
274281

282+
// Activate trust remote code
283+
if trust_remote_code {
284+
shard_argv.push("--trust-remote-code".to_string());
285+
}
286+
275287
// Activate tensor parallelism
276288
if world_size > 1 {
277289
shard_argv.push("--sharded".to_string());
@@ -692,6 +704,16 @@ fn spawn_shards(
692704
status_sender: mpsc::Sender<ShardStatus>,
693705
running: Arc<AtomicBool>,
694706
) -> Result<(), LauncherError> {
707+
if args.trust_remote_code {
708+
tracing::warn!(
709+
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
710+
args.model_id
711+
);
712+
if args.revision.is_none() {
713+
tracing::warn!("Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.");
714+
}
715+
}
716+
695717
// Start shard processes
696718
for rank in 0..num_shard {
697719
let model_id = args.model_id.clone();
@@ -705,6 +727,7 @@ fn spawn_shards(
705727
let shutdown_sender = shutdown_sender.clone();
706728
let otlp_endpoint = args.otlp_endpoint.clone();
707729
let quantize = args.quantize;
730+
let trust_remote_code = args.trust_remote_code;
708731
let master_port = args.master_port;
709732
let disable_custom_kernels = args.disable_custom_kernels;
710733
let watermark_gamma = args.watermark_gamma;
@@ -714,6 +737,7 @@ fn spawn_shards(
714737
model_id,
715738
revision,
716739
quantize,
740+
trust_remote_code,
717741
uds_path,
718742
rank,
719743
num_shard,

server/text_generation_server/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def serve(
2222
revision: Optional[str] = None,
2323
sharded: bool = False,
2424
quantize: Optional[Quantization] = None,
25+
trust_remote_code: bool = False,
2526
uds_path: Path = "/tmp/text-generation-server",
2627
logger_level: str = "INFO",
2728
json_output: bool = False,
@@ -63,7 +64,7 @@ def serve(
6364

6465
# Downgrade enum into str for easier management later on
6566
quantize = None if quantize is None else quantize.value
66-
server.serve(model_id, revision, sharded, quantize, uds_path)
67+
server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path)
6768

6869

6970
@app.command()

server/text_generation_server/models/__init__.py

Lines changed: 127 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -91,26 +91,52 @@
9191

9292

9393
def get_model(
94-
model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
94+
model_id: str,
95+
revision: Optional[str],
96+
sharded: bool,
97+
quantize: Optional[str],
98+
trust_remote_code: bool,
9599
) -> Model:
96100
if "facebook/galactica" in model_id:
97101
if sharded:
98-
return GalacticaSharded(model_id, revision, quantize=quantize)
102+
return GalacticaSharded(
103+
model_id,
104+
revision,
105+
quantize=quantize,
106+
trust_remote_code=trust_remote_code,
107+
)
99108
else:
100-
return Galactica(model_id, revision, quantize=quantize)
109+
return Galactica(
110+
model_id,
111+
revision,
112+
quantize=quantize,
113+
trust_remote_code=trust_remote_code,
114+
)
101115

102116
if model_id.startswith("bigcode/"):
103117
if sharded:
104118
if not FLASH_ATTENTION:
105119
raise NotImplementedError(
106120
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
107121
)
108-
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
122+
return FlashSantacoderSharded(
123+
model_id,
124+
revision,
125+
quantize=quantize,
126+
trust_remote_code=trust_remote_code,
127+
)
109128
else:
110129
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
111-
return santacoder_cls(model_id, revision, quantize=quantize)
130+
return santacoder_cls(
131+
model_id,
132+
revision,
133+
quantize=quantize,
134+
trust_remote_code=trust_remote_code,
135+
)
112136

113-
config = AutoConfig.from_pretrained(model_id, revision=revision)
137+
config = AutoConfig.from_pretrained(
138+
model_id, revision=revision, trust_remote_code=trust_remote_code
139+
)
114140
model_type = config.model_type
115141

116142
if model_type == "gpt_bigcode":
@@ -119,52 +145,133 @@ def get_model(
119145
raise NotImplementedError(
120146
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
121147
)
122-
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
148+
return FlashSantacoderSharded(
149+
model_id,
150+
revision,
151+
quantize=quantize,
152+
trust_remote_code=trust_remote_code,
153+
)
123154
else:
124155
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
125-
return santacoder_cls(model_id, revision, quantize=quantize)
156+
return santacoder_cls(
157+
model_id,
158+
revision,
159+
quantize=quantize,
160+
trust_remote_code=trust_remote_code,
161+
)
126162

127163
if model_type == "bloom":
128164
if sharded:
129-
return BLOOMSharded(model_id, revision, quantize=quantize)
165+
return BLOOMSharded(
166+
model_id,
167+
revision,
168+
quantize=quantize,
169+
trust_remote_code=trust_remote_code,
170+
)
130171
else:
131-
return BLOOM(model_id, revision, quantize=quantize)
172+
return BLOOM(
173+
model_id,
174+
revision,
175+
quantize=quantize,
176+
trust_remote_code=trust_remote_code,
177+
)
132178

133179
if model_type == "gpt_neox":
134180
if sharded:
135181
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
136-
return neox_cls(model_id, revision, quantize=quantize)
182+
return neox_cls(
183+
model_id,
184+
revision,
185+
quantize=quantize,
186+
trust_remote_code=trust_remote_code,
187+
)
137188
else:
138189
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
139-
return neox_cls(model_id, revision, quantize=quantize)
190+
return neox_cls(
191+
model_id,
192+
revision,
193+
quantize=quantize,
194+
trust_remote_code=trust_remote_code,
195+
)
140196

141197
if model_type == "llama":
142198
if sharded:
143199
if FLASH_ATTENTION:
144-
return FlashLlamaSharded(model_id, revision, quantize=quantize)
200+
return FlashLlamaSharded(
201+
model_id,
202+
revision,
203+
quantize=quantize,
204+
trust_remote_code=trust_remote_code,
205+
)
145206
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
146207
else:
147208
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
148-
return llama_cls(model_id, revision, quantize=quantize)
209+
return llama_cls(
210+
model_id,
211+
revision,
212+
quantize=quantize,
213+
trust_remote_code=trust_remote_code,
214+
)
149215

150216
if config.model_type == "opt":
151217
if sharded:
152-
return OPTSharded(model_id, revision, quantize=quantize)
218+
return OPTSharded(
219+
model_id,
220+
revision,
221+
quantize=quantize,
222+
trust_remote_code=trust_remote_code,
223+
)
153224
else:
154-
return OPT(model_id, revision, quantize=quantize)
225+
return OPT(
226+
model_id,
227+
revision,
228+
quantize=quantize,
229+
trust_remote_code=trust_remote_code,
230+
)
155231

156232
if model_type == "t5":
157233
if sharded:
158-
return T5Sharded(model_id, revision, quantize=quantize)
234+
return T5Sharded(
235+
model_id,
236+
revision,
237+
quantize=quantize,
238+
trust_remote_code=trust_remote_code,
239+
)
159240
else:
160-
return Seq2SeqLM(model_id, revision, quantize=quantize)
241+
return Seq2SeqLM(
242+
model_id,
243+
revision,
244+
quantize=quantize,
245+
trust_remote_code=trust_remote_code,
246+
)
161247

162248
if sharded:
163249
raise ValueError("sharded is not supported for AutoModel")
164250

165251
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
166-
return CausalLM(model_id, revision, quantize=quantize)
252+
return CausalLM(
253+
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
254+
)
167255
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
168-
return Seq2SeqLM(model_id, revision, quantize=quantize)
256+
return Seq2SeqLM(
257+
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
258+
)
259+
260+
auto_map = getattr(config, "auto_map", None)
261+
if trust_remote_code and auto_map is not None:
262+
if "AutoModelForCausalLM" in auto_map.keys():
263+
return CausalLM(
264+
model_id,
265+
revision,
266+
quantize=quantize,
267+
trust_remote_code=trust_remote_code,
268+
)
269+
if "AutoModelForSeq2SeqLM" in auto_map.keys:
270+
return Seq2SeqLM(
271+
model_id,
272+
revision,
273+
quantize=quantize,
274+
trust_remote_code=trust_remote_code,
275+
)
169276

170277
raise ValueError(f"Unsupported model type {model_type}")

server/text_generation_server/models/bloom.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,13 @@ def __init__(
5454
model_id: str,
5555
revision: Optional[str] = None,
5656
quantize: Optional[str] = None,
57+
trust_remote_code: bool = False,
5758
):
5859
super(BLOOM, self).__init__(
59-
model_id=model_id, revision=revision, quantize=quantize
60+
model_id=model_id,
61+
revision=revision,
62+
quantize=quantize,
63+
trust_remote_code=trust_remote_code,
6064
)
6165

6266
@property
@@ -70,6 +74,7 @@ def __init__(
7074
model_id: str,
7175
revision: Optional[str] = None,
7276
quantize: Optional[str] = None,
77+
trust_remote_code: bool = False,
7378
):
7479
self.process_group, rank, world_size = initialize_torch_distributed()
7580
if torch.cuda.is_available():
@@ -80,19 +85,29 @@ def __init__(
8085
dtype = torch.float32
8186

8287
tokenizer = AutoTokenizer.from_pretrained(
83-
model_id, revision=revision, padding_side="left", truncation_side="left"
88+
model_id,
89+
revision=revision,
90+
padding_side="left",
91+
truncation_side="left",
92+
trust_remote_code=trust_remote_code,
8493
)
8594

8695
config = AutoConfig.from_pretrained(
87-
model_id, revision=revision, slow_but_exact=False, tp_parallel=True
96+
model_id,
97+
revision=revision,
98+
slow_but_exact=False,
99+
tp_parallel=True,
100+
trust_remote_code=trust_remote_code,
88101
)
89102
config.pad_token_id = 3
90103

91104
torch.distributed.barrier(group=self.process_group)
92105
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
93106

94107
with init_empty_weights():
95-
model = AutoModelForCausalLM.from_config(config)
108+
model = AutoModelForCausalLM.from_config(
109+
config, trust_remote_code=trust_remote_code
110+
)
96111

97112
torch.distributed.barrier(group=self.process_group)
98113
self.load_weights(

0 commit comments

Comments
 (0)