Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BAAI/bge-small-en-v1.5 Optimization #1634

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
97 changes: 97 additions & 0 deletions examples/bge/bge-small-en-v1.5_ptq_qnn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
{
"input_model": {
"type": "HfModel",
"model_path": "BAAI/bge-small-en-v1.5",
"task": "feature-extraction",
"io_config": {
"input_names": [ "input_ids", "attention_mask", "token_type_ids" ],
"input_shapes": [ [ 1, 128 ], [ 1, 128 ], [ 1, 128 ] ],
"input_types": [ "int64", "int64", "int64" ],
"output_names": [ "last_hidden_state", "state" ]
}
},
"systems": {
"local_system": {
"type": "LocalSystem",
"accelerators": [ { "device": "npu", "execution_providers": [ "QNNExecutionProvider" ] } ]
}
},
"data_configs": [
{
"name": "quantize_data_config",
"type": "HuggingfaceContainer",
"load_dataset_config": { "data_name": "mteb/banking77", "split": "test" },
"pre_process_data_config": { "max_length": 128, "padding": "max_length", "input_cols": [ "text" ] },
"dataloader_config": { "batch_size": 1 }
}
],
"evaluators": {
"common_evaluator": {
"metrics": [
{
"name": "accuracy",
"type": "custom",
"sub_types": [
{
"name": "accuracy_custom",
"priority": 1,
"higher_is_better": true,
"goal": { "type": "max-degradation", "value": 0.05 }
}
],
"user_config": {
"user_script": "user_script.py",
"evaluate_func": "eval_accuracy",
"evaluate_func_kwargs": { "tasks": [ "Banking77Classification" ] }
}
},
{
"name": "latency_qnn",
"type": "latency",
"data_config": "quantize_data_config",
"sub_types": [ { "name": "avg" }, { "name": "max" }, { "name": "min" } ]
}
]
}
},
"passes": {
"conversion": { "type": "OnnxConversion", "target_opset": 17 },
"dynamic_shape_to_fixed": {
"type": "DynamicToFixedShape",
"dim_param": [ "batch_size", "sequence_length" ],
"dim_value": [ 1, 128 ]
},
"QNNPreprocess": { "type": "QNNPreprocess", "fuse_layernorm": true },
"OnnxQuantization": {
"type": "OnnxQuantization",
"data_config": "quantize_data_config",
"activation_type": "QUInt16",
"weight_type": "QUInt8",
"calibrate_method": "MinMax",
"quant_preprocess": true,
"prepare_qnn_config": true,
"op_types_to_quantize": [
Copy link
Contributor

@jambayk jambayk Feb 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a dev branch where I introduce an option called op_type_to_exclude which is used to modify op_types_to_quantize and nodes_to_exclude.

"op_types_to_exclude": PassConfigParam(

Looks like it might be useful here too when it gets merged

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise, we need to know all of the op types present in the model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently use append_first_op_types_to_quantize_list with nodes_to_exclude will do this. Will we also update this logic?

if run_config["append_first_op_types_to_quantize_list"]:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly, I am not sure why this option was added and if it is used for anything right now.

Not sure if we will touch this option and related logic but I plan to update the logic to be able to use op_types_to_exclude and nodes_to_exclude. The op_types_to_exclude has been very useful for me when I know I don't want to quantize all nodes for an op.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also created this PR in ort microsoft/onnxruntime#23779.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could merge following the clip?

"Mul",
"Transpose",
"MatMul",
"LayerNormalization",
"Gemm",
"Gelu",
"Unsqueeze",
"Gather",
"Sub",
"Where",
"Expand",
"Tanh",
"Reshape"
]
}
},
"pass_flows": [ [ "conversion", "dynamic_shape_to_fixed", "QNNPreprocess", "OnnxQuantization" ] ],
"evaluator": "common_evaluator",
"host": "local_system",
"target": "local_system",
"cache_dir": "cache",
"output_dir": "models/bge-small-en-v1.5",
"evaluate_input_model": false
}
60 changes: 60 additions & 0 deletions examples/bge/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# BAAI/bge-small-en-v1.5 Optimization

This folder contains examples of [BAAI/bge-small-en-v1.5 ](https://huggingface.co/BAAI/bge-small-en-v1.5) optimization using different workflows.

- NPU: [Optimization with PTQ using QNN EP](#ptq-using-qnn-ep)

## Optimization Workflows

### PTQ using QNN EP

This workflow performs the optimization pipeline:
- *PyTorch Model -> Onnx Model -> Static shaped Onnx Model -> Quantized Onnx Model*

The precision will drop when Add or Softmax types of op are quantized, so they are not included.

| Quantized Ops | precision | latency (avg) |
|-|-|-|
| None (original model) | 0.8574675324675324 | N/A |
| All ("Mul", "Transpose", "Unsqueeze", "Add", "Softmax", "Gelu", "LayerNormalization", "Gather", "MatMul", "Sub", "Where", "Expand", "Gemm", "Tanh", "Reshape") | 0.19707792207792205 | 24.95298 |
| Without Softmax | 0.19675324675324674 | 24.08456 |
| Without Add | 0.1968831168831169 | 64.3278 |
| Without Add, Softmax | 0.8511038961038961 | 40.48591 |

TODO(anyone): debug Add and Softmax to add them back to improve latency

## How to run
### Pip requirements
Install the necessary python packages:
```sh
# [NPU]
pip install git+https://github.com/microsoft/Olive#egg=olive-ai[qnn]
```

### Other dependencies
```sh
python -m pip install -r requirements.txt
```

### Run sample using config

The optimization techniques to run are specified in the relevant config json file.

First, install required packages according to passes.
```sh
olive run --config <config_file>.json --setup
```

Then, optimize the model
```sh
olive run --config <config_file>.json
```

or run simply with python code:
```python
from olive.workflows import run as olive_run
olive_run("<config_file>.json")
```

After running the above command, the model candidates and corresponding config will be saved in the output directory.
You can then select the best model and config from the candidates and run the model with the selected config.
1 change: 1 addition & 0 deletions examples/bge/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mteb
108 changes: 108 additions & 0 deletions examples/bge/user_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import json
from pathlib import Path
from typing import List

import mteb
import numpy as np
import torch
from transformers import AutoTokenizer

from olive.constants import Framework
from olive.engine.footprint import Footprint, FootprintNode
from olive.model import OliveModelHandler
from olive.workflows import run as olive_run


class OliveEncoder:
def __init__(self, model, session):
self.model = model
self.session = session
self.tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5")

def encode(self, corpus: List, **kwargs):
model_output = None
if self.model.framework == Framework.ONNX:
encoded_input = self.tokenizer(
corpus, padding="max_length", max_length=128, truncation=True, return_tensors="np"
)
# batch_size is 1 for static model
model_outputs = []
for i in range(len(corpus)):
model_inputs = {
"input_ids": encoded_input.input_ids[i : i + 1, :].astype(np.int64),
"attention_mask": encoded_input.attention_mask[i : i + 1, :].astype(np.int64),
"token_type_ids": encoded_input.token_type_ids[i : i + 1, :].astype(np.int64),
}
model_output = self.model.run_session(self.session, model_inputs)[0]
model_outputs.append(model_output[0])
model_output = np.array(model_outputs)
elif self.model.framework == Framework.PYTORCH:
encoded_input = self.tokenizer(corpus, padding=True, truncation=True, return_tensors="pt")
model_inputs = {
"input_ids": encoded_input.input_ids,
"attention_mask": encoded_input.attention_mask,
"token_type_ids": encoded_input.token_type_ids,
}
with torch.no_grad():
model_output = self.model.run_session(self.session, model_inputs)
model_output = model_output.last_hidden_state.numpy()
# select the last hidden state of the first token (i.e., [CLS]) as the sentence embedding.
return model_output[:, 0, :]


def eval_accuracy(model: OliveModelHandler, device, execution_providers, tasks):
sess = model.prepare_session(inference_settings=None, device=device, execution_providers=execution_providers)

evaluation = mteb.MTEB(tasks=tasks)
olive_encoder = OliveEncoder(model, sess)
results = evaluation.run(olive_encoder, output_folder=None)
return results[0].scores["test"][0]["main_score"]


if __name__ == "__main__":
import logging
import sys

logger = logging.getLogger("bge")
logger.addHandler(logging.StreamHandler(sys.stdout))
logger.setLevel(logging.INFO)

# Greedy search for the best combination of ops to quantize
all_ops = [
"Mul",
"Transpose",
"Unsqueeze",
"Add",
"Softmax",
"Gelu",
"LayerNormalization",
"Gather",
"MatMul",
"Sub",
"Where",
"Expand",
"Gemm",
"Tanh",
"Reshape",
]
target_accuracy = 0.8
with Path("bge-small-en-v1.5_ptq_qnn.json").open() as fin:
olive_config = json.load(fin)
for op in all_ops:
if op in olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"]:
continue
olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"].append(op)
result = olive_run(olive_config)
footprint: Footprint = next(iter(result.values()))
node: FootprintNode = next(iter(footprint.nodes.values()))
accuracy = node.metrics.value["accuracy-accuracy_custom"].value
logger.info(
"Ops: %s Accuracy: %f", olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"], accuracy
)
if accuracy < target_accuracy:
olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"].remove(op)
logger.info("Final Ops: %s", olive_config["passes"]["OnnxQuantization"]["op_types_to_quantize"])
Loading