Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ipex backend #3083

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

add ipex backend #3083

wants to merge 3 commits into from

Conversation

jiqing-feng
Copy link

@jiqing-feng jiqing-feng commented Nov 25, 2024

This PR enables ipex backend, script:

from sentence_transformers import SentenceTransformer

sentences = ["This is an example sentence", "Each sentence is converted"]
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', backend="ipex")
embeddings = model.encode(sentences)
print(embeddings) 

Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
@tomaarsen
Copy link
Collaborator

Hello!

Thanks for this PR - it's quite extensive already. I ran some tests with IPEX on WSL (as I'm on Windows currently) and translated the performance gains relative to fp32 to my normal benchmark graph here:
image

In essence, the performance gain is seemingly not very substantial, staying behind ONNX and OpenVINO. I'm curious if this roughly matches the expectations. As it stands right now, IPEX likely doesn't seem like an improvement over ONNX/OpenVINO and it might result in the backends becoming more complex without any notable gain.
I only tested CPU - according to https://huggingface.co/docs/optimum/en/intel/ipex/inference, that's all that's supported via Optimum right now.

I also had some issues with running the IPEXModel initially because the model itself was loaded as CPU whereas the dummy_inputs were moved to self._device in optimum-intel - which was automatically set to cuda as my machine as a CUDA-enabled GPU and my torch was installed with cuda support. I have a feeling like I can't fix that in Sentence Transformers via some parameters.

Another bottleneck that I'm personally noticing is that installing all optimum backends into one virtualenvironment becomes quite constrictive. For example, my CI won't test the latest transformers because optimum-intel[ipex] doesn't support it yet. So I may have to separate my tests per backend and make CI runners for each backend. That way I can still test the latest transformers for example.

cc @echarlaix @IlyasMoutawwakil as I believe you've both briefly worked on IPEX in Optimum Intel.

  • Tom Aarsen

@echarlaix
Copy link
Contributor

Thanks for the detailed feedback @tomaarsen! Yes that makes sense, @jiqing-feng I think we can instead add new integrations to optimum-intel directly cc @IlyasMoutawwakil

@jiqing-feng
Copy link
Author

jiqing-feng commented Nov 26, 2024

HI @tomaarsen , thanks for your benchmarking! There are 2 main issues in your comment: performance and tests constrictive, so I suppose you will consider merging this PR after the 2 issues are solved? (I am fixing device issue that can be easily fixed in optimum-intel)

Hi @echarlaix sentence-transformers is also in our ipex scope, we aim to upstream ipex in sentence-transformers. As you know optimum-intel ipex is under big refactoring, I just found IPEXModel and this PR is enough for sentence-transformers so we don't plan to integrate the sentence-transformers model specifically in optimum-intel. What we need is to fix the transformers' version compatibility and performance. Please let me know your concerns. Thanks!

@jiqing-feng
Copy link
Author

Hi @tomaarsen , can you share your benchmark script? Thanks!

@jiqing-feng
Copy link
Author

jiqing-feng commented Nov 26, 2024

Hi @tomaarsen . I use the evaluation_inference_speed.py for benchmarking, and make some little changes:

COMMAND: python evaluation_inference_speed.py

import sys
import time

import torch
from datasets import load_dataset

from sentence_transformers import SentenceTransformer
from optimum.intel.utils.modeling_utils import bind_cores_for_best_perf

bind_cores_for_best_perf()

model_name = sys.argv[1] if len(sys.argv) > 1 else "bert-base-nli-mean-tokens"

# Load a sentence transformer model
model_kwargs = {"torch_dtype": torch.bfloat16, "device_map": "cpu"}
model = SentenceTransformer(model_name, model_kwargs=model_kwargs, device="cpu", backend="ipex")

max_sentences = 100_000
all_nli_dataset = load_dataset("sentence-transformers/all-nli", "pair", split="train")
sentences = list(set(all_nli_dataset["anchor"]))[:max_sentences]

print("Model Name:", model_name)
print("Number of sentences:", len(sentences))

