-
Notifications
You must be signed in to change notification settings - Fork 0
/
debate_components.py
128 lines (105 loc) · 4.37 KB
/
debate_components.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import textwrap
from dataclasses import dataclass
from gradio_client import Client, handle_file
from ofrak.core import *
from ofrak.core.llm import LlmAnalyzer, LlmAttributes, LlmAnalyzerConfig
from ofrak.model.component_model import CC
@dataclass
class Person(ResourceView):
name: str
audio_sample: str
@dataclass
class AnswerQuestionConfig(ComponentConfig):
question: str
tts_client_url: str
ollama_model: str
ollama_url: str = "http://localhost:11434/api/chat"
class AnswerQuestionModifier(Modifier[AnswerQuestionConfig]):
targets = (Person,)
async def modify(self, resource: Resource, config: AnswerQuestionConfig):
"""
Answer question, add it to firmware object, and generate and run a wave file.
1. Run LlmAnalyzer to answer the provided question.
2. Run BinaryExtendModifier to append the response to the resource.
3. Run TTSAnalyzer to generate a Wavfile.
"""
await resource.save()
person = await resource.view_as(Person)
print(f"[+] {person.name} considering the question...")
raw_data = await resource.get_data()
start_line = len(raw_data.decode().splitlines())
await resource.run(
LlmAnalyzer,
LlmAnalyzerConfig(
api_url=config.ollama_url,
model=config.ollama_model,
prompt=config.question,
system_prompt=f"You are an actor playing {person.name} in a presidential debate. Do not break character.",
),
)
await resource.run(
BinaryExtendModifier,
BinaryExtendConfig(
wrap_text(resource.get_attributes(LlmAttributes).description).encode()
),
)
print(f"[+] {person.name} generating audio response...")
await resource.run(
TTSAnalyzer,
TTSAnalyzerConfig(start_line=start_line, client_url=config.tts_client_url),
)
resource.remove_attributes(LlmAttributes)
resource.remove_component(LlmAnalyzer.get_id())
@dataclass(**ResourceAttributes.DATACLASS_PARAMS)
class WavFile(ResourceAttributes):
wav: bytes
@dataclass
class TTSAnalyzerConfig(ComponentConfig):
start_line: int = 0
num_lines: int = -1 # If -1, read to end of file
client_url: str = "http://0.0.0.0:7860"
class TTSAnalyzer(Analyzer[TTSAnalyzerConfig, WavFile]):
"""
Analyze a binary blob to extract its mimetype and magic description.
"""
targets = (Person,)
outputs = (WavFile,)
async def analyze(
self, resource: Resource, config: TTSAnalyzerConfig = TTSAnalyzerConfig()
) -> WavFile:
data = await resource.get_data()
lines = data.decode().splitlines()
if len(lines) < config.start_line:
raise ValueError
if config.num_lines == -1:
text = "\n".join(lines[config.start_line :])
else:
if len(lines) < config.start_line + config.num_lines:
raise ValueError
text = "\n".join(
lines[config.start_line : config.start_line + config.num_lines]
)
client = Client(config.client_url) # Local url to the gradio instance.
person = await resource.view_as(Person)
print(f"[+] Converting text to speech for {person.name}...")
result = client.predict(
ref_audio_orig=handle_file(person.audio_sample),
ref_text="", # set to manually define the text in the refrence audio
gen_text=text,
model="F5-TTS", # F5-TTS is the only real model you can use.
remove_silence=False, # Remove silence between batches the model generates.
cross_fade_duration=0.15, # Fade in for the time between batches the model generates.
speed=1.2, # speed of the output afaik
api_name="/infer",
)
wav_data = open(mode="rb", file=result[0]).read()
print(f"[+] Conversion complete")
return WavFile(wav_data)
async def _run(self, resource: Resource, config: CC):
# Remove component, so that the analyzer can be rerun again.
if resource.has_component_run(self.get_id()):
print("Removing!")
resource.remove_component(self.get_id())
await super()._run(resource, config)
def wrap_text(text: str):
return "\n\n".join(textwrap.fill(para) for para in text.split("\n\n"))