diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 0cef859ddfb..e482cf29e29 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -16,6 +17,7 @@ #include #include #include +#include #include // Include our shim layer headers @@ -46,9 +48,27 @@ using executorch::runtime::Result; using executorch::runtime::Span; using executorch::runtime::etensor::Tensor; +// Structure to hold cached GPU tensor data for "keep on device" optimization +struct CachedGpuData { + void* data_ptr; // GPU memory pointer + size_t size_bytes; // Total size in bytes + int32_t scalar_type; // Data type + std::vector sizes; // Original shape +}; + +// Global device cache - maps name to cached GPU data +// Using raw GPU pointers instead of tensor handles for format independence +static std::unordered_map g_device_cache; + class ET_EXPERIMENTAL CudaBackend final : public ::executorch::runtime::BackendInterface { private: + // Cache control options (set via set_option before execute) + mutable int cache_output_slot_ = -1; // Which output slot to cache (-1 = none) + mutable std::string cache_output_name_; // Name to cache output under + mutable int use_cache_input_slot_ = -1; // Which input slot to use cache for (-1 = none) + mutable std::string use_cache_input_name_; // Name of cached tensor to use + Error load_function_pointers_into_handle( void* so_handle, AOTIDelegateHandle* handle) const { @@ -91,6 +111,51 @@ class ET_EXPERIMENTAL CudaBackend final return 1; } + Error set_option( + __ET_UNUSED executorch::runtime::BackendOptionContext& context, + const executorch::runtime::Span& + backend_options) override { + for (size_t i = 0; i < backend_options.size(); i++) { + const auto& option = backend_options[i]; + // Handle cache_output: "slot:name" format (e.g., "0:encoder_output") + if (strcmp(option.key, "cache_output") == 0) { + if (auto* arr = std::get_if< + std::array>( + &option.value)) { + std::string val(arr->data()); + auto colon_pos = val.find(':'); + if (colon_pos != std::string::npos) { + cache_output_slot_ = std::stoi(val.substr(0, colon_pos)); + cache_output_name_ = val.substr(colon_pos + 1); + } + } + } + // Handle use_cache_input: "slot:name" format (e.g., "1:encoder_output") + else if (strcmp(option.key, "use_cache_input") == 0) { + if (auto* arr = std::get_if< + std::array>( + &option.value)) { + std::string val(arr->data()); + auto colon_pos = val.find(':'); + if (colon_pos != std::string::npos) { + use_cache_input_slot_ = std::stoi(val.substr(0, colon_pos)); + use_cache_input_name_ = val.substr(colon_pos + 1); + } + } + } + // Handle clear_cache_input: reset input cache settings + else if (strcmp(option.key, "clear_cache_input") == 0) { + if (auto* val = std::get_if(&option.value)) { + if (*val) { + use_cache_input_slot_ = -1; + use_cache_input_name_.clear(); + } + } + } + } + return Error::Ok; + } + // Once per loaded binary blob Result init( BackendInitContext& context, @@ -223,14 +288,14 @@ class ET_EXPERIMENTAL CudaBackend final n_outputs); // GPU tensors for kernel output // Process input tensors: ExecuTorch provides CPU tensors, create GPU - // copies + // copies. For cached inputs, use GPU-to-GPU copy instead of CPU-to-GPU. for (int i = 0; i < n_inputs; i++) { // Get tensor dimensions and properties from ExecuTorch CPU tensor auto cpu_tensor = &(args[i]->toTensor()); auto sizes = cpu_tensor->sizes(); auto scalar_type = cpu_tensor->scalar_type(); - // Create GPU tensor with same shape + // Create GPU tensor with same shape (always needed for AOTI format) std::vector sizes_vec(sizes.begin(), sizes.end()); AOTITensorHandle gpu_input_handle; @@ -251,7 +316,43 @@ class ET_EXPERIMENTAL CudaBackend final gpu_inputs[i] = gpu_input_handle; - // Copy data from CPU to GPU + // Check if this input slot should use cached GPU data + if (i == use_cache_input_slot_ && !use_cache_input_name_.empty()) { + auto cache_it = g_device_cache.find(use_cache_input_name_); + if (cache_it != g_device_cache.end()) { + const CachedGpuData& cached = cache_it->second; + // GPU-to-GPU copy: fast DMA transfer, normalizes tensor format + size_t numel = gpu_inputs[i]->numel(); + size_t elem_size = gpu_inputs[i]->element_size(); + size_t copy_bytes = numel * elem_size; + + ET_CHECK_OR_RETURN_ERROR( + copy_bytes == cached.size_bytes, + Internal, + "Cached tensor size mismatch: expected %zu bytes, got %zu", + copy_bytes, + cached.size_bytes); + + cudaError_t cuda_err = cudaMemcpy( + gpu_inputs[i]->data_ptr(), + cached.data_ptr, + copy_bytes, + cudaMemcpyDeviceToDevice); + + ET_CHECK_OR_RETURN_ERROR( + cuda_err == cudaSuccess, + Internal, + "Failed GPU-to-GPU copy for cached input %d: %s", + i, + cudaGetErrorString(cuda_err)); + + // Skip the CPU-to-GPU copy below + continue; + } + // Cache miss: fall through to normal CPU-to-GPU copy + } + + // Copy data from CPU to GPU (normal path) ET_CHECK_OR_RETURN_ERROR( aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0) == Error::Ok, Internal, @@ -303,6 +404,57 @@ class ET_EXPERIMENTAL CudaBackend final "AOTInductorModelContainerRun failed with error code %d", error); + // Cache output GPU tensor data if requested + // We store the raw GPU pointer for later GPU-to-GPU copy + if (cache_output_slot_ >= 0 && cache_output_slot_ < static_cast(n_outputs) && + !cache_output_name_.empty()) { + auto* gpu_tensor = gpu_outputs[cache_output_slot_]; + size_t numel = gpu_tensor->numel(); + size_t elem_size = gpu_tensor->element_size(); + size_t size_bytes = numel * elem_size; + + // Allocate persistent GPU memory for the cache + void* cache_ptr = nullptr; + cudaError_t alloc_err = cudaMalloc(&cache_ptr, size_bytes); + ET_CHECK_OR_RETURN_ERROR( + alloc_err == cudaSuccess, + Internal, + "Failed to allocate GPU cache memory: %s", + cudaGetErrorString(alloc_err)); + + // Copy from tensor to cache (GPU-to-GPU) + cudaError_t copy_err = cudaMemcpy( + cache_ptr, + gpu_tensor->data_ptr(), + size_bytes, + cudaMemcpyDeviceToDevice); + ET_CHECK_OR_RETURN_ERROR( + copy_err == cudaSuccess, + Internal, + "Failed to copy output to GPU cache: %s", + cudaGetErrorString(copy_err)); + + // Free old cache if exists + auto old_it = g_device_cache.find(cache_output_name_); + if (old_it != g_device_cache.end()) { + cudaFree(old_it->second.data_ptr); + g_device_cache.erase(old_it); + } + + // Store in cache + CachedGpuData cached; + cached.data_ptr = cache_ptr; + cached.size_bytes = size_bytes; + cached.scalar_type = static_cast(gpu_tensor->scalar_type()); + auto sizes = gpu_tensor->sizes(); + cached.sizes.assign(sizes.begin(), sizes.end()); + g_device_cache[cache_output_name_] = std::move(cached); + + // Reset cache_output settings after caching + cache_output_slot_ = -1; + cache_output_name_.clear(); + } + // Copy GPU output results back to CPU output tensors for (int i = 0; i < n_outputs; i++) { auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); diff --git a/extension/asr/runner/runner.cpp b/extension/asr/runner/runner.cpp index 4f2523989c1..ebd5d29bdb3 100644 --- a/extension/asr/runner/runner.cpp +++ b/extension/asr/runner/runner.cpp @@ -16,6 +16,8 @@ #include #include #include +#include +#include #include #include #include @@ -196,6 +198,17 @@ Result> AsrRunner::transcribe( } } + // Tell CUDA backend to cache encoder output (slot 0) as "encoder_output" + { + ::executorch::runtime::BackendOptions<1> opts; + opts.set_option("cache_output", "0:encoder_output"); + auto err = + ::executorch::runtime::set_option("CudaBackend", opts.view()); + if (err != ::executorch::runtime::Error::Ok) { + ET_LOG(Info, "Failed to set cache_output option (backend may not support caching)"); + } + } + auto encoder_result = module_->execute(kEncoderMethodName, preprocessed_features); ET_CHECK_OK_OR_RETURN_ERROR(encoder_result.error()); @@ -249,6 +262,19 @@ Result> AsrRunner::transcribe( decoder_inputs.emplace_back(decoder_input_ptr); decoder_inputs.emplace_back(encoder_output_ptr); decoder_inputs.emplace_back(cache_position_ptr); + + // Tell CUDA backend to use cached encoder output for decoder input slot 2 + // Note: Decoder input order in AOTI is: input_ids[0], cache_position[1], encoder_output[2] + { + ::executorch::runtime::BackendOptions<1> opts; + opts.set_option("use_cache_input", "2:encoder_output"); + auto err = + ::executorch::runtime::set_option("CudaBackend", opts.view()); + if (err != ::executorch::runtime::Error::Ok) { + ET_LOG(Info, "Failed to set use_cache_input option (backend may not support caching)"); + } + } + // Add some green coloring for the first generated token // token_callback("\033[1;32m"); while (generated_tokens < config.max_new_tokens) {