Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2.1 java #363

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions binding/java/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ When done resources have to be released explicitly:
cheetah.delete();
```

### Language Model

The Cheetah Java SDK comes preloaded with a default English language model (`.pv` file).
Default models for other supported languages can be found in [lib/common](../../lib/common).

Create custom language models using the [Picovoice Console](https://console.picovoice.ai/). Here you can train
language models with custom vocabulary and boost words in the existing vocabulary.

## Demo App

For example usage refer to our [Java demos](../../demo/java).
3 changes: 2 additions & 1 deletion binding/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ plugins {

ext {
PUBLISH_GROUP_ID = 'ai.picovoice'
PUBLISH_VERSION = '2.0.2'
PUBLISH_VERSION = '2.1.0'
PUBLISH_ARTIFACT_ID = 'cheetah-java'
}

Expand Down Expand Up @@ -84,6 +84,7 @@ if (file("${rootDir}/publish-mavencentral.gradle").exists()) {
}

dependencies {
testImplementation 'com.google.code.gson:gson:2.10.1'
testImplementation 'org.junit.jupiter:junit-jupiter:5.4.2'
testImplementation 'org.junit.jupiter:junit-jupiter-params:5.8.2'
}
Expand Down
164 changes: 150 additions & 14 deletions binding/java/test/ai/picovoice/cheetah/CheetahTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2022-2023 Picovoice Inc.
Copyright 2022-2024 Picovoice Inc.

You may not use this file except in compliance with the license. A copy of the license is
located in the "LICENSE" file accompanying this source.
Expand All @@ -12,14 +12,22 @@

package ai.picovoice.cheetah;

import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import org.junit.jupiter.api.Test;

import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.stream.Stream;


Expand All @@ -32,6 +40,103 @@
public class CheetahTest {
private final String accessKey = System.getProperty("pvTestingAccessKey");

private static String appendLanguage(String s, String language) {
if (language.equals("en")) {
return s;
}
return s + "_" + language;
}

private static int levenshteinDistance(String[] transcript, String[] reference) {
int m = transcript.length;
int n = reference.length;
int[][] dp = new int[m + 1][n + 1];

for (int i = 0; i <= m; i++) {
dp[i][0] = i;
}

for (int j = 0; j <= n; j++) {
dp[0][j] = j;
}

for (int i = 1; i <= m; i++) {
for (int j = 1; j <= n; j++) {
if (transcript[i - 1].equalsIgnoreCase(reference[j - 1])) {
dp[i][j] = dp[i - 1][j - 1];
} else {
dp[i][j] = 1 + Math.min(dp[i - 1][j - 1], Math.min(dp[i - 1][j], dp[i][j - 1]));
}
}
}
return dp[m][n];
}

private static float getErrorRate(String transcript, String reference) {
String[] transcriptWords = transcript.split("\\s+");
String[] referenceWords = reference.split("\\s+");
int distance = levenshteinDistance(transcriptWords, referenceWords);

return (float) distance / (float) referenceWords.length;
}

private static ProcessTestData[] loadProcessTestData() throws IOException {
final Path testDataPath = Paths.get(System.getProperty("user.dir"))
.resolve("../../resources/.test")
.resolve("test_data.json");
final String testDataContent = new String(Files.readAllBytes(testDataPath), StandardCharsets.UTF_8);
final JsonObject testDataJson = JsonParser.parseString(testDataContent).getAsJsonObject();

final JsonArray testParameters = testDataJson
.getAsJsonObject("tests")
.getAsJsonArray("language_tests");

final ProcessTestData[] processTestData = new ProcessTestData[testParameters.size()];
for (int i = 0; i < testParameters.size(); i++) {
final JsonObject testData = testParameters.get(i).getAsJsonObject();
final String language = testData.get("language").getAsString();
final String testAudioFile = testData.get("audio_file").getAsString();
final String transcript = testData.get("transcript").getAsString();
final float errorRate = testData.get("error_rate").getAsFloat();

final JsonArray punctuationsJson = testData.getAsJsonArray("punctuations");
final String[] punctuations = new String[punctuationsJson.size()];
for (int j = 0; j < punctuationsJson.size(); j++) {
punctuations[j] = punctuationsJson.get(j).getAsString();
}
processTestData[i] = new ProcessTestData(
language,
testAudioFile,
transcript,
punctuations,
errorRate);
}
return processTestData;
}

private static Stream<Arguments> processTestProvider() throws IOException {
final ProcessTestData[] processTestData = loadProcessTestData();
final ArrayList<Arguments> testArgs = new ArrayList<>();
for (ProcessTestData processTestDataItem : processTestData) {
testArgs.add(Arguments.of(
processTestDataItem.language,
processTestDataItem.audioFile,
processTestDataItem.transcript,
processTestDataItem.punctuations,
false,
processTestDataItem.errorRate));
testArgs.add(Arguments.of(
processTestDataItem.language,
processTestDataItem.audioFile,
processTestDataItem.transcript,
processTestDataItem.punctuations,
true,
processTestDataItem.errorRate));
}

return testArgs.stream();
}

@Test
void getVersion() throws CheetahException {
Cheetah cheetah = new Cheetah.Builder()
Expand Down Expand Up @@ -84,22 +189,33 @@ void getErrorStack() {
}
}

@ParameterizedTest(name = "test transcribe with automatic punctuation set to ''{0}''")
@MethodSource("transcribeProvider")
void transcribe(boolean enableAutomaticPunctuation, String referenceTranscript) throws Exception {
@ParameterizedTest(name = "test process data for ''{0}'' with punctuation ''{4}''")
@MethodSource("processTestProvider")
void process(
String language,
String testAudioFile,
String referenceTranscript,
String[] punctuations,
boolean enableAutomaticPunctuation,
float targetErrorRate) throws Exception {
String modelPath = Paths.get(System.getProperty("user.dir"))
.resolve(String.format("../../lib/common/%s.pv", appendLanguage("cheetah_params", language)))
.toString();

Cheetah cheetah = new Cheetah.Builder()
.setAccessKey(accessKey)
.setModelPath(modelPath)
.setEnableAutomaticPunctuation(enableAutomaticPunctuation)
.build();

int frameLen = cheetah.getFrameLength();
String audioFilePath = Paths.get(System.getProperty("user.dir"))
.resolve("../../resources/audio_samples/test.wav")
.resolve(String.format("../../resources/audio_samples/%s", testAudioFile))
.toString();
File testAudioPath = new File(audioFilePath);

AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(testAudioPath);
assertEquals(audioInputStream.getFormat().getFrameRate(), 16000);
assertEquals(16000, audioInputStream.getFormat().getFrameRate());

int byteDepth = audioInputStream.getFormat().getFrameSize();
byte[] pcm = new byte[frameLen * byteDepth];
Expand All @@ -116,17 +232,37 @@ void transcribe(boolean enableAutomaticPunctuation, String referenceTranscript)
}
CheetahTranscript finalTranscriptObj = cheetah.flush();
transcript.append(finalTranscriptObj.getTranscript());
assertEquals(referenceTranscript, transcript.toString());

cheetah.delete();

String normalizedTranscript = referenceTranscript;
if (!enableAutomaticPunctuation) {
for (String punctuation : punctuations) {
normalizedTranscript = normalizedTranscript.replace(punctuation, "");
}
}

assertTrue(getErrorRate(transcript.toString(), normalizedTranscript) < targetErrorRate);
}

private static Stream<Arguments> transcribeProvider() {
return Stream.of(
Arguments.of(true,
"Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."),
Arguments.of(false,
"Mr quilter is the apostle of the middle classes and we are glad to welcome his gospel")
);
private static class ProcessTestData {
public final String language;
public final String audioFile;
public final String transcript;
public final String[] punctuations;
public final float errorRate;

public ProcessTestData(
String language,
String audioFile,
String transcript,
String[] punctuations,
float errorRate) {
this.language = language;
this.audioFile = audioFile;
this.transcript = transcript;
this.punctuations = punctuations;
this.errorRate = errorRate;
}
}
}
9 changes: 6 additions & 3 deletions demo/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ plugins {

repositories {
mavenCentral()
maven {
url 'https://s01.oss.sonatype.org/content/repositories/aipicovoice-1351/'
}
}

sourceSets {
Expand All @@ -15,14 +18,14 @@ sourceSets {
}

dependencies {
implementation 'ai.picovoice:cheetah-java:2.0.2'
implementation 'ai.picovoice:cheetah-java:2.1.0'
implementation 'commons-cli:commons-cli:1.4'
}

jar {
manifest {
attributes "Main-Class": "ai.picovoice.cheetahdemo.MicDemo",
"Class-Path": "cheetah-2.0.2.jar;commons-cli-1.4.jar"
"Class-Path": "cheetah-2.1.0.jar;commons-cli-1.4.jar"
}
from sourceSets.main.output
exclude "**/FileDemo.class"
Expand All @@ -33,7 +36,7 @@ jar {
task fileDemoJar(type: Jar) {
manifest {
attributes "Main-Class": "ai.picovoice.cheetahdemo.FileDemo",
"Class-Path": "cheetah-2.0.2.jar;commons-cli-1.4.jar"
"Class-Path": "cheetah-2.1.0.jar;commons-cli-1.4.jar"
}
from sourceSets.main.output
exclude "**/MicDemo.class"
Expand Down
Loading