for i in range(3):
    print("Run", i)
    start_time = time.time()
    emb = model.encode(sentences, batch_size=32)
    end_time = time.time()
    diff_time = end_time - start_time
    print(f"Done after {diff_time:.2f} seconds")
    print(f"Speed: {len(sentences) / diff_time:.2f} sentences / second")
    print("=====")

The results show:
Speed ratio: ipex / torch = 1.6
Data collected from Intel 4th Gen Xeon.

We will fix the device and transformers version issue ASAP. Before that, please help to verify the performance.
I suppose HF has access to Intel 4th Gen Xeon, do you mind validating on the PVC node (the pvc's CPU is 4th gen xeon)?

@tomaarsen
Copy link
Collaborator

The script that I used is quite messy, i.e. something like this:

Benchmarking script
from __future__ import annotations
import gc
import json
import tempfile
import time
from typing import Any, Generator, Literal

import numpy as np
from tqdm import tqdm, trange

from sentence_transformers import SentenceTransformer, export_optimized_onnx_model, export_dynamic_quantized_onnx_model
from datasets import load_dataset

from sentence_transformers.backend import export_static_quantized_openvino_model

def get_models_from_model_id(model_id: str, device: Literal["cpu", "cuda"] = "cpu", trust_remote_code: bool = False) -> Generator[tuple[SentenceTransformer, str], Any, None]:
    # Torch (fp32, fp16, bf16)
    yield SentenceTransformer(model_id, device=device, trust_remote_code=trust_remote_code), f"{model_id}-torch-fp32"
    # yield SentenceTransformer(model_id, device=device, trust_remote_code=trust_remote_code).half(), f"{model_id}-torch-fp16"
    # yield SentenceTransformer(model_id, device=device, trust_remote_code=trust_remote_code).bfloat16(), f"{model_id}-torch-bf16"

    # ONNX (default, O1, O2, O3, O4)
    # provider = "CUDAExecutionProvider" if device == "cuda" else "CPUExecutionProvider"
    # onnx_model = SentenceTransformer(model_id, device=device, trust_remote_code=trust_remote_code, backend="onnx", model_kwargs={"provider": provider})
    # yield onnx_model, f"{model_id}-onnx"
    # onnx_model = SentenceTransformer(model_id, device=device, trust_remote_code=trust_remote_code, backend="onnx", model_kwargs={"provider": provider, "file_name": "onnx/model_quantized.onnx"})
    # yield onnx_model, f"{model_id}-onnx-quantized"

    # for optimization_config in (["O1", "O2", "O3", "O4"] if device == "cuda" else ["O1", "O2", "O3"]):
    # # for optimization_config in (["O4"] if device == "cuda" else ["O1", "O2", "O3"]):
    #     with tempfile.TemporaryDirectory() as tempdir:
    #         onnx_model = SentenceTransformer(model_id, device=device, trust_remote_code=trust_remote_code, backend="onnx", model_kwargs={"provider": provider})
    #         onnx_model.save_pretrained(tempdir)
    #         export_optimized_onnx_model(onnx_model, optimization_config, tempdir)
    #         del onnx_model
    #         yield SentenceTransformer(tempdir, device=device, trust_remote_code=trust_remote_code, backend="onnx", model_kwargs={"provider": provider, "file_name": f"onnx/model_{optimization_config}.onnx"}), f"{model_id}-onnx-{optimization_config}"

    # for quantization_config in (["arm64", "avx2", "avx512", "avx512_vnni"] if device == "cuda" else ["arm64", "avx2", "avx512", "avx512_vnni"]):
    # for quantization_config in (["avx512_vnni"] if device == "cuda" else ["avx512_vnni"]):
    #     with tempfile.TemporaryDirectory() as tempdir:
    #         onnx_model = SentenceTransformer(model_id, device=device, trust_remote_code=trust_remote_code, backend="onnx", model_kwargs={"provider": provider, "export": True})
    #         onnx_model.save_pretrained(tempdir)
    #         export_dynamic_quantized_onnx_model(onnx_model, quantization_config, tempdir)
    #         del onnx_model
    #         yield SentenceTransformer(tempdir, device=device, trust_remote_code=trust_remote_code, backend="onnx", model_kwargs={"provider": provider, "file_name": f"onnx/model_{quantization_config}.onnx"}), f"{model_id}-onnx-{quantization_config}"

    # OpenVINO
    # if device == "cpu":
    #     from optimum.intel import OVWeightQuantizationConfig, OVQuantizationConfig
    #     # yield SentenceTransformer(model_id, trust_remote_code=trust_remote_code, backend="openvino", model_kwargs={"device": "GPU", "quantization_config": OVWeightQuantizationConfig(bits=4)}), f"{model_id}-openvino-igpu-i4"
    #     # yield SentenceTransformer(model_id, trust_remote_code=trust_remote_code, backend="openvino", model_kwargs={"device": "GPU", "load_in_8bit": True}), f"{model_id}-openvino-igpu-i8"
    #     # yield SentenceTransformer(model_id, trust_remote_code=trust_remote_code, backend="openvino", model_kwargs={"device": "GPU"}), f"{model_id}-openvino-igpu"
    #     # yield SentenceTransformer(model_id, trust_remote_code=trust_remote_code, backend="openvino", model_kwargs={"quantization_config": OVWeightQuantizationConfig(bits=4)}), f"{model_id}-openvino-i4"
    #     # yield SentenceTransformer(model_id, trust_remote_code=trust_remote_code, backend="openvino", model_kwargs={"load_in_8bit": True}), f"{model_id}-openvino-i8"
    #     # yield SentenceTransformer(model_id, trust_remote_code=trust_remote_code, backend="openvino"), f"{model_id}-openvino"
    #     # else:
    #     # yield SentenceTransformer(model_id, trust_remote_code=trust_remote_code, backend="openvino"), f"{model_id}-openvino"

    #     for quantization_config in [OVQuantizationConfig(bits=8)]:
    #         with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tempdir:
    #             openvino_model = SentenceTransformer(model_id, device=device, trust_remote_code=trust_remote_code, backend="openvino")
    #             openvino_model.save_pretrained(tempdir)
    #             export_static_quantized_openvino_model(openvino_model, quantization_config, tempdir)
    #             del openvino_model
    #             yield SentenceTransformer(tempdir, device=device, trust_remote_code=trust_remote_code, backend="openvino", model_kwargs={"file_name": f"openvino/openvino_model_qint{quantization_config.bits}_quantized.xml"}), f"{model_id}-openvino-qint{quantization_config.bits}"

    # IPEX
    yield SentenceTransformer(model_id, device=device, trust_remote_code=trust_remote_code, backend="ipex"), f"{model_id}-ipex"


def get_models(device: Literal["cpu", "cuda"] = "cpu", trust_remote_code: bool = False) -> Generator[tuple[SentenceTransformer, str], Any, None]:
    model_ids = [
        "all-MiniLM-L6-v2",
        "BAAI/bge-base-en-v1.5",
        "mixedbread-ai/mxbai-embed-large-v1",
        "BAAI/bge-m3",
    ]
    for model_id in model_ids:
        yield from get_models_from_model_id(model_id, device=device, trust_remote_code=trust_remote_code)

def get_datasets(max_samples: int = 10_000) -> Generator[tuple[list[str], str], Any, None]:
    # Short texts
    stsb_dataset = load_dataset("sentence-transformers/stsb", split="train")
    yield (stsb_dataset["sentence1"] + stsb_dataset["sentence2"])[:max_samples], "stsb"

    # Longer texts
    nq_dataset = load_dataset("sentence-transformers/natural-questions", split=f"train[:{max_samples}]")
    yield nq_dataset["answer"], "nq (answers)"

    # Very long texts
    # imdb_dataset = load_dataset("stanfordnlp/imdb", split="train")
    # long_texts = imdb_dataset.map(lambda sample: {"text": (sample["text"] * 4)[:10_000]}).filter(lambda sample: len(sample["text"]) > 8_000)
    # yield long_texts["text"][:max_samples], "imdb (8k+ chars)"

def get_batch_sizes(model_name: str):
    if model_name.startswith("all-MiniLM-L6-v2"):
        return [16, 32, 64, 128, 256]
    if model_name.startswith("Xenova/all-MiniLM-L6-v2"):
        return [16, 32, 64, 128, 256]
    elif model_name.startswith("BAAI/bge-base-en-v1.5"):
        return [16, 32, 64, 128]
    elif model_name.startswith("mixedbread-ai/mxbai-embed-large-v1"):
        return [128, 256]
    elif model_name.startswith("BAAI/bge-m3"):
        return [2, 4]

def main() -> None:
    try:
        with open("outputs_no_sort_cpu.json", "r") as f:
            outputs = json.load(f)
    except FileNotFoundError:
        outputs = {}

    # updated = False
    # for device in tqdm(["cuda", "cpu"], desc="Devices", leave=False):
    for device in tqdm(["cpu"], desc="Devices", leave=False):
        outputs[device] = outputs.get(device, {})
        for model, model_name in tqdm(get_models(device=device), desc="Models", leave=False):
            outputs[device][model_name] = outputs[device].get(model_name, {})
            for sentences, dataset_name in tqdm(get_datasets(max_samples=1000), desc="Datasets", leave=False):
                outputs[device][model_name][dataset_name] = outputs[device][model_name].get(dataset_name, {})
                for batch_size in tqdm(get_batch_sizes(model_name), desc="Batch sizes", leave=False):
                    # Warmup
                    model.encode(sentences[:batch_size], batch_size=batch_size)

                    outputs[device][model_name][dataset_name][str(batch_size)] = outputs[device][model_name][dataset_name].get(str(batch_size), [])
                    try:
                        while len(outputs[device][model_name][dataset_name][str(batch_size)]) < 3:
                            start_time = time.time()
                            model.encode(sentences, batch_size=batch_size)
                            outputs[device][model_name][dataset_name][str(batch_size)].append(len(sentences) / (time.time() - start_time))
                            # updated = True
                    except Exception as e:
                        outputs[device][model_name][dataset_name][str(batch_size)].append("Error")
                    with open("outputs_no_sort_cpu.json", "w") as f:
                        json.dump(outputs, f, indent=4)
            del model
            gc.collect()
            # if updated:
            #     quit()

if __name__ == "__main__":
    main()

I've ran this for a lot of different backend types, 4 different models, and 3 datasets. See more details in https://sbert.net/docs/sentence_transformer/usage/efficiency.html#benchmarks
I'm using an i7-17300K CPU for the CPU tests, i.e. consumer-grade hardware


I'm running your script now, with the bind_cores_for_best_perf (I didn't use that one previously). I see it also requires pip install py-libnuma.
It seems that my hardware does not contain the required instructions for torch.bfloat16:

