Skip to content

Commit bbf8aa5

Browse files
committed
bidi streaming impl
1 parent 332b4c1 commit bbf8aa5

File tree

8 files changed

+242
-20
lines changed

8 files changed

+242
-20
lines changed

protos/kaldi_serve.proto

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@ syntax = "proto3";
22
package kaldi_serve;
33

44
service KaldiServe {
5-
// Performs synchronous non-streaming speech recognition;
5+
// Performs synchronous non-streaming speech recognition.
66
rpc Recognize(RecognizeRequest) returns (RecognizeResponse) {}
77

88
// Performs synchronous client-to-server streaming speech recognition:
99
// receive results after all audio has been streamed and processed.
1010
rpc StreamingRecognize(stream RecognizeRequest) returns (RecognizeResponse) {}
11+
12+
// Performs synchronous bidirectional streaming speech recognition:
13+
// receive results as the audio is being streamed and processed.
14+
rpc BidiStreamingRecognize(stream RecognizeRequest) returns (stream RecognizeResponse) {}
1115
}
1216

1317
message RecognizeRequest {

python/kaldi_serve/core.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,18 @@ def recognize(self, config: RecognitionConfig, audio, uuid: str, timeout=None):
1919
request = RecognizeRequest(config=config, audio=audio, uuid=uuid)
2020
return self._client.Recognize(request, timeout=timeout)
2121

22-
def streaming_recognize(self, config: RecognitionConfig, audio_chunks, uuid: str, timeout=None):
23-
request_gen = (RecognizeRequest(config=config, audio=chunk, uuid=uuid) for chunk in audio_chunks)
22+
def streaming_recognize(self, config: RecognitionConfig, audio_chunks_gen, uuid: str, timeout=None):
23+
request_gen = (RecognizeRequest(config=config, audio=chunk, uuid=uuid) for chunk in audio_chunks_gen)
2424
return self._client.StreamingRecognize(request_gen, timeout=timeout)
2525

26-
def streaming_recognize_raw(self, audio_params, uuid: str, timeout=None):
27-
request_gen = (RecognizeRequest(config=config, audio=chunk, uuid=uuid) for config, chunk in audio_params)
28-
return self._client.StreamingRecognize(request_gen, timeout=timeout)
26+
def streaming_recognize_raw(self, audio_params_gen, uuid: str, timeout=None):
27+
request_gen = (RecognizeRequest(config=config, audio=chunk, uuid=uuid) for config, chunk in audio_params_gen)
28+
return self._client.StreamingRecognize(request_gen, timeout=timeout)
29+
30+
def bidi_streaming_recognize(self, config: RecognitionConfig, audio_chunks_gen, uuid: str, timeout=None):
31+
request_gen = (RecognizeRequest(config=config, audio=chunk, uuid=uuid) for chunk in audio_chunks_gen)
32+
return self._client.BidiStreamingRecognize(request_gen, timeout=timeout)
33+
34+
def bidi_streaming_recognize_raw(self, audio_params_gen, uuid: str, timeout=None):
35+
request_gen = (RecognizeRequest(config=config, audio=chunk, uuid=uuid) for config, chunk in audio_params_gen)
36+
return self._client.BidiStreamingRecognize(request_gen, timeout=timeout)

python/kaldi_serve/kaldi_serve_pb2.py

Lines changed: 11 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/kaldi_serve/kaldi_serve_pb2_grpc.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,19 @@ def __init__(self, channel):
2424
request_serializer=kaldi__serve__pb2.RecognizeRequest.SerializeToString,
2525
response_deserializer=kaldi__serve__pb2.RecognizeResponse.FromString,
2626
)
27+
self.BidiStreamingRecognize = channel.stream_stream(
28+
'/kaldi_serve.KaldiServe/BidiStreamingRecognize',
29+
request_serializer=kaldi__serve__pb2.RecognizeRequest.SerializeToString,
30+
response_deserializer=kaldi__serve__pb2.RecognizeResponse.FromString,
31+
)
2732

