Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 1 addition & 25 deletions cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2023-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -204,30 +204,6 @@ class RuntimeBuffers
TensorPtr logits;
//! Record the usage offset of the cacheGenerationLogits buffer
SizeType32 offset{0};

//! Temporarily store the transposed results of multiple fragment logits, [maxBeamWidth, kCACHE_LENGTH]
TensorPtr transposedLogits;

//! Temporarily store logits buffer address during the transposing, [kCACHE_LENGTH]
TensorPtr fragmentPointerDevice;

//! Temporarily store logits buffer address during the transposing, [maxBatchSize, kCACHE_LENGTH]
TensorPtr fragmentPointerHost;

//! Cycling index for workspace
size_t workIdx{0};

void cycleWorkIdx()
{
workIdx = (workIdx + 1) % (fragmentPointerHost->getShape().d[0]);
}

[[nodiscard]] TensorPtr getFragmentPointerHost()
{
TensorPtr slice = runtime::ITensor::slice(fragmentPointerHost, workIdx, 1);
cycleWorkIdx();
return slice;
};
};

GenerationLogitsCache generationLogitsCache;
Expand Down
11 changes: 7 additions & 4 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3945,16 +3945,19 @@ SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSl
{
for (SizeType32 beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto const beamBlockCount = cacheBlockIds[beamIdx].size();
auto const copyChunkSize = beamBlockCount * sizeof(tk::KVCacheIndex);
// For cross-KV (encoder features), all beams of a request share the same encoder output.
// Always use beam 0's block IDs so all beams attend to the correct encoder features.
auto const srcBeamIdx = isCrossKv() ? 0 : beamIdx;
auto const effectiveBlockCount = isCrossKv() ? cacheBlockIds[0].size() : cacheBlockIds[beamIdx].size();
auto const copyChunkSize = effectiveBlockCount * sizeof(tk::KVCacheIndex);
for (auto xIdx : {kIdx, vIdx})
{
auto const srcIndex = tc::flat_index(srcShape.d, poolIdx, beamIdx, xIdx, 0);
auto const srcIndex = tc::flat_index(srcShape.d, poolIdx, srcBeamIdx, xIdx, 0);
auto const dstIndex
= tc::flat_index(dstShape.d, absolutePoolIdx, outputSlotOffset + beamIdx, xIdx, 0);
std::memcpy(dstPtr + dstIndex, srcPtr + srcIndex, copyChunkSize);
}
maxBlockCount = std::max<SizeType32>(maxBlockCount, static_cast<SizeType32>(beamBlockCount));
maxBlockCount = std::max<SizeType32>(maxBlockCount, static_cast<SizeType32>(effectiveBlockCount));
}
}
}
Expand Down
9 changes: 1 addition & 8 deletions cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -146,16 +146,9 @@ void RuntimeBuffers::create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
auto const logitsType = engine.getTensorDataType(batch_manager::RuntimeBuffers::kLogitsTensorName);

generationLogitsCache.transposedLogits = manager.gpu(
ITensor::makeShape({maxBeamWidth, GenerationLogitsCache::kCACHE_LENGTH, vocabSizePadded}), logitsType);
generationLogitsCache.logits = manager.gpu(
ITensor::makeShape({GenerationLogitsCache::kCACHE_LENGTH, maxBatchSize * maxBeamWidth, vocabSizePadded}),
logitsType);

generationLogitsCache.fragmentPointerDevice
= manager.gpu(ITensor::makeShape({GenerationLogitsCache::kCACHE_LENGTH}), nvinfer1::DataType::kINT64);
generationLogitsCache.fragmentPointerHost = tensorrt_llm::runtime::BufferManager::pinnedPool(
ITensor::makeShape({maxBatchSize, GenerationLogitsCache::kCACHE_LENGTH}), nvinfer1::DataType::kINT64);
}

