Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 155 additions & 3 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <cuda_runtime.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
Expand All @@ -16,6 +17,7 @@
#include <filesystem>
#include <fstream>
#include <string>
#include <unordered_map>
#include <vector>

// Include our shim layer headers
Expand Down Expand Up @@ -46,9 +48,27 @@ using executorch::runtime::Result;
using executorch::runtime::Span;
using executorch::runtime::etensor::Tensor;

// Structure to hold cached GPU tensor data for "keep on device" optimization
struct CachedGpuData {
void* data_ptr; // GPU memory pointer
size_t size_bytes; // Total size in bytes
int32_t scalar_type; // Data type
std::vector<int64_t> sizes; // Original shape
};

// Global device cache - maps name to cached GPU data
// Using raw GPU pointers instead of tensor handles for format independence
static std::unordered_map<std::string, CachedGpuData> g_device_cache;
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The global g_device_cache is never cleaned up, causing a memory leak. GPU memory allocated via cudaMalloc (line 418) is stored in this cache but never freed when the backend is destroyed. Consider adding cleanup logic in the destroy() method to iterate through g_device_cache and call cudaFree() on each cached data_ptr, or implement a RAII wrapper for cache management.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The global g_device_cache is accessed without synchronization, creating potential race conditions in multi-threaded environments. If multiple threads call execute() concurrently, simultaneous reads/writes to the cache could cause data corruption or crashes. Consider using a mutex (e.g., std::mutex) to protect all accesses to g_device_cache, or document that the CUDA backend is not thread-safe.

Copilot uses AI. Check for mistakes.

class ET_EXPERIMENTAL CudaBackend final
: public ::executorch::runtime::BackendInterface {
private:
// Cache control options (set via set_option before execute)
mutable int cache_output_slot_ = -1; // Which output slot to cache (-1 = none)
mutable std::string cache_output_name_; // Name to cache output under
mutable int use_cache_input_slot_ = -1; // Which input slot to use cache for (-1 = none)
mutable std::string use_cache_input_name_; // Name of cached tensor to use
Comment on lines +67 to +70
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These cache control fields are marked mutable and modified in const methods (set_option and execute), which is unusual and suggests potential design issues. The mutable keyword typically indicates thread-safety concerns or shared state that bypasses const-correctness. Since execute() is const but modifies these fields, concurrent calls to execute() on the same backend instance will have race conditions when accessing these member variables. Consider using proper synchronization or redesigning to avoid mutable state in const methods.

Copilot uses AI. Check for mistakes.

Error load_function_pointers_into_handle(
void* so_handle,
AOTIDelegateHandle* handle) const {
Expand Down Expand Up @@ -91,6 +111,51 @@ class ET_EXPERIMENTAL CudaBackend final
return 1;
}

Error set_option(
__ET_UNUSED executorch::runtime::BackendOptionContext& context,
const executorch::runtime::Span<executorch::runtime::BackendOption>&
backend_options) override {
for (size_t i = 0; i < backend_options.size(); i++) {
const auto& option = backend_options[i];
// Handle cache_output: "slot:name" format (e.g., "0:encoder_output")
if (strcmp(option.key, "cache_output") == 0) {
if (auto* arr = std::get_if<
std::array<char, executorch::runtime::kMaxOptionValueLength>>(
&option.value)) {
std::string val(arr->data());
auto colon_pos = val.find(':');
if (colon_pos != std::string::npos) {
cache_output_slot_ = std::stoi(val.substr(0, colon_pos));
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The std::stoi call can throw std::invalid_argument or std::out_of_range exceptions if the input string is malformed or the number is too large. Since this function should return an Error rather than throwing exceptions, consider wrapping this in a try-catch block and returning an appropriate error, or validate the input string before calling std::stoi.

Copilot uses AI. Check for mistakes.
cache_output_name_ = val.substr(colon_pos + 1);
}
}
}
// Handle use_cache_input: "slot:name" format (e.g., "1:encoder_output")
else if (strcmp(option.key, "use_cache_input") == 0) {
if (auto* arr = std::get_if<
std::array<char, executorch::runtime::kMaxOptionValueLength>>(
&option.value)) {
std::string val(arr->data());
auto colon_pos = val.find(':');
if (colon_pos != std::string::npos) {
use_cache_input_slot_ = std::stoi(val.substr(0, colon_pos));
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The std::stoi call can throw std::invalid_argument or std::out_of_range exceptions if the input string is malformed or the number is too large. Since this function should return an Error rather than throwing exceptions, consider wrapping this in a try-catch block and returning an appropriate error, or validate the input string before calling std::stoi.

Copilot uses AI. Check for mistakes.
use_cache_input_name_ = val.substr(colon_pos + 1);
}
}
}
// Handle clear_cache_input: reset input cache settings
else if (strcmp(option.key, "clear_cache_input") == 0) {
if (auto* val = std::get_if<bool>(&option.value)) {
if (*val) {
use_cache_input_slot_ = -1;
use_cache_input_name_.clear();
}
}
}
}
return Error::Ok;
}

// Once per loaded binary blob
Result<DelegateHandle*> init(
BackendInitContext& context,
Expand Down Expand Up @@ -223,14 +288,14 @@ class ET_EXPERIMENTAL CudaBackend final
n_outputs); // GPU tensors for kernel output

// Process input tensors: ExecuTorch provides CPU tensors, create GPU
// copies
// copies. For cached inputs, use GPU-to-GPU copy instead of CPU-to-GPU.
for (int i = 0; i < n_inputs; i++) {
// Get tensor dimensions and properties from ExecuTorch CPU tensor
auto cpu_tensor = &(args[i]->toTensor());
auto sizes = cpu_tensor->sizes();
auto scalar_type = cpu_tensor->scalar_type();

// Create GPU tensor with same shape
// Create GPU tensor with same shape (always needed for AOTI format)
std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end());

AOTITensorHandle gpu_input_handle;
Expand All @@ -251,7 +316,43 @@ class ET_EXPERIMENTAL CudaBackend final

gpu_inputs[i] = gpu_input_handle;

// Copy data from CPU to GPU
// Check if this input slot should use cached GPU data
if (i == use_cache_input_slot_ && !use_cache_input_name_.empty()) {
auto cache_it = g_device_cache.find(use_cache_input_name_);
if (cache_it != g_device_cache.end()) {
const CachedGpuData& cached = cache_it->second;
// GPU-to-GPU copy: fast DMA transfer, normalizes tensor format
size_t numel = gpu_inputs[i]->numel();
size_t elem_size = gpu_inputs[i]->element_size();
size_t copy_bytes = numel * elem_size;

ET_CHECK_OR_RETURN_ERROR(
copy_bytes == cached.size_bytes,
Internal,
"Cached tensor size mismatch: expected %zu bytes, got %zu",
copy_bytes,
cached.size_bytes);

cudaError_t cuda_err = cudaMemcpy(
gpu_inputs[i]->data_ptr(),
cached.data_ptr,
copy_bytes,
cudaMemcpyDeviceToDevice);

ET_CHECK_OR_RETURN_ERROR(
cuda_err == cudaSuccess,
Internal,
"Failed GPU-to-GPU copy for cached input %d: %s",
i,
cudaGetErrorString(cuda_err));

// Skip the CPU-to-GPU copy below
continue;
}
// Cache miss: fall through to normal CPU-to-GPU copy
}

// Copy data from CPU to GPU (normal path)
ET_CHECK_OR_RETURN_ERROR(
aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0) == Error::Ok,
Internal,
Expand Down Expand Up @@ -303,6 +404,57 @@ class ET_EXPERIMENTAL CudaBackend final
"AOTInductorModelContainerRun failed with error code %d",
error);

