Skip to content

Commit

Permalink
v2.1 java (#363)
Browse files Browse the repository at this point in the history
  • Loading branch information
ksyeo1010 authored and ErisMik committed Dec 5, 2024
1 parent 2c53a7a commit c489285
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 18 deletions.
8 changes: 8 additions & 0 deletions binding/java/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ When done resources have to be released explicitly:
cheetah.delete();
```

### Language Model

The Cheetah Java SDK comes preloaded with a default English language model (`.pv` file).
Default models for other supported languages can be found in [lib/common](../../lib/common).

Create custom language models using the [Picovoice Console](https://console.picovoice.ai/). Here you can train
language models with custom vocabulary and boost words in the existing vocabulary.

## Demo App

For example usage refer to our [Java demos](../../demo/java).
3 changes: 2 additions & 1 deletion binding/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ plugins {

ext {
PUBLISH_GROUP_ID = 'ai.picovoice'
PUBLISH_VERSION = '2.0.2'
PUBLISH_VERSION = '2.1.0'
PUBLISH_ARTIFACT_ID = 'cheetah-java'
}

Expand Down Expand Up @@ -84,6 +84,7 @@ if (file("${rootDir}/publish-mavencentral.gradle").exists()) {
}

dependencies {
testImplementation 'com.google.code.gson:gson:2.10.1'
testImplementation 'org.junit.jupiter:junit-jupiter:5.4.2'
testImplementation 'org.junit.jupiter:junit-jupiter-params:5.8.2'
}
Expand Down
164 changes: 150 additions & 14 deletions binding/java/test/ai/picovoice/cheetah/CheetahTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2022-2023 Picovoice Inc.
Copyright 2022-2024 Picovoice Inc.
You may not use this file except in compliance with the license. A copy of the license is
located in the "LICENSE" file accompanying this source.
Expand All @@ -12,14 +12,22 @@

package ai.picovoice.cheetah;

import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import org.junit.jupiter.api.Test;

import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.stream.Stream;


Expand All @@ -32,6 +40,103 @@
public class CheetahTest {
private final String accessKey = System.getProperty("pvTestingAccessKey");

private static String appendLanguage(String s, String language) {
if (language.equals("en")) {
return s;
}
return s + "_" + language;
}

private static int levenshteinDistance(String[] transcript, String[] reference) {
int m = transcript.length;
int n = reference.length;
int[][] dp = new int[m + 1][n + 1];

for (int i = 0; i <= m; i++) {
dp[i][0] = i;
}

for (int j = 0; j <= n; j++) {
dp[0][j] = j;
}

for (int i = 1; i <= m; i++) {
for (int j = 1; j <= n; j++) {
if (transcript[i - 1].equalsIgnoreCase(reference[j - 1])) {
dp[i][j] = dp[i - 1][j - 1];
} else {
dp[i][j] = 1 + Math.min(dp[i - 1][j - 1], Math.min(dp[i - 1][j], dp[i][j - 1]));
}
}
}
return dp[m][n];
}

private static float getErrorRate(String transcript, String reference) {
String[] transcriptWords = transcript.split("\\s+");
String[] referenceWords = reference.split("\\s+");
int distance = levenshteinDistance(transcriptWords, referenceWords);

return (float) distance / (float) referenceWords.length;
}

private static ProcessTestData[] loadProcessTestData() throws IOException {
final Path testDataPath = Paths.get(System.getProperty("user.dir"))
.resolve("../../resources/.test")
.resolve("test_data.json");
final String testDataContent = new String(Files.readAllBytes(testDataPath), StandardCharsets.UTF_8);
final JsonObject testDataJson = JsonParser.parseString(testDataContent).getAsJsonObject();

final JsonArray testParameters = testDataJson
.getAsJsonObject("tests")
.getAsJsonArray("language_tests");

final ProcessTestData[] processTestData = new ProcessTestData[testParameters.size()];
for (int i = 0; i < testParameters.size(); i++) {
final JsonObject testData = testParameters.get(i).getAsJsonObject();
final String language = testData.get("language").getAsString();
final String testAudioFile = testData.get("audio_file").getAsString();
final String transcript = testData.get("transcript").getAsString();
final float errorRate = testData.get("error_rate").getAsFloat();

final JsonArray punctuationsJson = testData.getAsJsonArray("punctuations");
final String[] punctuations = new String[punctuationsJson.size()];
for (int j = 0; j < punctuationsJson.size(); j++) {
punctuations[j] = punctuationsJson.get(j).getAsString();
}
processTestData[i] = new ProcessTestData(
language,
testAudioFile,
transcript,
punctuations,
errorRate);
}
return processTestData;
}

private static Stream<Arguments> processTestProvider() throws IOException {
final ProcessTestData[] processTestData = loadProcessTestData();
final ArrayList<Arguments> testArgs = new ArrayList<>();
for (ProcessTestData processTestDataItem : processTestData) {
testArgs.add(Arguments.of(
processTestDataItem.language,
processTestDataItem.audioFile,
processTestDataItem.transcript,
processTestDataItem.punctuations,
false,
processTestDataItem.errorRate));
testArgs.add(Arguments.of(
processTestDataItem.language,
processTestDataItem.audioFile,
processTestDataItem.transcript,
processTestDataItem.punctuations,
true,
processTestDataItem.errorRate));
}

return testArgs.stream();
}

@Test
void getVersion() throws CheetahException {
Cheetah cheetah = new Cheetah.Builder()
Expand Down Expand Up @@ -84,22 +189,33 @@ void getErrorStack() {
}
}

@ParameterizedTest(name = "test transcribe with automatic punctuation set to ''{0}''")
@MethodSource("transcribeProvider")
void transcribe(boolean enableAutomaticPunctuation, String referenceTranscript) throws Exception {
@ParameterizedTest(name = "test process data for ''{0}'' with punctuation ''{4}''")
@MethodSource("processTestProvider")
void process(
String language,
String testAudioFile,
String referenceTranscript,
String[] punctuations,
boolean enableAutomaticPunctuation,
float targetErrorRate) throws Exception {
String modelPath = Paths.get(System.getProperty("user.dir"))
.resolve(String.format("../../lib/common/%s.pv", appendLanguage("cheetah_params", language)))
.toString();

Cheetah cheetah = new Cheetah.Builder()
.setAccessKey(accessKey)
.setModelPath(modelPath)
.setEnableAutomaticPunctuation(enableAutomaticPunctuation)
.build();

int frameLen = cheetah.getFrameLength();
String audioFilePath = Paths.get(System.getProperty("user.dir"))
.resolve("../../resources/audio_samples/test.wav")
.resolve(String.format("../../resources/audio_samples/%s", testAudioFile))
.toString();
File testAudioPath = new File(audioFilePath);

AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(testAudioPath);
assertEquals(audioInputStream.getFormat().getFrameRate(), 16000);
assertEquals(16000, audioInputStream.getFormat().getFrameRate());

int byteDepth = audioInputStream.getFormat().getFrameSize();
byte[] pcm = new byte[frameLen * byteDepth];
Expand All @@ -116,17 +232,37 @@ void transcribe(boolean enableAutomaticPunctuation, String referenceTranscript)
}
CheetahTranscript finalTranscriptObj = cheetah.flush();
transcript.append(finalTranscriptObj.getTranscript());
assertEquals(referenceTranscript, transcript.toString());

cheetah.delete();

String normalizedTranscript = referenceTranscript;
if (!enableAutomaticPunctuation) {
for (String punctuation : punctuations) {
normalizedTranscript = normalizedTranscript.replace(punctuation, "");
}
}

assertTrue(getErrorRate(transcript.toString(), normalizedTranscript) < targetErrorRate);
}

private static Stream<Arguments> transcribeProvider() {
return Stream.of(
Arguments.of(true,
"Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."),
Arguments.of(false,
"Mr quilter is the apostle of the middle classes and we are glad to welcome his gospel")
);
private static class ProcessTestData {
public final String language;
public final String audioFile;
public final String transcript;
public final String[] punctuations;
public final float errorRate;

public ProcessTestData(
String language,
String audioFile,
String transcript,
String[] punctuations,
float errorRate) {
this.language = language;
this.audioFile = audioFile;
this.transcript = transcript;
this.punctuations = punctuations;
this.errorRate = errorRate;
}
}
}
9 changes: 6 additions & 3 deletions demo/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ plugins {

repositories {
mavenCentral()
maven {
url 'https://s01.oss.sonatype.org/content/repositories/aipicovoice-1351/'
}
}

sourceSets {
Expand All @@ -15,14 +18,14 @@ sourceSets {
}

dependencies {
implementation 'ai.picovoice:cheetah-java:2.0.2'
implementation 'ai.picovoice:cheetah-java:2.1.0'
implementation 'commons-cli:commons-cli:1.4'
}

jar {
manifest {
attributes "Main-Class": "ai.picovoice.cheetahdemo.MicDemo",
"Class-Path": "cheetah-2.0.2.jar;commons-cli-1.4.jar"
"Class-Path": "cheetah-2.1.0.jar;commons-cli-1.4.jar"
}
from sourceSets.main.output
exclude "**/FileDemo.class"
Expand All @@ -33,7 +36,7 @@ jar {
task fileDemoJar(type: Jar) {
manifest {
attributes "Main-Class": "ai.picovoice.cheetahdemo.FileDemo",
"Class-Path": "cheetah-2.0.2.jar;commons-cli-1.4.jar"
"Class-Path": "cheetah-2.1.0.jar;commons-cli-1.4.jar"
}
from sourceSets.main.output
exclude "**/MicDemo.class"
Expand Down
File renamed without changes.

0 comments on commit c489285

Please sign in to comment.