From 54f50d5cb004963d5376e5c0e9cfaca93b550cdd Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 11 May 2023 12:50:19 +0800 Subject: [PATCH] WIP: Add shallow fusion to C API --- sherpa-onnx/c-api/c-api.cc | 3 +++ sherpa-onnx/c-api/c-api.h | 6 +++++ sherpa-onnx/csharp-api/CPPLINT.cfg | 1 - swift-api-examples/SherpaOnnx.swift | 34 ++++++++++++++++--------- swift-api-examples/decode-file.swift | 37 +++++++++++++++++++--------- 5 files changed, 57 insertions(+), 24 deletions(-) delete mode 100644 sherpa-onnx/csharp-api/CPPLINT.cfg diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 45909bba0..5baf1d9de 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -43,6 +43,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( recognizer_config.model_config.num_threads = config->model_config.num_threads; recognizer_config.model_config.debug = config->model_config.debug; + recognizer_config.lm_config.model = config->lm_config.model; + recognizer_config.lm_config.scale = config->lm_config.scale; + recognizer_config.decoding_method = config->decoding_method; recognizer_config.max_active_paths = config->max_active_paths; diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 9678e3463..489b4fcbb 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -67,9 +67,15 @@ SHERPA_ONNX_API typedef struct SherpaOnnxFeatureConfig { int32_t feature_dim; } SherpaOnnxFeatureConfig; +SHERPA_ONNX_API typedef struct SherpaOnnxOnlineLMConfig { + const char *model; + float scale; +} SherpaOnnxOnlineLMConfig; + SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig { SherpaOnnxFeatureConfig feat_config; SherpaOnnxOnlineTransducerModelConfig model_config; + SherpaOnnxOnlineLMConfig lm_config; /// Possible values are: greedy_search, modified_beam_search const char *decoding_method; diff --git a/sherpa-onnx/csharp-api/CPPLINT.cfg b/sherpa-onnx/csharp-api/CPPLINT.cfg deleted file mode 100644 index 51ff339c1..000000000 --- a/sherpa-onnx/csharp-api/CPPLINT.cfg +++ /dev/null @@ -1 +0,0 @@ -exclude_files=.* diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index 4838b27d0..dee1258e1 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -36,7 +36,7 @@ func sherpaOnnxOnlineTransducerModelConfig( tokens: String, numThreads: Int = 2, debug: Int = 0 -) -> SherpaOnnxOnlineTransducerModelConfig{ +) -> SherpaOnnxOnlineTransducerModelConfig { return SherpaOnnxOnlineTransducerModelConfig( encoder: toCPointer(encoder), decoder: toCPointer(decoder), @@ -56,19 +56,30 @@ func sherpaOnnxFeatureConfig( feature_dim: Int32(featureDim)) } +func sherpaOnnxOnlineLMConfig( + model: String = "", + scale: Float = 0.5 +) -> SherpaOnnxOnlineLMConfig { + return SherpaOnnxOnlineLMConfig( + model: toCPointer(model), + scale: scale) +} + func sherpaOnnxOnlineRecognizerConfig( - featConfig: SherpaOnnxFeatureConfig, - modelConfig: SherpaOnnxOnlineTransducerModelConfig, - enableEndpoint: Bool = false, - rule1MinTrailingSilence: Float = 2.4, - rule2MinTrailingSilence: Float = 1.2, - rule3MinUtteranceLength: Float = 30, - decodingMethod: String = "greedy_search", - maxActivePaths: Int = 4 -) -> SherpaOnnxOnlineRecognizerConfig{ + featConfig: SherpaOnnxFeatureConfig, + modelConfig: SherpaOnnxOnlineTransducerModelConfig, + lmConfig: SherpaOnnxOnlineLMConfig, + enableEndpoint: Bool = false, + rule1MinTrailingSilence: Float = 2.4, + rule2MinTrailingSilence: Float = 1.2, + rule3MinUtteranceLength: Float = 30, + decodingMethod: String = "greedy_search", + maxActivePaths: Int = 4 +) -> SherpaOnnxOnlineRecognizerConfig { return SherpaOnnxOnlineRecognizerConfig( feat_config: featConfig, model_config: modelConfig, + lm_config: lmConfig, decoding_method: toCPointer(decodingMethod), max_active_paths: Int32(maxActivePaths), enable_endpoint: enableEndpoint ? 1 : 0, @@ -152,7 +163,8 @@ class SherpaOnnxRecognizer { /// Get the decoding results so far func getResult() -> SherpaOnnxOnlineRecongitionResult { - let result: UnsafeMutablePointer? = GetOnlineStreamResult(recognizer, stream) + let result: UnsafeMutablePointer? = GetOnlineStreamResult( + recognizer, stream) return SherpaOnnxOnlineRecongitionResult(result: result) } diff --git a/swift-api-examples/decode-file.swift b/swift-api-examples/decode-file.swift index dd5edfc5d..56828b93d 100644 --- a/swift-api-examples/decode-file.swift +++ b/swift-api-examples/decode-file.swift @@ -13,34 +13,47 @@ extension AVAudioPCMBuffer { } func run() { - let encoder = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx" - let decoder = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" - let joiner = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx" + let encoder = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx" + let decoder = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" + let joiner = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx" let tokens = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt" + let lm = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/with-state-epoch-999-avg-1.onnx" let modelConfig = sherpaOnnxOnlineTransducerModelConfig( encoder: encoder, decoder: decoder, joiner: joiner, tokens: tokens, - numThreads: 2) + numThreads: 2, + debug: 0 + ) let featConfig = sherpaOnnxFeatureConfig( sampleRate: 16000, featureDim: 80 ) + + let lmConfig = sherpaOnnxOnlineLMConfig( + model: lm, + scale: 0.5 + ) var config = sherpaOnnxOnlineRecognizerConfig( - featConfig: featConfig, - modelConfig: modelConfig, - enableEndpoint: false, - decodingMethod: "modified_beam_search", - maxActivePaths: 4 + featConfig: featConfig, + modelConfig: modelConfig, + lmConfig: lmConfig, + enableEndpoint: false, + decodingMethod: "modified_beam_search", + /* decodingMethod: "greedy_search", */ + maxActivePaths: 4 ) - let recognizer = SherpaOnnxRecognizer(config: &config) - let filePath = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav" + let filePath = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/5.wav" let fileURL: NSURL = NSURL(fileURLWithPath: filePath) let audioFile = try! AVAudioFile(forReading: fileURL as URL) @@ -60,7 +73,7 @@ func run() { recognizer.acceptWaveform(samples: tailPadding) recognizer.inputFinished() - while (recognizer.isReady()) { + while recognizer.isReady() { recognizer.decode() }