Skip to content

Commit

Permalink
enable low-precision pipeline (#31625)
Browse files Browse the repository at this point in the history
* enable low-precision pipeline

* fix parameter for ASR

* reformat

* fix asr bug

* fix bug for zero-shot

* add dtype check

* rm useless comments

* add np.float16 check

* Update src/transformers/pipelines/image_classification.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/pipelines/token_classification.py

Co-authored-by: Marc Sun <[email protected]>

* fix comments

* fix asr check

* make fixup

* No more need for is_torch_available()

---------

Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: Matt <[email protected]>
Co-authored-by: Matt <[email protected]>
  • Loading branch information
4 people committed Sep 20, 2024
1 parent 7b2b536 commit 49a0bef
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 4 deletions.
5 changes: 4 additions & 1 deletion src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,10 @@ def postprocess(
key = "logits" if self.type == "ctc_with_lm" else "tokens"
stride = None
for outputs in model_outputs:
items = outputs[key].numpy()
if self.framework == "pt" and outputs[key].dtype in (torch.bfloat16, torch.float16):
items = outputs[key].to(torch.float32).numpy()
else:
items = outputs[key].numpy()
stride = outputs.get("stride", None)
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
total_n, left, right = stride
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/pipelines/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
if is_torch_available():
import torch

from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES


Expand Down Expand Up @@ -299,7 +301,11 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE
ignore_labels = ["O"]
all_entities = []
for model_outputs in all_outputs:
logits = model_outputs["logits"][0].numpy()
if self.framework == "pt" and model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16):
logits = model_outputs["logits"][0].to(torch.float32).numpy()
else:
logits = model_outputs["logits"][0].numpy()

sentence = all_outputs[0]["sentence"]
input_ids = model_outputs["input_ids"][0]
offset_mapping = (
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2143,7 +2143,7 @@ def nested_simplify(obj, decimals=3):
return nested_simplify(obj.numpy().tolist())
elif isinstance(obj, float):
return round(obj, decimals)
elif isinstance(obj, (np.int32, np.float32)):
elif isinstance(obj, (np.int32, np.float32, np.float16)):
return nested_simplify(obj.item(), decimals)
else:
raise Exception(f"Not supported: {type(obj)}")
Expand Down
42 changes: 42 additions & 0 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,48 @@ def test_small_model_pt(self):
):
_ = speech_recognizer(waveform, return_timestamps="char")

@require_torch
def test_small_model_pt_fp16(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/s2t-small-mustc-en-fr-st",
tokenizer="facebook/s2t-small-mustc-en-fr-st",
framework="pt",
torch_dtype=torch.float16,
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"})
output = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(output, {"text": "(Applaudissements)"})

# Non CTC models cannot use return_timestamps
with self.assertRaisesRegex(
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
):
_ = speech_recognizer(waveform, return_timestamps="char")

@require_torch
def test_small_model_pt_bf16(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/s2t-small-mustc-en-fr-st",
tokenizer="facebook/s2t-small-mustc-en-fr-st",
framework="pt",
torch_dtype=torch.bfloat16,
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"})
output = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(output, {"text": "(Applaudissements)"})

# Non CTC models cannot use return_timestamps
with self.assertRaisesRegex(
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
):
_ = speech_recognizer(waveform, return_timestamps="char")

@slow
@require_torch_accelerator
def test_whisper_fp16(self):
Expand Down
35 changes: 35 additions & 0 deletions tests/pipelines/test_pipelines_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
Expand All @@ -38,6 +39,10 @@
from .test_pipelines_common import ANY


if is_torch_available():
import torch


VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]

# These 2 model types require different inputs than those of the usual text models.
Expand Down Expand Up @@ -841,6 +846,36 @@ def test_small_model_pt(self):
],
)

@require_torch
def test_small_model_pt_fp16(self):
model_name = "hf-internal-testing/tiny-bert-for-token-classification"
token_classifier = pipeline(
task="token-classification", model=model_name, framework="pt", torch_dtype=torch.float16
)
outputs = token_classifier("This is a test !")
self.assertEqual(
nested_simplify(outputs),
[
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
],
)

@require_torch
def test_small_model_pt_bf16(self):
model_name = "hf-internal-testing/tiny-bert-for-token-classification"
token_classifier = pipeline(
task="token-classification", model=model_name, framework="pt", torch_dtype=torch.bfloat16
)
outputs = token_classifier("This is a test !")
self.assertEqual(
nested_simplify(outputs),
[
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
],
)

@require_torch
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
Expand Down
55 changes: 54 additions & 1 deletion tests/pipelines/test_pipelines_zero_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,22 @@
ZeroShotClassificationPipeline,
pipeline,
)
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
slow,
)

from .test_pipelines_common import ANY


if is_torch_available():
import torch


# These 2 model types require different inputs than those of the usual text models.
_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}

Expand Down Expand Up @@ -176,6 +187,48 @@ def test_small_model_pt(self):
},
)

@require_torch
def test_small_model_pt_fp16(self):
zero_shot_classifier = pipeline(
"zero-shot-classification",
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
framework="pt",
torch_dtype=torch.float16,
)
outputs = zero_shot_classifier(
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
)

self.assertEqual(
nested_simplify(outputs),
{
"sequence": "Who are you voting for in 2020?",
"labels": ["science", "public health", "politics"],
"scores": [0.333, 0.333, 0.333],
},
)

@require_torch
def test_small_model_pt_bf16(self):
zero_shot_classifier = pipeline(
"zero-shot-classification",
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
framework="pt",
torch_dtype=torch.bfloat16,
)
outputs = zero_shot_classifier(
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
)

self.assertEqual(
nested_simplify(outputs),
{
"sequence": "Who are you voting for in 2020?",
"labels": ["science", "public health", "politics"],
"scores": [0.333, 0.333, 0.333],
},
)

@require_tf
def test_small_model_tf(self):
zero_shot_classifier = pipeline(
Expand Down

0 comments on commit 49a0bef

Please sign in to comment.