2833

2934
class KaldiServeServicer(object):
3035
# missing associated documentation comment in .proto file
3136
pass
3237

3338
def Recognize(self, request, context):
34-
"""Performs synchronous non-streaming speech recognition;
39+
"""Performs synchronous non-streaming speech recognition.
3540
"""
3641
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
3742
context.set_details('Method not implemented!')
@@ -45,6 +50,14 @@ def StreamingRecognize(self, request_iterator, context):
4550
context.set_details('Method not implemented!')
4651
raise NotImplementedError('Method not implemented!')
4752

53+
def BidiStreamingRecognize(self, request_iterator, context):
54+
"""Performs synchronous bidirectional streaming speech recognition:
55+
receive results as the audio is being streamed and processed.
56+
"""
57+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
58+
context.set_details('Method not implemented!')
59+
raise NotImplementedError('Method not implemented!')
60+
4861

4962
def add_KaldiServeServicer_to_server(servicer, server):
5063
rpc_method_handlers = {
@@ -58,6 +71,11 @@ def add_KaldiServeServicer_to_server(servicer, server):
5871
request_deserializer=kaldi__serve__pb2.RecognizeRequest.FromString,
5972
response_serializer=kaldi__serve__pb2.RecognizeResponse.SerializeToString,
6073
),
74+
'BidiStreamingRecognize': grpc.stream_stream_rpc_method_handler(
75+
servicer.BidiStreamingRecognize,
76+
request_deserializer=kaldi__serve__pb2.RecognizeRequest.FromString,
77+
response_serializer=kaldi__serve__pb2.RecognizeResponse.SerializeToString,
78+
),
6179
}
6280
generic_handler = grpc.method_handlers_generic_handler(
6381
'kaldi_serve.KaldiServe', rpc_method_handlers)

python/kaldi_serve/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def chunks_from_mic(secs: int, frame_rate: int, channels: int):
3434
p = pyaudio.PyAudio()
3535
sample_format = pyaudio.paInt16
3636

37-
# This is in samples not seconds
38-
chunk_size = 4000
37+
# 8k samples ~ 1sec of audio
38+
chunk_size = 8000
3939

4040
stream = p.open(format=sample_format,
4141
channels=channels,
@@ -45,6 +45,7 @@ def chunks_from_mic(secs: int, frame_rate: int, channels: int):
4545

4646
sample_width = p.get_sample_size(sample_format)
4747

48+
print('recording...')
4849
for _ in range(0, int(frame_rate / chunk_size * secs)):
4950
# The right way probably is to not send headers at all and let the
5051
# server side's chunk handler maintain state, taking data from

python/scripts/example_client.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def parse_response(response):
5656
return output
5757

5858

59-
def transcribe_chunks(client, audio_chunks, model: str, language_code: str, raw: bool=False, word_level: bool=False):
59+
def transcribe_chunks_streaming(client, audio_chunks, model: str, language_code: str, raw: bool=False, word_level: bool=False):
6060
"""
6161
Transcribe the given audio chunks
6262
"""
@@ -66,7 +66,6 @@ def transcribe_chunks(client, audio_chunks, model: str, language_code: str, raw:
6666

6767
try:
6868
if raw:
69-
print('streaming raw')
7069
config = lambda chunk_len: RecognitionConfig(
7170
sample_rate_hertz=SR,
7271
encoding=encoding,
@@ -80,7 +79,6 @@ def transcribe_chunks(client, audio_chunks, model: str, language_code: str, raw:
8079
audio_params = [(config(len(chunk)), RecognitionAudio(content=chunk)) for chunk in audio_chunks]
8180
response = client.streaming_recognize_raw(audio_params, uuid="")
8281
else:
83-
print('streaming with headers')
8482
audio = (RecognitionAudio(content=chunk) for chunk in audio_chunks)
8583
config = RecognitionConfig(
8684
sample_rate_hertz=SR,
@@ -97,6 +95,38 @@ def transcribe_chunks(client, audio_chunks, model: str, language_code: str, raw:
9795

9896
pprint(parse_response(response))
9997

98+
def transcribe_chunks_bidi_streaming(client, audio_chunks, model: str, language_code: str, word_level: bool=False):
99+
"""
100+
Transcribe the given audio chunks
101+
"""
102+
103+
response = {}
104+
encoding = RecognitionConfig.AudioEncoding.LINEAR16
105+
106+
try:
107+
config = lambda chunk_len: RecognitionConfig(
108+
sample_rate_hertz=SR,
109+
encoding=encoding,
110+
language_code=language_code,
111+
max_alternatives=10,
112+
model=model,
113+
raw=True,
114+
word_level=word_level,
115+
data_bytes=chunk_len
116+
)
117+
118+
def audio_params_gen(audio_chunks_gen):
119+
for chunk in audio_chunks_gen:
120+
yield config(len(chunk)), RecognitionAudio(content=chunk)
121+
122+
response_gen = client.bidi_streaming_recognize_raw(audio_params_gen(audio_chunks), uuid="")
123+
except Exception as e:
124+
traceback.print_exc()
125+
print(f'error: {str(e)}')
126+
127+
for response in response_gen:
128+
pprint(parse_response(response))
129+
100130

101131
def decode_files(client, audio_paths: List[str], model: str, language_code: str, raw: bool=False, pcm: bool=False, word_level: bool=False):
102132
"""
@@ -105,7 +135,7 @@ def decode_files(client, audio_paths: List[str], model: str, language_code: str,
105135
chunked_audios = [chunks_from_file(x, chunk_size=random.randint(1, 3), raw=raw, pcm=pcm) for x in audio_paths]
106136

107137
threads = [
108-
threading.Thread(target=transcribe_chunks, args=(client, chunks, model, language_code, raw, word_level))
138+
threading.Thread(target=transcribe_chunks_streaming, args=(client, chunks, model, language_code, raw, word_level))
109139
for chunks in chunked_audios
110140
]
111141

@@ -126,6 +156,6 @@ def decode_files(client, audio_paths: List[str], model: str, language_code: str,
126156
word_level = args["--word-level"]
127157

128158
if args["mic"]:
129-
transcribe_chunks(client, chunks_from_mic(int(args["--n-secs"]), SR, 1), model, language_code, raw, word_level)
159+
transcribe_chunks_bidi_streaming(client, chunks_from_mic(int(args["--n-secs"]), SR, 1), model, language_code, word_level)
130160
else:
131161
decode_files(client, args["<file>"], model, language_code, raw, pcm, word_level)

src/decoder.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ class Decoder final {
177177
kaldi::SingleUtteranceNnet3Decoder &,
178178
const std::size_t &,
179179
utterance_results_t &,
180-
const bool &) const;
180+
const bool &,
181+
const bool & = false) const;
181182
};
182183

183184
Decoder::Decoder(const kaldi::BaseFloat &beam,
@@ -532,9 +533,13 @@ void Decoder::decode_stream_final(kaldi::OnlineNnet2FeaturePipeline &feature_pip
532533
kaldi::SingleUtteranceNnet3Decoder &decoder,
533534
const std::size_t &n_best,
534535
utterance_results_t &results,
535-
const bool &word_level) const {
536-
feature_pipeline.InputFinished();
537-
decoder.FinalizeDecoding();
536+
const bool &word_level,
537+
const bool &bidi_streaming) const {
538+
539+
if (!bidi_streaming) {
540+
feature_pipeline.InputFinished();
541+
decoder.FinalizeDecoding();
542+
}
538543

539544
if (decoder.NumFramesDecoded() == 0) {
540545
KALDI_WARN << "audio may be empty :: decoded no frames";

0 commit comments

Comments
 (0)