Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Phi-3 Vision ONNX example #291

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ members = [
'examples/gpt2',
'examples/model-info',
'examples/yolov8',
'examples/phi-3-vision',
'examples/modnet',
'examples/sentence-transformers',
'examples/training'
Expand Down
22 changes: 22 additions & 0 deletions examples/phi-3-vision/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
publish = false
name = "phi-3-vision"
version = "0.0.0"
edition = "2021"

[dependencies]
anyhow = "1.0.89"
image = "0.25"
ndarray = "0.16"
ort = { path = "../../", features = ["fetch-models"] }
tokio = { version = "1.40.0", features = ["full"] }
tokenizers = "0.20"
tracing-subscriber = { version = "0.3", default-features = false, features = [
"env-filter",
"fmt",
] }
tracing = "0.1"

[features]
load-dynamic = ["ort/load-dynamic"]
cuda = ["ort/cuda"]
56 changes: 56 additions & 0 deletions examples/phi-3-vision/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Phi-3 Vision ONNX Example

This example demonstrates the usage of Microsoft's [Phi-3 Vision model](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu)

Phi-3 Vision ONNX is a multimodal model that combines vision and language processing. It uses three interconnected ONNX models:

- Vision model: Processes images to extract visual features
- Text embedding model: Embeds input text into a format compatible with the model
- Text generation model: Produces text outputs based on the combined visual and textual inputs

This multi-model structure requires a coordinated process:

1. Image Processing:
- Preprocess the input image
- Pass it through the vision ONNX model for visual features

2. Text Embedding:
- Tokenize input text
- Process it with the text embedding ONNX model

3. Multimodal Fusion:
- Combine visual features and text embeddings into a single input

4. Text Generation:
- The combined input is fed into the text generation ONNX model.
- The model generates text tokens one by one in an autoregressive manner.
- For each token, the model uses past key/value states to maintain context.

The specific configuration for the model can be found in `data/genai_config.json`.

## Limitations and Performance

This example currently only supports single image input.

The performance of ONNX-based LLM inference can be relatively slow, especially on CPU:

- On an Apple M1 Pro:
- For image+text input (about 300 tokens): ~5 seconds per output token
- For text-only input (about 10 tokens): ~200ms per output token

## Run this Example

Before running the example, you'll need to download the ONNX model files to the `data` directory. At present, the `SessionBuilder.commit_from_url` method doesn't support initialization for models split into `.onnx` and `.onnx.data` files, which is the case for Phi-3 Vision models.

To get started, use the `/data/download.sh` script to download the following three model files:

1. `phi-3-v-128k-instruct-vision.onnx` and `phi-3-v-128k-instruct-vision.onnx.data`
2. `phi-3-v-128k-instruct-text-embedding.onnx` and `phi-3-v-128k-instruct-text-embedding.onnx.data`
3. `phi-3-v-128k-instruct-text.onnx` and `phi-3-v-128k-instruct-text.onnx.data`
4. `tokenizer.json`

Once the model files are downloaded, you can run the example using Cargo:

```bash
cargo run -p phi-3-vision
```
5 changes: 5 additions & 0 deletions examples/phi-3-vision/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fn main() {
// Need this for CoreML. See: https://ort.pyke.io/perf/execution-providers#coreml
#[cfg(target_os = "macos")]
println!("cargo:rustc-link-arg=-fapple-link-rtlib");
}
3 changes: 3 additions & 0 deletions examples/phi-3-vision/data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
/*.onnx
/*.onnx.data
/tokenizer.json
16 changes: 16 additions & 0 deletions examples/phi-3-vision/data/download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

BASE_URL="https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu/resolve/main/cpu-int4-rtn-block-32-acc-level-4/"
FILES=(
"phi-3-v-128k-instruct-text-embedding.onnx"
"phi-3-v-128k-instruct-text-embedding.onnx.data"
"phi-3-v-128k-instruct-text.onnx"
"phi-3-v-128k-instruct-text.onnx.data"
"phi-3-v-128k-instruct-vision.onnx"
"phi-3-v-128k-instruct-vision.onnx.data"
"tokenizer.json"
)

for FILE in "${FILES[@]}"; do
wget "${BASE_URL}${FILE}"
done
Binary file added examples/phi-3-vision/data/example.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
68 changes: 68 additions & 0 deletions examples/phi-3-vision/data/genai_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
{
"model": {
"bos_token_id": 1,
"context_length": 131072,
"decoder": {
"session_options": {
"log_id": "onnxruntime-genai",
"provider_options": []
},
"filename": "phi-3-v-128k-instruct-text.onnx",
"head_size": 96,
"hidden_size": 3072,
"inputs": {
"inputs_embeds": "inputs_embeds",
"attention_mask": "attention_mask",
"past_key_names": "past_key_values.%d.key",
"past_value_names": "past_key_values.%d.value"
},
"outputs": {
"logits": "logits",
"present_key_names": "present.%d.key",
"present_value_names": "present.%d.value"
},
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32
},
"embedding": {
"filename": "phi-3-v-128k-instruct-text-embedding.onnx",
"inputs": {
"input_ids": "input_ids"
},
"outputs": {
"inputs_embeds": "inputs_embeds"
}
},
"vision": {
"filename": "phi-3-v-128k-instruct-vision.onnx",
"inputs": {
"pixel_values": "pixel_values",
"image_sizes": "image_sizes"
},
"outputs": {
"visual_features": "visual_features"
}
},
"eos_token_id": 32007,
"pad_token_id": 32000,
"type": "phi3v",
"vocab_size": 32064
},
"search": {
"diversity_penalty": 0.0,
"do_sample": false,
"early_stopping": true,
"length_penalty": 1.0,
"max_length": 131072,
"min_length": 0,
"no_repeat_ngram_size": 0,
"num_beams": 1,
"num_return_sequences": 1,
"past_present_share_buffer": true,
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_k": 1,
"top_p": 1.0
}
}
192 changes: 192 additions & 0 deletions examples/phi-3-vision/src/image_process.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
//! This file is a Rust implementation of the image processing code for Phi-3-vision-128k-instruct model.
//! The original Python version can be found at:
//! https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
//!
//! The image transformation is configured as Phi3ImageTransform in the processor config:
//! https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu/blob/main/cpu-int4-rtn-block-32-acc-level-4/processor_config.json
//!
//! This Rust implementation aims to provide similar functionality for preprocessing images
//! to be used with the Phi-3 vision model, adapting the original Python code to Rust.
use anyhow::Result;
use image::{DynamicImage, GenericImageView, ImageBuffer};
use ndarray::{s, Array2, Array4, Array5, Axis};

/// see https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu/blob/main/cpu-int4-rtn-block-32-acc-level-4/processor_config.json
/// NOTE: The default setting in processor_config.json is num_crops = 16,
/// but this is too slow for practical use. We use 1 here for better performance.
pub const NUM_CROPS: usize = 1;
pub const _NUM_IMG_TOKENS: usize = 144;

const OPENAI_CLIP_MEAN: [f32; 3] = [0.48145466, 0.4578275, 0.40821073];
const OPENAI_CLIP_STD: [f32; 3] = [0.26862954, 0.26130258, 0.27577711];

pub struct Phi3VImageProcessor {
num_crops: usize,
image_mean: Vec<f32>,
image_std: Vec<f32>,
do_convert_rgb: bool,
}

impl Phi3VImageProcessor {
pub fn new() -> Self {
Self {
num_crops: NUM_CROPS,
image_mean: OPENAI_CLIP_MEAN.to_vec(),
image_std: OPENAI_CLIP_STD.to_vec(),
do_convert_rgb: true,
}
}

pub fn _calc_num_image_tokens(&self, image: &DynamicImage) -> usize {
let transformed = self.hd_transform(image);
let (width, height) = transformed.dimensions();
self.calc_num_image_tokens_from_image_size(width, height)
}

pub fn calc_num_image_tokens_from_image_size(&self, width: u32, height: u32) -> usize {
let (new_width, new_height) = self.calc_hd_transform_size(width, height);
((new_height / 336 * new_width / 336 + 1) * 144 + 1 + (new_height / 336 + 1) * 12) as usize
}

pub fn preprocess(&self, image: &DynamicImage) -> Result<BatchFeature> {
let rgb_image = if self.do_convert_rgb { image.to_rgb8() } else { image.to_rgb8() };
let rgb_image = DynamicImage::ImageRgb8(rgb_image);

let transformed = self.hd_transform(&rgb_image);
let (width, height) = transformed.dimensions();
let shapes = vec![height as i64, width as i64];
let image_sizes = Array2::from_shape_vec((1, 2), shapes)?;

let num_img_tokens = self.calc_num_image_tokens_from_image_size(width, height);

let normalized = self.normalize_image(&transformed);
let global_image = self.create_global_image(&normalized);
let local_patches = self.create_local_patches(&normalized);

let mut all_patches = vec![global_image];
all_patches.extend(local_patches);

let padded_images = self.pad_to_max_num_crops_tensor(&all_patches, self.num_crops + 1);
let pixel_values = padded_images.insert_axis(Axis(0));

Ok(BatchFeature {
pixel_values,
image_sizes,
num_img_tokens: vec![num_img_tokens as i64],
})
}

fn hd_transform(&self, image: &DynamicImage) -> DynamicImage {
let (width, height) = image.dimensions();
let mut transposed = false;
let (width, height) = if width < height {
transposed = true;
(height, width)
} else {
(width, height)
};

let ratio = width as f32 / height as f32;
let mut scale = 1;
while (scale as f32 * (scale as f32 / ratio).ceil()) <= self.num_crops as f32 {
scale += 1;
}
scale -= 1;

let new_width = scale * 336;
let new_height = (new_width as f32 / ratio) as u32;

let resized = image.resize_exact(new_width, new_height, image::imageops::FilterType::Lanczos3);
let padded = self.padding_336(&resized);

if transposed {
padded.rotate90()
} else {
padded
}
}

fn padding_336(&self, image: &DynamicImage) -> DynamicImage {
let (width, height) = image.dimensions();
let tar = ((height as f32 / 336.0).ceil() * 336.0) as u32;
let top_padding = (tar - height) / 2;
let mut padded = ImageBuffer::from_pixel(width, tar, image::Rgba([255, 255, 255, 255]));
image::imageops::overlay(&mut padded, image, 0, top_padding as i64);
DynamicImage::ImageRgba8(padded)
}

fn calc_hd_transform_size(&self, width: u32, height: u32) -> (u32, u32) {
let (width, height) = if width < height { (height, width) } else { (width, height) };

let ratio = width as f32 / height as f32;
let mut scale = 1;
while (scale as f32 * (scale as f32 / ratio).ceil()) <= self.num_crops as f32 {
scale += 1;
}
scale -= 1;

let new_width = scale * 336;
let new_height = (new_width as f32 / ratio) as u32;

self.calc_padded_size(new_width, new_height)
}

fn calc_padded_size(&self, width: u32, height: u32) -> (u32, u32) {
let target_height = ((height as f32 / 336.0).ceil() * 336.0) as u32;
(width, target_height)
}

fn normalize_image(&self, image: &DynamicImage) -> Array4<f32> {
let (width, height) = image.dimensions();
let mut normalized = Array4::<f32>::zeros((1, 3, height as usize, width as usize));

for (x, y, pixel) in image.pixels() {
for c in 0..3 {
normalized[[0, c, y as usize, x as usize]] = (pixel[c] as f32 / 255.0 - self.image_mean[c]) / self.image_std[c];
}
}

normalized
}

fn create_global_image(&self, _image: &Array4<f32>) -> Array4<f32> {
Array4::<f32>::zeros((1, 3, 336, 336))
}

fn create_local_patches(&self, image: &Array4<f32>) -> Vec<Array4<f32>> {
let (_, _, height, width) = image.dim();
let mut patches = Vec::new();

for h in (0..height).step_by(336) {
for w in (0..width).step_by(336) {
let patch = image
.slice(s![.., .., h..std::cmp::min(h + 336, height), w..std::cmp::min(w + 336, width)])
.to_owned();
patches.push(patch);
}
}

patches
}

fn pad_to_max_num_crops_tensor(&self, patches: &[Array4<f32>], max_crops: usize) -> Array4<f32> {
let (_, channels, height, width) = patches[0].dim();
let mut padded = Array4::<f32>::zeros((max_crops, channels, height, width));

for (i, patch) in patches.iter().enumerate() {
if i >= max_crops {
break;
}
// Remove the extra dimension when assigning
padded.slice_mut(s![i, .., .., ..]).assign(&patch.slice(s![0, .., .., ..]));
}

padded
}
}

pub struct BatchFeature {
pub pixel_values: Array5<f32>,
pub image_sizes: Array2<i64>,
pub num_img_tokens: Vec<i64>,
}
Loading
Loading