Skip to content

Commit

Permalink
Merge pull request #5 from NielsRogge/add_hf
Browse files Browse the repository at this point in the history
Improve HF integration
  • Loading branch information
JusperLee authored Sep 21, 2024
2 parents 851a8c2 + e9d2942 commit e7ec6cd
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion look2hear/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch
import torch.nn as nn

from huggingface_hub import PyTorchModelHubMixin


def _unsqueeze_to_3d(x):
"""Normalize shape of `x` to [batch, n_chan, time]."""
Expand All @@ -32,7 +34,7 @@ def pad_to_appropriate_length(x, lcm):
return x


class BaseModel(nn.Module):
class BaseModel(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/JusperLee/Apollo", pipeline_tag="audio-to-audio"):
def __init__(self, sample_rate, in_chan=1):
super().__init__()
self._sample_rate = sample_rate
Expand Down

0 comments on commit e7ec6cd

Please sign in to comment.