Skip to content

Add Java VertexAI bidi compile tests #6903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
import com.google.common.util.concurrent.ListenableFuture;
import com.google.firebase.ai.FirebaseAI;
import com.google.firebase.ai.GenerativeModel;
import com.google.firebase.ai.LiveGenerativeModel;
import com.google.firebase.ai.java.ChatFutures;
import com.google.firebase.ai.java.GenerativeModelFutures;
import com.google.firebase.ai.java.LiveModelFutures;
import com.google.firebase.ai.java.LiveSessionFutures;
import com.google.firebase.ai.type.BlockReason;
import com.google.firebase.ai.type.Candidate;
import com.google.firebase.ai.type.Citation;
Expand All @@ -32,25 +35,36 @@
import com.google.firebase.ai.type.FileDataPart;
import com.google.firebase.ai.type.FinishReason;
import com.google.firebase.ai.type.FunctionCallPart;
import com.google.firebase.ai.type.FunctionResponsePart;
import com.google.firebase.ai.type.GenerateContentResponse;
import com.google.firebase.ai.type.GenerationConfig;
import com.google.firebase.ai.type.HarmCategory;
import com.google.firebase.ai.type.HarmProbability;
import com.google.firebase.ai.type.HarmSeverity;
import com.google.firebase.ai.type.ImagePart;
import com.google.firebase.ai.type.InlineDataPart;
import com.google.firebase.ai.type.LiveContentResponse;
import com.google.firebase.ai.type.LiveGenerationConfig;
import com.google.firebase.ai.type.MediaData;
import com.google.firebase.ai.type.ModalityTokenCount;
import com.google.firebase.ai.type.Part;
import com.google.firebase.ai.type.PromptFeedback;
import com.google.firebase.ai.type.PublicPreviewAPI;
import com.google.firebase.ai.type.ResponseModality;
import com.google.firebase.ai.type.SafetyRating;
import com.google.firebase.ai.type.SpeechConfig;
import com.google.firebase.ai.type.TextPart;
import com.google.firebase.ai.type.UsageMetadata;
import com.google.firebase.ai.type.Voices;
import com.google.firebase.concurrent.FirebaseExecutors;
import java.util.Calendar;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import kotlin.OptIn;
import kotlinx.serialization.json.JsonElement;
import kotlinx.serialization.json.JsonNull;
import kotlinx.serialization.json.JsonObject;
import org.junit.Assert;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
Expand All @@ -59,13 +73,36 @@
/**
* Tests in this file exist to be compiled, not invoked
*/
@OptIn(markerClass = PublicPreviewAPI.class)
public class JavaCompileTests {

public void initializeJava() throws Exception {
FirebaseAI vertex = FirebaseAI.getInstance();
GenerativeModel model = vertex.generativeModel("fake-model-name");
GenerativeModel model = vertex.generativeModel("fake-model-name", getConfig());
LiveGenerativeModel live = vertex.liveModel("fake-model-name", getLiveConfig());
GenerativeModelFutures futures = GenerativeModelFutures.from(model);
LiveModelFutures liveFutures = LiveModelFutures.from(live);
testFutures(futures);
testLiveFutures(liveFutures);
}

private GenerationConfig getConfig() {
return new GenerationConfig.Builder().build();
// TODO b/406558430 GenerationConfig.Builder.setParts returns void
}

private LiveGenerationConfig getLiveConfig() {
return new LiveGenerationConfig.Builder()
.setTopK(10)
.setTopP(11.0F)
.setTemperature(32.0F)
.setCandidateCount(1)
.setMaxOutputTokens(0xCAFEBABE)
.setFrequencyPenalty(1.0F)
.setPresencePenalty(2.0F)
.setResponseModality(ResponseModality.AUDIO)
.setSpeechConfig(new SpeechConfig(Voices.AOEDE))
.build();
}

private void testFutures(GenerativeModelFutures futures) throws Exception {
Expand Down Expand Up @@ -236,4 +273,62 @@ public void validateUsageMetadata(UsageMetadata metadata) {
}
}
}

private void testLiveFutures(LiveModelFutures futures) throws Exception {
LiveSessionFutures session = futures.connect().get();
session
.receive()
.subscribe(
new Subscriber<LiveContentResponse>() {
@Override
public void onSubscribe(Subscription s) {
s.request(Long.MAX_VALUE);
}

@Override
public void onNext(LiveContentResponse response) {
validateLiveContentResponse(response);
}

@Override
public void onError(Throwable t) {
// Ignore
}

@Override
public void onComplete() {
// Also ignore
}
});

session.send("Fake message");
session.send(new Content.Builder().addText("Fake message").build());

byte[] bytes = new byte[] {(byte) 0xCA, (byte) 0xFE, (byte) 0xBA, (byte) 0xBE};
session.sendMediaStream(List.of(new MediaData(bytes, "image/jxl")));

FunctionResponsePart functionResponse =
new FunctionResponsePart("myFunction", new JsonObject(Map.of()));
session.sendFunctionResponse(List.of(functionResponse, functionResponse));

session.startAudioConversation(part -> functionResponse);
session.startAudioConversation();
session.stopAudioConversation();
session.stopReceiving();
session.close();
}

