Skip to content

Commit

Permalink
refactor!: make Session::run take &mut self
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Feb 8, 2025
1 parent 0f02d79 commit bd2aff7
Show file tree
Hide file tree
Showing 25 changed files with 193 additions and 168 deletions.
2 changes: 1 addition & 1 deletion backends/tract/tests/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn mnist_5() -> ort::Result<()> {

ort::set_api(ort_tract::api());

let session = Session::builder()?
let mut session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")
.expect("Could not download model from file");
Expand Down
2 changes: 1 addition & 1 deletion docs/pages/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Once you've got a model, load it via `ort` by creating a [`Session`](/fundamenta
```rust
use ort::session::{builder::GraphOptimizationLevel, Session};

let model = Session::builder()?
let mut model = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)?
.commit_from_file("yolov8m.onnx")?;
Expand Down
12 changes: 6 additions & 6 deletions docs/pages/perf/io-binding.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ Outputs can be bound to any device -- they can even stay on the EP device if you
If you do bind an output to the session's device, it is not guaranteed to be synchronized after `run`, just like `bind_input`. You can force outputs to synchronize immediately using `IoBinding::synchronize_outputs`.

## Running
To run a session using an I/O binding, you simply call `run()` on `IoBinding` instead of on the session:
To run a session using an I/O binding, you simply call the session's `run_binding()` function with the created `IoBinding`:
```rs
let outputs = io_binding.run()?;
let outputs = session.run_binding(&binding)?;
```

`outputs` provides the same interface as the outputs returned by `Session::run`, it just returns the outputs that you bound earlier.
Expand All @@ -77,10 +77,10 @@ let action: Tensor<f32> = outputs.remove("action").unwrap().downcast()?;
Here is a more complete example of the I/O binding API in a scenario where I/O performance can be improved significantly. This example features a typical text-to-image diffusion pipeline, using a text encoder like CLIP to create the condition tensor and a UNet for diffusion.

```rs
let text_encoder = Session::builder()?
let mut text_encoder = Session::builder()?
.with_execution_providers([CUDAExecutionProvider::default().build()])?
.commit_from_file("text_encoder.onnx")?;
let unet = Session::builder()?
let mut unet = Session::builder()?
.with_execution_providers([CUDAExecutionProvider::default().build()])?
.commit_from_file("unet.onnx")?;

Expand All @@ -91,7 +91,7 @@ let text_condition = {
vec![49, 272, 503, 286, 1396, 353, 9653, 284, 1234, 287, 616, 2438, 11, 7926, 13, 3423, 338, 3362, 25, 12520, 238, 242]
))?)?;
binding.bind_output_to_device("output0", &text_encoder.allocator().memory_info())?;
binding.run()?.remove("output0").unwrap()
text_encoder.run_binding(&binding)?.remove("output0").unwrap()
};