AssertionError: BF16 weight prepack needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq, but the desired instruction sets are not available. Please set dtype to torch.float or set weights_prepack to False.

Or torch.float16:

AssertionError: FP16 weight prepack needs the cpu support avx_ne_convert or avx512_core_fp16, but the desired instruction sets are not available. Please set dtype to torch.float or set weights_prepack to False.

Only with torch.float does it work correctly - and here it has a small performance improvement around 3%. I was also under the impression that this was running on float16 due to some of the warnings I saw, but I suspect that it actually ran on fp32. I didn't realise that the recent hardware was required to get the performance gain, but it makes sense given that only recent hardware can run BF16.

I'll try and get access to a Intel 4th Gen Xeon CPU.

And yes, if it's possible to get bf16 performance preservation (e.g. 99.9%+) with ~1.6x speedup, then I'll definitely consider merging this. If we can make that work, then I'll try and fix the tests issue that I mentioned.
Some questions:

  • Is 4th gen Xeon accessible for individual consumers? I see that there's also "Workstation" processors, but I'm not sure if many common users are expected to have this hardware. I recognize that there's cloud offerings too, though.
  • Can the performance gain (e.g. 1.6x+) be expected to be reached with 4th Gen Xeon only, or also with newer hardware (5th gen, 6th gen, etc.)?

  • Tom Aarsen

