Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Whisper static generation #1275

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open

Add Whisper static generation #1275

wants to merge 19 commits into from

Conversation

Spycsh
Copy link
Contributor

@Spycsh Spycsh commented Aug 20, 2024

What does this PR do?

Add static KV cache for Whisper family models. Previously it takes huge amount of time to do the Whisper inference (each first inference for each different input audio) on HPU because of dynamic shape of KV cache in SDPA attention module. With this PR the latency should be drastically reduced on HPU.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@Spycsh
Copy link
Contributor Author

Spycsh commented Aug 20, 2024

Hi @ssarkar2 @bhargaveede @vivekgoe @regisss , if you have any issues or questions on this PR, please don't hesitate to contact me :)

Here is the code to reproduced. This script is adapted from https://github.com/opea-project/GenAIComps/blob/main/comps/asr/whisper/whisper_model.py.

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import contextlib
import os
import time
import urllib.request

import numpy as np
import torch
from datasets import Audio, Dataset
from pydub import AudioSegment


class WhisperModel:
    """Convert audio to text."""

    def __init__(self, model_name_or_path="openai/whisper-small", language="english", device="cpu", hpu_max_len=8192):
        if device == "hpu":
            # Explicitly link HPU with Torch
            from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

            adapt_transformers_to_gaudi()
        from transformers import WhisperForConditionalGeneration, WhisperProcessor

        self.device = device
        asr_model_name_or_path = os.environ.get("ASR_MODEL_PATH", model_name_or_path)
        print("Downloading model: {}".format(asr_model_name_or_path))
        self.model = WhisperForConditionalGeneration.from_pretrained(asr_model_name_or_path).to(self.device)
        self.processor = WhisperProcessor.from_pretrained(asr_model_name_or_path)
        self.model.eval()

        self.hpu_max_len = hpu_max_len
        self.language = language

    def _audiosegment_to_librosawav(self, audiosegment):
        # https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentget_array_of_samples
        # This way is faster than librosa.load or HuggingFace Dataset wrapper
        channel_sounds = audiosegment.split_to_mono()[:1]  # only select the first channel
        samples = [s.get_array_of_samples() for s in channel_sounds]

        fp_arr = np.array(samples).T.astype(np.float32)
        fp_arr /= np.iinfo(samples[0].typecode).max
        fp_arr = fp_arr.reshape(-1)

        return fp_arr

    def audio2text(self, audio_path):
        """Convert audio to text.

        audio_path: the path to the input audio, e.g. ~/xxx.mp3
        """
        start = time.time()

        try:
            waveform = AudioSegment.from_file(audio_path).set_frame_rate(16000)
            waveform = self._audiosegment_to_librosawav(waveform)
        except Exception as e:
            print(f"[ASR] audiosegment to librosa wave fail: {e}")
            audio_dataset = Dataset.from_dict({"audio": [audio_path]}).cast_column("audio", Audio(sampling_rate=16000))
            waveform = audio_dataset[0]["audio"]["array"]

        try:
            processed_inputs = self.processor(
                waveform,
                return_tensors="pt",
                truncation=False,
                padding="longest",
                return_attention_mask=True,
                sampling_rate=16000,
            )
        except RuntimeError as e:
            if "Padding size should be less than" in str(e):
                # short-form
                processed_inputs = self.processor(
                    waveform,
                    return_tensors="pt",
                    sampling_rate=16000,
                )
            else:
                raise e
        if processed_inputs.input_features.shape[-1] < 3000:
            # short-form
            processed_inputs = self.processor(
                waveform,
                return_tensors="pt",
                sampling_rate=16000,
            )

        predicted_ids = self.model.generate(
            **(
                processed_inputs.to(
                    self.device,
                )
            ),
            language=self.language,
        )
        # pylint: disable=E1101
        result = self.processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
        if self.language in ["chinese", "mandarin"]:
            from zhconv import convert

            result = convert(result, "zh-cn")
        print(f"generated text in {time.time() - start} seconds, and the result is: {result}")
        return result


if __name__ == "__main__":
    asr = WhisperModel(model_name_or_path="openai/whisper-tiny", language="english", device="hpu")
    urllib.request.urlretrieve(f"https://github.com/Spycsh/assets/raw/main/ljspeech_60s_audio.wav", "sample.wav")
    text = asr.audio2text("sample.wav")

    # Test multilanguage asr
    urllib.request.urlretrieve(
        "https://paddlespeech.bj.bcebos.com/Parakeet/docs/demos/labixiaoxin.wav",
        "sample.wav",
    )
    asr.language = "chinese"
    text = asr.audio2text("sample.wav")

    asr.language = "english"
    urllib.request.urlretrieve(
        "https://github.com/intel/intel-extension-for-transformers/raw/main/intel_extension_for_transformers/neural_chat/assets/audio/sample.wav",
        "sample.wav",
    )
    text = asr.audio2text("sample.wav")

    for i in [5, 10, 30, 60]:
        urllib.request.urlretrieve(f"https://github.com/Spycsh/assets/raw/main/ljspeech_{i}s_audio.wav", "sample.wav")
        text = asr.audio2text("sample.wav")

@emascarenhas
Copy link
Contributor

Hi @Spycsh ,

I haven't reviewed this code as yet and have few initial questions.

Please confirm code merges cleanly with OH main branch and runs with latest 1.17.1 as well as next 1.18.0 latest build.

How did you test these changes? Do you need to add any new tests to the CI?
How does this fit in with the run_speech_recogntion_seq2seq.py script in examples/speech-recognition?

In addition, please run "tests/ci/fast_tests.sh" after installing ".[tests]" and any slow tests, e.g., the ones that run the whisper model in the the CI test suite although e.g.,
GAUDI2_CI=1 RUN_SLOW=1 python -m pytest
tests/test_examples.py -v -s -k whisper

Also do "pip install -U ruff; make style" and check for any issues.

@Spycsh
Copy link
Contributor Author

Spycsh commented Sep 4, 2024

Hi @emascarenhas , thanks for suggestions. I've added the test, passed style check and passed the test on habana 1.16.1 and OH merged with the latest main branch. I currently do not have environment with newer habana driver but I think it should work.

Please tell me if you have further suggestions!

@emascarenhas
Copy link
Contributor

I got this result
***** eval metrics *****
eval_loss = 1.0118
eval_model_preparation_time = 0.0243
eval_runtime = 0:11:54.28
eval_samples = 2894
eval_samples_per_second = 4.052
eval_steps_per_second = 0.127
eval_wer = 3.024
max_memory_allocated (GB) = 53.45
memory_allocated (GB) = 53.44
total_memory_available (GB) = 94.62

Do you need to adjust the performance criteria?

tests/test_speech_recognition_example.py:69: AssertionError
=================================================================== short test summary info ====================================================================
FAILED tests/test_speech_recognition_example.py::test_speech_recognition_bf16[openai/whisper-small-32-2.892] - assert 4.052 >= (2 * 2.892)
================================================================ 1 failed in 733.05s (0:12:13) =================================================================

@emascarenhas
Copy link
Contributor

emascarenhas commented Sep 4, 2024

@Spycsh ,
How can I tell whether the new code in optimum/habana/transformers/models/whisper/modeling_whisper.py is used?
Since you are using the run_speech_recognition_seq2seq.py script from earlier without any change, does it invoke any of the new definitions/methods in the modeling_whisper.py file? This script is already part of the CI testing.

What I expected was a new or modified script that will use the code you introduced in the models directory.

@Spycsh
Copy link
Contributor Author

Spycsh commented Sep 5, 2024

Hi @emascarenhas , on my gaudi2 the throughput (eval_samples_per_second) should be 7.757 vs. baseline 2.892 so I assert the PR throughtput should be >=2 baseline throughput. However your tested data show a throughput 4.052. I have no idea why the PR throughput vary a lot between our hardwares. Could you please also check whether the baseline is close to 2.892 on your hardware? We can still lower the value of multiple a little bit to make it pass in your environment (e.g. PR throughput >= 1.2 baseline throughput).

My detailed logs of the latest run on PR branch are shown below:

***** eval metrics *****
  eval_loss                   =     1.0118
  eval_model_preparation_time =     0.0449
  eval_runtime                = 0:06:17.92
  eval_samples                =       2894
  eval_samples_per_second     =      7.658
  eval_steps_per_second       =      0.241
  eval_wer                    =     2.0325
  max_memory_allocated (GB)   =      13.16
  memory_allocated (GB)       =      13.15
  total_memory_available (GB) =      94.62
[INFO|modelcard.py:449] 2024-09-05 01:38:31,766 >> Dropping the following result as it does not have all the necessary fields:
{'task': {'name': 'Automatic Speech Recognition', 'type': 'automatic-speech-recognition'}, 'dataset': {'name': 'mozilla-foundation/common_voice_11_0 hi', 'type': 'mozilla-foundation/common_voice_11_0', 'config': 'hi', 'split': 'test', 'args': 'hi'}}

My command to trigger the test:

pytest -v -s tests/test_speech_recognition_example.py

Regarding to your another question about where is modeling_whisper.py used:

Previously run_speech_recognition_seq2seq.py already used openai/whisper-small to measure the throughput, do the finetuning on hindi language. It use the native transformers implementation for Whisper and do not provide any optimization. Meanwhile, it has a super slow warmup speed because it does not have a static KV cache shape optimization. Every time decoding iteration has a increasing input length so multiple different shape of HPU graphs will be initialized and every initialization takes some time so it would be very slow. My solution is to keep the KV shape the same padded to max-length during decoding so HPU graph will be as few as possible. That is what this PR aim to provide.

After this PR, the run_speech_recognition_seq2seq.py will not go into native transformers implementation of Whisper but go into the attention logics I write in modeling_whisper.py. And this should benefit from static generation.

Welcome to further questions!

@emascarenhas
Copy link
Contributor

It makes sense.

I suppose we don't need to have tests/test_speech_recognition_example.py because we already have an examples test in the CI that invokes run_speech_recognition_seq2seq.py. We don't need to run the same test run twice. If it is running the same test with exact same code coverage, then I think we should remove test_speech_recognition_example.py ?

@Spycsh
Copy link
Contributor Author

Spycsh commented Sep 6, 2024

@emascarenhas Sure :) I've removed the redundant test.

@emascarenhas
Copy link
Contributor

Looks good to me. @libinta , Please add run-test label.

@libinta libinta added run-test Run CI for PRs from external contributors and removed review wip labels Sep 11, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-test Run CI for PRs from external contributors
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants