Skip to content

Commit

Permalink
Fix M4T for ASR pipeline (huggingface#32296)
Browse files Browse the repository at this point in the history
* tentative fix

* do the same for M4T
  • Loading branch information
ylacombe authored Jul 30, 2024
1 parent 084b509 commit 2fbbcf5
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3154,6 +3154,7 @@ def generate(
"""
text_decoder_input_ids = kwargs.pop("decoder_input_ids", None)
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
input_features = input_features if input_features is not None else kwargs.pop("inputs")
if tgt_lang is not None:
inputs = kwargs.get("input_embeds") if input_features is None else input_features
inputs = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3422,6 +3422,7 @@ def generate(
"""
text_decoder_input_ids = kwargs.pop("decoder_input_ids", None)
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
input_features = input_features if input_features is not None else kwargs.pop("inputs")
if tgt_lang is not None:
inputs = kwargs.get("input_embeds") if input_features is None else input_features
inputs = (
Expand Down

0 comments on commit 2fbbcf5

Please sign in to comment.