Skip to content

Commit

Permalink
Update attentions.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Bebra777228 authored Aug 4, 2024
1 parent 599d457 commit 9ae1709
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/infer_pack/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

now_dir = os.getcwd()

from src.infer_pack import commons
from src.infer_pack.commons import subsequent_mask, convert_pad_shape
from src.infer_pack.modules import LayerNorm

def init_layer_list(num_layers, layer_fn, *args, **kwargs):
Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_s
self.norm_layers_2 = init_layer_list(n_layers, LayerNorm, hidden_channels)

def forward(self, x, x_mask, h, h_mask):
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
self_attn_mask = subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for self_attn_layer, norm0, encdec_attn_layer, norm1, ffn_layer, norm2 in zip(self.self_attn_layers, self.norm_layers_0, self.encdec_attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2):
Expand Down Expand Up @@ -149,23 +149,23 @@ def _get_relative_embeddings(self, relative_embeddings, length):
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(relative_embeddings, commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
padded_relative_embeddings = F.pad(relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
else:
padded_relative_embeddings = relative_embeddings
return padded_relative_embeddings[:, slice_start_position:slice_end_position]

def _relative_position_to_absolute_position(self, x):
batch, heads, length, _ = x.size()
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
return x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]

def _absolute_position_to_relative_position(self, x):
batch, heads, length, _ = x.size()
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
return x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]

def _attention_bias_proximal(self, length):
Expand Down Expand Up @@ -205,11 +205,11 @@ def _causal_padding(self, x):
return x
pad_l = self.kernel_size - 1
pad_r = 0
return F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [pad_l, pad_r]]))
return F.pad(x, convert_pad_shape([[0, 0], [0, 0], [pad_l, pad_r]]))

def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
return F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [pad_l, pad_r]]))
return F.pad(x, convert_pad_shape([[0, 0], [0, 0], [pad_l, pad_r]]))

0 comments on commit 9ae1709

Please sign in to comment.