Skip to content
Open
15 changes: 10 additions & 5 deletions cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2023-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -208,7 +208,9 @@ class RuntimeBuffers
//! Temporarily store the transposed results of multiple fragment logits, [maxBeamWidth, kCACHE_LENGTH]
TensorPtr transposedLogits;

//! Temporarily store logits buffer address during the transposing, [kCACHE_LENGTH]
//! Temporarily store logits buffer address during the transposing, [maxBatchSize, kCACHE_LENGTH]
//! One row per batch slot (same layout as fragmentPointerHost) so concurrent flushes for
//! different requests in the same batch never clobber each other's pointer arrays.
TensorPtr fragmentPointerDevice;

//! Temporarily store logits buffer address during the transposing, [maxBatchSize, kCACHE_LENGTH]
Expand All @@ -222,11 +224,14 @@ class RuntimeBuffers
workIdx = (workIdx + 1) % (fragmentPointerHost->getShape().d[0]);
}

[[nodiscard]] TensorPtr getFragmentPointerHost()
//! Returns matching host and device pointer rows for the current workIdx, then advances
//! workIdx. Always call this instead of the individual getters to avoid ordering bugs.
[[nodiscard]] std::pair<TensorPtr, TensorPtr> getFragmentPointerSlot()
{
TensorPtr slice = runtime::ITensor::slice(fragmentPointerHost, workIdx, 1);
TensorPtr host = runtime::ITensor::slice(fragmentPointerHost, workIdx, 1);
TensorPtr device = runtime::ITensor::slice(fragmentPointerDevice, workIdx, 1);
cycleWorkIdx();
return slice;
return {std::move(host), std::move(device)};
};
};

Expand Down
13 changes: 9 additions & 4 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3957,16 +3957,21 @@ SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSl
{
for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto const beamBlockCount = cacheBlockIds[beamIdx].size();
auto const copyChunkSize = beamBlockCount * sizeof(tk::KVCacheIndex);
// For cross-KV (encoder features), all beams of a request share the same encoder output.
// Always use beam 0's block IDs so all beams attend to the correct encoder features.
auto const srcBeamIdx = isCrossKv() ? 0 : beamIdx;
TLLM_CHECK_WITH_INFO(!isCrossKv() || !cacheBlockIds.empty(),
"Cross-KV sequence has no block IDs for request %lu", requestId);
auto const effectiveBlockCount = isCrossKv() ? cacheBlockIds[0].size() : cacheBlockIds[beamIdx].size();
auto const copyChunkSize = effectiveBlockCount * sizeof(tk::KVCacheIndex);
for (auto xIdx : {kIdx, vIdx})
{
auto const srcIndex = tc::flat_index(srcShape.d, poolIdx, beamIdx, xIdx, 0);
auto const srcIndex = tc::flat_index(srcShape.d, poolIdx, srcBeamIdx, xIdx, 0);
auto const dstIndex
= tc::flat_index(dstShape.d, absolutePoolIdx, outputSlotOffset + beamIdx, xIdx, 0);
std::memcpy(dstPtr + dstIndex, srcPtr + srcIndex, copyChunkSize);
}
maxBlockCount = std::max<SizeType32>(maxBlockCount, static_cast<SizeType32>(beamBlockCount));
maxBlockCount = std::max<SizeType32>(maxBlockCount, static_cast<SizeType32>(effectiveBlockCount));
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -152,8 +152,8 @@ void RuntimeBuffers::create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
ITensor::makeShape({GenerationLogitsCache::kCACHE_LENGTH, maxBatchSize * maxBeamWidth, vocabSizePadded}),
logitsType);

generationLogitsCache.fragmentPointerDevice
= manager.gpu(ITensor::makeShape({GenerationLogitsCache::kCACHE_LENGTH}), nvinfer1::DataType::kINT64);
generationLogitsCache.fragmentPointerDevice = manager.gpu(
ITensor::makeShape({maxBatchSize, GenerationLogitsCache::kCACHE_LENGTH}), nvinfer1::DataType::kINT64);
generationLogitsCache.fragmentPointerHost = tensorrt_llm::runtime::BufferManager::pinnedPool(
ITensor::makeShape({maxBatchSize, GenerationLogitsCache::kCACHE_LENGTH}), nvinfer1::DataType::kINT64);
}
Expand Down
13 changes: 13 additions & 0 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1216,8 +1216,21 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
{
for (auto const& llmReq : activeRequests)
{
// Remove from mInflightReqIds so changeBeamWidth can proceed on the next iteration.
// terminateRequest frees seqSlot/KV cache but does not clean up mInflightReqIds.
mInflightReqIds.erase(llmReq->mRequestId);
terminateRequest(llmReq);
}
// Force buffer/decoder reset to clean up any partial state from the aborted batch
// (e.g. partially-filled cross-KV block offsets from mid-context-chunk processing).
// Guard on mInflightReqIds.empty(): in pipeline-parallel multi-micro-batch mode,
// other micro-batches may still have requests tracked here; changeBeamWidth asserts
// emptiness so we skip the reset and let the next successful forwardAsync iteration
// perform it when the set is clear.
if (mWorldConfig.isLastPipelineParallelRank() && mInflightReqIds.empty())
{
changeBeamWidth(mOperatingBeamWidth);
}
}
catch (std::exception const& e)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ void copyGenerationLogits(RuntimeBuffers::GenerationLogitsCache& generationLogit

auto const fragmentSize = llmReq.getGenerationLogitsFragmentsSize();

// Merge logits fragments on device
// Merge logits fragments on device. getFragmentPointerSlot() returns the matching host and
// device rows for the current workIdx and advances the index atomically, so concurrent flushes
// for different requests in the same batch never clobber each other's pointer arrays.
auto const& transposeBufferPtr = generationLogitsCache.transposedLogits;
auto const& cachePointerDevice = generationLogitsCache.fragmentPointerDevice;
auto const& cachePointerHost = generationLogitsCache.getFragmentPointerHost();
auto [cachePointerHost, cachePointerDevice] = generationLogitsCache.getFragmentPointerSlot();
tensorrt_llm::runtime::kernels::mergeLogitsFragments(bufferManager, *transposeBufferPtr,
llmReq.getGenerationLogitsFragments(), *cachePointerDevice, *cachePointerHost, 0, 1, reqBeamWidth,
bufferManager.getStream(), 0);
Expand Down
1 change: 1 addition & 0 deletions cpp/tests/unit_tests/batch_manager/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ add_gtest(rnnCacheFormatterTest rnnCacheFormatterTest.cpp)
add_gtest(cudaGraphExecutorCacheTest cudaGraphExecutorCacheTest.cpp)
add_gtest(agentTreeTest agentTreeTest.cpp)
add_gtest(truncateBlocksTest truncateBlocksTest.cpp)
add_gtest(encDecBeamSearchTest encDecBeamSearchTest.cpp)
Loading
Loading