diff --git a/parler_tts/dac_wrapper/modeling_dac.py b/parler_tts/dac_wrapper/modeling_dac.py index d3d5a44..cbc751d 100644 --- a/parler_tts/dac_wrapper/modeling_dac.py +++ b/parler_tts/dac_wrapper/modeling_dac.py @@ -12,6 +12,9 @@ class DACModel(PreTrainedModel): config_class = DACConfig + # Set main input to 'input_values' for voice steering + main_input_name = "input_values" + def __init__(self, config): super().__init__(config) diff --git a/parler_tts/modeling_parler_tts.py b/parler_tts/modeling_parler_tts.py index 2d5391f..9e39d20 100644 --- a/parler_tts/modeling_parler_tts.py +++ b/parler_tts/modeling_parler_tts.py @@ -3483,13 +3483,19 @@ def generate( # Apply the pattern mask to the final ids output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) - # Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask - _, mask = self.decoder.build_delay_pattern_mask( - input_ids, - bos_token_id=generation_config._bos_token_tensor, - pad_token_id=generation_config._pad_token_tensor, - max_length=output_ids.shape[1], - ) + if "input_values" in model_kwargs: + # Handle input_values for voice steering + mask = (output_ids != generation_config.bos_token_id) & (output_ids != generation_config.pad_token_id) + else: + # Revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask + _, mask = self.decoder.build_delay_pattern_mask( + input_ids, + bos_token_id=generation_config.bos_token_id, + pad_token_id=generation_config.pad_token_id, + max_length=output_ids.shape[1], + ) + mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id) + mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id) output_ids = output_ids[mask].reshape(batch_size, self.decoder.num_codebooks, -1)