diff --git a/audiodiffusion/__init__.py b/audiodiffusion/__init__.py index fbf3c43..7a972bb 100644 --- a/audiodiffusion/__init__.py +++ b/audiodiffusion/__init__.py @@ -9,7 +9,7 @@ # from diffusers import AudioDiffusionPipeline from .pipeline_audio_diffusion import AudioDiffusionPipeline -VERSION = "1.5.4" +VERSION = "1.5.5" class AudioDiffusion: diff --git a/audiodiffusion/audio_encoder.py b/audiodiffusion/audio_encoder.py index 6561a72..d154662 100644 --- a/audiodiffusion/audio_encoder.py +++ b/audiodiffusion/audio_encoder.py @@ -82,7 +82,7 @@ def forward(self, x): return x @torch.no_grad() - def encode(self, audio_files): + def encode(self, audio_files, pool="average"): self.eval() y = [] for audio_file in audio_files: @@ -97,7 +97,13 @@ def encode(self, audio_files): ) for slice in range(self.mel.get_number_of_slices()) ] - y += [torch.mean(self(torch.Tensor(x)), dim=0)] + y += [self(torch.Tensor(x))] + if pool == "average": + y[-1] = torch.mean(y[-1], dim=0) + elif pool == "max": + y[-1] = torch.max(y[-1], dim=0) + else: + assert pool is None, f"Unknown pooling method {pool}" return torch.stack(y)