@jiqing-feng
Copy link
Author

For your question:

  1. Yes, individual customers can easily get access to Xeon on AWS.
  2. I've tested on 6th Gen, it has the same speed-up ratio.

@echarlaix
Copy link
Contributor

Hi @echarlaix sentence-transformers is also in our ipex scope, we aim to upstream ipex in sentence-transformers. As you know optimum-intel ipex is under big refactoring,

Yes and would make sense to wait for the refactorization from huggingface/optimum-intel#1009 before doing a benchmark @jiqing-feng

@jiqing-feng
Copy link
Author

jiqing-feng commented Dec 6, 2024

Hi @tomaarsen . The refactor of optimum-intel has been completed. Please install optimum-intel by git clone https://github.com/huggingface/optimum-intel.git && cd optimum-intel && pip install .[ipex] and then run the following codes:

import sys
import time

import torch
from datasets import load_dataset

from sentence_transformers import SentenceTransformer
from optimum.intel.utils.modeling_utils import bind_cores_for_best_perf

bind_cores_for_best_perf()

model_name = sys.argv[1] if len(sys.argv) > 1 else "bert-base-nli-mean-tokens"

# Load a sentence transformer model
model_kwargs = {"torch_dtype": torch.bfloat16, "device_map": "cpu"}
ipex_model = SentenceTransformer(model_name, model_kwargs=model_kwargs, device="cpu", backend="ipex")
trans_model = SentenceTransformer(model_name, model_kwargs=model_kwargs, device="cpu")

