Skip to content

Commit dec3472

Browse files
authored
Merge pull request #27 from Stability-AI/os_hf_hub
Add huggingface hub support
2 parents ac7046d + a4d9274 commit dec3472

File tree

6 files changed

+42
-10
lines changed

6 files changed

+42
-10
lines changed

run_gradio.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from stable_audio_tools import get_pretrained_model
12
from stable_audio_tools.interface.gradio import create_ui
23
import json
34

@@ -6,21 +7,16 @@
67
def main(args):
78
torch.manual_seed(42)
89

9-
print(f"Loading model config from {args.model_config}")
10-
# Load config from json file
11-
with open(args.model_config) as f:
12-
model_config = json.load(f)
13-
14-
15-
interface = create_ui(model_config, args.ckpt_path, args.pretransform_ckpt_path)
10+
interface = create_ui(model_config_path = args.model_config, ckpt_path=args.ckpt_path, pretrained_name=args.pretrained_name, pretransform_ckpt_path=args.pretransform_ckpt_path)
1611
interface.queue()
1712
interface.launch(share=True, auth=(args.username, args.password) if args.username is not None else None)
1813

1914
if __name__ == "__main__":
2015
import argparse
2116
parser = argparse.ArgumentParser(description='Run gradio interface')
22-
parser.add_argument('--model-config', type=str, help='Path to model config', required=True)
23-
parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=True)
17+
parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False)
18+
parser.add_argument('--model-config', type=str, help='Path to model config', required=False)
19+
parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False)
2420
parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False)
2521
parser.add_argument('--username', type=str, help='Gradio username', required=False)
2622
parser.add_argument('--password', type=str, help='Gradio password', required=False)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
'ema-pytorch==0.2.3',
1919
'encodec==0.1.1',
2020
'gradio==3.42.0',
21+
'huggingface_hub',
2122
'importlib-resources==5.12.0',
2223
'k-diffusion==0.1.1',
2324
'laion-clap==1.1.4',

stable_audio_tools/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .models.factory import create_model_from_config, create_model_from_config_path
2+
from .models.pretrained import get_pretrained_model
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .factory import create_model_from_config
1+
from .factory import create_model_from_config, create_model_from_config_path

stable_audio_tools/models/factory.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
13
def create_model_from_config(model_config):
24
model_type = model_config.get('model_type', None)
35

@@ -20,7 +22,13 @@ def create_model_from_config(model_config):
2022
return create_musicgen_from_config(model_config)
2123
else:
2224
raise NotImplementedError(f'Unknown model type: {model_type}')
25+
26+
def create_model_from_config_path(model_config_path):
27+
with open(model_config_path) as f:
28+
model_config = json.load(f)
2329

30+
return create_model_from_config(model_config)
31+
2432
def create_pretransform_from_config(pretransform_config, sample_rate):
2533
pretransform_type = pretransform_config.get('type', None)
2634

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import json
2+
3+
from .factory import create_model_from_config
4+
from .utils import load_ckpt_state_dict
5+
6+
from huggingface_hub import hf_hub_download
7+
8+
def get_pretrained_model(name: str):
9+
10+
model_config_path = hf_hub_download(name, filename="model_config.json", repo_type='model')
11+
12+
with open(model_config_path) as f:
13+
model_config = json.load(f)
14+
15+
model = create_model_from_config(model_config)
16+
17+
# Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file
18+
try:
19+
model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model')
20+
except Exception as e:
21+
model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model')
22+
23+
model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
24+
25+
return model, model_config

0 commit comments

Comments
 (0)