Skip to content

Fix crash in greedy assisted generation with different tokenizers#46936

Open
Sunt-ing wants to merge 1 commit into
huggingface:mainfrom
Sunt-ing:2
Open

Fix crash in greedy assisted generation with different tokenizers#46936
Sunt-ing wants to merge 1 commit into
huggingface:mainfrom
Sunt-ing:2

Conversation

@Sunt-ing

Copy link
Copy Markdown
Contributor

What does this PR do?

Universal Assisted Generation with a greedy main model (do_sample=False, main and assistant on different tokenizers) crashes with a shape mismatch:

RuntimeError: The size of tensor a (23) must match the size of tensor b (14) at non-singleton dimension 1

generate(..., do_sample=False) with different tokenizers routes to AssistedCandidateGeneratorDifferentTokenizers. Its assistant_kwargs inherit the main model's kwargs, including a position_ids sized to the main tokenizer's length. The assistant re-encodes the prompt with its own tokenizer, usually to a different length, but that inherited position_ids is passed straight into the assistant's draft round unchanged. An absolute-position assistant (GPT-2) then crashes for any length mismatch, and a rotary assistant (Llama) crashes whenever its re-encoding is longer than the main sequence. The generator already pops the inherited attention_mask for this exact reason, but never the position_ids.

The fix drops the inherited position_ids before the assistant generates, right next to the existing attention_mask pop that handles the identical main-length-mismatch problem. The assistant then rebuilds position_ids from its own input. Assisted decoding stays lossless: greedy UAG output is now token-identical to plain greedy.

Reproduction (CPU, real tiny checkpoints) and before/after
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def load(name):
    tok = AutoTokenizer.from_pretrained(name)
    model = AutoModelForCausalLM.from_pretrained(name, dtype=torch.float32).eval()
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    return model, tok

main, main_tok = load("hf-internal-testing/tiny-random-LlamaForCausalLM")   # rotary
asst, asst_tok = load("hf-internal-testing/tiny-random-gpt2")               # absolute pos
ids = main_tok("The quick brown fox jumps over the lazy dog", return_tensors="pt").input_ids

kw = dict(attention_mask=torch.ones_like(ids), max_new_tokens=20,
          do_sample=False, num_beams=1, pad_token_id=main_tok.pad_token_id)
greedy = main.generate(ids, **kw)
uag = main.generate(ids, assistant_model=asst, tokenizer=main_tok,
                    assistant_tokenizer=asst_tok, **kw)
assert greedy[0].tolist() == uag[0].tolist()   # lossless

Before this PR:

main=Llama(rotary)  assistant=GPT-2(absolute pos)  -> RuntimeError: size of tensor a (23) must match b (14)
rotary assistant, re-encoding longer than main     -> RuntimeError
absolute-pos assistant, any length mismatch        -> RuntimeError
rotary assistant, re-encoding shorter than main    -> OK

After this PR: all of the above run, and greedy UAG output is token-identical to plain greedy across every case. Reverting the fix brings the crash back. ruff check and ruff format are clean.

  • I confirm that this is not a pure code agent PR.

Who can review?

@gante

Greedy Universal Assisted Generation (do_sample=False with the main and
assistant models on different tokenizers) crashed with a tensor-size
RuntimeError. The assistant re-encodes the prompt to a different length, but
the main model position_ids were inherited unchanged into the assistant first
draft round, mismatching its input length. Drop the inherited position_ids
before the assistant generates, next to the existing attention_mask pop, so
the assistant rebuilds them from its own input. Greedy assisted decoding stays
lossless versus plain greedy.
@github-actions

Copy link
Copy Markdown
Contributor

CI Dashboard: View test results in Grafana

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant