Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove invalid references to facebook/parler-tts-small #132

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions parler_tts/configuration_parler_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

logger = logging.get_logger(__name__)

MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/parler_tts-small": "https://huggingface.co/facebook/parler_tts-small/resolve/main/config.json",
PARLER_TTS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"parler-tts/parler-tts-mini-v1": "https://huggingface.co/parler-tts/parler-tts-mini-v1/resolve/main/config.json",
# See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
}

Expand All @@ -31,7 +31,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
This is the configuration class to store the configuration of an [`ParlerTTSDecoder`]. It is used to instantiate a
Parler-TTS decoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Parler-TTS
[facebook/parler_tts-small](https://huggingface.co/facebook/parler_tts-small) architecture.
[parler-tts/parler-tts-mini-v1](https://huggingface.co/parler-tts/parler-tts-mini-v1) architecture.

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Expand Down Expand Up @@ -203,7 +203,7 @@ class ParlerTTSConfig(PretrainedConfig):
... text_encoder_config, audio_encoder_config, decoder_config
... )

>>> # Initializing a ParlerTTSForConditionalGeneration (with random weights) from the facebook/parler_tts-small style configuration
>>> # Initializing a ParlerTTSForConditionalGeneration (with random weights) from the parler-tts/parler-tts-mini-v1 style configuration
>>> model = ParlerTTSForConditionalGeneration(configuration)

>>> # Accessing the model configuration
Expand Down
14 changes: 7 additions & 7 deletions parler_tts/modeling_parler_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@
logger.warn("Flash attention 2 is not installed")

_CONFIG_FOR_DOC = "ParlerTTSConfig"
_CHECKPOINT_FOR_DOC = "facebook/parler_tts-small"
_CHECKPOINT_FOR_DOC = "parler-tts/parler-tts-mini-v1"

MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/parler_tts-small",
"parler-tts/parler-tts-mini-v1",
# See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
]

Expand Down Expand Up @@ -2357,7 +2357,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
```python
>>> from parler_tts import ParlerTTSForConditionalGeneration

>>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small")
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1")
```"""

# At the moment fast initialization is not supported for composite models
Expand Down Expand Up @@ -2411,7 +2411,7 @@ def from_sub_models_pretrained(

- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `gpt2`, or namespaced under a user or
organization name, like `facebook/parler_tts-small`.
organization name, like `parler-tts/parler-tts-mini-v1`.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.

Expand Down Expand Up @@ -2440,7 +2440,7 @@ def from_sub_models_pretrained(
>>> model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
... text_encoder_pretrained_model_name_or_path="t5-base",
... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz",
... decoder_pretrained_model_name_or_path="facebook/parler_tts-small",
... decoder_pretrained_model_name_or_path="parler-tts/parler-tts-mini-v1",
... )
>>> # saving model after fine-tuning
>>> model.save_pretrained("./parler_tts-ft")
Expand Down Expand Up @@ -2607,8 +2607,8 @@ def forward(
>>> from transformers import AutoProcessor, ParlerTTSForConditionalGeneration
>>> import torch

>>> processor = AutoProcessor.from_pretrained("facebook/parler_tts-small")
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small")
>>> processor = AutoProcessor.from_pretrained("parler-tts/parler-tts-mini-v1")
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1")

>>> inputs = processor(
... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
Expand Down