Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GREAT MODELS, but a number of issues ... #125

Open
apresence opened this issue Aug 30, 2024 · 5 comments
Open

GREAT MODELS, but a number of issues ... #125

apresence opened this issue Aug 30, 2024 · 5 comments

Comments

@apresence
Copy link

First off -- AMAZING TTS!!!

I know I'm repeating several other issues that have been opened, but I've spent several days testing and code tweaking to try to resolve the issues I have found, and wanted to share. Plus, I figured rolling them all up into one place might be helpful.

It would be AWESOME if we could get this thing working reliably!

I've tried the following:

  • Every bit of sample code I could get my hands on including this, the samples in the HF repo, the sample Gradio app, and the samples given in INFERENCE.md here on github
  • These models: parler-tts-mini-v1, parler-tts-large-v1, parler-tts-mini-expresso, parler-tts-mini-jenny-30H
  • Several hacks that are supposed to fix the issues, including this and this, one that sets the random seed across all components (cpu, gpu, torch, numpy, etc.) to try to increase determinism, and some others
  • All three supported attention implementations
  • Using torch.compile() with 'default', 'reduce-overhead' and 'max-autotune', just to see if it made a difference
  • Padding or not padding the inputs on either the text or description
  • Providing an attention mask or not providing it on either the text or description
  • Using the names that are supposed to make the output more consistent like Jerry, Thomas, Talia, and Elisabeth for expresso and the ones listed here for mini and large. I've also used several of the example descriptions verbatim

I wrote a program that works through a sampling of all of the above combinations, used it to generate 500 WAV files from the same paragraph of input text and description (description varies by model, of course), then randomly sampled about 10% of them.

At least one or more of the following issues occur regardless of the model or which combination of the above are done:

  • Voice change between generations (The one exception: parler-tts-mini-jenny-30H -- even then, there is variance in the tone, speed and emotion between generations)
  • Words mis-pronounced randomly -- the same word might be fine in one generation, but not the next. This is particularly bad if a multi-syllable word is at the end of the text, it may only get partially pronounced, or mispronounced after a syllable or two
  • Entire words dropped
  • Groups of words/phrases dropped
  • Words/phrases spoken out of order
  • Silent pauses, sometimes up to several seconds long
  • Jarring variations in volume level
  • Random pause lengths regardless of commas or periods

Any more than about 50 input tokens and the issues get much worse.

I'm wondering if there isn't an issue with the way attention or KV caching is implemented. That seems to fit as a cause for the issues.

For one thing, this message is logged at the first generation:
prompt_attention_mask is specified but attention_mask is not. A full attention_mask will be created. Make sure this is the intended behaviour.

A further hint towards attention issues is in the code examples: sometimes only the text mask is given, sometimes only the description mask, sometimes both, and sometimes neither. Sometimes they're padded, sometimes not.

Looking at the code, it seems that the input attention mask is ignored in some cases, generated/re-generated/re-shaped/modified several times throughout the generation cycle, and so is the cache. The code that manipulates them is spread around and repeated in different places. There are also multiple conditionals that check the torch version and whatnot and change the way things are processed. Then there are comments like:

  • "In the generation case of prompt_cross_attention=True, we need to recreate an attention mask from scratch"
  • "As it is, the masked ids from the prompt will still count in the positions embeddings" (so the mask is ignored?)
  • "Force float32 since bfloat16 loses precision on long contexts" ... then later ... "In PEFT, usually we cast the layer norms in float32 ... need cast them back in the correct dtype just to be sure everything works as expected. This might slowdown training & inference ..."
  • "no matter the length, we just slice it"
  • "These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view."
    • "TODO: Remove the query_length != 1 check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 init."
  • "There is a memcpy here, that is very bad."
  • "in this case we assume that the mask comes already in inverted form and requires no inversion or slicing"
  • "flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference." and then later " ... The -q_len: slice assumes left padding."
  • "For SDPA, when possible, we will rely on its is_causal argument instead of its attn_mask argument .. This feature is not compatible with static cache, as SDPA will fail to infer the attention mask."

I do not intend this as a form of criticism -- the quality coming out of Parler is amazing! I highlight these in case anyone with the requisite knowledge might be able to review them. While I am a developer, I am brand new to transformers and don't really understand the underlying concepts at play here.

It's worth noting that I can see a lot of this code was copied and pasted from somewhere else (most notably, MusicGen), so many of these little wrinkles may have been pre-existing.

Here's my setup:

Hardware: RTX 4090 FE 24GB VRAM
Drivers: 555.42.06, Cuda 12.5
OS: Ubuntu 22.04.4 LTS

I also tried it on an RTX 8000 48GB VRAM, same results.

Thanks!

@apresence
Copy link
Author

Something definitely seems strange with the cache. I noticed the following warning when compiling:

