Skip to content

Commit

Permalink
go v2.1 changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ksyeo1010 committed Dec 2, 2024
1 parent 69550b2 commit 91feeb9
Show file tree
Hide file tree
Showing 13 changed files with 106 additions and 34 deletions.
138 changes: 106 additions & 32 deletions binding/go/cheetah_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2022-2023 Picovoice Inc.
// Copyright 2022-2024 Picovoice Inc.
//
// You may not use this file except in compliance with the license. A copy of the license is
// located in the "LICENSE" file accompanying this source.
Expand All @@ -15,6 +15,7 @@ package cheetah

import (
"encoding/binary"
"encoding/json"
"flag"
"io/ioutil"
"log"
Expand All @@ -23,61 +24,74 @@ import (
"reflect"
"strings"
"testing"

"github.com/agnivade/levenshtein"
)

type TestParameters struct {
type LanguageTests struct {
language string
testAudioFile string
transcript string
punctuations []string
errorRate float32
enableAutomaticPunctuation bool
}

var (
testAccessKey string
cheetah Cheetah
processTestParameters []TestParameters
languageTests []LanguageTests
)

func TestMain(m *testing.M) {

flag.StringVar(&testAccessKey, "access_key", "", "AccessKey for testing")
flag.Parse()

processTestParameters = loadTestData()
languageTests = loadTestData()
os.Exit(m.Run())
}

func loadTestData() []TestParameters {
punctuations := []string{"."}
transcript := "Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
func appendLanguage(s string, language string) string {
if language == "en" {
return s
} else {
return s + "_" + language
}
}

testCaseWithPunctuation := TestParameters{
language: "en",
testAudioFile: "test.wav",
transcript: transcript,
enableAutomaticPunctuation: true,
errorRate: 0.025,
func loadTestData() []LanguageTests {
content, err := ioutil.ReadFile("../../resources/.test/test_data.json")
if err != nil {
log.Fatalf("Could not read test data json: %v", err)
}
processTestParameters = append(processTestParameters, testCaseWithPunctuation)

transcriptWithoutPunctuation := transcript
for _, p := range punctuations {
transcriptWithoutPunctuation = strings.ReplaceAll(transcriptWithoutPunctuation, p, "")
var testData struct {
Tests struct {
LanguageTests []struct {
Language string `json:"language"`
AudioFile string `json:"audio_file"`
Transcript string `json:"transcript"`
Punctuations []string `json:"punctuations"`
ErrorRate float32 `json:"error_rate"`
} `json:"language_tests"`
} `json:"tests"`
}
err = json.Unmarshal(content, &testData)
if err != nil {
log.Fatalf("Could not decode test data json: %v", err)
}

for _, x := range testData.Tests.LanguageTests {
languageTestParameters := LanguageTests{
language: x.Language,
testAudioFile: x.AudioFile,
transcript: x.Transcript,
punctuations: x.Punctuations,
errorRate: x.ErrorRate,
}

testCaseWithoutPunctuation := TestParameters{
language: "en",
testAudioFile: "test.wav",
transcript: transcriptWithoutPunctuation,
enableAutomaticPunctuation: false,
errorRate: 0.025,
languageTests = append(languageTests, languageTestParameters)
}
processTestParameters = append(processTestParameters, testCaseWithoutPunctuation)

return processTestParameters
return languageTests
}

func TestVersion(t *testing.T) {
Expand All @@ -102,15 +116,62 @@ func TestVersion(t *testing.T) {
}
}

func min(a, b int) int {
if a < b {
return a
}
return b
}

func levenshteinDistance(transcriptWords, referenceWords []string) int {
m, n := len(transcriptWords), len(referenceWords)
dp := make([][]int, m+1)
for i := range dp {
dp[i] = make([]int, n+1)
}

for i := 0; i <= m; i++ {
dp[i][0] = i
}
for j := 0; j <= n; j++ {
dp[0][j] = j
}

for i := 1; i <= m; i++ {
for j := 1; j <= n; j++ {
cost := 0
if !strings.EqualFold(transcriptWords[i-1], referenceWords[j-1]) {
cost = 1
}
dp[i][j] = min(dp[i-1][j]+1,
min(dp[i][j-1]+1,
dp[i-1][j-1]+cost))
}
}
return dp[m][n]
}

func getWordErrorRate(transcript, reference string) float32 {
transcriptWords := strings.Fields(transcript)
referenceWords := strings.Fields(reference)

dist := levenshteinDistance(transcriptWords, referenceWords)
return float32(dist) / float32(len(referenceWords))
}

func runProcessTestCase(
t *testing.T,
_ string,
language string,
testAudioFile string,
referenceTranscript string,
punctuations []string,
targetErrorRate float32,
enableAutomaticPunctuation bool) {

modelPath, _ := filepath.Abs(filepath.Join("../../lib/common", appendLanguage("cheetah_params", language)+".pv"))

cheetah = NewCheetah(testAccessKey)
cheetah.ModelPath = modelPath
cheetah.EnableAutomaticPunctuation = enableAutomaticPunctuation
err := cheetah.Init()
if err != nil {
Expand Down Expand Up @@ -157,15 +218,28 @@ func runProcessTestCase(
}
transcript += final

errorRate := float32(levenshtein.ComputeDistance(transcript, referenceTranscript)) / float32(len(referenceTranscript))
var normalizedTranscript = referenceTranscript
if !enableAutomaticPunctuation {
for _, punctuation := range punctuations {
normalizedTranscript = strings.ReplaceAll(normalizedTranscript, punctuation, "")
}
}

errorRate := getWordErrorRate(transcript, normalizedTranscript)
if errorRate >= targetErrorRate {
t.Fatalf("Expected '%f' got '%f'", targetErrorRate, errorRate)
}
}

func TestProcess(t *testing.T) {
for _, test := range processTestParameters {
runProcessTestCase(t, test.language, test.testAudioFile, test.transcript, test.errorRate, test.enableAutomaticPunctuation)
for _, test := range languageTests {
runProcessTestCase(t, test.language, test.testAudioFile, test.transcript, test.punctuations, test.errorRate, false)
}
}

func TestProcessWithPunctuation(t *testing.T) {
for _, test := range languageTests {
runProcessTestCase(t, test.language, test.testAudioFile, test.transcript, test.punctuations, test.errorRate, true)
}
}

Expand Down
Binary file modified binding/go/embedded/lib/common/cheetah_params.pv
Binary file not shown.
Binary file modified binding/go/embedded/lib/linux/x86_64/libpv_cheetah.so
Binary file not shown.
Binary file modified binding/go/embedded/lib/mac/arm64/libpv_cheetah.dylib
Binary file not shown.
Binary file modified binding/go/embedded/lib/mac/x86_64/libpv_cheetah.dylib
Binary file not shown.
Binary file not shown.
Binary file modified binding/go/embedded/lib/raspberry-pi/cortex-a53/libpv_cheetah.so
Binary file not shown.
Binary file not shown.
Binary file modified binding/go/embedded/lib/raspberry-pi/cortex-a72/libpv_cheetah.so
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified binding/go/embedded/lib/windows/amd64/libpv_cheetah.dll
Binary file not shown.
2 changes: 0 additions & 2 deletions binding/go/go.mod
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
module github.com/Picovoice/cheetah/binding/go/v2

go 1.16

require github.com/agnivade/levenshtein v1.1.1

0 comments on commit 91feeb9

Please sign in to comment.