private void validateLiveContentResponse(LiveContentResponse response) {
// int status = response.getStatus();
// Assert.assertEquals(status, LiveContentResponse.Status.Companion.getNORMAL());
// Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getINTERRUPTED());
// Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getTURN_COMPLETE());
// TODO b/412743328 LiveContentResponse.Status inaccessible for Java users
Content data = response.getData();
if (data != null) {
validateContent(data);
}
String text = response.getText();
validateFunctionCalls(response.getFunctionCalls());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
import com.google.firebase.concurrent.FirebaseExecutors;
import com.google.firebase.vertexai.FirebaseVertexAI;
import com.google.firebase.vertexai.GenerativeModel;
import com.google.firebase.vertexai.LiveGenerativeModel;
import com.google.firebase.vertexai.java.ChatFutures;
import com.google.firebase.vertexai.java.GenerativeModelFutures;
import com.google.firebase.vertexai.java.LiveModelFutures;
import com.google.firebase.vertexai.java.LiveSessionFutures;
import com.google.firebase.vertexai.type.BlockReason;
import com.google.firebase.vertexai.type.Candidate;
import com.google.firebase.vertexai.type.Citation;
Expand All @@ -33,24 +36,33 @@
import com.google.firebase.vertexai.type.FileDataPart;
import com.google.firebase.vertexai.type.FinishReason;
import com.google.firebase.vertexai.type.FunctionCallPart;
import com.google.firebase.vertexai.type.FunctionResponsePart;
import com.google.firebase.vertexai.type.GenerateContentResponse;
import com.google.firebase.vertexai.type.GenerationConfig;
import com.google.firebase.vertexai.type.HarmCategory;
import com.google.firebase.vertexai.type.HarmProbability;
import com.google.firebase.vertexai.type.HarmSeverity;
import com.google.firebase.vertexai.type.ImagePart;
import com.google.firebase.vertexai.type.InlineDataPart;
import com.google.firebase.vertexai.type.LiveContentResponse;
import com.google.firebase.vertexai.type.LiveGenerationConfig;
import com.google.firebase.vertexai.type.MediaData;
import com.google.firebase.vertexai.type.ModalityTokenCount;
import com.google.firebase.vertexai.type.Part;
import com.google.firebase.vertexai.type.PromptFeedback;
import com.google.firebase.vertexai.type.ResponseModality;
import com.google.firebase.vertexai.type.SafetyRating;
import com.google.firebase.vertexai.type.SpeechConfig;
import com.google.firebase.vertexai.type.TextPart;
import com.google.firebase.vertexai.type.UsageMetadata;
import com.google.firebase.vertexai.type.Voices;
import java.util.Calendar;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import kotlinx.serialization.json.JsonElement;
import kotlinx.serialization.json.JsonNull;
import kotlinx.serialization.json.JsonObject;
import org.junit.Assert;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
Expand All @@ -63,9 +75,31 @@ public class JavaCompileTests {

public void initializeJava() throws Exception {
FirebaseVertexAI vertex = FirebaseVertexAI.getInstance();
GenerativeModel model = vertex.generativeModel("fake-model-name");
GenerativeModel model = vertex.generativeModel("fake-model-name", getConfig());
LiveGenerativeModel live = vertex.liveModel("fake-model-name", getLiveConfig());
GenerativeModelFutures futures = GenerativeModelFutures.from(model);
LiveModelFutures liveFutures = LiveModelFutures.from(live);
testFutures(futures);
testLiveFutures(liveFutures);
}

private GenerationConfig getConfig() {
return new GenerationConfig.Builder().build();
// TODO b/406558430 GenerationConfig.Builder.setParts returns void
}

private LiveGenerationConfig getLiveConfig() {
return new LiveGenerationConfig.Builder()
.setTopK(10)
.setTopP(11.0F)
.setTemperature(32.0F)
.setCandidateCount(1)
.setMaxOutputTokens(0xCAFEBABE)
.setFrequencyPenalty(1.0F)
.setPresencePenalty(2.0F)
.setResponseModality(ResponseModality.AUDIO)
.setSpeechConfig(new SpeechConfig(Voices.AOEDE))
.build();
}

private void testFutures(GenerativeModelFutures futures) throws Exception {
Expand Down Expand Up @@ -236,4 +270,62 @@ public void validateUsageMetadata(UsageMetadata metadata) {
}
}
}

private void testLiveFutures(LiveModelFutures futures) throws Exception {
LiveSessionFutures session = futures.connect().get();
session
.receive()
.subscribe(
new Subscriber<LiveContentResponse>() {
@Override
public void onSubscribe(Subscription s) {
s.request(Long.MAX_VALUE);
}

@Override
public void onNext(LiveContentResponse response) {
validateLiveContentResponse(response);
}

@Override
public void onError(Throwable t) {
// Ignore
}

@Override
public void onComplete() {
// Also ignore
}
});

session.send("Fake message");
session.send(new Content.Builder().addText("Fake message").build());

byte[] bytes = new byte[] {(byte) 0xCA, (byte) 0xFE, (byte) 0xBA, (byte) 0xBE};
session.sendMediaStream(List.of(new MediaData(bytes, "image/jxl")));

FunctionResponsePart functionResponse =
new FunctionResponsePart("myFunction", new JsonObject(Map.of()));
session.sendFunctionResponse(List.of(functionResponse, functionResponse));

session.startAudioConversation(part -> functionResponse);
session.startAudioConversation();
session.stopAudioConversation();
session.stopReceiving();
session.close();
}

private void validateLiveContentResponse(LiveContentResponse response) {
// int status = response.getStatus();
// Assert.assertEquals(status, LiveContentResponse.Status.Companion.getNORMAL());
// Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getINTERRUPTED());
// Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getTURN_COMPLETE());
// TODO b/412743328 LiveContentResponse.Status inaccessible for Java users
Content data = response.getData();
if (data != null) {
validateContent(data);
}
String text = response.getText();
validateFunctionCalls(response.getFunctionCalls());
}
}
Loading