if (modelConfig.useCrossAttention())
Expand Down
10 changes: 10 additions & 0 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1216,8 +1216,18 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
{
for (auto const& llmReq : activeRequests)
{
// Remove from mInflightReqIds so changeBeamWidth can proceed on the next iteration.
// terminateRequest frees seqSlot/KV cache but does not clean up mInflightReqIds.
mInflightReqIds.erase(llmReq->mRequestId);
terminateRequest(llmReq);
}
// Force buffer/decoder reset to clean up any partial state from the aborted batch
// (e.g. partially-filled cross-KV block offsets from mid-context-chunk processing).
// This prevents subsequent requests from reusing stale RuntimeBuffers.
if (mWorldConfig.isLastPipelineParallelRank())
{
changeBeamWidth(mOperatingBeamWidth);
}
}
catch (std::exception const& e)
{
Expand Down
42 changes: 23 additions & 19 deletions cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/

#include "inflightBatchingUtils.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"

namespace tensorrt_llm::batch_manager::utils
{
Expand Down Expand Up @@ -102,34 +101,39 @@ void copyGenerationLogits(RuntimeBuffers::GenerationLogitsCache& generationLogit
"Dropped tokens have to be defined for all beams.");

auto const fragmentSize = llmReq.getGenerationLogitsFragmentsSize();
auto const& fragments = llmReq.getGenerationLogitsFragments();

// Merge logits fragments on device
auto const& transposeBufferPtr = generationLogitsCache.transposedLogits;
auto const& cachePointerDevice = generationLogitsCache.fragmentPointerDevice;
auto const& cachePointerHost = generationLogitsCache.getFragmentPointerHost();
tensorrt_llm::runtime::kernels::mergeLogitsFragments(bufferManager, *transposeBufferPtr,
llmReq.getGenerationLogitsFragments(), *cachePointerDevice, *cachePointerHost, 0, 1, reqBeamWidth,
bufferManager.getStream(), 0);
llmReq.clearGenerationLogitsFragments();

// Copy logits to host
// Bypass mergeLogitsFragmentsKernel: copy each beam's logits directly from fragment GPU memory
// to the host, step by step. Each fragment has shape [1, beamWidth, vocabSizePadded] after
// unsqueeze(0) in HandleGenerationLogits; beam b's data starts at offset b*vocab from the
// fragment's base pointer. This avoids a kernel+pointer-indirection pattern that causes
// intermittent token corruption in gather_generation_logits+concurrent-beam-width scenarios.
for (SizeType32 beam = 0; beam < reqBeamWidth; beam++)
{
auto const droppedSize = !numDroppedTokens.empty() ? numDroppedTokens.at(beam) : 0;
// Ignore logits of dropped tokens
auto const beamFragmentSize = fragmentSize - droppedSize;
// If this function is called before the decoder, the request does not contain the generated token of the
// current iteration, so we add 1 to the number of tokens.
auto const numGenerationToken
= static_cast<SizeType32>(beforeDecoder) + llmReq.getNumTokens(beam) - llmReq.mPromptLen;
auto const hostOffset = numGenerationToken - beamFragmentSize;

// [beamWidth, GENERATION_LOGITS_BUFFER_LENGTH, vocabSizePadded] -> [beamFragmentSize, vocabSizePadded]
auto beamDeviceTensorPtr = ITensor::slice(transposeBufferPtr, {beam, 0}, beamFragmentSize);
// [beamWidth, mMaxNewTokens, vocabSizePadded] -> [beamFragmentSize, vocabSizePadded]
auto beamHostTensorPtr = ITensor::slice(llmReq.getGenerationLogitsHost(), {beam, hostOffset}, beamFragmentSize);
bufferManager.copy(*beamDeviceTensorPtr, *beamHostTensorPtr);
SizeType32 constexpr kOneStep = 1;
for (SizeType32 stepIdx = 0; stepIdx < static_cast<SizeType32>(beamFragmentSize); ++stepIdx)
{
// frag shape: [1, beamWidth, vocabSizePadded]. Beam b starts at offset b*vocab.
auto const fragBeamSlice = ITensor::slice(fragments.at(stepIdx), {0, beam}, kOneStep);
// host shape: [beamWidth, mMaxNewTokens, vocabSizePadded]. Target: [beam, hostOffset+stepIdx, :].
auto const hostStepSlice
= ITensor::slice(llmReq.getGenerationLogitsHost(), {beam, hostOffset + stepIdx}, kOneStep);
bufferManager.copy(*fragBeamSlice, *hostStepSlice);
}
}
// Clear the fragment list. Although BufferManager::copy() enqueues async GPU-to-host
// transfers, no explicit stream synchronization is required here: clearing the list
// releases the fragment tensor *objects*, not the underlying GPU buffer. The source
// memory lives inside generationLogitsCache.logits which is owned by RuntimeBuffers and
// remains valid until the next changeBeamWidth() call. GPU-side ordering is ensured by
// the CUDA event that the decoder stream waits on before reading any shared state.
llmReq.clearGenerationLogitsFragments();

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
Expand Down
102 changes: 1 addition & 101 deletions cpp/tensorrt_llm/runtime/runtimeKernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
* SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -232,75 +232,6 @@ void invokeTileTensor(ITensor& output, ITensor const& input, SizeType32 const be
inputRowSize, outputRowSize, static_cast<uint32_t>(beamWidth));
}

// In the following kernel, we launch a grid with (microBatchSize * beamWidth, outputLen) blocks of threads. Each thread
// block copies a `vocabSizePadded` length logits tensor from the "inputLogits (microBatchSize, beamWidth,
// vocabSizePadded)" to the "outputGenerationLogits (batchSize, beamWidth, outputLen, vocabSizePadded)"
template <typename T>
__global__ void mergeLogitsFragmentsKernel(T* output, T** fragmentsVector, int const outputLen, int firstBatchSlotIdx,
int beamWidth, int vocabSizePadded, int stepOffset)
{
// output: shape: [batchSize, beamWidth, outputLen, vocabSize]
// inputVecor.at(i): shape: [microBatchSize, beamWidth, vocabSize]

// Current step
int const curStep = blockIdx.y;

// The relatively batch slot index that this thread block in microBatchSize.
int const relativeBatchSlotIdx = blockIdx.x / beamWidth;

// The Absolute batch slot index in batchSize.
int const absoluteBatchSlotIdx = firstBatchSlotIdx + relativeBatchSlotIdx;

// The beam index that this thread block process
int const mbeamIdx = blockIdx.x % beamWidth;

// The output pointer
unsigned int const outputOffset
= (absoluteBatchSlotIdx * beamWidth * outputLen + mbeamIdx * outputLen + curStep + stepOffset)
* vocabSizePadded;

T* outputPtr = &output[outputOffset];

unsigned int const inputOffset = (relativeBatchSlotIdx * beamWidth + mbeamIdx) * vocabSizePadded;
// The input pointer.
T const* inputPtr = &fragmentsVector[curStep][inputOffset];

// The threads in the block collaborate to copy the logits.
for (int idx = threadIdx.x; idx < vocabSizePadded; idx += blockDim.x)
{
outputPtr[idx] = inputPtr[idx];
}
}

template <typename T>
void invokeMergeLogitsFragments(BufferManager const& bufferManager, ITensor& output,
std::vector<TensorPtr> const& fragmentsVector, ITensor& cachePointerDevice, ITensor& cachePointerHost,
SizeType32 firstBatchSlotIdx, SizeType32 microBatchSize, SizeType32 beamWidth, CudaStream const& stream,
int stepOffset)
{
size_t const fragmentsVectorSize = fragmentsVector.size();

auto cachePointerHostPtr = bufferCast<T*>(cachePointerHost);

for (int i = 0; i < fragmentsVectorSize; i++)
{
cachePointerHostPtr[i] = bufferCast<T>(*fragmentsVector.at(i));
}
bufferManager.copy(cachePointerHost, cachePointerDevice);

dim3 const blockSize(256);
dim3 const gridSize{(unsigned int) (microBatchSize * beamWidth), (unsigned int) (fragmentsVectorSize)};

auto const& outputShape = output.getShape();
auto const vocabSizePadded = static_cast<SizeType32>(outputShape.d[outputShape.nbDims - 1]);
auto const outputLen = static_cast<SizeType32>(outputShape.d[outputShape.nbDims - 2]);

TLLM_CHECK_WITH_INFO(outputLen >= fragmentsVectorSize, "Fragments size does not match outputLen size");

mergeLogitsFragmentsKernel<T><<<gridSize, blockSize, 0, stream.get()>>>(bufferCast<T>(output),
bufferCast<T*>(cachePointerDevice), outputLen, firstBatchSlotIdx, beamWidth, vocabSizePadded, stepOffset);
}

} // namespace

