Skip to content

Commit

Permalink
feat: arbitrarily configurable execution providers
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Sep 21, 2024
1 parent e16fd5b commit abd527b
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 324 deletions.
85 changes: 42 additions & 43 deletions src/execution_providers/cann.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::{ArbitrarilyConfigurableExecutionProvider, ExecutionProviderOptions};
use crate::{
error::{Error, Result},
execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch},
Expand Down Expand Up @@ -28,65 +29,79 @@ pub enum CANNExecutionProviderImplementationMode {

#[derive(Default, Debug, Clone)]
pub struct CANNExecutionProvider {
device_id: Option<i32>,
npu_mem_limit: Option<usize>,
arena_extend_strategy: Option<ArenaExtendStrategy>,
enable_cann_graph: Option<bool>,
dump_graphs: Option<bool>,
precision_mode: Option<CANNExecutionProviderPrecisionMode>,
op_select_impl_mode: Option<CANNExecutionProviderImplementationMode>,
optypelist_for_impl_mode: Option<String>
options: ExecutionProviderOptions
}

impl CANNExecutionProvider {
#[must_use]
pub fn with_device_id(mut self, device_id: i32) -> Self {
self.device_id = Some(device_id);
self.options.set("device_id", device_id.to_string());
self
}

/// Configure the size limit of the device memory arena in bytes. This size limit is only for the execution
/// provider’s arena. The total device memory usage may be higher.
#[must_use]
pub fn with_memory_limit(mut self, limit: usize) -> Self {
self.npu_mem_limit = Some(limit);
self.options.set("npu_mem_limit", limit.to_string());
self
}

/// Configure the strategy for extending the device's memory arena.
#[must_use]
pub fn with_arena_extend_strategy(mut self, strategy: ArenaExtendStrategy) -> Self {
self.arena_extend_strategy = Some(strategy);
self.options.set(
"arena_extend_strategy",
match strategy {
ArenaExtendStrategy::NextPowerOfTwo => "kNextPowerOfTwo",
ArenaExtendStrategy::SameAsRequested => "kSameAsRequested"
}
);
self
}

/// Configure whether to use the graph inference engine to speed up performance. The recommended and default setting
/// is true. If false, it will fall back to the single-operator inference engine.
#[must_use]
pub fn with_cann_graph(mut self, enable: bool) -> Self {
self.enable_cann_graph = Some(enable);
self.options.set("enable_cann_graph", if enable { "1" } else { "0" });
self
}

/// Configure whether to dump the subgraph into ONNX format for analysis of subgraph segmentation.
#[must_use]
pub fn with_dump_graphs(mut self) -> Self {
self.dump_graphs = Some(true);
self.options.set("dump_graphs", "1");
self
}

/// Set the precision mode of the operator. See [`CANNExecutionProviderPrecisionMode`].
#[must_use]
pub fn with_precision_mode(mut self, mode: CANNExecutionProviderPrecisionMode) -> Self {
self.precision_mode = Some(mode);
self.options.set(
"precision_mode",
match mode {
CANNExecutionProviderPrecisionMode::ForceFP32 => "force_fp32",
CANNExecutionProviderPrecisionMode::ForceFP16 => "force_fp16",
CANNExecutionProviderPrecisionMode::AllowFP32ToFP16 => "allow_fp32_to_fp16",
CANNExecutionProviderPrecisionMode::MustKeepOrigin => "must_keep_origin_dtype",
CANNExecutionProviderPrecisionMode::AllowMixedPrecision => "allow_mix_precision"
}
);
self
}

/// Configure the implementation mode for operators. Some CANN operators can have both high-precision and
/// high-performance implementations.
#[must_use]
pub fn with_implementation_mode(mut self, mode: CANNExecutionProviderImplementationMode) -> Self {
self.op_select_impl_mode = Some(mode);
self.options.set(
"op_select_impl_mode",
match mode {
CANNExecutionProviderImplementationMode::HighPrecision => "high_precision",
CANNExecutionProviderImplementationMode::HighPerformance => "high_performance"
}
);
self
}

Expand All @@ -100,7 +115,7 @@ impl CANNExecutionProvider {
/// - `ROIAlign`
#[must_use]
pub fn with_implementation_mode_oplist(mut self, list: impl ToString) -> Self {
self.optypelist_for_impl_mode = Some(list.to_string());
self.options.set("optypelist_for_impl_mode", list.to_string());
self
}

Expand All @@ -110,6 +125,13 @@ impl CANNExecutionProvider {
}
}

impl ArbitrarilyConfigurableExecutionProvider for CANNExecutionProvider {
fn with_arbitrary_config(mut self, key: impl ToString, value: impl ToString) -> Self {
self.options.set(key.to_string(), value.to_string());
self
}
}

impl From<CANNExecutionProvider> for ExecutionProviderDispatch {
fn from(value: CANNExecutionProvider) -> Self {
ExecutionProviderDispatch::new(value)
Expand All @@ -131,39 +153,16 @@ impl ExecutionProvider for CANNExecutionProvider {
{
let mut cann_options: *mut ort_sys::OrtCANNProviderOptions = std::ptr::null_mut();
crate::ortsys![unsafe CreateCANNProviderOptions(&mut cann_options)?];
let (key_ptrs, value_ptrs, len, keys, values) = super::map_keys! {
device_id = self.device_id,
npu_mem_limit = self.npu_mem_limit,
arena_extend_strategy = self.arena_extend_strategy.as_ref().map(|v| match v {
ArenaExtendStrategy::NextPowerOfTwo => "kNextPowerOfTwo",
ArenaExtendStrategy::SameAsRequested => "kSameAsRequested"
}),
enable_cann_graph = self.enable_cann_graph.map(<bool as Into<i32>>::into),
dump_graphs = self.dump_graphs.map(<bool as Into<i32>>::into),
precision_mode = self.precision_mode.as_ref().map(|v| match v {
CANNExecutionProviderPrecisionMode::ForceFP32 => "force_fp32",
CANNExecutionProviderPrecisionMode::ForceFP16 => "force_fp16",
CANNExecutionProviderPrecisionMode::AllowFP32ToFP16 => "allow_fp32_to_fp16",
CANNExecutionProviderPrecisionMode::MustKeepOrigin => "must_keep_origin_dtype",
CANNExecutionProviderPrecisionMode::AllowMixedPrecision => "allow_mix_precision"
}),
op_select_impl_mode = self.op_select_impl_mode.as_ref().map(|v| match v {
CANNExecutionProviderImplementationMode::HighPrecision => "high_precision",
CANNExecutionProviderImplementationMode::HighPerformance => "high_performance"
}),
optypelist_for_impl_mode = self.optypelist_for_impl_mode.clone()
};
if let Err(e) =
crate::error::status_to_result(crate::ortsys![unsafe UpdateCANNProviderOptions(cann_options, key_ptrs.as_ptr(), value_ptrs.as_ptr(), len as _)])
{
let ffi_options = self.options.to_ffi();
if let Err(e) = crate::error::status_to_result(
crate::ortsys![unsafe UpdateCANNProviderOptions(cann_options, ffi_options.key_ptrs(), ffi_options.value_ptrs(), ffi_options.len())]
) {
crate::ortsys![unsafe ReleaseCANNProviderOptions(cann_options)];
std::mem::drop((keys, values));
return Err(e);
}

let status = crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_CANN(session_builder.session_options_ptr.as_ptr(), cann_options)];
crate::ortsys![unsafe ReleaseCANNProviderOptions(cann_options)];
std::mem::drop((keys, values));
return crate::error::status_to_result(status);
}

Expand Down
97 changes: 40 additions & 57 deletions src/execution_providers/cuda.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::ops::BitOr;

use super::{ArbitrarilyConfigurableExecutionProvider, ExecutionProviderOptions};
use crate::{
error::{Error, Result},
execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch},
Expand Down Expand Up @@ -83,41 +84,34 @@ impl Default for CUDAExecutionProviderCuDNNConvAlgoSearch {

#[derive(Debug, Default, Clone)]
pub struct CUDAExecutionProvider {
device_id: Option<i32>,
gpu_mem_limit: Option<usize>,
user_compute_stream: Option<*mut ()>,
arena_extend_strategy: Option<ArenaExtendStrategy>,
cudnn_conv_algo_search: Option<CUDAExecutionProviderCuDNNConvAlgoSearch>,
do_copy_in_default_stream: Option<bool>,
cudnn_conv_use_max_workspace: Option<bool>,
cudnn_conv1d_pad_to_nc1d: Option<bool>,
enable_cuda_graph: Option<bool>,
enable_skip_layer_norm_strict_mode: Option<bool>,
use_tf32: Option<bool>,
prefer_nhwc: Option<bool>,
sdpa_kernel: Option<u32>,
fuse_conv_bias: Option<bool>
options: ExecutionProviderOptions
}

impl CUDAExecutionProvider {
#[must_use]
pub fn with_device_id(mut self, device_id: i32) -> Self {
self.device_id = Some(device_id);
self.options.set("device_id", device_id.to_string());
self
}

/// Configure the size limit of the device memory arena in bytes. This size limit is only for the execution
/// provider’s arena. The total device memory usage may be higher.
#[must_use]
pub fn with_memory_limit(mut self, limit: usize) -> Self {
self.gpu_mem_limit = Some(limit as _);
self.options.set("gpu_mem_limit", limit.to_string());
self
}

/// Confiure the strategy for extending the device's memory arena.
#[must_use]
pub fn with_arena_extend_strategy(mut self, strategy: ArenaExtendStrategy) -> Self {
self.arena_extend_strategy = Some(strategy);
self.options.set(
"arena_extend_strategy",
match strategy {
ArenaExtendStrategy::NextPowerOfTwo => "kNextPowerOfTwo",
ArenaExtendStrategy::SameAsRequested => "kSameAsRequested"
}
);
self
}

Expand All @@ -127,15 +121,22 @@ impl CUDAExecutionProvider {
/// done for cuDNN convolution algorithms. See [`CUDAExecutionProviderCuDNNConvAlgoSearch`] for more info.
#[must_use]
pub fn with_conv_algorithm_search(mut self, search: CUDAExecutionProviderCuDNNConvAlgoSearch) -> Self {
self.cudnn_conv_algo_search = Some(search);
self.options.set(
"cudnn_conv_algo_search",
match search {
CUDAExecutionProviderCuDNNConvAlgoSearch::Exhaustive => "EXHAUSTIVE",
CUDAExecutionProviderCuDNNConvAlgoSearch::Heuristic => "HEURISTIC",
CUDAExecutionProviderCuDNNConvAlgoSearch::Default => "DEFAULT"
}
);
self
}

/// Whether to do copies in the default stream or use separate streams. The recommended setting is true. If false,
/// there are race conditions and possibly better performance.
#[must_use]
pub fn with_copy_in_default_stream(mut self, enable: bool) -> Self {
self.do_copy_in_default_stream = Some(enable);
self.options.set("do_copy_in_default_stream", if enable { "1" } else { "0" });
self
}

Expand All @@ -149,7 +150,7 @@ impl CUDAExecutionProvider {
/// cuDNN selecting a suboptimal convolution algorithm. The recommended (and default) value is `true`.
#[must_use]
pub fn with_conv_max_workspace(mut self, enable: bool) -> Self {
self.cudnn_conv_use_max_workspace = Some(enable);
self.options.set("cudnn_conv_use_max_workspace", if enable { "1" } else { "0" });
self
}

Expand All @@ -161,7 +162,7 @@ impl CUDAExecutionProvider {
/// true to instead use `[N, C, 1, D]`.
#[must_use]
pub fn with_conv1d_pad_to_nc1d(mut self, enable: bool) -> Self {
self.cudnn_conv1d_pad_to_nc1d = Some(enable);
self.options.set("cudnn_conv1d_pad_to_nc1d", if enable { "1" } else { "0" });
self
}

Expand Down Expand Up @@ -190,15 +191,15 @@ impl CUDAExecutionProvider {
/// > `run()`s only perform graph replays of the graph captured and cached in the first `run()`.
#[must_use]
pub fn with_cuda_graph(mut self) -> Self {
self.enable_cuda_graph = Some(true);
self.options.set("enable_cuda_graph", "1");
self
}

/// Whether to use strict mode in the `SkipLayerNormalization` implementation. The default and recommanded setting
/// is `false`. If enabled, accuracy may improve slightly, but performance may decrease.
#[must_use]
pub fn with_skip_layer_norm_strict_mode(mut self) -> Self {
self.enable_skip_layer_norm_strict_mode = Some(true);
self.options.set("enable_skip_layer_norm_strict_mode", "1");
self
}

Expand All @@ -207,33 +208,33 @@ impl CUDAExecutionProvider {
/// rounded with 10 bits of mantissa and results are accumulated with float32 precision.
#[must_use]
pub fn with_tf32(mut self, enable: bool) -> Self {
self.use_tf32 = Some(enable);
self.options.set("use_tf32", if enable { "1" } else { "0" });
self
}

#[must_use]
pub fn with_prefer_nhwc(mut self) -> Self {
self.prefer_nhwc = Some(true);
self.options.set("prefer_nhwc", "1");
self
}

/// # Safety
/// The provided `stream` must outlive the environment/session created with the execution provider.
#[must_use]
pub unsafe fn with_compute_stream(mut self, stream: *mut ()) -> Self {
self.user_compute_stream = Some(stream);
self.options.set("user_compute_stream", (stream as usize).to_string());
self
}

#[must_use]
pub fn with_attention_backend(mut self, flags: CUDAExecutionProviderAttentionBackend) -> Self {
self.sdpa_kernel = Some(flags.0);
self.options.set("sdpa_kernel", flags.0.to_string());
self
}

#[must_use]
pub fn with_fuse_conv_bias(mut self, enable: bool) -> Self {
self.fuse_conv_bias = Some(enable);
self.options.set("fuse_conv_bias", if enable { "1" } else { "0" });
self
}

Expand All @@ -246,6 +247,13 @@ impl CUDAExecutionProvider {
}
}

impl ArbitrarilyConfigurableExecutionProvider for CUDAExecutionProvider {
fn with_arbitrary_config(mut self, key: impl ToString, value: impl ToString) -> Self {
self.options.set(key.to_string(), value.to_string());
self
}
}

impl From<CUDAExecutionProvider> for ExecutionProviderDispatch {
fn from(value: CUDAExecutionProvider) -> Self {
ExecutionProviderDispatch::new(value)
Expand All @@ -267,41 +275,16 @@ impl ExecutionProvider for CUDAExecutionProvider {
{
let mut cuda_options: *mut ort_sys::OrtCUDAProviderOptionsV2 = std::ptr::null_mut();
crate::ortsys![unsafe CreateCUDAProviderOptions(&mut cuda_options)?];
let (key_ptrs, value_ptrs, len, keys, values) = super::map_keys! {
device_id = self.device_id,
arena_extend_strategy = self.arena_extend_strategy.as_ref().map(|v| match v {
ArenaExtendStrategy::NextPowerOfTwo => "kNextPowerOfTwo",
ArenaExtendStrategy::SameAsRequested => "kSameAsRequested"
}),
cudnn_conv_algo_search = self.cudnn_conv_algo_search.as_ref().map(|v| match v {
CUDAExecutionProviderCuDNNConvAlgoSearch::Exhaustive => "EXHAUSTIVE",
CUDAExecutionProviderCuDNNConvAlgoSearch::Heuristic => "HEURISTIC",
CUDAExecutionProviderCuDNNConvAlgoSearch::Default => "DEFAULT"
}),
// has_user_compute_stream = self.user_compute_stream.as_ref().map(|_| 1),
user_compute_stream = self.user_compute_stream.map(|x| x as usize),
gpu_mem_limit = self.gpu_mem_limit,
do_copy_in_default_stream = self.do_copy_in_default_stream.map(<bool as Into<i32>>::into),
cudnn_conv_use_max_workspace = self.cudnn_conv_use_max_workspace.map(<bool as Into<i32>>::into),
cudnn_conv1d_pad_to_nc1d = self.cudnn_conv1d_pad_to_nc1d.map(<bool as Into<i32>>::into),
enable_cuda_graph = self.enable_cuda_graph.map(<bool as Into<i32>>::into),
enable_skip_layer_norm_strict_mode = self.enable_skip_layer_norm_strict_mode.map(<bool as Into<i32>>::into),
use_tf32 = self.use_tf32.map(<bool as Into<i32>>::into),
prefer_nhwc = self.prefer_nhwc.map(<bool as Into<i32>>::into),
sdpa_kernel = self.sdpa_kernel,
fuse_conv_bias = self.fuse_conv_bias.map(<bool as Into<i32>>::into)
};
if let Err(e) =
crate::error::status_to_result(crate::ortsys![unsafe UpdateCUDAProviderOptions(cuda_options, key_ptrs.as_ptr(), value_ptrs.as_ptr(), len as _)])
{
let ffi_options = self.options.to_ffi();
if let Err(e) = crate::error::status_to_result(
crate::ortsys![unsafe UpdateCUDAProviderOptions(cuda_options, ffi_options.key_ptrs(), ffi_options.value_ptrs(), ffi_options.len())]
) {
crate::ortsys![unsafe ReleaseCUDAProviderOptions(cuda_options)];
std::mem::drop((keys, values));
return Err(e);
}

let status = crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_CUDA_V2(session_builder.session_options_ptr.as_ptr(), cuda_options)];
crate::ortsys![unsafe ReleaseCUDAProviderOptions(cuda_options)];
std::mem::drop((keys, values));
return crate::error::status_to_result(status);
}

Expand Down
Loading

0 comments on commit abd527b

Please sign in to comment.