diff --git a/backends/tract/tests/session.rs b/backends/tract/tests/session.rs index e6f696f0..5320482a 100644 --- a/backends/tract/tests/session.rs +++ b/backends/tract/tests/session.rs @@ -14,7 +14,7 @@ fn mnist_5() -> ort::Result<()> { ort::set_api(ort_tract::api()); - let session = Session::builder()? + let mut session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level3)? .commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx") .expect("Could not download model from file"); diff --git a/docs/pages/index.mdx b/docs/pages/index.mdx index 898d1c73..37f7677e 100644 --- a/docs/pages/index.mdx +++ b/docs/pages/index.mdx @@ -53,7 +53,7 @@ Once you've got a model, load it via `ort` by creating a [`Session`](/fundamenta ```rust use ort::session::{builder::GraphOptimizationLevel, Session}; -let model = Session::builder()? +let mut model = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level3)? .with_intra_threads(4)? .commit_from_file("yolov8m.onnx")?; diff --git a/docs/pages/perf/io-binding.mdx b/docs/pages/perf/io-binding.mdx index 0de5df27..2489fad0 100644 --- a/docs/pages/perf/io-binding.mdx +++ b/docs/pages/perf/io-binding.mdx @@ -62,9 +62,9 @@ Outputs can be bound to any device -- they can even stay on the EP device if you If you do bind an output to the session's device, it is not guaranteed to be synchronized after `run`, just like `bind_input`. You can force outputs to synchronize immediately using `IoBinding::synchronize_outputs`. ## Running -To run a session using an I/O binding, you simply call `run()` on `IoBinding` instead of on the session: +To run a session using an I/O binding, you simply call the session's `run_binding()` function with the created `IoBinding`: ```rs -let outputs = io_binding.run()?; +let outputs = session.run_binding(&binding)?; ``` `outputs` provides the same interface as the outputs returned by `Session::run`, it just returns the outputs that you bound earlier. @@ -77,10 +77,10 @@ let action: Tensor = outputs.remove("action").unwrap().downcast()?; Here is a more complete example of the I/O binding API in a scenario where I/O performance can be improved significantly. This example features a typical text-to-image diffusion pipeline, using a text encoder like CLIP to create the condition tensor and a UNet for diffusion. ```rs -let text_encoder = Session::builder()? +let mut text_encoder = Session::builder()? .with_execution_providers([CUDAExecutionProvider::default().build()])? .commit_from_file("text_encoder.onnx")?; -let unet = Session::builder()? +let mut unet = Session::builder()? .with_execution_providers([CUDAExecutionProvider::default().build()])? .commit_from_file("unet.onnx")?; @@ -91,7 +91,7 @@ let text_condition = { vec![49, 272, 503, 286, 1396, 353, 9653, 284, 1234, 287, 616, 2438, 11, 7926, 13, 3423, 338, 3362, 25, 12520, 238, 242] ))?)?; binding.bind_output_to_device("output0", &text_encoder.allocator().memory_info())?; - binding.run()?.remove("output0").unwrap() + text_encoder.run_binding(&binding)?.remove("output0").unwrap() }; let input_allocator = Allocator::new( @@ -111,7 +111,7 @@ io_binding.bind_output("noise_pred", Tensor::::new(&output_allocator, [1, 4 for _ in 0..20 { io_binding.bind_input("latents", &latents)?; - let noise_pred = io_binding.run()?.remove("noise_pred").unwrap(); + let noise_pred = unet.run_binding(&io_binding)?.remove("noise_pred").unwrap(); let mut latents = latents.extract_tensor_mut(); latents += &noise_pred.try_extract_tensor::()?; diff --git a/examples/async-gpt2-api/examples/async-gpt2-api.rs b/examples/async-gpt2-api/examples/async-gpt2-api.rs index 8a783c0e..cdb2a0ca 100644 --- a/examples/async-gpt2-api/examples/async-gpt2-api.rs +++ b/examples/async-gpt2-api/examples/async-gpt2-api.rs @@ -17,7 +17,7 @@ use ort::{ }; use rand::Rng; use tokenizers::Tokenizer; -use tokio::net::TcpListener; +use tokio::{net::TcpListener, sync::Mutex}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] @@ -44,7 +44,7 @@ async fn main() -> anyhow::Result<()> { let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap(); let app_state = AppState { - session: Arc::new(session), + session: Arc::new(Mutex::new(session)), tokenizer: Arc::new(tokenizer) }; @@ -59,15 +59,21 @@ async fn main() -> anyhow::Result<()> { #[derive(Clone)] struct AppState { - session: Arc, + session: Arc>, tokenizer: Arc } -fn generate_stream(tokenizer: Arc, session: Arc, mut tokens: Vec, gen_tokens: usize) -> impl Stream> + Send { +fn generate_stream( + tokenizer: Arc, + session: Arc>, + mut tokens: Vec, + gen_tokens: usize +) -> impl Stream> + Send { async_stream_lite::try_async_stream(|yielder| async move { for _ in 0..gen_tokens { + let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?; let probabilities = { - let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?; + let mut session = session.lock().await; let outputs = session.run_async(ort::inputs![input])?.await?; let (dim, probabilities) = outputs["output1"].try_extract_raw_tensor()?; @@ -93,7 +99,7 @@ fn generate_stream(tokenizer: Arc, session: Arc, mut tokens: }) } -impl FromRef for Arc { +impl FromRef for Arc> { fn from_ref(input: &AppState) -> Self { Arc::clone(&input.session) } @@ -104,6 +110,6 @@ impl FromRef for Arc { } } -async fn generate(State(session): State>, State(tokenizer): State>) -> Sse>> { +async fn generate(State(session): State>>, State(tokenizer): State>) -> Sse>> { Sse::new(generate_stream(tokenizer, session, vec![0], 50)).keep_alive(KeepAlive::new()) } diff --git a/examples/cudarc/src/main.rs b/examples/cudarc/src/main.rs index 14c333c4..afc76964 100644 --- a/examples/cudarc/src/main.rs +++ b/examples/cudarc/src/main.rs @@ -19,7 +19,7 @@ fn main() -> anyhow::Result<()> { .with_execution_providers([CUDAExecutionProvider::default().build().error_on_failure()]) .commit()?; - let model = + let mut session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/modnet_photographic_portrait_matting.onnx")?; let original_img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("photo.jpg")).unwrap(); @@ -45,7 +45,7 @@ fn main() -> anyhow::Result<()> { ) .unwrap() }; - let outputs = model.run(ort::inputs![tensor])?; + let outputs = session.run(ort::inputs![tensor])?; let output = outputs["output"].try_extract_tensor::()?; diff --git a/examples/custom-ops/examples/custom-ops.rs b/examples/custom-ops/examples/custom-ops.rs index f92c0646..3a0539a4 100644 --- a/examples/custom-ops/examples/custom-ops.rs +++ b/examples/custom-ops/examples/custom-ops.rs @@ -74,7 +74,7 @@ impl Operator for CustomOpTwo { } fn main() -> ort::Result<()> { - let session = Session::builder()? + let mut session = Session::builder()? .with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)? .commit_from_file("tests/data/custom_op_test.onnx")?; diff --git a/examples/gpt2/examples/gpt2.rs b/examples/gpt2/examples/gpt2.rs index bf679b2d..7366c4ca 100644 --- a/examples/gpt2/examples/gpt2.rs +++ b/examples/gpt2/examples/gpt2.rs @@ -35,7 +35,7 @@ fn main() -> ort::Result<()> { let mut rng = rand::thread_rng(); // Load our model - let session = Session::builder()? + let mut session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? .commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/gpt2.onnx")?; diff --git a/examples/modnet/examples/modnet.rs b/examples/modnet/examples/modnet.rs index 5ce39a93..264811b8 100644 --- a/examples/modnet/examples/modnet.rs +++ b/examples/modnet/examples/modnet.rs @@ -15,7 +15,7 @@ fn main() -> ort::Result<()> { .with_execution_providers([CUDAExecutionProvider::default().build()]) .commit()?; - let model = + let mut session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/modnet_photographic_portrait_matting.onnx")?; let original_img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("photo.jpg")).unwrap(); @@ -31,7 +31,7 @@ fn main() -> ort::Result<()> { input[[0, 2, y, x]] = (b as f32 - 127.5) / 127.5; } - let outputs = model.run(inputs!["input" => TensorRef::from_array_view(input.view())?])?; + let outputs = session.run(inputs!["input" => TensorRef::from_array_view(input.view())?])?; let output = outputs["output"].try_extract_tensor::()?; diff --git a/examples/phi-3-vision/src/main.rs b/examples/phi-3-vision/src/main.rs index ffd8cf31..d48eff9d 100644 --- a/examples/phi-3-vision/src/main.rs +++ b/examples/phi-3-vision/src/main.rs @@ -24,7 +24,7 @@ fn get_current_time() -> Instant { Instant::now() } -fn get_image_embedding(vision_model: &Session, img: &Option) -> Result> { +fn get_image_embedding(vision_model: &mut Session, img: &Option) -> Result> { let visual_features = if let Some(img) = img { let image_processor = image_process::Phi3VImageProcessor::new(); let result = image_processor.preprocess(img)?; @@ -46,7 +46,7 @@ fn get_image_embedding(vision_model: &Session, img: &Option) -> Re Ok(visual_features) } -fn get_text_embedding(text_embedding_model: &Session, input_ids: &Array2) -> Result> { +fn get_text_embedding(text_embedding_model: &mut Session, input_ids: &Array2) -> Result> { let outputs = text_embedding_model.run(ort::inputs![ "input_ids" => TensorRef::from_array_view(input_ids)?, ])?; @@ -102,9 +102,9 @@ fn format_chat_template(img: &Option, txt: &str) -> String { pub async fn generate_text( tokenizer: &Tokenizer, - vision_model: &Session, - text_embedding_model: &Session, - generation_model: &Session, + vision_model: &mut Session, + text_embedding_model: &mut Session, + generation_model: &mut Session, image: &Option, text: &str ) -> Result<()> { @@ -213,19 +213,19 @@ async fn main() -> Result<()> { let data_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("data"); let tokenizer = Tokenizer::from_file(data_dir.join("tokenizer.json")).map_err(|e| anyhow::anyhow!("Error loading tokenizer: {:?}", e))?; - let vision_model = Session::builder()?.commit_from_file(data_dir.join(VISION_MODEL_NAME))?; - let text_embedding_model = Session::builder()?.commit_from_file(data_dir.join(TEXT_EMBEDDING_MODEL_NAME))?; - let generation_model = Session::builder()?.commit_from_file(data_dir.join(GENERATION_MODEL_NAME))?; + let mut vision_model = Session::builder()?.commit_from_file(data_dir.join(VISION_MODEL_NAME))?; + let mut text_embedding_model = Session::builder()?.commit_from_file(data_dir.join(TEXT_EMBEDDING_MODEL_NAME))?; + let mut generation_model = Session::builder()?.commit_from_file(data_dir.join(GENERATION_MODEL_NAME))?; // Generate text from text let image: Option = None; let text = "Who are you?".to_string(); - generate_text(&tokenizer, &vision_model, &text_embedding_model, &generation_model, &image, &text).await?; + generate_text(&tokenizer, &mut vision_model, &mut text_embedding_model, &mut generation_model, &image, &text).await?; // Generate text from image and text let image: Option = Some(image::open(data_dir.join("example.jpg"))?); let text = "What is shown in this image?".to_string(); - generate_text(&tokenizer, &vision_model, &text_embedding_model, &generation_model, &image, &text).await?; + generate_text(&tokenizer, &mut vision_model, &mut text_embedding_model, &mut generation_model, &image, &text).await?; Ok(()) } diff --git a/examples/sentence-transformers/examples/semantic-similarity.rs b/examples/sentence-transformers/examples/semantic-similarity.rs index 6aed6e63..4f228f1f 100644 --- a/examples/sentence-transformers/examples/semantic-similarity.rs +++ b/examples/sentence-transformers/examples/semantic-similarity.rs @@ -25,7 +25,7 @@ fn main() -> ort::Result<()> { .commit()?; // Load our model - let session = Session::builder()? + let mut session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? .commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/all-MiniLM-L6-v2.onnx")?; diff --git a/examples/training/examples/train-clm-simple.rs b/examples/training/examples/train-clm-simple.rs index 187cf827..b1bc3551 100644 --- a/examples/training/examples/train-clm-simple.rs +++ b/examples/training/examples/train-clm-simple.rs @@ -110,7 +110,7 @@ fn main() -> ort::Result<()> { trainer.export("trained-clm.onnx", ["probs"])?; - let session = Session::builder()?.commit_from_file("trained-clm.onnx")?; + let mut session = Session::builder()?.commit_from_file("trained-clm.onnx")?; let mut stdout = std::io::stdout(); diff --git a/examples/training/examples/train-clm.rs b/examples/training/examples/train-clm.rs index 2c29c69d..3f94da03 100644 --- a/examples/training/examples/train-clm.rs +++ b/examples/training/examples/train-clm.rs @@ -102,7 +102,7 @@ fn main() -> ort::Result<()> { trainer.export("trained-clm.onnx", ["probs"])?; - let session = Session::builder()?.commit_from_file("trained-clm.onnx")?; + let mut session = Session::builder()?.commit_from_file("trained-clm.onnx")?; let mut stdout = std::io::stdout(); diff --git a/examples/yolov8/Cargo.toml b/examples/yolov8/Cargo.toml index bfa1def2..5ca11aba 100644 --- a/examples/yolov8/Cargo.toml +++ b/examples/yolov8/Cargo.toml @@ -5,7 +5,7 @@ version = "0.0.0" edition = "2021" [dependencies] -ort = { path = "../../" } +ort = { path = "../../", features = [ "fetch-models" ] } ndarray = "0.16" tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] } image = "0.25" diff --git a/examples/yolov8/examples/yolov8.rs b/examples/yolov8/examples/yolov8.rs index 7c24ddf5..f7625982 100644 --- a/examples/yolov8/examples/yolov8.rs +++ b/examples/yolov8/examples/yolov8.rs @@ -64,7 +64,7 @@ fn main() -> ort::Result<()> { input[[0, 2, y, x]] = (b as f32) / 255.; } - let model = Session::builder()?.commit_from_url(YOLOV8M_URL)?; + let mut model = Session::builder()?.commit_from_url(YOLOV8M_URL)?; // Run YOLOv8 inference let outputs: SessionOutputs = model.run(inputs!["images" => TensorRef::from_array_view(&input)?])?; diff --git a/src/adapter.rs b/src/adapter.rs index e5ff368c..db6fc421 100644 --- a/src/adapter.rs +++ b/src/adapter.rs @@ -36,7 +36,7 @@ impl Drop for AdapterInner { /// ``` /// # use ort::{adapter::Adapter, session::{run_options::RunOptions, Session}, value::Tensor}; /// # fn main() -> ort::Result<()> { -/// let model = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?; +/// let mut model = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?; /// let lora = Adapter::from_file("tests/data/adapter.orl", None)?; /// /// let mut run_options = RunOptions::new()?; @@ -69,7 +69,7 @@ impl Adapter { /// # value::Tensor /// # }; /// # fn main() -> ort::Result<()> { - /// let model = Session::builder()? + /// let mut model = Session::builder()? /// .with_execution_providers([CUDAExecutionProvider::default().build()])? /// .commit_from_file("tests/data/lora_model.onnx")?; /// @@ -116,7 +116,7 @@ impl Adapter { /// # value::Tensor /// # }; /// # fn main() -> ort::Result<()> { - /// let model = Session::builder()? + /// let mut model = Session::builder()? /// .with_execution_providers([CUDAExecutionProvider::default().build()])? /// .commit_from_file("tests/data/lora_model.onnx")?; /// @@ -166,7 +166,7 @@ mod tests { #[test] fn test_lora() -> crate::Result<()> { let model = std::fs::read("tests/data/lora_model.onnx").expect(""); - let session = Session::builder()?.commit_from_memory(&model)?; + let mut session = Session::builder()?.commit_from_memory(&model)?; let lora = std::fs::read("tests/data/adapter.orl").expect(""); let lora = Adapter::from_memory(&lora, None)?; @@ -190,7 +190,7 @@ mod tests { #[test] fn test_lora_from_memory() -> crate::Result<()> { let model = std::fs::read("tests/data/lora_model.onnx").expect(""); - let session = Session::builder()?.commit_from_memory(&model)?; + let mut session = Session::builder()?.commit_from_memory(&model)?; let lora_bytes = std::fs::read("tests/data/adapter.orl").expect(""); let lora = Adapter::from_memory(&lora_bytes, None)?; diff --git a/src/io_binding.rs b/src/io_binding.rs index 2fe2ca78..59989a0d 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -3,13 +3,11 @@ use alloc::{ ffi::CString, string::{String, ToString}, - sync::Arc, - vec::Vec + sync::Arc }; use core::{ fmt::Debug, - ptr::{self, NonNull}, - slice + ptr::{self, NonNull} }; use crate::{ @@ -17,7 +15,7 @@ use crate::{ error::Result, memory::MemoryInfo, ortsys, - session::{NoSelectedOutputs, RunOptions, Session, SharedSessionInner, output::SessionOutputs}, + session::{Session, SharedSessionInner}, util::MiniMap, value::{DynValue, Value, ValueInner, ValueTypeMarker} }; @@ -50,10 +48,10 @@ use crate::{ /// # value::Tensor /// # }; /// # fn main() -> ort::Result<()> { -/// let text_encoder = Session::builder()? +/// let mut text_encoder = Session::builder()? /// .with_execution_providers([CUDAExecutionProvider::default().build()])? /// .commit_from_file("text_encoder.onnx")?; -/// let unet = Session::builder()? +/// let mut unet = Session::builder()? /// .with_execution_providers([CUDAExecutionProvider::default().build()])? /// .commit_from_file("unet.onnx")?; /// @@ -82,7 +80,7 @@ use crate::{ /// /// for _ in 0..20 { /// io_binding.bind_input("latents", &latents)?; -/// let noise_pred = io_binding.run()?.remove("noise_pred").unwrap(); +/// let noise_pred = unet.run_binding(&io_binding)?.remove("noise_pred").unwrap(); /// /// let mut latents = latents.extract_tensor_mut(); /// latents += &noise_pred.try_extract_tensor::()?; @@ -99,8 +97,8 @@ use crate::{ pub struct IoBinding { ptr: NonNull, held_inputs: MiniMap>, - output_values: MiniMap>, - session: Arc + pub(crate) output_values: MiniMap>, + _session: Arc } impl IoBinding { @@ -109,9 +107,9 @@ impl IoBinding { ortsys![unsafe CreateIoBinding(session.ptr().cast_mut(), &mut ptr)?; nonNull(ptr)]; Ok(Self { ptr: unsafe { NonNull::new_unchecked(ptr) }, - session: session.inner(), held_inputs: MiniMap::new(), - output_values: MiniMap::new() + output_values: MiniMap::new(), + _session: session.inner() }) } @@ -197,55 +195,6 @@ impl IoBinding { self.synchronize_outputs()?; Ok(()) } - - /// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`]. - pub fn run(&mut self) -> Result> { - self.run_inner(None) - } - - /// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`]. - pub fn run_with_options(&mut self, run_options: &RunOptions) -> Result> { - self.run_inner(Some(run_options)) - } - - fn run_inner(&mut self, run_options: Option<&RunOptions>) -> Result> { - let run_options_ptr = if let Some(run_options) = run_options { run_options.ptr() } else { ptr::null() }; - ortsys![unsafe RunWithBinding(self.session.ptr().cast_mut(), run_options_ptr, self.ptr())?]; - - // let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Value> = self.output_values.values().map(|c| (c.ptr().cast_mut(), - // c)).collect(); - let mut count = self.output_values.len(); - if count > 0 { - let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut(); - ortsys![unsafe GetBoundOutputValues(self.ptr(), self.session.allocator.ptr().cast_mut(), &mut output_values_ptr, &mut count)?; nonNull(output_values_ptr)]; - - let output_values = unsafe { slice::from_raw_parts(output_values_ptr, count).to_vec() } - .into_iter() - .zip(self.output_values.iter()) - .map(|(ptr, (_, value))| unsafe { - if let Some(value) = value { - DynValue::clone_of(value) - } else { - DynValue::from_ptr( - NonNull::new(ptr).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"), - Some(Arc::clone(&self.session)) - ) - } - }) - .collect::>(); - - // output values will be freed when the `Value`s in `SessionOutputs` drop - - Ok(SessionOutputs::new_backed( - self.output_values.iter().map(|(k, _)| k.as_str()).collect(), - output_values, - &self.session.allocator, - output_values_ptr.cast() - )) - } else { - Ok(SessionOutputs::new_empty()) - } - } } unsafe impl Send for IoBinding {} @@ -308,7 +257,7 @@ mod tests { #[test] #[cfg(all(feature = "ndarray", feature = "fetch-models"))] fn test_mnist_input_bound() -> Result<()> { - let session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?; + let mut session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?; let array = get_image(); @@ -316,7 +265,7 @@ mod tests { binding.bind_input(&session.inputs[0].name, &Tensor::from_array(array)?)?; binding.bind_output_to_device(&session.outputs[0].name, &MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::CPUOutput)?)?; - let outputs = binding.run()?; + let outputs = session.run_binding(&binding)?; let probabilities = extract_probabilities(&outputs[0])?; assert_eq!(probabilities[0].0, 5); @@ -326,7 +275,7 @@ mod tests { #[test] #[cfg(all(feature = "ndarray", feature = "fetch-models"))] fn test_mnist_input_output_bound() -> Result<()> { - let session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?; + let mut session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?; let array = get_image(); @@ -336,7 +285,7 @@ mod tests { let output = Array2::from_shape_simple_fn((1, 10), || 0.0_f32); binding.bind_output(&session.outputs[0].name, Tensor::from_array(output)?)?; - let outputs = binding.run()?; + let outputs = session.run_binding(&binding)?; let probabilities = extract_probabilities(&outputs[0])?; assert_eq!(probabilities[0].0, 5); @@ -346,7 +295,7 @@ mod tests { #[test] #[cfg(all(feature = "ndarray", feature = "fetch-models"))] fn test_send_iobinding() -> Result<()> { - let session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?; + let mut session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?; let array = get_image(); @@ -356,7 +305,7 @@ mod tests { let probabilities = std::thread::spawn(move || { binding.bind_input(&session.inputs[0].name, &Tensor::from_array(array)?)?; - let outputs = binding.run()?; + let outputs = session.run_binding(&binding)?; let probabilities = extract_probabilities(&outputs[0])?; Ok::, crate::Error>(probabilities) }) @@ -371,7 +320,7 @@ mod tests { #[test] #[cfg(all(feature = "ndarray", feature = "fetch-models"))] fn test_mnist_clear_bounds() -> Result<()> { - let session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?; + let mut session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?; let array = get_image(); @@ -382,7 +331,7 @@ mod tests { binding.bind_output(&session.outputs[0].name, Tensor::from_array(output)?)?; { - let outputs = binding.run()?; + let outputs = session.run_binding(&binding)?; let probabilities = extract_probabilities(&outputs[0])?; assert_eq!(probabilities[0].0, 5); } @@ -391,13 +340,13 @@ mod tests { binding.bind_output_to_device(&session.outputs[0].name, &MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::CPUOutput)?)?; { - let outputs = binding.run()?; + let outputs = session.run_binding(&binding)?; let probabilities = extract_probabilities(&outputs[0])?; assert_eq!(probabilities[0].0, 5); } binding.clear_inputs(); - assert!(binding.run().is_err()); + assert!(session.run_binding(&binding).is_err()); Ok(()) } diff --git a/src/operator/tests.rs b/src/operator/tests.rs index e9b8d242..3ec78a75 100644 --- a/src/operator/tests.rs +++ b/src/operator/tests.rs @@ -78,7 +78,7 @@ impl Operator for CustomOpTwo { #[test] fn test_custom_ops() -> crate::Result<()> { let model = std::fs::read("tests/data/custom_op_test.onnx").expect(""); - let session = Session::builder()? + let mut session = Session::builder()? .with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)? .commit_from_memory(&model)?; diff --git a/src/session/mod.rs b/src/session/mod.rs index 1fe2676a..c3da343a 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -3,7 +3,7 @@ //! ``` //! # use ort::{session::Session, value::TensorRef}; //! # fn main() -> ort::Result<()> { -//! let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +//! let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; //! let input = ndarray::Array4::::zeros((1, 64, 64, 3)); //! let outputs = session.run(ort::inputs![TensorRef::from_array_view(&input)?])?; //! # Ok(()) @@ -16,8 +16,9 @@ use core::{ ffi::{CStr, c_char}, iter, marker::PhantomData, - ops::Deref, - ptr::{self, NonNull} + ops::{Deref, DerefMut}, + ptr::{self, NonNull}, + slice }; use crate::{ @@ -27,7 +28,7 @@ use crate::{ memory::Allocator, metadata::ModelMetadata, ortsys, - value::{Value, ValueType} + value::{DynValue, Value, ValueType} }; #[cfg(feature = "std")] @@ -83,7 +84,7 @@ impl Drop for SharedSessionInner { /// ``` /// # use ort::{session::Session, value::TensorRef}; /// # fn main() -> ort::Result<()> { -/// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// let input = ndarray::Array4::::zeros((1, 64, 64, 3)); /// let outputs = session.run(ort::inputs![TensorRef::from_array_view(&input)?])?; /// # Ok(()) @@ -113,6 +114,11 @@ impl Deref for InMemorySession<'_> { &self.session } } +impl DerefMut for InMemorySession<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.session + } +} /// Information about a [`Session`] input. #[derive(Debug)] @@ -188,13 +194,13 @@ impl Session { /// # use std::sync::Arc; /// # use ort::{session::{run_options::RunOptions, Session}, tensor::TensorElementType, value::{Value, ValueType, TensorRef}}; /// # fn main() -> ort::Result<()> { - /// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// let input = ndarray::Array4::::zeros((1, 64, 64, 3)); /// let outputs = session.run(ort::inputs![TensorRef::from_array_view(&input)?])?; /// # Ok(()) /// # } /// ``` - pub fn run<'s, 'i, 'v: 'i, const N: usize>(&'s self, input_values: impl Into>) -> Result> { + pub fn run<'s, 'i, 'v: 'i, const N: usize>(&'s mut self, input_values: impl Into>) -> Result> { match input_values.into() { SessionInputs::ValueSlice(input_values) => { self.run_inner::(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) @@ -218,7 +224,7 @@ impl Session { /// # use std::sync::Arc; /// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef}, tensor::TensorElementType}; /// # fn main() -> ort::Result<()> { - /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// # let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// # let input = Value::from_array(ndarray::Array4::::zeros((1, 64, 64, 3)))?; /// let run_options = Arc::new(RunOptions::new()?); /// @@ -237,7 +243,7 @@ impl Session { /// # } /// ``` pub fn run_with_options<'r, 's: 'r, 'i, 'v: 'i, O: SelectedOutputMarker, const N: usize>( - &'s self, + &'s mut self, input_values: impl Into>, run_options: &'r RunOptions ) -> Result> { @@ -333,6 +339,58 @@ impl Session { Ok(SessionOutputs::new(output_names, outputs)) } + pub fn run_binding<'b, 's: 'b>(&'s mut self, binding: &'b IoBinding) -> Result> { + self.run_binding_inner(binding, None) + } + + pub fn run_binding_with_options<'r, 'b, 's: 'b>( + &'s mut self, + binding: &'b IoBinding, + run_options: &'r RunOptions + ) -> Result> { + self.run_binding_inner(binding, Some(run_options)) + } + + fn run_binding_inner<'r, 'b, 's: 'b>( + &'s self, + binding: &'b IoBinding, + run_options: Option<&'r RunOptions> + ) -> Result> { + let run_options_ptr = if let Some(run_options) = run_options { run_options.ptr() } else { ptr::null() }; + ortsys![unsafe RunWithBinding(self.inner.ptr().cast_mut(), run_options_ptr, binding.ptr())?]; + + // let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Value> = self.output_values.values().map(|c| (c.ptr().cast_mut(), + // c)).collect(); + let mut count = binding.output_values.len(); + if count > 0 { + let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut(); + ortsys![unsafe GetBoundOutputValues(binding.ptr(), self.inner.allocator.ptr().cast_mut(), &mut output_values_ptr, &mut count)?; nonNull(output_values_ptr)]; + + let output_values = unsafe { slice::from_raw_parts(output_values_ptr, count).to_vec() } + .into_iter() + .zip(binding.output_values.iter()) + .map(|(ptr, (_, value))| unsafe { + if let Some(value) = value { + DynValue::clone_of(value) + } else { + DynValue::from_ptr(NonNull::new(ptr).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"), Some(self.inner())) + } + }) + .collect::>(); + + // output values will be freed when the `Value`s in `SessionOutputs` drop + + Ok(SessionOutputs::new_backed( + binding.output_values.iter().map(|(k, _)| k.as_str()).collect(), + output_values, + self.allocator(), + output_values_ptr.cast() + )) + } else { + Ok(SessionOutputs::new_empty()) + } + } + /// Asynchronously run input data through the ONNX graph, performing inference. /// /// Inference will be performed on a thread in the session's thread pool. **Thus, the session must have been @@ -346,16 +404,16 @@ impl Session { /// # use std::sync::Arc; /// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef}, tensor::TensorElementType}; /// # fn main() -> ort::Result<()> { tokio_test::block_on(async { - /// let session = Session::builder()?.with_intra_threads(2)?.commit_from_file("tests/data/upsample.onnx")?; + /// let mut session = Session::builder()?.with_intra_threads(2)?.commit_from_file("tests/data/upsample.onnx")?; /// let input = ndarray::Array4::::zeros((1, 64, 64, 3)); /// let outputs = session.run_async(ort::inputs![TensorRef::from_array_view(&input)?])?.await?; /// # Ok(()) /// # }) } /// ``` #[cfg(feature = "std")] - #[cfg_attr(docsrs, doc(cfg(feature = "std")))] // TODO: parking_lot + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub fn run_async<'s, 'i, 'v: 'i + 's, const N: usize>( - &'s self, + &'s mut self, input_values: impl Into> ) -> Result> { match input_values.into() { @@ -372,9 +430,9 @@ impl Session { /// Asynchronously run input data through the ONNX graph, performing inference, with the given [`RunOptions`]. /// See [`Session::run_with_options`] and [`Session::run_async`] for more details. #[cfg(feature = "std")] - #[cfg_attr(docsrs, doc(cfg(feature = "std")))] // TODO: parking_lot + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub fn run_async_with_options<'s, 'i, 'v: 'i + 's, 'r, O: SelectedOutputMarker, const N: usize>( - &'s self, + &'s mut self, input_values: impl Into>, run_options: &'r RunOptions ) -> Result> { @@ -468,7 +526,7 @@ impl Session { /// Ends profiling for this session. /// /// Note that this must be explicitly called at the end of profiling, otherwise the profiling file will be empty. - pub fn end_profiling(&self) -> Result { + pub fn end_profiling(&mut self) -> Result { let mut profiling_name: *mut c_char = ptr::null_mut(); ortsys![unsafe SessionEndProfiling(self.inner.session_ptr.as_ptr(), self.inner.allocator.ptr().cast_mut(), &mut profiling_name)]; @@ -483,7 +541,7 @@ impl Session { /// # use std::sync::Arc; /// # use ort::{session::{run_options::RunOptions, Session, WorkloadType}, tensor::TensorElementType, value::{Value, ValueType, TensorRef}}; /// # fn main() -> ort::Result<()> { - /// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// session.set_workload_type(WorkloadType::Efficient)?; /// /// let input = ndarray::Array4::::zeros((1, 64, 64, 3)); @@ -491,7 +549,7 @@ impl Session { /// # Ok(()) /// # } /// ``` - pub fn set_workload_type(&self, workload_type: WorkloadType) -> Result<()> { + pub fn set_workload_type(&mut self, workload_type: WorkloadType) -> Result<()> { static KEY: &[u8] = b"ep.dynamic.workload_type\0"; match workload_type { WorkloadType::Default => self.set_dynamic_option(KEY.as_ptr().cast(), b"Default\0".as_ptr().cast()), @@ -499,7 +557,7 @@ impl Session { } } - pub(crate) fn set_dynamic_option(&self, key: *const c_char, value: *const c_char) -> Result<()> { + pub(crate) fn set_dynamic_option(&mut self, key: *const c_char, value: *const c_char) -> Result<()> { ortsys![unsafe SetEpDynamicOptions(self.inner.session_ptr.as_ptr(), &key, &value, 1)?]; Ok(()) } diff --git a/src/session/output.rs b/src/session/output.rs index 81ff92cc..f2faa9a0 100644 --- a/src/session/output.rs +++ b/src/session/output.rs @@ -19,7 +19,7 @@ use crate::{ /// ``` /// # use ort::{value::TensorRef, session::{builder::GraphOptimizationLevel, Session}}; /// # fn main() -> ort::Result<()> { -/// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// let input = ndarray::Array4::::zeros((1, 64, 64, 3)); /// let outputs = session.run(ort::inputs![TensorRef::from_array_view(&input)?])?; /// diff --git a/src/session/run_options.rs b/src/session/run_options.rs index 0f79d5b6..ca722487 100644 --- a/src/session/run_options.rs +++ b/src/session/run_options.rs @@ -22,7 +22,7 @@ use crate::{ /// # use std::sync::Arc; /// # use ort::{session::{Session, run_options::{RunOptions, OutputSelector}}, memory::Allocator, value::Tensor}; /// # fn main() -> ort::Result<()> { -/// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// let input = Tensor::::new(&Allocator::default(), [1, 64, 64, 3])?; /// /// let output0 = session.outputs[0].name.as_str(); @@ -101,7 +101,7 @@ impl OutputSelector { /// # use std::sync::Arc; /// # use ort::{session::{Session, run_options::{RunOptions, OutputSelector}}, memory::Allocator, value::Tensor}; /// # fn main() -> ort::Result<()> { - /// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// let input = Tensor::::new(&Allocator::default(), [1, 64, 64, 3])?; /// /// let output0 = session.outputs[0].name.as_str(); @@ -199,7 +199,7 @@ impl RunOptions { /// # use std::sync::Arc; /// # use ort::{session::{Session, run_options::{RunOptions, OutputSelector}}, memory::Allocator, value::Tensor}; /// # fn main() -> ort::Result<()> { - /// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// let input = Tensor::::new(&Allocator::default(), [1, 64, 64, 3])?; /// /// let output0 = session.outputs[0].name.as_str(); @@ -254,7 +254,7 @@ impl RunOptions { /// # use std::sync::Arc; /// # use ort::{session::{Session, run_options::{RunOptions, OutputSelector}}, value::Value}; /// # fn main() -> ort::Result<()> { - /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// # let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// # let input = Value::from_array(ndarray::Array4::::zeros((1, 64, 64, 3)))?; /// let run_options = Arc::new(RunOptions::new()?); /// @@ -283,7 +283,7 @@ impl RunOptions { /// # use std::sync::Arc; /// # use ort::{session::{Session, run_options::{RunOptions, OutputSelector}}, value::Value}; /// # fn main() -> ort::Result<()> { - /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// # let mut session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// # let input = Value::from_array(ndarray::Array4::::zeros((1, 64, 64, 3)))?; /// let run_options = Arc::new(RunOptions::new()?); /// diff --git a/src/value/mod.rs b/src/value/mod.rs index 2a7e85e8..a31b8fbf 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -203,7 +203,7 @@ impl DerefMut for ValueRefMut<'_, Type> { /// ``` /// # use ort::{session::Session, value::Tensor}; /// # fn main() -> ort::Result<()> { -/// # let upsample = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// # let mut upsample = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// // Create a Tensor value from a raw data vector /// let value = Tensor::from_array(([1usize, 1, 1, 3], vec![1.0_f32, 2.0, 3.0].into_boxed_slice()))?; /// diff --git a/tests/mnist.rs b/tests/mnist.rs index 49b1f6dd..b3813c1d 100644 --- a/tests/mnist.rs +++ b/tests/mnist.rs @@ -15,21 +15,25 @@ fn mnist_5() -> ort::Result<()> { ort::init().with_name("integration_test").commit()?; - let session = Session::builder()? + let mut session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? .commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx") .expect("Could not download model from file"); - let metadata = session.metadata()?; - assert_eq!(metadata.name()?, "CNTKGraph"); - assert_eq!(metadata.producer()?, "CNTK"); + let input0_shape = { + let metadata = session.metadata()?; + assert_eq!(metadata.name()?, "CNTKGraph"); + assert_eq!(metadata.producer()?, "CNTK"); - let input0_shape: &Vec = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"); - let output0_shape: &Vec = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"); + let input0_shape: &Vec = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"); + let output0_shape: &Vec = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"); - assert_eq!(input0_shape, &[1, 1, 28, 28]); - assert_eq!(output0_shape, &[1, 10]); + assert_eq!(input0_shape, &[1, 1, 28, 28]); + assert_eq!(output0_shape, &[1, 10]); + + input0_shape + }; // Load image and resize to model's shape, converting to RGB format let image_buffer: ImageBuffer, Vec> = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join(IMAGE_TO_LOAD)) diff --git a/tests/squeezenet.rs b/tests/squeezenet.rs index fd4fac05..0c52e48f 100644 --- a/tests/squeezenet.rs +++ b/tests/squeezenet.rs @@ -21,23 +21,27 @@ fn squeezenet_mushroom() -> ort::Result<()> { ort::init().with_name("integration_test").commit()?; - let session = Session::builder()? + let mut session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? .commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/squeezenet.onnx") .expect("Could not download model from file"); - let metadata = session.metadata()?; - assert_eq!(metadata.name()?, "main_graph"); - assert_eq!(metadata.producer()?, "pytorch"); - let class_labels = get_imagenet_labels()?; - let input0_shape: &Vec = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"); - let output0_shape: &Vec = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"); + let input0_shape = { + let metadata = session.metadata()?; + assert_eq!(metadata.name()?, "main_graph"); + assert_eq!(metadata.producer()?, "pytorch"); + + let input0_shape: &Vec = session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"); + let output0_shape: &Vec = session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"); + + assert_eq!(input0_shape, &[1, 3, 224, 224]); + assert_eq!(output0_shape, &[1, 1000]); - assert_eq!(input0_shape, &[1, 3, 224, 224]); - assert_eq!(output0_shape, &[1, 1000]); + input0_shape + }; // Load image and resize to model's shape, converting to RGB format let image_buffer: ImageBuffer, Vec> = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join(IMAGE_TO_LOAD)) diff --git a/tests/upsample.rs b/tests/upsample.rs index 47d12d76..f81a9c10 100644 --- a/tests/upsample.rs +++ b/tests/upsample.rs @@ -52,18 +52,20 @@ fn upsample() -> ort::Result<()> { let session_data = std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.onnx")).expect("Could not open model from file"); - let session = Session::builder()? + let mut session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? .commit_from_memory(&session_data) .expect("Could not read model from memory"); - let metadata = session.metadata()?; - assert_eq!(metadata.name()?, "tf2onnx"); - assert_eq!(metadata.producer()?, "tf2onnx"); + { + let metadata = session.metadata()?; + assert_eq!(metadata.name()?, "tf2onnx"); + assert_eq!(metadata.producer()?, "tf2onnx"); - assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]); - assert_eq!(session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"), &[-1, -1, -1, 3]); + assert_eq!(session.inputs[0].input_type.tensor_dimensions().expect("input0 to be a tensor type"), &[-1, -1, -1, 3]); + assert_eq!(session.outputs[0].output_type.tensor_dimensions().expect("output0 to be a tensor type"), &[-1, -1, -1, 3]); + } // Load image, converting to RGB format let image_buffer = load_input_image(IMAGE_TO_LOAD); @@ -93,7 +95,7 @@ fn upsample_with_ort_model() -> ort::Result<()> { let session_data = std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.ort")).expect("Could not open model from file"); - let session = Session::builder()? + let mut session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? .commit_from_memory_directly(&session_data) // Zero-copy. diff --git a/tests/vectorizer.rs b/tests/vectorizer.rs index 00810857..aa01d082 100644 --- a/tests/vectorizer.rs +++ b/tests/vectorizer.rs @@ -12,17 +12,19 @@ use test_log::test; #[test] fn vectorizer() -> ort::Result<()> { - let session = Session::builder()? + let mut session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? .commit_from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("vectorizer.onnx")) .expect("Could not load model"); - let metadata = session.metadata()?; - assert_eq!(metadata.producer()?, "skl2onnx"); - assert_eq!(metadata.description()?, "test description"); - assert_eq!(metadata.custom_keys()?, ["custom_key"]); - assert_eq!(metadata.custom("custom_key")?.as_deref(), Some("custom_value")); + { + let metadata = session.metadata()?; + assert_eq!(metadata.producer()?, "skl2onnx"); + assert_eq!(metadata.description()?, "test description"); + assert_eq!(metadata.custom_keys()?, ["custom_key"]); + assert_eq!(metadata.custom("custom_key")?.as_deref(), Some("custom_value")); + } let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap());