From 49a0bef4c1d959d9008d4a7128ca5e24c2ac7fc1 Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Sat, 21 Sep 2024 07:43:30 +0800 Subject: [PATCH] enable low-precision pipeline (#31625) * enable low-precision pipeline * fix parameter for ASR * reformat * fix asr bug * fix bug for zero-shot * add dtype check * rm useless comments * add np.float16 check * Update src/transformers/pipelines/image_classification.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/pipelines/token_classification.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * fix comments * fix asr check * make fixup * No more need for is_torch_available() --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Matt Co-authored-by: Matt --- .../pipelines/automatic_speech_recognition.py | 5 +- .../pipelines/token_classification.py | 8 ++- src/transformers/testing_utils.py | 2 +- ..._pipelines_automatic_speech_recognition.py | 42 ++++++++++++++ .../test_pipelines_token_classification.py | 35 ++++++++++++ tests/pipelines/test_pipelines_zero_shot.py | 55 ++++++++++++++++++- 6 files changed, 143 insertions(+), 4 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 4301982f1e901c..9b82b67820c51b 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -565,7 +565,10 @@ def postprocess( key = "logits" if self.type == "ctc_with_lm" else "tokens" stride = None for outputs in model_outputs: - items = outputs[key].numpy() + if self.framework == "pt" and outputs[key].dtype in (torch.bfloat16, torch.float16): + items = outputs[key].to(torch.float32).numpy() + else: + items = outputs[key].numpy() stride = outputs.get("stride", None) if stride is not None and self.type in {"ctc", "ctc_with_lm"}: total_n, left, right = stride diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index e1d763eafa8b71..9256f238148476 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -19,6 +19,8 @@ from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES if is_torch_available(): + import torch + from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES @@ -299,7 +301,11 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE ignore_labels = ["O"] all_entities = [] for model_outputs in all_outputs: - logits = model_outputs["logits"][0].numpy() + if self.framework == "pt" and model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16): + logits = model_outputs["logits"][0].to(torch.float32).numpy() + else: + logits = model_outputs["logits"][0].numpy() + sentence = all_outputs[0]["sentence"] input_ids = model_outputs["input_ids"][0] offset_mapping = ( diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index b86e3af91ca727..e0608acfeb8a54 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -2143,7 +2143,7 @@ def nested_simplify(obj, decimals=3): return nested_simplify(obj.numpy().tolist()) elif isinstance(obj, float): return round(obj, decimals) - elif isinstance(obj, (np.int32, np.float32)): + elif isinstance(obj, (np.int32, np.float32, np.float16)): return nested_simplify(obj.item(), decimals) else: raise Exception(f"Not supported: {type(obj)}") diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 842933d2b76c94..c12292fc3370d3 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -167,6 +167,48 @@ def test_small_model_pt(self): ): _ = speech_recognizer(waveform, return_timestamps="char") + @require_torch + def test_small_model_pt_fp16(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="facebook/s2t-small-mustc-en-fr-st", + tokenizer="facebook/s2t-small-mustc-en-fr-st", + framework="pt", + torch_dtype=torch.float16, + ) + waveform = np.tile(np.arange(1000, dtype=np.float32), 34) + output = speech_recognizer(waveform) + self.assertEqual(output, {"text": "(Applaudissements)"}) + output = speech_recognizer(waveform, chunk_length_s=10) + self.assertEqual(output, {"text": "(Applaudissements)"}) + + # Non CTC models cannot use return_timestamps + with self.assertRaisesRegex( + ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$" + ): + _ = speech_recognizer(waveform, return_timestamps="char") + + @require_torch + def test_small_model_pt_bf16(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="facebook/s2t-small-mustc-en-fr-st", + tokenizer="facebook/s2t-small-mustc-en-fr-st", + framework="pt", + torch_dtype=torch.bfloat16, + ) + waveform = np.tile(np.arange(1000, dtype=np.float32), 34) + output = speech_recognizer(waveform) + self.assertEqual(output, {"text": "(Applaudissements)"}) + output = speech_recognizer(waveform, chunk_length_s=10) + self.assertEqual(output, {"text": "(Applaudissements)"}) + + # Non CTC models cannot use return_timestamps + with self.assertRaisesRegex( + ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$" + ): + _ = speech_recognizer(waveform, return_timestamps="char") + @slow @require_torch_accelerator def test_whisper_fp16(self): diff --git a/tests/pipelines/test_pipelines_token_classification.py b/tests/pipelines/test_pipelines_token_classification.py index 41415c8c34589e..5e804bec199ab0 100644 --- a/tests/pipelines/test_pipelines_token_classification.py +++ b/tests/pipelines/test_pipelines_token_classification.py @@ -27,6 +27,7 @@ from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler from transformers.testing_utils import ( is_pipeline_test, + is_torch_available, nested_simplify, require_tf, require_torch, @@ -38,6 +39,10 @@ from .test_pipelines_common import ANY +if is_torch_available(): + import torch + + VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]] # These 2 model types require different inputs than those of the usual text models. @@ -841,6 +846,36 @@ def test_small_model_pt(self): ], ) + @require_torch + def test_small_model_pt_fp16(self): + model_name = "hf-internal-testing/tiny-bert-for-token-classification" + token_classifier = pipeline( + task="token-classification", model=model_name, framework="pt", torch_dtype=torch.float16 + ) + outputs = token_classifier("This is a test !") + self.assertEqual( + nested_simplify(outputs), + [ + {"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4}, + {"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7}, + ], + ) + + @require_torch + def test_small_model_pt_bf16(self): + model_name = "hf-internal-testing/tiny-bert-for-token-classification" + token_classifier = pipeline( + task="token-classification", model=model_name, framework="pt", torch_dtype=torch.bfloat16 + ) + outputs = token_classifier("This is a test !") + self.assertEqual( + nested_simplify(outputs), + [ + {"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4}, + {"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7}, + ], + ) + @require_torch def test_pt_ignore_subwords_slow_tokenizer_raises(self): model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" diff --git a/tests/pipelines/test_pipelines_zero_shot.py b/tests/pipelines/test_pipelines_zero_shot.py index 1003898df6c968..7c437b0a418da2 100644 --- a/tests/pipelines/test_pipelines_zero_shot.py +++ b/tests/pipelines/test_pipelines_zero_shot.py @@ -21,11 +21,22 @@ ZeroShotClassificationPipeline, pipeline, ) -from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow +from transformers.testing_utils import ( + is_pipeline_test, + is_torch_available, + nested_simplify, + require_tf, + require_torch, + slow, +) from .test_pipelines_common import ANY +if is_torch_available(): + import torch + + # These 2 model types require different inputs than those of the usual text models. _TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"} @@ -176,6 +187,48 @@ def test_small_model_pt(self): }, ) + @require_torch + def test_small_model_pt_fp16(self): + zero_shot_classifier = pipeline( + "zero-shot-classification", + model="sshleifer/tiny-distilbert-base-cased-distilled-squad", + framework="pt", + torch_dtype=torch.float16, + ) + outputs = zero_shot_classifier( + "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"] + ) + + self.assertEqual( + nested_simplify(outputs), + { + "sequence": "Who are you voting for in 2020?", + "labels": ["science", "public health", "politics"], + "scores": [0.333, 0.333, 0.333], + }, + ) + + @require_torch + def test_small_model_pt_bf16(self): + zero_shot_classifier = pipeline( + "zero-shot-classification", + model="sshleifer/tiny-distilbert-base-cased-distilled-squad", + framework="pt", + torch_dtype=torch.bfloat16, + ) + outputs = zero_shot_classifier( + "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"] + ) + + self.assertEqual( + nested_simplify(outputs), + { + "sequence": "Who are you voting for in 2020?", + "labels": ["science", "public health", "politics"], + "scores": [0.333, 0.333, 0.333], + }, + ) + @require_tf def test_small_model_tf(self): zero_shot_classifier = pipeline(