Skip to content

Commit

Permalink
Client Improvement
Browse files Browse the repository at this point in the history
1. [GUI] Redux the function to set TreeWidget so that items can be added to TreeWidget in a more readable way
2. [GUI] Allow users to do GPT-SoVITS model inference locally and checking output through a button
3. [Core] Merge source code from the latest version of GPT-SoVITS
4. [Core] Support doing GPT-SoVITS model inference without webui (gradio)
5. Update documentations & requirements
  • Loading branch information
Spr-Aachen committed May 31, 2024
1 parent 5f29ec7 commit f47296a
Show file tree
Hide file tree
Showing 31 changed files with 2,313 additions and 1,349 deletions.
67 changes: 56 additions & 11 deletions EVT_Core/TTS/GPT_SoVITS/Convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import subprocess
import torch
from pathlib import Path
from scipy.io.wavfile import write

from .config import python_exec, webui_port_infer_tts, is_share
#from .GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav


current_dir = Path(__file__).absolute().parent.as_posix()
Expand Down Expand Up @@ -138,16 +138,61 @@ def Convert(
Model_Path_Load_s2G: str = "GPT_SoVITS/pretrained_models/s2G488k.pth",
Model_Dir_Load_bert: str = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
Model_Dir_Load_ssl: str = "GPT_SoVITS/pretrained_models/chinese-hubert-base",
Set_FP16_Run: bool = False
Ref_Audio: str = "",
Ref_Text_Free: bool = False,
Ref_Text: str = "",
Ref_Language: str = "多语种混合",
Text: str = '请输入语句',
Language: str = "多语种混合",
How_To_Cut: str = "按标点符号切",
Top_K: int = 5,
Top_P: float = 1.,
Temperature: float = 1.,
Set_FP16_Run: bool = False,
Audio_Path_Save: str = ...,
Use_WebUI: bool = False
):
# 1C-推理
change_tts_inference(
if_tts = True,
bert_path = Model_Dir_Load_bert,
cnhubert_base_path = Model_Dir_Load_ssl,
gpu_number = gpus,
is_half = Set_FP16_Run,
gpt_path = Model_Path_Load_s1,
sovits_path = Model_Path_Load_s2G
)
if Use_WebUI:
change_tts_inference(
if_tts = True,
bert_path = Model_Dir_Load_bert,
cnhubert_base_path = Model_Dir_Load_ssl,
gpu_number = gpus,
is_half = Set_FP16_Run,
gpt_path = Model_Path_Load_s1,
sovits_path = Model_Path_Load_s2G
)
else:
os.environ["gpt_path"] = Model_Path_Load_s1
os.environ["sovits_path"] = Model_Path_Load_s2G
os.environ["cnhubert_base_path"] = Model_Dir_Load_ssl
os.environ["bert_path"] = Model_Dir_Load_bert
os.environ["_CUDA_VISIBLE_DEVICES"] = gpus
os.environ["is_half"] = str(Set_FP16_Run)
os.environ["infer_ttswebui"] = str(webui_port_infer_tts)
os.environ["is_share"] = str(is_share)

from .GPT_SoVITS.inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav

change_gpt_weights(gpt_path)

change_sovits_weights(sovits_path)

TTS_Result = get_tts_wav(
ref_wav_path = Ref_Audio,
prompt_text = Ref_Text,
prompt_language = Ref_Language,
text = Text,
text_language = Language,
how_to_cut = How_To_Cut,
top_k = Top_K,
top_p = Top_P,
temperature = Temperature,
ref_free = Ref_Text_Free
)
SR, Audio = list(TTS_Result)[-1]

write(Audio_Path_Save, SR, Audio)

# 2-GPT-SoVITS-变声
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,12 @@ def process(data, res):
try:
wav_name, spk_name, language, text = line.split("|")
# todo.append([name,text,"zh"])
todo.append(
[wav_name, text, language_v1_to_language_v2.get(language, language)]
)
if language in language_v1_to_language_v2.keys():
todo.append(
[wav_name, text, language_v1_to_language_v2.get(language, language)]
)
else:
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
except:
print(line, traceback.format_exc())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ def name2go(wav_name,wav_path):
tensor_wav16 = tensor_wav16.to(device)
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum()!= 0:
nan_fails.append(wav_name)
nan_fails.append((wav_name,wav_path))
print("nan filtered:%s"%wav_name)
return
wavfile.write(
"%s/%s"%(wav32dir,wav_name),
32000,
tmp_audio32.astype("int16"),
)
my_save(ssl,hubert_path )
my_save(ssl,hubert_path)

