|
| 1 | +from stable_audio_tools import get_pretrained_model |
1 | 2 | from stable_audio_tools.interface.gradio import create_ui |
2 | 3 | import json |
3 | 4 |
|
|
6 | 7 | def main(args): |
7 | 8 | torch.manual_seed(42) |
8 | 9 |
|
9 | | - print(f"Loading model config from {args.model_config}") |
10 | | - # Load config from json file |
11 | | - with open(args.model_config) as f: |
12 | | - model_config = json.load(f) |
13 | | - |
14 | | - |
15 | | - interface = create_ui(model_config, args.ckpt_path, args.pretransform_ckpt_path) |
| 10 | + interface = create_ui(model_config_path = args.model_config, ckpt_path=args.ckpt_path, pretrained_name=args.pretrained_name, pretransform_ckpt_path=args.pretransform_ckpt_path) |
16 | 11 | interface.queue() |
17 | 12 | interface.launch(share=True, auth=(args.username, args.password) if args.username is not None else None) |
18 | 13 |
|
19 | 14 | if __name__ == "__main__": |
20 | 15 | import argparse |
21 | 16 | parser = argparse.ArgumentParser(description='Run gradio interface') |
22 | | - parser.add_argument('--model-config', type=str, help='Path to model config', required=True) |
23 | | - parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=True) |
| 17 | + parser.add_argument('--pretrained-name', type=str, help='Name of pretrained model', required=False) |
| 18 | + parser.add_argument('--model-config', type=str, help='Path to model config', required=False) |
| 19 | + parser.add_argument('--ckpt-path', type=str, help='Path to model checkpoint', required=False) |
24 | 20 | parser.add_argument('--pretransform-ckpt-path', type=str, help='Optional to model pretransform checkpoint', required=False) |
25 | 21 | parser.add_argument('--username', type=str, help='Gradio username', required=False) |
26 | 22 | parser.add_argument('--password', type=str, help='Gradio password', required=False) |
|
0 commit comments