let input_allocator = Allocator::new(
Expand All @@ -111,7 +111,7 @@ io_binding.bind_output("noise_pred", Tensor::<f32>::new(&output_allocator, [1, 4

for _ in 0..20 {
io_binding.bind_input("latents", &latents)?;
let noise_pred = io_binding.run()?.remove("noise_pred").unwrap();
let noise_pred = unet.run_binding(&io_binding)?.remove("noise_pred").unwrap();

let mut latents = latents.extract_tensor_mut();
latents += &noise_pred.try_extract_tensor::<f32>()?;
Expand Down
20 changes: 13 additions & 7 deletions examples/async-gpt2-api/examples/async-gpt2-api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use ort::{
};
use rand::Rng;
use tokenizers::Tokenizer;
use tokio::net::TcpListener;
use tokio::{net::TcpListener, sync::Mutex};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

#[tokio::main]
Expand All @@ -44,7 +44,7 @@ async fn main() -> anyhow::Result<()> {
let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap();

let app_state = AppState {
session: Arc::new(session),
session: Arc::new(Mutex::new(session)),
tokenizer: Arc::new(tokenizer)
};

Expand All @@ -59,15 +59,21 @@ async fn main() -> anyhow::Result<()> {

#[derive(Clone)]
struct AppState {
session: Arc<Session>,
session: Arc<Mutex<Session>>,
tokenizer: Arc<Tokenizer>
}

fn generate_stream(tokenizer: Arc<Tokenizer>, session: Arc<Session>, mut tokens: Vec<i64>, gen_tokens: usize) -> impl Stream<Item = ort::Result<Event>> + Send {
fn generate_stream(
tokenizer: Arc<Tokenizer>,
session: Arc<Mutex<Session>>,
mut tokens: Vec<i64>,
gen_tokens: usize
) -> impl Stream<Item = ort::Result<Event>> + Send {
async_stream_lite::try_async_stream(|yielder| async move {
for _ in 0..gen_tokens {
let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?;
let probabilities = {
let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?;
let mut session = session.lock().await;
let outputs = session.run_async(ort::inputs![input])?.await?;
let (dim, probabilities) = outputs["output1"].try_extract_raw_tensor()?;

Expand All @@ -93,7 +99,7 @@ fn generate_stream(tokenizer: Arc<Tokenizer>, session: Arc<Session>, mut tokens:
})
}

impl FromRef<AppState> for Arc<Session> {
impl FromRef<AppState> for Arc<Mutex<Session>> {
fn from_ref(input: &AppState) -> Self {
Arc::clone(&input.session)
}
Expand All @@ -104,6 +110,6 @@ impl FromRef<AppState> for Arc<Tokenizer> {
}
}

async fn generate(State(session): State<Arc<Session>>, State(tokenizer): State<Arc<Tokenizer>>) -> Sse<impl Stream<Item = ort::Result<Event>>> {
async fn generate(State(session): State<Arc<Mutex<Session>>>, State(tokenizer): State<Arc<Tokenizer>>) -> Sse<impl Stream<Item = ort::Result<Event>>> {
Sse::new(generate_stream(tokenizer, session, vec![0], 50)).keep_alive(KeepAlive::new())
}
4 changes: 2 additions & 2 deletions examples/cudarc/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn main() -> anyhow::Result<()> {
.with_execution_providers([CUDAExecutionProvider::default().build().error_on_failure()])
.commit()?;

let model =
let mut session =
Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/modnet_photographic_portrait_matting.onnx")?;

let original_img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("photo.jpg")).unwrap();
Expand All @@ -45,7 +45,7 @@ fn main() -> anyhow::Result<()> {
)
.unwrap()
};
let outputs = model.run(ort::inputs![tensor])?;
let outputs = session.run(ort::inputs![tensor])?;

let output = outputs["output"].try_extract_tensor::<f32>()?;

Expand Down
2 changes: 1 addition & 1 deletion examples/custom-ops/examples/custom-ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl Operator for CustomOpTwo {
}

fn main() -> ort::Result<()> {
let session = Session::builder()?
let mut session = Session::builder()?
.with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)?
.commit_from_file("tests/data/custom_op_test.onnx")?;

Expand Down
2 changes: 1 addition & 1 deletion examples/gpt2/examples/gpt2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn main() -> ort::Result<()> {
let mut rng = rand::thread_rng();

// Load our model
let session = Session::builder()?
let mut session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/gpt2.onnx")?;
Expand Down
4 changes: 2 additions & 2 deletions examples/modnet/examples/modnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn main() -> ort::Result<()> {
.with_execution_providers([CUDAExecutionProvider::default().build()])
.commit()?;

let model =
let mut session =
Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/modnet_photographic_portrait_matting.onnx")?;

let original_img = image::open(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("photo.jpg")).unwrap();
Expand All @@ -31,7 +31,7 @@ fn main() -> ort::Result<()> {
input[[0, 2, y, x]] = (b as f32 - 127.5) / 127.5;
}

let outputs = model.run(inputs!["input" => TensorRef::from_array_view(input.view())?])?;
let outputs = session.run(inputs!["input" => TensorRef::from_array_view(input.view())?])?;

let output = outputs["output"].try_extract_tensor::<f32>()?;

Expand Down
20 changes: 10 additions & 10 deletions examples/phi-3-vision/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn get_current_time() -> Instant {
Instant::now()
}

fn get_image_embedding(vision_model: &Session, img: &Option<DynamicImage>) -> Result<Array3<f32>> {
fn get_image_embedding(vision_model: &mut Session, img: &Option<DynamicImage>) -> Result<Array3<f32>> {
let visual_features = if let Some(img) = img {
let image_processor = image_process::Phi3VImageProcessor::new();
let result = image_processor.preprocess(img)?;
Expand All @@ -46,7 +46,7 @@ fn get_image_embedding(vision_model: &Session, img: &Option<DynamicImage>) -> Re
Ok(visual_features)
}

fn get_text_embedding(text_embedding_model: &Session, input_ids: &Array2<i64>) -> Result<Array3<f32>> {
fn get_text_embedding(text_embedding_model: &mut Session, input_ids: &Array2<i64>) -> Result<Array3<f32>> {
let outputs = text_embedding_model.run(ort::inputs![
"input_ids" => TensorRef::from_array_view(input_ids)?,
])?;
Expand Down Expand Up @@ -102,9 +102,9 @@ fn format_chat_template(img: &Option<DynamicImage>, txt: &str) -> String {

pub async fn generate_text(
tokenizer: &Tokenizer,
vision_model: &Session,
text_embedding_model: &Session,
generation_model: &Session,
vision_model: &mut Session,
text_embedding_model: &mut Session,
generation_model: &mut Session,
image: &Option<DynamicImage>,
text: &str
) -> Result<()> {
Expand Down Expand Up @@ -213,19 +213,19 @@ async fn main() -> Result<()> {

let data_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("data");
let tokenizer = Tokenizer::from_file(data_dir.join("tokenizer.json")).map_err(|e| anyhow::anyhow!("Error loading tokenizer: {:?}", e))?;
let vision_model = Session::builder()?.commit_from_file(data_dir.join(VISION_MODEL_NAME))?;
let text_embedding_model = Session::builder()?.commit_from_file(data_dir.join(TEXT_EMBEDDING_MODEL_NAME))?;
let generation_model = Session::builder()?.commit_from_file(data_dir.join(GENERATION_MODEL_NAME))?;
let mut vision_model = Session::builder()?.commit_from_file(data_dir.join(VISION_MODEL_NAME))?;
let mut text_embedding_model = Session::builder()?.commit_from_file(data_dir.join(TEXT_EMBEDDING_MODEL_NAME))?;
let mut generation_model = Session::builder()?.commit_from_file(data_dir.join(GENERATION_MODEL_NAME))?;

// Generate text from text
let image: Option<DynamicImage> = None;
let text = "Who are you?".to_string();
generate_text(&tokenizer, &vision_model, &text_embedding_model, &generation_model, &image, &text).await?;
generate_text(&tokenizer, &mut vision_model, &mut text_embedding_model, &mut generation_model, &image, &text).await?;

// Generate text from image and text
let image: Option<DynamicImage> = Some(image::open(data_dir.join("example.jpg"))?);
let text = "What is shown in this image?".to_string();
generate_text(&tokenizer, &vision_model, &text_embedding_model, &generation_model, &image, &text).await?;
generate_text(&tokenizer, &mut vision_model, &mut text_embedding_model, &mut generation_model, &image, &text).await?;

Ok(())
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fn main() -> ort::Result<()> {
.commit()?;

// Load our model
let session = Session::builder()?
let mut session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/all-MiniLM-L6-v2.onnx")?;
Expand Down
2 changes: 1 addition & 1 deletion examples/training/examples/train-clm-simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ fn main() -> ort::Result<()> {

trainer.export("trained-clm.onnx", ["probs"])?;

let session = Session::builder()?.commit_from_file("trained-clm.onnx")?;
let mut session = Session::builder()?.commit_from_file("trained-clm.onnx")?;

let mut stdout = std::io::stdout();

Expand Down
2 changes: 1 addition & 1 deletion examples/training/examples/train-clm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ fn main() -> ort::Result<()> {

trainer.export("trained-clm.onnx", ["probs"])?;

let session = Session::builder()?.commit_from_file("trained-clm.onnx")?;
let mut session = Session::builder()?.commit_from_file("trained-clm.onnx")?;

let mut stdout = std::io::stdout();

Expand Down
2 changes: 1 addition & 1 deletion examples/yolov8/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ version = "0.0.0"
edition = "2021"

[dependencies]
ort = { path = "../../" }
ort = { path = "../../", features = [ "fetch-models" ] }
ndarray = "0.16"
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
image = "0.25"
Expand Down
2 changes: 1 addition & 1 deletion examples/yolov8/examples/yolov8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ fn main() -> ort::Result<()> {
input[[0, 2, y, x]] = (b as f32) / 255.;
}

let model = Session::builder()?.commit_from_url(YOLOV8M_URL)?;
let mut model = Session::builder()?.commit_from_url(YOLOV8M_URL)?;

// Run YOLOv8 inference
let outputs: SessionOutputs = model.run(inputs!["images" => TensorRef::from_array_view(&input)?])?;
Expand Down
10 changes: 5 additions & 5 deletions src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl Drop for AdapterInner {
/// ```
/// # use ort::{adapter::Adapter, session::{run_options::RunOptions, Session}, value::Tensor};
/// # fn main() -> ort::Result<()> {
/// let model = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?;
/// let mut model = Session::builder()?.commit_from_file("tests/data/lora_model.onnx")?;
/// let lora = Adapter::from_file("tests/data/adapter.orl", None)?;
///
/// let mut run_options = RunOptions::new()?;
Expand Down Expand Up @@ -69,7 +69,7 @@ impl Adapter {
/// # value::Tensor
/// # };
/// # fn main() -> ort::Result<()> {
/// let model = Session::builder()?
/// let mut model = Session::builder()?
/// .with_execution_providers([CUDAExecutionProvider::default().build()])?
/// .commit_from_file("tests/data/lora_model.onnx")?;
///
Expand Down Expand Up @@ -116,7 +116,7 @@ impl Adapter {
/// # value::Tensor
/// # };
/// # fn main() -> ort::Result<()> {
/// let model = Session::builder()?
/// let mut model = Session::builder()?
/// .with_execution_providers([CUDAExecutionProvider::default().build()])?
/// .commit_from_file("tests/data/lora_model.onnx")?;
///
Expand Down Expand Up @@ -166,7 +166,7 @@ mod tests {
#[test]
fn test_lora() -> crate::Result<()> {
let model = std::fs::read("tests/data/lora_model.onnx").expect("");
let session = Session::builder()?.commit_from_memory(&model)?;
let mut session = Session::builder()?.commit_from_memory(&model)?;
let lora = std::fs::read("tests/data/adapter.orl").expect("");
let lora = Adapter::from_memory(&lora, None)?;

Expand All @@ -190,7 +190,7 @@ mod tests {
#[test]
fn test_lora_from_memory() -> crate::Result<()> {
let model = std::fs::read("tests/data/lora_model.onnx").expect("");
let session = Session::builder()?.commit_from_memory(&model)?;
let mut session = Session::builder()?.commit_from_memory(&model)?;

let lora_bytes = std::fs::read("tests/data/adapter.orl").expect("");
let lora = Adapter::from_memory(&lora_bytes, None)?;
Expand Down
Loading

0 comments on commit bd2aff7

Please sign in to comment.