Skip to content

Commit

Permalink
add pooling to encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
teticio committed May 23, 2023
1 parent a923a95 commit 0455127
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion audiodiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# from diffusers import AudioDiffusionPipeline
from .pipeline_audio_diffusion import AudioDiffusionPipeline

VERSION = "1.5.4"
VERSION = "1.5.5"


class AudioDiffusion:
Expand Down
10 changes: 8 additions & 2 deletions audiodiffusion/audio_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def forward(self, x):
return x

@torch.no_grad()
def encode(self, audio_files):
def encode(self, audio_files, pool="average"):
self.eval()
y = []
for audio_file in audio_files:
Expand All @@ -97,7 +97,13 @@ def encode(self, audio_files):
)
for slice in range(self.mel.get_number_of_slices())
]
y += [torch.mean(self(torch.Tensor(x)), dim=0)]
y += [self(torch.Tensor(x))]
if pool == "average":
y[-1] = torch.mean(y[-1], dim=0)
elif pool == "max":
y[-1] = torch.max(y[-1], dim=0)
else:
assert pool is None, f"Unknown pooling method {pool}"
return torch.stack(y)


Expand Down

0 comments on commit 0455127

Please sign in to comment.