Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mmaudio/ext/autoencoder/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def __init__(
super().__init__()

if data_dim == 80:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
self.register_buffer("data_mean", torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
self.register_buffer("data_std", torch.tensor(DATA_STD_80D, dtype=torch.float32))
elif data_dim == 128:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
self.register_buffer("data_mean", torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
self.register_buffer("data_std", torch.tensor(DATA_STD_128D, dtype=torch.float32))

self.data_mean = self.data_mean.view(1, -1, 1)
self.data_std = self.data_std.view(1, -1, 1)
Expand Down
14 changes: 7 additions & 7 deletions mmaudio/model/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def __init__(self, dim, frequency_embedding_size, max_period):
self.max_period = max_period
assert dim % 2 == 0, 'dim must be even.'

with torch.autocast('cuda', enabled=False):
self.freqs = nn.Buffer(
1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
frequency_embedding_size)),
persistent=False)
freq_scale = 10000 / max_period
self.freqs = freq_scale * self.freqs
self.register_buffer('freqs',
1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
frequency_embedding_size)),
persistent=False
)
freq_scale = 10000 / max_period
self.freqs = freq_scale * self.freqs

def timestep_embedding(self, t):
"""
Expand Down
8 changes: 4 additions & 4 deletions mmaudio/model/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def initialize_rotations(self):
self._clip_seq_len,
device=self.device)

self.latent_rot = nn.Buffer(latent_rot, persistent=False)
self.clip_rot = nn.Buffer(clip_rot, persistent=False)
self.register_buffer('latent_rot', latent_rot, persistent=False)
self.register_buffer('clip_rot', clip_rot, persistent=False)

def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
self._latent_seq_len = latent_seq_len
Expand Down Expand Up @@ -238,9 +238,9 @@ def preprocess_conditions(self, clip_f: torch.Tensor, sync_f: torch.Tensor,
sync_f = sync_f.flatten(1, 2) # (B, VN, D)

# extend vf to match x
clip_f = self.clip_input_proj(clip_f) # (B, VN, D)
clip_f = self.clip_input_proj(clip_f.clone()) # (B, VN, D)
sync_f = self.sync_input_proj(sync_f) # (B, VN, D)
text_f = self.text_input_proj(text_f) # (B, VN, D)
text_f = self.text_input_proj(text_f.clone()) # (B, VN, D)

# upsample the sync features to match the audio
sync_f = sync_f.transpose(1, 2) # (B, D, VN)
Expand Down