// Cache output GPU tensor data if requested
// We store the raw GPU pointer for later GPU-to-GPU copy
if (cache_output_slot_ >= 0 && cache_output_slot_ < static_cast<int>(n_outputs) &&
!cache_output_name_.empty()) {
auto* gpu_tensor = gpu_outputs[cache_output_slot_];
size_t numel = gpu_tensor->numel();
size_t elem_size = gpu_tensor->element_size();
size_t size_bytes = numel * elem_size;

// Allocate persistent GPU memory for the cache
void* cache_ptr = nullptr;
cudaError_t alloc_err = cudaMalloc(&cache_ptr, size_bytes);
ET_CHECK_OR_RETURN_ERROR(
alloc_err == cudaSuccess,
Internal,
"Failed to allocate GPU cache memory: %s",
cudaGetErrorString(alloc_err));

// Copy from tensor to cache (GPU-to-GPU)
cudaError_t copy_err = cudaMemcpy(
cache_ptr,
gpu_tensor->data_ptr(),
size_bytes,
cudaMemcpyDeviceToDevice);
ET_CHECK_OR_RETURN_ERROR(
copy_err == cudaSuccess,
Internal,
"Failed to copy output to GPU cache: %s",
cudaGetErrorString(copy_err));
Comment on lines +431 to +435
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory leak: If cudaMemcpy fails (line 426-430), the function returns early via ET_CHECK_OR_RETURN_ERROR, but the cache_ptr allocated at line 418 is never freed. Consider adding cudaFree(cache_ptr) before returning on error, or use a RAII wrapper like a unique_ptr with a custom deleter to ensure automatic cleanup.

Suggested change
ET_CHECK_OR_RETURN_ERROR(
copy_err == cudaSuccess,
Internal,
"Failed to copy output to GPU cache: %s",
cudaGetErrorString(copy_err));
if (copy_err != cudaSuccess) {
cudaFree(cache_ptr);
ET_CHECK_OR_RETURN_ERROR(
false,
Internal,
"Failed to copy output to GPU cache: %s",
cudaGetErrorString(copy_err));
}

Copilot uses AI. Check for mistakes.

// Free old cache if exists
auto old_it = g_device_cache.find(cache_output_name_);
if (old_it != g_device_cache.end()) {
cudaFree(old_it->second.data_ptr);
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old cached GPU memory is freed without checking for errors. While cudaFree rarely fails, if it does, the error is silently ignored, potentially indicating GPU state corruption. Consider checking the return value and logging a warning if the free operation fails, similar to how other CUDA operations are checked in this file.

Suggested change
cudaFree(old_it->second.data_ptr);
cudaError_t free_err = cudaFree(old_it->second.data_ptr);
if (free_err != cudaSuccess) {
std::fprintf(
stderr,
"Warning: Failed to free old GPU cache memory: %s\n",
cudaGetErrorString(free_err));
}

Copilot uses AI. Check for mistakes.
g_device_cache.erase(old_it);
}

// Store in cache
CachedGpuData cached;
cached.data_ptr = cache_ptr;
cached.size_bytes = size_bytes;
cached.scalar_type = static_cast<int32_t>(gpu_tensor->scalar_type());
auto sizes = gpu_tensor->sizes();
cached.sizes.assign(sizes.begin(), sizes.end());
g_device_cache[cache_output_name_] = std::move(cached);

// Reset cache_output settings after caching
cache_output_slot_ = -1;
cache_output_name_.clear();
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike cache_output_slot_ which is reset after use (lines 454-455), the use_cache_input_slot_ and use_cache_input_name_ are never cleared after execution. This means once set, they will continue to be used for all subsequent executions, which may not be the intended behavior. Consider resetting these fields after the execute() method completes to match the pattern used for cache_output settings.

Suggested change
cache_output_name_.clear();
cache_output_name_.clear();
// Reset cache_input settings after use (fix for CodeQL warning)
use_cache_input_slot_ = -1;
use_cache_input_name_.clear();

Copilot uses AI. Check for mistakes.
}