max_sentences = 100_00
all_nli_dataset = load_dataset("sentence-transformers/all-nli", "pair", split="train")
sentences = list(set(all_nli_dataset["anchor"]))[:max_sentences]

print("Model Name:", model_name)
print("Number of sentences:", len(sentences))

for batci_size in [32, 16, 8, 4, 1]:
    print(f"test with batch_size = {batci_size}")
    diff_times = []
    for model in [trans_model, ipex_model]:
        for i in range(2):
            print("Run", i)
            start_time = time.time()
            emb = model.encode(sentences, batch_size=batci_size)
            end_time = time.time()
            diff_time = end_time - start_time
        print(f"Done after {diff_time:.2f} seconds")
        print(f"Speed: {len(sentences) / diff_time:.2f} sentences / second")
        print("=====")
        diff_times.append(diff_time)

    print(f"speed-up ratio: {round(diff_times[0]/diff_times[1], 2)}")

My torch version is 2.5.1 and intel_extension_for_pytorch version is 2.5.0.

From my observation on Intel 4th Gen Xeon, the speed-up compared ipex backend with torch backend is:

|  batch size  |   1   |   4   |   8   |  16   |  32   |
-------------------------------------------------------
|  speed-up    | 1.93x | 1.76x | 1.75x | 1.54x | 1.4x  |

@echarlaix
Copy link
Contributor

In case we don't want to integrate IPEX models to sentence-transformers for now, here is an integration in optimum-intel to have support there in meantime huggingface/optimum-intel#1034 cc @jiqing-feng @tomaarsen

@jiqing-feng
Copy link
Author

jiqing-feng commented Dec 9, 2024

In case we don't want to integrate IPEX models to sentence-transformers for now, here is an integration in optimum-intel to have support there in meantime huggingface/optimum-intel#1034 cc @jiqing-feng @tomaarsen

Hi @echarlaix , thanks for your integration. Let's release the new version of optimum-intel so @tomaarsen can set up the tests ASAP. Thanks!

@jiqing-feng
Copy link
Author

Hi @tomaarsen . The latest version of optimum-intel has integrated our changes, you can set up the tests on optimum-intel==1.21.0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants