Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix XTTS streaming for transformers update #46

Merged
merged 3 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added TTS/tts/layers/xtts/__init__.py
Empty file.
44 changes: 18 additions & 26 deletions TTS/tts/layers/xtts/stream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
import random
import warnings
from typing import Callable, List, Optional, Union
from typing import Callable, Optional, Union

import numpy as np
import torch
Expand All @@ -21,10 +21,11 @@
PreTrainedModel,
StoppingCriteriaList,
)
from transformers.generation.stopping_criteria import validate_stopping_criteria
from transformers.generation.utils import GenerateOutput, SampleOutput, logger


def setup_seed(seed):
def setup_seed(seed: int) -> None:
if seed == -1:
return
torch.manual_seed(seed)
Expand All @@ -49,9 +50,9 @@ def generate( # noqa: PLR0911
generation_config: Optional[StreamGenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
synced_gpus: Optional[bool] = False,
seed=0,
seed: int = 0,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -90,7 +91,7 @@ def generate( # noqa: PLR0911
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
Expand Down Expand Up @@ -151,18 +152,7 @@ def generate( # noqa: PLR0911
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None

# 3. Define model inputs
# inputs_tensor has to be defined
Expand All @@ -174,6 +164,9 @@ def generate( # noqa: PLR0911
)
batch_size = inputs_tensor.shape[0]

device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

# 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
Expand All @@ -182,7 +175,7 @@ def generate( # noqa: PLR0911
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs

if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
generation_config.pad_token_id,
Expand All @@ -209,16 +202,15 @@ def generate( # noqa: PLR0911

# 5. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
device=inputs_tensor.device,
)
else:
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
Expand Down Expand Up @@ -577,7 +569,7 @@ def generate( # noqa: PLR0911

def typeerror():
raise ValueError(
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
"`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]`"
f"of positive integers, but is {generation_config.force_words_ids}."
)

Expand Down Expand Up @@ -649,7 +641,7 @@ def sample_stream(
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
eos_token_id: Optional[Union[int, list[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ dependencies = [
"gruut[de,es,fr]==2.2.3",
# Tortoise
"einops>=0.6.0",
"transformers>=4.33.0,<4.41.0",
"transformers>=4.41.1",
# Bark
"encodec>=0.1.1",
# XTTS
Expand Down
Loading