Skip to content

Commit

Permalink
feat: MIGraphX execution provider, ref #212
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Jun 20, 2024
1 parent a92dd30 commit 19d66de
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
79 changes: 79 additions & 0 deletions src/execution_providers/migraphx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use std::{ffi::CString, ptr};

use super::ExecutionProvider;
use crate::{ortsys, Error, ExecutionProviderDispatch, Result, SessionBuilder};

#[derive(Debug, Default, Clone)]
pub struct MIGraphXExecutionProvider {
device_id: i32,
enable_fp16: bool,
enable_int8: bool,
use_native_calibration_table: bool,
int8_calibration_table_name: Option<CString>
}

impl MIGraphXExecutionProvider {
#[must_use]
pub fn with_device_id(mut self, device_id: i32) -> Self {
self.device_id = device_id;
self
}

#[must_use]
pub fn with_fp16(mut self) -> Self {
self.enable_fp16 = true;
self
}

#[must_use]
pub fn with_int8(mut self) -> Self {
self.enable_int8 = true;
self
}

#[must_use]
pub fn with_native_calibration_table(mut self, table_name: Option<impl AsRef<str>>) -> Self {
self.use_native_calibration_table = true;
self.int8_calibration_table_name = table_name.map(|c| CString::new(c.as_ref()).expect("invalid string"));
self
}

#[must_use]
pub fn build(self) -> ExecutionProviderDispatch {
self.into()
}
}

impl From<MIGraphXExecutionProvider> for ExecutionProviderDispatch {
fn from(value: MIGraphXExecutionProvider) -> Self {
ExecutionProviderDispatch::new(value)
}
}

impl ExecutionProvider for MIGraphXExecutionProvider {
fn as_str(&self) -> &'static str {
"MIGraphXExecutionProvider"
}

fn supported_by_platform(&self) -> bool {
cfg!(any(all(target_os = "linux", target_arch = "x86_64"), all(target_os = "windows", target_arch = "x86_64")))
}

#[allow(unused, unreachable_code)]
fn register(&self, session_builder: &SessionBuilder) -> Result<()> {
#[cfg(any(feature = "load-dynamic", feature = "migraphx"))]
{
let options = ort_sys::OrtMIGraphXProviderOptions {
device_id: self.device_id,
migraphx_fp16_enable: self.enable_fp16.into(),
migraphx_int8_enable: self.enable_int8.into(),
migraphx_use_native_calibration_table: self.use_native_calibration_table.into(),
migraphx_int8_calibration_table_name: self.int8_calibration_table_name.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null)
};
ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.session_options_ptr.as_ptr(), &options) -> Error::ExecutionProvider];
return Ok(());
}

Err(Error::ExecutionProviderNotRegistered(self.as_str()))
}
}
2 changes: 2 additions & 0 deletions src/execution_providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ mod xnnpack;
pub use self::xnnpack::XNNPACKExecutionProvider;
mod armnn;
pub use self::armnn::ArmNNExecutionProvider;
mod migraphx;
pub use self::migraphx::MIGraphXExecutionProvider;

/// ONNX Runtime works with different hardware acceleration libraries through its extensible **Execution Providers**
/// (EP) framework to optimally execute the ONNX models on the hardware platform. This interface enables flexibility for
Expand Down

0 comments on commit 19d66de

Please sign in to comment.