[None][fix] Fix encoder-decoder beam search corruption (cross-KV sharing, mixed-beam exception, gather_generation_logits)#15444
Conversation
In KVCacheManager::copyBlockOffsets, when building the cross-attention KV cache offset table for encoder-decoder models (e.g. Whisper), each beam of a request was assigned separate physical blocks. However, the encoder context phase only writes encoder features to beam-0's blocks (contextBeamWidth=1 means only slot 0 per request is addressed). Beams 1..N-1 pointed to uninitialised GPU memory, causing degenerate repetitive output such as "happ happ happ" during beam-search decoding. Fix: when isCrossKv(), always use beam-0's source block IDs and block count for every beam. The encoder output is identical for all beams of a request, so sharing the same physical blocks is semantically correct. Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
When concurrent requests with different beam widths (e.g. beam=1 and beam=5) triggered a verifyRequests() exception in TrtGptModelInflightBatching:: forwardAsync(), the exception handler freed the KV-cache sequence slots via terminateRequest() but did not remove the affected request IDs from mInflightReqIds. The subsequent changeBeamWidth() call checks TLLM_CHECK(mInflightReqIds.empty()) and would abort the process. Fix: in the exception catch block, erase each active request from mInflightReqIds before calling terminateRequest(), then call changeBeamWidth(mOperatingBeamWidth) to reset RuntimeBuffers and DecoderState so the next batch starts from a clean state. Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
…eam widths copyGenerationLogits() used mergeLogitsFragmentsKernel — a CUDA kernel that reads fragment GPU addresses from a pointer array (cachePointerDevice) — to gather kCACHE_LENGTH logit steps into a transposedLogits scratch buffer before copying to host. When gather_generation_logits=True was combined with concurrent requests of different beam widths (causing changeBeamWidth() to be called between batches), this kernel intermittently corrupted token generation, resulting in ~26% runs producing degenerate repetitive output. Fix: bypass the kernel entirely. Copy each (beam, step) logit slice directly from the fragment GPU tensor to the corresponding host location using bufferManager.copy(). This is functionally equivalent — the fragments already carry the correct data — and eliminates the pointer-indirection pattern that caused the corruption. Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
…ch buffers mergeLogitsFragmentsKernel and its wrapper mergeLogitsFragments are no longer called after copyGenerationLogits() was rewritten to copy generation logits directly from each fragment to the host without the GPU-side gather step. Remove: - mergeLogitsFragmentsKernel / invokeMergeLogitsFragments / mergeLogitsFragments from runtimeKernels.cu and its declaration from runtimeKernels.h - GenerationLogitsCache scratch fields that only served the kernel: transposedLogits, fragmentPointerDevice, fragmentPointerHost, workIdx, cycleWorkIdx(), getFragmentPointerHost() - Corresponding GPU/pinned-memory allocations in runtimeBuffers.cpp - Now-unused #include of runtimeKernels.h in inflightBatchingUtils.cpp Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
Two tests in encDecBeamSearchTest.cpp: CrossKvBeamSharingTest/CopyBlockOffsetsAllBeamsShareBeam0Blocks Verifies that KVCacheManager::copyBlockOffsets for a cross-KV cache (CacheType::kCROSS, beam width > 1) writes the same physical block IDs into every beam slot. In the simple context-only setup the allocator shares blocks across beams, so this acts as a sanity check for the correct invariant. CopyGenerationLogitsTest/DirectCopyPlacesEachBeamStepAtCorrectHostOffset Verifies that copyGenerationLogits correctly places each (beam, step) logit from its GPU fragment tensor into the right slot of the host logits buffer. Making copyGenerationLogits a no-op produces 56 assertion failures. Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
cadd97b to
0c5a864
Compare
📝 WalkthroughWalkthroughTwo encoder-decoder beam search bugs are fixed: ChangesEncoder-Decoder Beam Search Bug Fixes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp (1)
119-124: ⚡ Quick winMake the new slice temporaries and slice length guideline-compliant.
fragBeamSliceandhostStepSliceare not reassigned, and the new slice-size literal should be named once.♻️ Proposed cleanup
+ SizeType32 constexpr kSingleLogitStep = 1; for (SizeType32 beam = 0; beam < reqBeamWidth; beam++) { auto const droppedSize = !numDroppedTokens.empty() ? numDroppedTokens.at(beam) : 0; auto const beamFragmentSize = fragmentSize - droppedSize; auto const numGenerationToken @@ for (SizeType32 stepIdx = 0; stepIdx < static_cast<SizeType32>(beamFragmentSize); ++stepIdx) { // frag shape: [1, beamWidth, vocabSizePadded]. Beam b starts at offset b*vocab. - auto fragBeamSlice = ITensor::slice(fragments.at(stepIdx), {0, beam}, 1); + auto const fragBeamSlice = ITensor::slice(fragments.at(stepIdx), {0, beam}, kSingleLogitStep); // host shape: [beamWidth, mMaxNewTokens, vocabSizePadded]. Target: [beam, hostOffset+stepIdx, :]. - auto hostStepSlice = ITensor::slice(llmReq.getGenerationLogitsHost(), {beam, hostOffset + stepIdx}, 1); + auto const hostStepSlice + = ITensor::slice(llmReq.getGenerationLogitsHost(), {beam, hostOffset + stepIdx}, kSingleLogitStep); bufferManager.copy(*fragBeamSlice, *hostStepSlice); }As per coding guidelines, “Variables not modified after initialization should be declared as
const” and “All literals except0,nullptr,true,falseshould only be used for variable initialization.”🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp` around lines 119 - 124, The variables fragBeamSlice and hostStepSlice are not reassigned after initialization, so they should be declared as const to comply with coding guidelines. Additionally, the literal value 1 is used directly in both ITensor::slice calls instead of being extracted to a named variable; create a const variable to hold this slice dimension value once and use it in both the ITensor::slice call for fragBeamSlice and the ITensor::slice call for hostStepSlice.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp`:
- Around line 125-128: The bufferManager.copy() method enqueues asynchronous
GPU-to-host copies on the internal stream without waiting for completion.
Calling llmReq.clearGenerationLogitsFragments() immediately after the copy loop
destroys the source tensors while the GPU may still be reading them, creating a
use-after-free condition. Add bufferManager.getStream().synchronize() after the
closing brace of the copy loop and before the
llmReq.clearGenerationLogitsFragments() call to ensure all asynchronous copies
complete before the fragments are cleared. Apply the same fix in
handleGenerationLogits.cpp where this same pattern exists around the
corresponding generation logits handling code.
---
Nitpick comments:
In `@cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp`:
- Around line 119-124: The variables fragBeamSlice and hostStepSlice are not
reassigned after initialization, so they should be declared as const to comply
with coding guidelines. Additionally, the literal value 1 is used directly in
both ITensor::slice calls instead of being extracted to a named variable; create
a const variable to hold this slice dimension value once and use it in both the
ITensor::slice call for fragBeamSlice and the ITensor::slice call for
hostStepSlice.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: cf608cb9-b683-44f3-86c9-33205876a4db
📒 Files selected for processing (9)
cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.hcpp/tensorrt_llm/batch_manager/kvCacheManager.cppcpp/tensorrt_llm/batch_manager/runtimeBuffers.cppcpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cppcpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cppcpp/tensorrt_llm/runtime/runtimeKernels.cucpp/tensorrt_llm/runtime/runtimeKernels.hcpp/tests/unit_tests/batch_manager/CMakeLists.txtcpp/tests/unit_tests/batch_manager/encDecBeamSearchTest.cpp
💤 Files with no reviewable changes (4)
- cpp/tensorrt_llm/runtime/runtimeKernels.h
- cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp
- cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h
- cpp/tensorrt_llm/runtime/runtimeKernels.cu
…erationLogits BufferManager::copy() enqueues async GPU-to-host copies; a synchronize() before clearGenerationLogitsFragments() was initially added (and then reverted) in response to a code-review comment. Document the conclusion: no sync is required because clearGenerationLogitsFragments() only releases shared_ptr wrapper objects — the underlying CUDA allocation lives inside generationLogitsCache.logits (owned by RuntimeBuffers) until the next changeBeamWidth() call, and GPU ordering is already guaranteed by the CUDA event the decoder stream waits on. Also apply coding-guideline fixes to the same function: declare fragBeamSlice and hostStepSlice as const, and replace the bare literal 1 with a named constexpr kOneStep. Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
…s PR runtimeKernels.cu (1993-2022→1993-2026), runtimeKernels.h (2019-2023→2019-2026), runtimeBuffers.h (2023-2024→2023-2026), runtimeBuffers.cpp (2025→2025-2026) Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
f29710f to
adb7b8c
Compare
@coderabbitai summary
Description
Fixes three bugs in the TRT backend (
ModelRunnerCpp) that caused corrupted output when concurrent requests with different beam widths (e.g. beam=1 and beam=5) were submitted to the same executor — specifically for encoder-decoder models (Whisper) with beam search.Bug 1 — Cross-KV block sharing (Fix 1)
KVCacheManager::copyBlockOffsetsassigned separate physical blocks to each beam of an encoder-decoder cross-attention sequence. The TRT encoder runs withcontextBeamWidth=1so only beam-0's blocks were ever populated with encoder features. Beams 1..N-1 attended to uninitialised GPU memory, producing degenerate repetitive output ("happ happ happ").Fix: when
isCrossKv()is true, always use beam-0's source block IDs and count for all beams. The encoder output is identical for all beams of a request.Bug 2 —
mInflightReqIdsleak on mixed-beam exception (Fix 2)When concurrent requests with different beam widths triggered the
verifyRequests()exception inTrtGptModelInflightBatching::forwardAsync, the exception handler freed the KV-cache but did not remove request IDs frommInflightReqIds. The subsequentchangeBeamWidth()call checksTLLM_CHECK(mInflightReqIds.empty())and aborted.Fix: erase affected IDs from
mInflightReqIdsand callchangeBeamWidth(mOperatingBeamWidth)in the exception catch block so the next batch starts from a clean state.Bug 3 —
gather_generation_logitscorruption with concurrent beam widths (Fix 3)copyGenerationLogits()usedmergeLogitsFragmentsKernel— a CUDA kernel that reads fragment GPU addresses via a sharedfragmentPointerDevicepointer array — to gather kCACHE_LENGTH logit steps into a scratch buffer before copying to host. Whengather_generation_logits=Truewas combined with concurrent mixed beam-width requests (causingchangeBeamWidth()to be called between batches), this kernel intermittently corrupted token generation, producing ~26% degenerate output.Fix: bypass the kernel entirely. Copy each (beam, step) logit slice directly from its fragment GPU tensor to the corresponding host location via
bufferManager.copy(). This is functionally equivalent and eliminates the pointer-indirection pattern that caused the corruption.The dead kernel (
mergeLogitsFragmentsKernel,invokeMergeLogitsFragments,mergeLogitsFragments) and its now-unused scratch buffers (transposedLogits,fragmentPointerDevice,fragmentPointerHost) are removed in commit 4.Test Coverage
C++ unit tests (
cpp/tests/unit_tests/batch_manager/encDecBeamSearchTest.cpp):CrossKvBeamSharingTest/CopyBlockOffsetsAllBeamsShareBeam0Blocks— verifies Fix 1 invariant: all beam slots in the cross-KV offset table equal beam-0.CopyGenerationLogitsTest/DirectCopyPlacesEachBeamStepAtCorrectHostOffset— verifies Fix 3: each (beam, step) logit lands at the correct host offset. MakingcopyGenerationLogits()a no-op produces 56 assertion failures.Validation on Whisper large-v3 with concurrent mixed-beam execution:
gather_generation_logitsgather_generation_logits=TruePR Checklist
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.