diff --git a/mmaudio/ext/autoencoder/vae.py b/mmaudio/ext/autoencoder/vae.py index c5c9ba3..b49a12c 100644 --- a/mmaudio/ext/autoencoder/vae.py +++ b/mmaudio/ext/autoencoder/vae.py @@ -75,11 +75,11 @@ def __init__( super().__init__() if data_dim == 80: - self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32)) - self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32)) + self.register_buffer("data_mean", torch.tensor(DATA_MEAN_80D, dtype=torch.float32)) + self.register_buffer("data_std", torch.tensor(DATA_STD_80D, dtype=torch.float32)) elif data_dim == 128: - self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32)) - self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32)) + self.register_buffer("data_mean", torch.tensor(DATA_MEAN_128D, dtype=torch.float32)) + self.register_buffer("data_std", torch.tensor(DATA_STD_128D, dtype=torch.float32)) self.data_mean = self.data_mean.view(1, -1, 1) self.data_std = self.data_std.view(1, -1, 1) diff --git a/mmaudio/model/embeddings.py b/mmaudio/model/embeddings.py index 297feb4..77dc90a 100644 --- a/mmaudio/model/embeddings.py +++ b/mmaudio/model/embeddings.py @@ -20,13 +20,13 @@ def __init__(self, dim, frequency_embedding_size, max_period): self.max_period = max_period assert dim % 2 == 0, 'dim must be even.' - with torch.autocast('cuda', enabled=False): - self.freqs = nn.Buffer( - 1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) / - frequency_embedding_size)), - persistent=False) - freq_scale = 10000 / max_period - self.freqs = freq_scale * self.freqs + self.register_buffer('freqs', + 1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) / + frequency_embedding_size)), + persistent=False + ) + freq_scale = 10000 / max_period + self.freqs = freq_scale * self.freqs def timestep_embedding(self, t): """ diff --git a/mmaudio/model/networks.py b/mmaudio/model/networks.py index f378585..74e9a11 100644 --- a/mmaudio/model/networks.py +++ b/mmaudio/model/networks.py @@ -166,8 +166,8 @@ def initialize_rotations(self): self._clip_seq_len, device=self.device) - self.latent_rot = nn.Buffer(latent_rot, persistent=False) - self.clip_rot = nn.Buffer(clip_rot, persistent=False) + self.register_buffer('latent_rot', latent_rot, persistent=False) + self.register_buffer('clip_rot', clip_rot, persistent=False) def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None: self._latent_seq_len = latent_seq_len @@ -238,9 +238,9 @@ def preprocess_conditions(self, clip_f: torch.Tensor, sync_f: torch.Tensor, sync_f = sync_f.flatten(1, 2) # (B, VN, D) # extend vf to match x - clip_f = self.clip_input_proj(clip_f) # (B, VN, D) + clip_f = self.clip_input_proj(clip_f.clone()) # (B, VN, D) sync_f = self.sync_input_proj(sync_f) # (B, VN, D) - text_f = self.text_input_proj(text_f) # (B, VN, D) + text_f = self.text_input_proj(text_f.clone()) # (B, VN, D) # upsample the sync features to match the audio sync_f = sync_f.transpose(1, 2) # (B, D, VN)