diff --git a/rvc/lib/jit/__init__.py b/rvc/lib/jit/__init__.py new file mode 100644 index 0000000..03950d3 --- /dev/null +++ b/rvc/lib/jit/__init__.py @@ -0,0 +1,164 @@ +import pickle +import time +from collections import OrderedDict +from io import BytesIO + +import torch +from tqdm import tqdm + + +def load_inputs(path, device, is_half=False): + parm = torch.load(path, map_location=torch.device("cpu")) + for key in parm.keys(): + parm[key] = parm[key].to(device) + if is_half and parm[key].dtype == torch.float32: + parm[key] = parm[key].half() + elif not is_half and parm[key].dtype == torch.float16: + parm[key] = parm[key].float() + return parm + + +def benchmark( + model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False +): + parm = load_inputs(inputs_path, device, is_half) + total_ts = 0.0 + bar = tqdm(range(epoch)) + for i in bar: + start_time = time.perf_counter() + o = model(**parm) + total_ts += time.perf_counter() - start_time + print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}") + + +def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False): + benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half) + + +def to_jit_model( + model_path, + model_type: str, + mode: str = "trace", + inputs_path: str = None, + device=torch.device("cpu"), + is_half=False, +): + model = None + if model_type.lower() == "synthesizer": + from .get_synthesizer import get_synthesizer + + model, _ = get_synthesizer(model_path, device) + model.forward = model.infer + elif model_type.lower() == "rmvpe": + from .get_rmvpe import get_rmvpe + + model = get_rmvpe(model_path, device) + elif model_type.lower() == "hubert": + from .get_hubert import get_hubert_model + + model = get_hubert_model(model_path, device) + model.forward = model.infer + else: + raise ValueError(f"No model type named {model_type}") + model = model.eval() + model = model.half() if is_half else model.float() + if mode == "trace": + assert not inputs_path + inputs = load_inputs(inputs_path, device, is_half) + model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) + elif mode == "script": + model_jit = torch.jit.script(model) + model_jit.to(device) + model_jit = model_jit.half() if is_half else model_jit.float() + # model = model.half() if is_half else model.float() + return (model, model_jit) + + +def export( + model: torch.nn.Module, + mode: str = "trace", + inputs: dict = None, + device=torch.device("cpu"), + is_half: bool = False, +) -> dict: + model = model.half() if is_half else model.float() + model.eval() + if mode == "trace": + assert inputs is not None + model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs) + elif mode == "script": + model_jit = torch.jit.script(model) + model_jit.to(device) + model_jit = model_jit.half() if is_half else model_jit.float() + buffer = BytesIO() + # model_jit=model_jit.cpu() + torch.jit.save(model_jit, buffer) + del model_jit + cpt = OrderedDict() + cpt["model"] = buffer.getvalue() + cpt["is_half"] = is_half + return cpt + + +def load(path: str): + with open(path, "rb") as f: + return pickle.load(f) + + +def save(ckpt: dict, save_path: str): + with open(save_path, "wb") as f: + pickle.dump(ckpt, f) + + +def rmvpe_jit_export( + model_path: str, + mode: str = "script", + inputs_path: str = None, + save_path: str = None, + device=torch.device("cpu"), + is_half=False, +): + if not save_path: + save_path = model_path.rstrip(".pth") + save_path += ".half.jit" if is_half else ".jit" + if "cuda" in str(device) and ":" not in str(device): + device = torch.device("cuda:0") + from .get_rmvpe import get_rmvpe + + model = get_rmvpe(model_path, device) + inputs = None + if mode == "trace": + inputs = load_inputs(inputs_path, device, is_half) + ckpt = export(model, mode, inputs, device, is_half) + ckpt["device"] = str(device) + save(ckpt, save_path) + return ckpt + + +def synthesizer_jit_export( + model_path: str, + mode: str = "script", + inputs_path: str = None, + save_path: str = None, + device=torch.device("cpu"), + is_half=False, +): + if not save_path: + save_path = model_path.rstrip(".pth") + save_path += ".half.jit" if is_half else ".jit" + if "cuda" in str(device) and ":" not in str(device): + device = torch.device("cuda:0") + from .get_synthesizer import get_synthesizer + + model, cpt = get_synthesizer(model_path, device) + assert isinstance(cpt, dict) + model.forward = model.infer + inputs = None + if mode == "trace": + inputs = load_inputs(inputs_path, device, is_half) + ckpt = export(model, mode, inputs, device, is_half) + cpt.pop("weight") + cpt["model"] = ckpt["model"] + cpt["device"] = device + save(cpt, save_path) + return cpt diff --git a/rvc/lib/jit/get_hubert.py b/rvc/lib/jit/get_hubert.py index aec7132..a073f74 100644 --- a/rvc/lib/jit/get_hubert.py +++ b/rvc/lib/jit/get_hubert.py @@ -1,10 +1,11 @@ import math import random from typing import Optional, Tuple -from fairseq.checkpoint_utils import load_model_ensemble_and_task + import numpy as np import torch import torch.nn.functional as F +from fairseq.checkpoint_utils import load_model_ensemble_and_task # from fairseq.data.data_utils import compute_mask_indices from fairseq.utils import index_put