V0901 20:36:51.238000 125202030948352 torch/_dynamo/guards.py:2611] [0/1] [__recompiles] Recompiling function forward in /work/parler-tts/parler_tts/modeling_parler_tts.py:2576
V0901 20:36:51.238000 125202030948352 torch/_dynamo/guards.py:2611] [0/1] [__recompiles]     triggered by the following guard failure(s):
V0901 20:36:51.238000 125202030948352 torch/_dynamo/guards.py:2611] [0/1] [__recompiles]     - tensor 'L['cache_position']' size mismatch at index 0. expected 51, actual 1

It's saying it expected cache_position to be of size 51, but it's actually 1. Using the handy python module icecream, we can trace this through a generation. As you'll see, cache_position starts out as a Tensor with an array of size 51 with numbers 0 .. 50 in it. But then suddenly in prepare_inputs_for_generation() it is converted to a tensor with only one value. This tracks with the compiler warning.

Log snippet:

2024-09-01 19:08:14,490 [MainThread  ] [INFO ] >>> Performing inference into test_comp_inf_20240901190805.attn_impl=eager.model=expresso.pad_dir=left.pad_len=50.seed=42.tdev=cuda0.ttyp=torch.bfloat16_gen.wav
2024-09-01 19:08:14,492 [MainThread  ] [INFO ] Chunk 1/7: numtoks=11/50 text="Hey, there, I'm Parly!"
...
ic| modeling_parler_tts.py:2819 in prepare_inputs_for_generation()
    cache_position: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                            18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                            36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
                           device='cuda:0')
...
ic| modeling_parler_tts.py:2833 in prepare_inputs_for_generation()
    cache_position: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                            18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                            36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
                           device='cuda:0')
...
ic| modeling_parler_tts.py:1360 in forward()
    cache_position.unsqueeze(0): tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                                          18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                                          36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]],
                                        device='cuda:0')
...
ic| modeling_parler_tts.py:1607 in _update_causal_mask()
    cache_position: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                            18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                            36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
                           device='cuda:0')
ic| modeling_parler_tts.py:1608 in _update_causal_mask()
    cache_position.reshape(-1, 1): tensor([[ 0],
                                           [ 1],
                                           [ 2],
                                           [ 3],
                                           [ 4],
                                           [ 5],
                                           [ 6],
                                           [ 7],
                                           [ 8],
                                           [ 9],
                                           [10],
                                           [11],
                                           [12],
                                           [13],
                                           [14],
                                           [15],
                                           [16],
                                           [17],
                                           [18],
                                           [19],
                                           [20],
                                           [21],
                                           [22],
                                           [23],
                                           [24],
                                           [25],
                                           [26],
                                           [27],
                                           [28],
                                           [29],
                                           [30],
                                           [31],
                                           [32],
                                           [33],
                                           [34],
                                           [35],
                                           [36],
                                           [37],
                                           [38],
                                           [39],
                                           [40],
                                           [41],
                                           [42],
                                           [43],
                                           [44],
                                           [45],
                                           [46],
                                           [47],
                                           [48],
                                           [49],
                                           [50]], device='cuda:0')
ic| modeling_parler_tts.py:433 in forward()
    cache_position: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                            18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                            36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
                           device='cuda:0')
...
ic| modeling_parler_tts.py:433 in forward()- cache_position: None
ic| modeling_parler_tts.py:2801 in prepare_inputs_for_generation()
    cache_position[0] if cache_position is not None else past_key_values.get_seq_length(): tensor(51, device='cuda:0')
ic| modeling_parler_tts.py:2802 in prepare_inputs_for_generation()
    past_key_values.get_seq_length(): tensor(51, device='cuda:0')
ic| modeling_parler_tts.py:2819 in prepare_inputs_for_generation()
    cache_position: tensor([51], device='cuda:0')
...
ic| modeling_parler_tts.py:433 in forward()
    cache_position: tensor([51], device='cuda:0')
ic| modeling_parler_tts.py:2801 in prepare_inputs_for_generation()
    cache_position[0] if cache_position is not None else past_key_values.get_seq_length(): tensor(52, device='cuda:0')
ic| modeling_parler_tts.py:2802 in prepare_inputs_for_generation()
    past_key_values.get_seq_length(): tensor(52, device='cuda:0')
ic| modeling_parler_tts.py:2819 in prepare_inputs_for_generation()
    cache_position: tensor([52], device='cuda:0')
ic| modeling_parler_tts.py:2826 in prepare_inputs_for_generation()
    decoder_input_ids.shape[1]: 1
...
ic| modeling_parler_tts.py:433 in forward()
    cache_position: tensor([52], device='cuda:0')
ic| modeling_parler_tts.py:2801 in prepare_inputs_for_generation()
    cache_position[0] if cache_position is not None else past_key_values.get_seq_length(): tensor(53, device='cuda:0')
