Skip to content

PyTorch backend (1.3.0rc15): KV-cache block reuse changes first-token logits between first and subsequent identical requests #15503

Description

@ShuaiShao93

Summary

In tensorrt-llm==1.3.0rc15, running LLM(...).generate([prompt], ..., max_tokens=1, temperature=0.0) repeatedly with the same prompt returns different generation logits on the first call than on every subsequent identical call. Runs 2..N are bit-identical to each other, but each differs from run 1 across ~71-87% of vocab positions, with max_abs up to ~0.16 (~1 BF16 ULP at the relevant logit scale).

Confirmed root cause: KV-cache block reuse. Disabling it via kv_cache_config=KvCacheConfig(enable_block_reuse=False) makes runs 1..N bit-identical (neq=0 everywhere). The first call computes K/V freshly during prefill and stores them in cache blocks; subsequent identical calls detect the matching prefix and reuse the cached K/V instead of recomputing. The two attention code paths — fresh compute-and-store vs paged-context-FMHA reading from cache — disagree by ~1 BF16 ULP per position, compounding through 32 layers into a ~0.1 head-logit delta.

What we're comparing

To be unambiguous about scope: each generate() call produces one new token (max_tokens=1). For each call I take the generation_logits of that one step — shape [1, 1, vocab_size], flattened to [vocab_size] — and compare across separate identical generate() calls of the same prompt.

There is no multi-token generation in this repro, no LoRA, no FP8, no chunked prefill, no LoRA, no batching. Single prompt, separate calls.

Environment

  • tensorrt-llm 1.3.0rc15 (PyTorch backend, default tensorrt_llm.LLM API)
  • Python 3.12, NVIDIA L40S (SM 8.9), driver 595.71.05, CUDA 13.2
  • Model: meta-llama/Meta-Llama-3.1-8B-Instruct, vanilla BF16 (no quantization)

Minimal repro

import numpy as np
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import KvCacheConfig

def trial(enable_block_reuse: bool):
    llm = LLM(
        model="meta-llama/Meta-Llama-3.1-8B-Instruct",
        tokenizer="meta-llama/Meta-Llama-3.1-8B-Instruct",
        max_seq_len=512,
        max_batch_size=8,
        enable_chunked_prefill=False,
        kv_cache_config=KvCacheConfig(enable_block_reuse=enable_block_reuse),
    )
    sp = SamplingParams(
        max_tokens=1, temperature=0.0, ignore_eos=True,
        return_generation_logits=True, detokenize=False,
    )
    prompt = "The quick brown fox jumps over the lazy dog. " * 4

    logits = []
    for _ in range(6):                              # 6 separate generate() calls
        out = llm.generate([prompt], sampling_params=sp, use_tqdm=False)
        g = out[0].outputs[0].generation_logits     # [1, 1, vocab_size]
        arr = g.detach().cpu().numpy() if hasattr(g, "detach") else np.asarray(g)
        logits.append(arr.reshape(-1))              # (vocab,) — single token's logits

    print(f"--- enable_block_reuse={enable_block_reuse} ---")
    for i in range(1, len(logits)):
        neq = int(np.count_nonzero(logits[0] != logits[i]))
        mx  = float(np.max(np.abs(logits[0].astype(np.float32) - logits[i].astype(np.float32))))
        print(f"  run 1 vs run {i+1}: neq={neq:6d}  max_abs={mx:.6f}")
    llm.shutdown()

trial(enable_block_reuse=True)   # default — diverges
trial(enable_block_reuse=False)  # diverges become zero

Observed

--- enable_block_reuse=True  (default) ---
  run 1 vs run 2: neq= 91315  max_abs=0.109375
  run 1 vs run 3: neq= 91315  max_abs=0.109375
  run 1 vs run 4: neq= 91315  max_abs=0.109375
  run 1 vs run 5: neq= 91315  max_abs=0.109375
  run 1 vs run 6: neq= 91315  max_abs=0.109375

--- enable_block_reuse=False ---
  run 1 vs run 2: neq=     0  max_abs=0.000000
  run 1 vs run 3: neq=     0  max_abs=0.000000
  run 1 vs run 4: neq=     0  max_abs=0.000000
  run 1 vs run 5: neq=     0  max_abs=0.000000
  run 1 vs run 6: neq=     0  max_abs=0.000000

The identical neq across "run 1 vs run i" comparisons (when reuse is on) plus runs 2..6 being bit-identical to each other indicates: run 1 takes one code path (fresh prefill, write to cache), runs 2..N take a different deterministic code path (load from cache, no fresh compute). The two code paths disagree by ~1 BF16 ULP per position, compounding through the model into the observed ~0.1 head-logit delta.

Additional ruled-out causes

  • CUDA graphs (cuda_graph_config=None): no change, gap still present.
  • Overlap scheduler (disable_overlap_scheduler=True): no change.
  • Both disabled simultaneously: no change.
  • enable_chunked_prefill=False (already off).

KV block reuse is the only thing that flips this.

Why this matters

For any consumer that compares generation logits or softmax-derived probabilities for the same prompt — golden tests, classifier-style probes that read fixed token-id probability sums, score reproducibility checks — the first call to a never-seen prompt produces a different answer than every repeat. At max_abs ≈ 0.1 on BF16 logits, probabilities near a decision boundary can flip across the first/subsequent boundary.

enable_block_reuse is on by default and is a real performance feature; users typically don't want to disable it. The structural problem is that the fresh-prefill code path and the cache-reload code path are not bit-equivalent, despite both being deterministic individually.

Expected

Either the fresh-prefill and cache-reload paths should produce bit-identical outputs (so the optimization is invisible), or the API/docs should call out that enable_block_reuse=True makes logits batch-history-dependent and recommend disabling it for use cases that need stable per-prompt logits.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Inference runtime<NV>General operational aspects of TRTLLM execution not in other categories.Pytorch<NV>Pytorch backend related issues

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions