Skip to content

Commit aa98ba0

Browse files
Introduce seq_len as inference param, and improve warnings (#15716)
Summary: Changes: 1. add `--seq_len` param to llama script to distinguish max_seq_len which is compile time param 2. Add warnings in the runner when `seq_len` is clamped to `max_seq_len` to avoid silently clamping it. 3. Add warnings in the token generator when EOS is not reached due to insufficient seq_len or max_seq_len. Differential Revision: D86696759
1 parent 3e90b44 commit aa98ba0

File tree

3 files changed

+55
-3
lines changed

3 files changed

+55
-3
lines changed

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,13 @@ def post_process():
941941
with open(f"{args.artifact}/outputs/outputs.txt", "r") as f:
942942
outputs.append(f.read())
943943

944-
seq_len = args.max_seq_len
944+
# Use --seq_len if provided (inference-only), otherwise fall back to --max_seq_len
945+
seq_len = args.seq_len
946+
if seq_len is None:
947+
logging.info(
948+
f"--seq_len not provided, using --max_seq_len ({args.max_seq_len}) as fallback"
949+
)
950+
seq_len = args.max_seq_len
945951
multi_prompts = " ".join([f'--prompt "{prompt}"' for prompt in args.prompt])
946952
lookahead_args = " ".join(
947953
[
@@ -1170,11 +1176,18 @@ def _build_parser():
11701176

11711177
parser.add_argument(
11721178
"--max_seq_len",
1173-
help="This refers to maximum number of tokens that the model can process & consider at once to generate predictions/responses.",
1179+
help="[Compile-time] Maximum sequence length compiled into the model (sets buffer sizes and context_len). This is the hard limit for the model's context window.",
11741180
default=512,
11751181
type=int,
11761182
)
11771183

1184+
parser.add_argument(
1185+
"--seq_len",
1186+
help="[Runtime] Maximum number of tokens to generate (prompt + output). If not specified, uses --max_seq_len. Will be clamped to compiled max_seq_len if exceeded.",
1187+
default=None,
1188+
type=int,
1189+
)
1190+
11781191
parser.add_argument(
11791192
"--prefill_ar_len",
11801193
help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid and lookahead mode.",

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,22 @@ Error Runner<T>::generate_from_prompt_or_file(
376376
stats_.inference_start_ms = time_in_ms();
377377

378378
int32_t seq_len = config.seq_len;
379-
seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
379+
if (seq_len > context_len_) {
380+
ET_LOG(
381+
Info,
382+
"Warning: Requested seq_len (%d) exceeds compiled max_seq_len (%d). Clamping to %d.",
383+
seq_len,
384+
context_len_,
385+
context_len_);
386+
seq_len = context_len_;
387+
} else if (seq_len <= 0) {
388+
ET_LOG(
389+
Info,
390+
"Warning: Invalid seq_len (%d). Using compiled max_seq_len (%d).",
391+
seq_len,
392+
context_len_);
393+
seq_len = context_len_;
394+
}
380395
int32_t n_bos = (cur_pos_ == 0) ? 1 : 0;
381396

382397
// encode the (string) prompt into tokens sequence

examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,30 @@ Result<int64_t> TokenGenerator<T>::generate(
323323
break;
324324
}
325325
}
326+
327+
// Check if generation was truncated due to seq_len limit (no EOS token)
328+
if (eos_ids_->count(cur_token) == 0 && pos >= seq_len - 1) {
329+
printf("\n");
330+
ET_LOG(
331+
Info,
332+
"Warning: Generation stopped at seq_len limit (%d) without reaching EOS token. Response may be incomplete.",
333+
seq_len);
334+
if (seq_len >= metadata_.context_len) {
335+
ET_LOG(
336+
Info,
337+
"- seq_len (%d) already equals compiled max_seq_len (%d). Consider recompiling with larger --max_seq_len.",
338+
seq_len,
339+
metadata_.context_len);
340+
} else {
341+
ET_LOG(
342+
Info,
343+
"- seq_len (%d) is less than compiled max_seq_len (%d). Consider increasing --seq_len (up to %d).",
344+
seq_len,
345+
metadata_.context_len,
346+
metadata_.context_len);
347+
}
348+
}
349+
326350
return pos - start_pos;
327351
}
328352
// Explicit instantiations

0 commit comments

Comments
 (0)