ic| modeling_parler_tts.py:2802 in prepare_inputs_for_generation()
    past_key_values.get_seq_length(): tensor(53, device='cuda:0')
ic| modeling_parler_tts.py:2819 in prepare_inputs_for_generation()
    cache_position: tensor([53], device='cuda:0')
ic| modeling_parler_tts.py:2826 in prepare_inputs_for_generation()
...

@apresence
Copy link
Author

apresence commented Sep 1, 2024

BTW, I used the expresso model for that, but the same thing happens for the other models. Here's the config I am using:
attention_implementation='eager'
padding_side='left' (for input), padding='max_length', padding_size=50
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode='reduce-overhead', fullgraph=True)

Also, to clarify, the tensor remains 1x1 throughout the generation until the next generation, then it is reset to the 51x1.

@ylacombe
Copy link
Collaborator

ylacombe commented Sep 2, 2024

Hey @apresence, thanks for the thorough feedback, there's definitely a lot to unpack.

In theory, every issues regarding generation (inconsistency, words that are dropped, pauses etc.) are explained by the data on which the model was used, and the tokenizer that we used. The model is a LLM that learns to associate tokens to sounds. As such, the model can have difficulty to pronounce infrequent tokens or infrequent sequences of tokens.
Since it's a LLM, it also suffers from classic LLMs issues: hallucinations, inconsistent behavior etc.

Regarding length of the audio generated, it was trained on audios that are mostly under 20 seconds, and thus can't generalize to long prompts!

These are issues that we're aware of. Hopefully, we'll solve some of these in a next version (if any!)

Also cc @eustlb regarding the compilation warning with the cache position

@eustlb
Copy link
Contributor

eustlb commented Sep 5, 2024

Hey @apresence,

Thank you very much for your detailed feedback. Concerning the point you've raised about cache_position and recompilation, that's actually and expected result.
Indeed, when running generation:

  1. for the first forward pass the hidden states of the prompt text are pre-pended to the one-dimensional start of sequence tensor of the decoder. This way cache_position, which will indicate where to store key and values in the cache, should be a tensor of with number of tokens in the prompt + 1 → the value 51 in your example.
  2. after that, at each new time step, the position in the cache will be only one value, since we auto-regressively generate new tokens.

It is therefore expected to see this recompilation during the warmup step: torch will first compile the case where we have 51 values in cache_position and then recompile when only one value. You can also read this issue for more information about this necessary warmup step : #93

Test it by yourself using the following snippet 🤗:
You'll see that we have no recompilation at the second generate call.

from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import torch

# debugging
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)

# reproducibility
torch.manual_seed(0)

# set-up device args
torch_device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
attn_implementation = "sdpa"

# model
model_name = "parler-tts/parler-tts-mini-v1"
model = ParlerTTSForConditionalGeneration.from_pretrained(
    model_name,
    attn_implementation=attn_implementation
).to(torch_device, dtype=torch_dtype)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="default", fullgraph=True)

# tokenizers
padding_side = "left"
description_tokenizer = AutoTokenizer.from_pretrained(model_name) 
prompt_tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side)

def tokenize_inputs(description, prompt):
    tokenized_description = description_tokenizer(description, return_tensors="pt", padding='max_length', max_length=50)
    input_ids = tokenized_description.input_ids.to(torch_device)
    attention_mask = tokenized_description.attention_mask.to(torch_device)

    tokenized_prompt = prompt_tokenizer(prompt, return_tensors="pt", padding='max_length', max_length=50)
    prompt_input_ids = tokenized_prompt.input_ids.to(torch_device)
    prompt_attention_mask = tokenized_prompt.attention_mask.to(torch_device)

    return input_ids, prompt_input_ids, attention_mask, prompt_attention_mask 

# first generation
prompt = "Hey, how are you doing today?"
description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."
input_ids, prompt_input_ids, attention_mask, prompt_attention_mask = tokenize_inputs(description, prompt)

_ = model.generate(
    input_ids=input_ids, 
    prompt_input_ids=prompt_input_ids,
    attention_mask=attention_mask,
    prompt_attention_mask=prompt_attention_mask
)

print("Completed first generate!")

# second generation, debugging parameters will show us if recompilation happens
prompt = "Hey, how are you doing?"
description = "A male speaker with a slightly low-pitched voice delivers his words quite expressively, in a very confined sounding environment with clear audio quality. He speaks very fast."
input_ids, prompt_input_ids, attention_mask, prompt_attention_mask = tokenize_inputs(description, prompt)

_ = model.generate(
    input_ids=input_ids, 
    prompt_input_ids=prompt_input_ids,
    attention_mask=attention_mask,
    prompt_attention_mask=prompt_attention_mask
)

print("Completed second generate!")

@kunci115
Copy link

kunci115 commented Sep 11, 2024

I also experienced with numbers wrong pronounced or skipped, and spelling felts like really hard for example CCB, they will spell it like "seb"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants