|
8 | 8 |
|
9 | 9 | now_dir = Path(os.getcwd())
|
10 | 10 |
|
11 |
| -from src.infer_pack.models import ( |
12 |
| - SynthesizerTrnMs256NSFsid, |
13 |
| - SynthesizerTrnMs256NSFsid_nono, |
14 |
| - SynthesizerTrnMs768NSFsid, |
15 |
| - SynthesizerTrnMs768NSFsid_nono, |
16 |
| -) |
| 11 | +from src.infer_pack.models import Synthesizer, Synthesizer_nono |
17 | 12 | from src.my_utils import load_audio
|
18 | 13 | from src.vc_infer_pipeline import VC
|
19 | 14 |
|
@@ -97,10 +92,12 @@ def get_vc(device, is_half, config, model_path):
|
97 | 92 | pitch_guidance = cpt.get("f0", 1)
|
98 | 93 | version = cpt.get("version", "v1")
|
99 | 94 |
|
| 95 | + input_dim = 256 if version == "v1" else 768 |
| 96 | + |
100 | 97 | if version == "v1":
|
101 |
| - net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half) if pitch_guidance == 1 else SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) |
102 |
| - elif version == "v2": |
103 |
| - net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half) if pitch_guidance == 1 else SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) |
| 98 | + net_g = Synthesizer(input_dim, *cpt["config"], is_half=is_half, f0=pitch_guidance == 1) if pitch_guidance == 1 else Synthesizer_nono(input_dim, *cpt["config"]) |
| 99 | + else: |
| 100 | + net_g = Synthesizer(input_dim, *cpt["config"], is_half=is_half, f0=pitch_guidance == 1) if pitch_guidance == 1 else Synthesizer_nono(input_dim, *cpt["config"]) |
104 | 101 |
|
105 | 102 | del net_g.enc_q
|
106 | 103 | logger.info(net_g.load_state_dict(cpt["weight"], strict=False))
|
|
0 commit comments