// Copy GPU output results back to CPU output tensors
for (int i = 0; i < n_outputs; i++) {
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());
Expand Down
26 changes: 26 additions & 0 deletions extension/asr/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include <executorch/extension/llm/runner/util.h>
#include <executorch/extension/llm/sampler/util.h>
#include <executorch/extension/tensor/tensor_ptr_maker.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/platform/assert.h>
#include <executorch/runtime/platform/log.h>
Expand Down Expand Up @@ -196,6 +198,17 @@ Result<std::vector<int64_t>> AsrRunner::transcribe(
}
}

// Tell CUDA backend to cache encoder output (slot 0) as "encoder_output"
{
::executorch::runtime::BackendOptions<1> opts;
opts.set_option("cache_output", "0:encoder_output");
auto err =
::executorch::runtime::set_option("CudaBackend", opts.view());
if (err != ::executorch::runtime::Error::Ok) {
ET_LOG(Info, "Failed to set cache_output option (backend may not support caching)");
}
}

auto encoder_result =
module_->execute(kEncoderMethodName, preprocessed_features);
ET_CHECK_OK_OR_RETURN_ERROR(encoder_result.error());
Expand Down Expand Up @@ -249,6 +262,19 @@ Result<std::vector<int64_t>> AsrRunner::transcribe(
decoder_inputs.emplace_back(decoder_input_ptr);
decoder_inputs.emplace_back(encoder_output_ptr);
decoder_inputs.emplace_back(cache_position_ptr);

// Tell CUDA backend to use cached encoder output for decoder input slot 2
// Note: Decoder input order in AOTI is: input_ids[0], cache_position[1], encoder_output[2]
{
::executorch::runtime::BackendOptions<1> opts;
opts.set_option("use_cache_input", "2:encoder_output");
Comment on lines +266 to +270
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states "Decoder input order in AOTI is: input_ids[0], cache_position[1], encoder_output[2]", but the actual order in the code is: decoder_input_ptr[0] (input_ids), encoder_output_ptr[1], cache_position_ptr[2]. The encoder_output is at index 1, not index 2. Either the comment is wrong or the cache input slot should be "1:encoder_output" instead of "2:encoder_output".

Suggested change
// Tell CUDA backend to use cached encoder output for decoder input slot 2
// Note: Decoder input order in AOTI is: input_ids[0], cache_position[1], encoder_output[2]
{
::executorch::runtime::BackendOptions<1> opts;
opts.set_option("use_cache_input", "2:encoder_output");
// Tell CUDA backend to use cached encoder output for decoder input slot 1
// Note: Decoder input order in AOTI is: input_ids[0], encoder_output[1], cache_position[2]
{
::executorch::runtime::BackendOptions<1> opts;
opts.set_option("use_cache_input", "1:encoder_output");

Copilot uses AI. Check for mistakes.
auto err =
::executorch::runtime::set_option("CudaBackend", opts.view());
if (err != ::executorch::runtime::Error::Ok) {
ET_LOG(Info, "Failed to set use_cache_input option (backend may not support caching)");
}
}
Comment on lines +268 to +276
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache input option is set before the decoder loop but never cleared afterward. This means the cache will persist and be used for subsequent calls to transcribe(), which may not be intended. If the encoder is run again with different input, the cached encoder output will be stale but still used. Consider clearing the cache input option after the decoding loop completes (after line 332) using the "clear_cache_input" option, or document this behavior clearly.

Copilot uses AI. Check for mistakes.

// Add some green coloring for the first generated token
// token_callback("\033[1;32m");
while (generated_tokens < config.max_new_tokens) {
Expand Down
Loading