@@ -97,7 +97,7 @@ class MoshiConfig(PretrainedConfig):
97
97
98
98
Example:
99
99
100
- ```python
100
+ ```python # TODO(YL): update
101
101
>>> from transformers import (
102
102
... MoshiConfig,
103
103
... EncodecConfig,
@@ -189,21 +189,24 @@ def __init__(self,
189
189
self .depth_head_dim = depth_head_dim or depth_hidden_size // depth_num_attention_heads
190
190
self .depth_num_key_value_heads = depth_num_key_value_heads if depth_num_key_value_heads is not None else depth_num_attention_heads
191
191
192
- super ().__init__ (tie_word_embeddings = tie_word_embeddings , ** kwargs )
193
-
194
- if "audio_encoder" not in kwargs :
192
+ audio_encoder_config = kwargs .pop ("audio_encoder" , None )
193
+ if audio_encoder_config is None :
195
194
raise ValueError ("Config has to be initialized with audio_encoder config" )
196
-
197
- audio_encoder_config = kwargs .pop ("audio_encoder" )
195
+
198
196
audio_encoder_model_type = audio_encoder_config .pop ("model_type" )
199
197
200
198
self .audio_encoder = AutoConfig .for_model (audio_encoder_model_type , ** audio_encoder_config )
199
+
201
200
if self .num_codebooks > self .audio_encoder .num_codebooks :
202
201
raise ValueError (f"`num_codebooks={ num_codebooks } ` is greater than the maximum number of codebooks that the audio encoder can deal with ({ self .audio_encoder .num_codebooks } ). Please lower it." )
203
202
204
203
self .audio_vocab_size = self .audio_encoder .codebook_size if audio_vocab_size is None else audio_vocab_size
204
+
205
+ super ().__init__ (tie_word_embeddings = tie_word_embeddings , ** kwargs )
205
206
206
-
207
+ @property
208
+ def sampling_rate (self ):
209
+ return self .audio_encoder .sampling_rate
207
210
208
211
@classmethod
209
212
def from_audio_encoder_config (
@@ -213,17 +216,12 @@ def from_audio_encoder_config(
213
216
):
214
217
r"""
215
218
Instantiate a [`MoshiConfig`] (or a derived class) from an audio encoder configuration.
216
-
219
+
217
220
Returns:
218
221
[`MoshiConfig`]: An instance of a configuration object
219
222
"""
220
-
223
+
221
224
return cls (
222
225
audio_encoder = audio_encoder_config .to_dict (),
223
226
** kwargs ,
224
227
)
225
-
226
- @property
227
- # This is a property because you might want to change the codec model on the fly
228
- def sampling_rate (self ):
229
- return self .audio_encoder .sampling_rate
0 commit comments