template <typename T>
Expand Down Expand Up @@ -437,37 +368,6 @@ void tileTensor(ITensor& output, ITensor const& input, SizeType32 beamWidth, Cud
}
}

void mergeLogitsFragments(BufferManager const& bufferManager, ITensor& output,
std::vector<TensorPtr> const& fragmentsVector, ITensor& cachePointerDevice, ITensor& cachePointerHost,
SizeType32 firstBatchSlotIdx, SizeType32 const microBatchSize, SizeType32 const beamWidth, CudaStream const& stream,
int stepOffset)
{
switch (output.getDataType())
{
case nvinfer1::DataType::kFLOAT:
invokeMergeLogitsFragments<float>(bufferManager, output, fragmentsVector, cachePointerDevice, cachePointerHost,
firstBatchSlotIdx, microBatchSize, beamWidth, stream, stepOffset);
break;
case nvinfer1::DataType::kHALF:
invokeMergeLogitsFragments<half>(bufferManager, output, fragmentsVector, cachePointerDevice, cachePointerHost,
firstBatchSlotIdx, microBatchSize, beamWidth, stream, stepOffset);
break;
#ifdef ENABLE_BF16
case nvinfer1::DataType::kBF16:
invokeMergeLogitsFragments<__nv_bfloat16>(bufferManager, output, fragmentsVector, cachePointerDevice,
cachePointerHost, firstBatchSlotIdx, microBatchSize, beamWidth, stream, stepOffset);
break;
#endif // ENABLE_BF16
#ifdef ENABLE_FP8
case nvinfer1::DataType::kFP8:
invokeMergeLogitsFragments<__nv_fp8_e4m3>(bufferManager, output, fragmentsVector, cachePointerDevice,
cachePointerHost, firstBatchSlotIdx, microBatchSize, beamWidth, stream, stepOffset);
break;
#endif // ENABLE_FP8
default: TLLM_THROW("data type not supported");
}
}