with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
Expand All @@ -103,8 +103,8 @@ def name2go(wav_name,wav_path):
if(len(nan_fails)>0 and is_half==True):
is_half=False
model=model.float()
for wav_name in nan_fails:
for wav in nan_fails:
try:
name2go(wav_name)
name2go(wav[0],wav[1])
except:
print(wav_name,traceback.format_exc())
74 changes: 47 additions & 27 deletions EVT_GUI/Functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,17 @@ class CustomSignals_Functions(QObject):

def Function_ScrollToWidget(
Trigger: QWidget,
TargetWidget: Optional[QWidget],
ScrollArea: QScrollArea,
TargetWidget: QWidget,
ScrollArea: Optional[QScrollArea] = None,
#Alignment: str = 'Top'
):
'''
'''
if ScrollArea is None:
ScrollArea = Function_FindParentUI(TargetWidget, QScrollArea)

def ScrollToWidget():
if TargetWidget is not None:
TargetRect = TargetWidget.mapToGlobal(QPoint(0, 0))
else:
try:
TargetRect = ScrollArea.widget().layout().itemAt(Trigger.property("Index")).mapToGlobal(QPoint(0, 0))
except:
raise Exception("Please set property 'Index' for Trigger widget!")
TargetRect = TargetWidget.mapToGlobal(QPoint(0, 0))
TargetYPos = TargetRect.y() - ScrollArea.widget().mapToGlobal(QPoint(0, 0)).y()

ScrollArea.verticalScrollBar().setValue(TargetYPos)
Expand All @@ -60,11 +57,49 @@ def TreeWidgetEvent(Item, Column):
Trigger.clicked.connect(ScrollToWidget)


def Function_AddToTreeWidget(
Widget: QWidget,
TreeWidget: TreeWidgetBase,
RootItemText: str,
ChildItemText: Optional[str] = None,
ScrollArea: Optional[QScrollArea] = None
):
'''
'''
RootItemTexts = TreeWidget.rootItemTexts()
if RootItemText in RootItemTexts:
RootItem = TreeWidget.topLevelItem(RootItemTexts.index(RootItemText))
else:
RootItem = QTreeWidgetItem(TreeWidget)
RootItem.setText(0, RootItemText)
RootItemTextFont = QFont()
RootItemTextFont.setPixelSize(15)
RootItem.setFont(0, RootItemTextFont)
RootItem.setExpanded(True) if not RootItem.isExpanded() else None

ChildItemTexts = TreeWidget.childItemTexts(RootItem)
if ChildItemText is None:
ChildItem = None
elif ChildItemText in ChildItemTexts:
ChildItem = RootItem.child(ChildItemTexts.index(ChildItemText))
else:
ChildItem = QTreeWidgetItem(RootItem)
ChildItem.setText(0, ChildItemText)
ChildItemTextFont = QFont()
ChildItemTextFont.setPixelSize(12)
ChildItem.setFont(0, ChildItemTextFont)

Function_ScrollToWidget(
Trigger = ChildItem if ChildItem is not None else RootItem,
TargetWidget = Widget,
ScrollArea = ScrollArea
)


def Function_SetTreeWidget(
TreeWidget: QTreeWidget,
ItemTexts: dict = {'RootItemText': ('ChildItemText', )},
AddVertically: bool = False,
HideHeader: bool = True,
ExpandItems: bool = True
):
'''
Expand All @@ -91,23 +126,8 @@ def Function_SetTreeWidget(
TreeWidget.setColumnCount(1) if AddVertically else None
TreeWidget.addTopLevelItems(RootItems)

TreeWidget.setHeaderHidden(HideHeader)

TreeWidget.expandAll() if ExpandItems else None

'''
def Function_SetTreeView(
TreeView: QTreeView,
HeaderTexts: list = [],
RootItemTexts: list = [()],
ChildItemTexts: list = [(())],
AddVertically: bool = False
):
for Index, HeaderText in enumerate(HeaderTexts):
TreeView.setHeaderLabels(HeaderTexts)
TreeView.header().setOrientation(Qt.Vertical)
'''

def Function_SetChildWidgetsVisibility(
Container: QWidget,
Expand Down Expand Up @@ -279,7 +299,7 @@ def Function_ParamsHandler(
def Function_ParamsSynchronizer(
Trigger: Union[QObject, list],
FromTo: dict = {},
Times: Optional[float] = None,
Times: Union[int, float] = 1,
Connection: str = "Connect"
):
'''
Expand Down Expand Up @@ -316,7 +336,7 @@ def Function_ParamsChecker(
Params = []

for UI in ParamsFrom:
Param = Function_ParamsHandler(UI, "Get")
Param = Function_ParamsHandler(UI, "Get") if isinstance(UI, QWidget) else UI
if isinstance(Param, str):
if Param.strip() == "None" or Param.strip() == "":
if UI in ToIterable(EmptyAllowed):
Expand Down
Loading

0 comments on commit f47296a

Please sign in to comment.