void invokeUpdateKVBlockArrayDraftTokenLocation(ITensor const& seqAcceptedDraftTokenOffsets,
ITensor const& packedAcceptedDraftTokensIndices, ITensor const& pastKeyValueLengths, void* const* pointerArray,
::tensorrt_llm::kernels::KVCacheIndex const* offsetArray, SizeType32 layerCount, SizeType32 seqCount,
Expand Down
7 changes: 1 addition & 6 deletions cpp/tensorrt_llm/runtime/runtimeKernels.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -47,11 +47,6 @@ void scatterTensor(ITensor& output, ITensor const& input, SizeType32 beamWidth,

void tileTensor(ITensor& output, ITensor const& input, SizeType32 beamWidth, CudaStream const& stream);

void mergeLogitsFragments(BufferManager const& bufferManager, ITensor& output,
std::vector<TensorPtr> const& fragmentsVector, ITensor& cachePointerDevice, ITensor& cachePointerHost,
SizeType32 firstBatchSlotIdx, SizeType32 microBatchSize, SizeType32 beamWidth, CudaStream const& stream,
int stepOffset);

void invokeUpdateKVBlockArrayDraftTokenLocation(ITensor const& seqAcceptedDraftTokenOffsets,
ITensor const& packedAcceptedDraftTokensIndices, ITensor const& pastKeyValueLengths, void* const* pointerArray,
::tensorrt_llm::kernels::KVCacheIndex const* offsetArray, SizeType32 layerCount, SizeType32 seqCount,
Expand Down
1 change: 1 addition & 0 deletions cpp/tests/unit_tests/batch_manager/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ add_gtest(rnnCacheFormatterTest rnnCacheFormatterTest.cpp)
add_gtest(cudaGraphExecutorCacheTest cudaGraphExecutorCacheTest.cpp)
add_gtest(agentTreeTest agentTreeTest.cpp)
add_gtest(truncateBlocksTest truncateBlocksTest.cpp)
add_gtest(encDecBeamSearchTest encDecBeamSearchTest.cpp)
Loading
Loading