diff --git a/firebase-ai/src/testUtil/java/com/google/firebase/ai/JavaCompileTests.java b/firebase-ai/src/testUtil/java/com/google/firebase/ai/JavaCompileTests.java index 5e363ed95b2..b369bcd6d07 100644 --- a/firebase-ai/src/testUtil/java/com/google/firebase/ai/JavaCompileTests.java +++ b/firebase-ai/src/testUtil/java/com/google/firebase/ai/JavaCompileTests.java @@ -20,8 +20,11 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.firebase.ai.FirebaseAI; import com.google.firebase.ai.GenerativeModel; +import com.google.firebase.ai.LiveGenerativeModel; import com.google.firebase.ai.java.ChatFutures; import com.google.firebase.ai.java.GenerativeModelFutures; +import com.google.firebase.ai.java.LiveModelFutures; +import com.google.firebase.ai.java.LiveSessionFutures; import com.google.firebase.ai.type.BlockReason; import com.google.firebase.ai.type.Candidate; import com.google.firebase.ai.type.Citation; @@ -32,25 +35,36 @@ import com.google.firebase.ai.type.FileDataPart; import com.google.firebase.ai.type.FinishReason; import com.google.firebase.ai.type.FunctionCallPart; +import com.google.firebase.ai.type.FunctionResponsePart; import com.google.firebase.ai.type.GenerateContentResponse; +import com.google.firebase.ai.type.GenerationConfig; import com.google.firebase.ai.type.HarmCategory; import com.google.firebase.ai.type.HarmProbability; import com.google.firebase.ai.type.HarmSeverity; import com.google.firebase.ai.type.ImagePart; import com.google.firebase.ai.type.InlineDataPart; +import com.google.firebase.ai.type.LiveContentResponse; +import com.google.firebase.ai.type.LiveGenerationConfig; +import com.google.firebase.ai.type.MediaData; import com.google.firebase.ai.type.ModalityTokenCount; import com.google.firebase.ai.type.Part; import com.google.firebase.ai.type.PromptFeedback; +import com.google.firebase.ai.type.PublicPreviewAPI; +import com.google.firebase.ai.type.ResponseModality; import com.google.firebase.ai.type.SafetyRating; +import com.google.firebase.ai.type.SpeechConfig; import com.google.firebase.ai.type.TextPart; import com.google.firebase.ai.type.UsageMetadata; +import com.google.firebase.ai.type.Voices; import com.google.firebase.concurrent.FirebaseExecutors; import java.util.Calendar; import java.util.List; import java.util.Map; import java.util.concurrent.Executor; +import kotlin.OptIn; import kotlinx.serialization.json.JsonElement; import kotlinx.serialization.json.JsonNull; +import kotlinx.serialization.json.JsonObject; import org.junit.Assert; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; @@ -59,13 +73,36 @@ /** * Tests in this file exist to be compiled, not invoked */ +@OptIn(markerClass = PublicPreviewAPI.class) public class JavaCompileTests { public void initializeJava() throws Exception { FirebaseAI vertex = FirebaseAI.getInstance(); - GenerativeModel model = vertex.generativeModel("fake-model-name"); + GenerativeModel model = vertex.generativeModel("fake-model-name", getConfig()); + LiveGenerativeModel live = vertex.liveModel("fake-model-name", getLiveConfig()); GenerativeModelFutures futures = GenerativeModelFutures.from(model); + LiveModelFutures liveFutures = LiveModelFutures.from(live); testFutures(futures); + testLiveFutures(liveFutures); + } + + private GenerationConfig getConfig() { + return new GenerationConfig.Builder().build(); + // TODO b/406558430 GenerationConfig.Builder.setParts returns void + } + + private LiveGenerationConfig getLiveConfig() { + return new LiveGenerationConfig.Builder() + .setTopK(10) + .setTopP(11.0F) + .setTemperature(32.0F) + .setCandidateCount(1) + .setMaxOutputTokens(0xCAFEBABE) + .setFrequencyPenalty(1.0F) + .setPresencePenalty(2.0F) + .setResponseModality(ResponseModality.AUDIO) + .setSpeechConfig(new SpeechConfig(Voices.AOEDE)) + .build(); } private void testFutures(GenerativeModelFutures futures) throws Exception { @@ -236,4 +273,62 @@ public void validateUsageMetadata(UsageMetadata metadata) { } } } + + private void testLiveFutures(LiveModelFutures futures) throws Exception { + LiveSessionFutures session = futures.connect().get(); + session + .receive() + .subscribe( + new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(LiveContentResponse response) { + validateLiveContentResponse(response); + } + + @Override + public void onError(Throwable t) { + // Ignore + } + + @Override + public void onComplete() { + // Also ignore + } + }); + + session.send("Fake message"); + session.send(new Content.Builder().addText("Fake message").build()); + + byte[] bytes = new byte[] {(byte) 0xCA, (byte) 0xFE, (byte) 0xBA, (byte) 0xBE}; + session.sendMediaStream(List.of(new MediaData(bytes, "image/jxl"))); + + FunctionResponsePart functionResponse = + new FunctionResponsePart("myFunction", new JsonObject(Map.of())); + session.sendFunctionResponse(List.of(functionResponse, functionResponse)); + + session.startAudioConversation(part -> functionResponse); + session.startAudioConversation(); + session.stopAudioConversation(); + session.stopReceiving(); + session.close(); + } + + private void validateLiveContentResponse(LiveContentResponse response) { + // int status = response.getStatus(); + // Assert.assertEquals(status, LiveContentResponse.Status.Companion.getNORMAL()); + // Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getINTERRUPTED()); + // Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getTURN_COMPLETE()); + // TODO b/412743328 LiveContentResponse.Status inaccessible for Java users + Content data = response.getData(); + if (data != null) { + validateContent(data); + } + String text = response.getText(); + validateFunctionCalls(response.getFunctionCalls()); + } } diff --git a/firebase-vertexai/src/testUtil/java/com/google/firebase/vertexai/JavaCompileTests.java b/firebase-vertexai/src/testUtil/java/com/google/firebase/vertexai/JavaCompileTests.java index 066e672ffb8..cf71db18798 100644 --- a/firebase-vertexai/src/testUtil/java/com/google/firebase/vertexai/JavaCompileTests.java +++ b/firebase-vertexai/src/testUtil/java/com/google/firebase/vertexai/JavaCompileTests.java @@ -21,8 +21,11 @@ import com.google.firebase.concurrent.FirebaseExecutors; import com.google.firebase.vertexai.FirebaseVertexAI; import com.google.firebase.vertexai.GenerativeModel; +import com.google.firebase.vertexai.LiveGenerativeModel; import com.google.firebase.vertexai.java.ChatFutures; import com.google.firebase.vertexai.java.GenerativeModelFutures; +import com.google.firebase.vertexai.java.LiveModelFutures; +import com.google.firebase.vertexai.java.LiveSessionFutures; import com.google.firebase.vertexai.type.BlockReason; import com.google.firebase.vertexai.type.Candidate; import com.google.firebase.vertexai.type.Citation; @@ -33,24 +36,33 @@ import com.google.firebase.vertexai.type.FileDataPart; import com.google.firebase.vertexai.type.FinishReason; import com.google.firebase.vertexai.type.FunctionCallPart; +import com.google.firebase.vertexai.type.FunctionResponsePart; import com.google.firebase.vertexai.type.GenerateContentResponse; +import com.google.firebase.vertexai.type.GenerationConfig; import com.google.firebase.vertexai.type.HarmCategory; import com.google.firebase.vertexai.type.HarmProbability; import com.google.firebase.vertexai.type.HarmSeverity; import com.google.firebase.vertexai.type.ImagePart; import com.google.firebase.vertexai.type.InlineDataPart; +import com.google.firebase.vertexai.type.LiveContentResponse; +import com.google.firebase.vertexai.type.LiveGenerationConfig; +import com.google.firebase.vertexai.type.MediaData; import com.google.firebase.vertexai.type.ModalityTokenCount; import com.google.firebase.vertexai.type.Part; import com.google.firebase.vertexai.type.PromptFeedback; +import com.google.firebase.vertexai.type.ResponseModality; import com.google.firebase.vertexai.type.SafetyRating; +import com.google.firebase.vertexai.type.SpeechConfig; import com.google.firebase.vertexai.type.TextPart; import com.google.firebase.vertexai.type.UsageMetadata; +import com.google.firebase.vertexai.type.Voices; import java.util.Calendar; import java.util.List; import java.util.Map; import java.util.concurrent.Executor; import kotlinx.serialization.json.JsonElement; import kotlinx.serialization.json.JsonNull; +import kotlinx.serialization.json.JsonObject; import org.junit.Assert; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; @@ -63,9 +75,31 @@ public class JavaCompileTests { public void initializeJava() throws Exception { FirebaseVertexAI vertex = FirebaseVertexAI.getInstance(); - GenerativeModel model = vertex.generativeModel("fake-model-name"); + GenerativeModel model = vertex.generativeModel("fake-model-name", getConfig()); + LiveGenerativeModel live = vertex.liveModel("fake-model-name", getLiveConfig()); GenerativeModelFutures futures = GenerativeModelFutures.from(model); + LiveModelFutures liveFutures = LiveModelFutures.from(live); testFutures(futures); + testLiveFutures(liveFutures); + } + + private GenerationConfig getConfig() { + return new GenerationConfig.Builder().build(); + // TODO b/406558430 GenerationConfig.Builder.setParts returns void + } + + private LiveGenerationConfig getLiveConfig() { + return new LiveGenerationConfig.Builder() + .setTopK(10) + .setTopP(11.0F) + .setTemperature(32.0F) + .setCandidateCount(1) + .setMaxOutputTokens(0xCAFEBABE) + .setFrequencyPenalty(1.0F) + .setPresencePenalty(2.0F) + .setResponseModality(ResponseModality.AUDIO) + .setSpeechConfig(new SpeechConfig(Voices.AOEDE)) + .build(); } private void testFutures(GenerativeModelFutures futures) throws Exception { @@ -236,4 +270,62 @@ public void validateUsageMetadata(UsageMetadata metadata) { } } } + + private void testLiveFutures(LiveModelFutures futures) throws Exception { + LiveSessionFutures session = futures.connect().get(); + session + .receive() + .subscribe( + new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(LiveContentResponse response) { + validateLiveContentResponse(response); + } + + @Override + public void onError(Throwable t) { + // Ignore + } + + @Override + public void onComplete() { + // Also ignore + } + }); + + session.send("Fake message"); + session.send(new Content.Builder().addText("Fake message").build()); + + byte[] bytes = new byte[] {(byte) 0xCA, (byte) 0xFE, (byte) 0xBA, (byte) 0xBE}; + session.sendMediaStream(List.of(new MediaData(bytes, "image/jxl"))); + + FunctionResponsePart functionResponse = + new FunctionResponsePart("myFunction", new JsonObject(Map.of())); + session.sendFunctionResponse(List.of(functionResponse, functionResponse)); + + session.startAudioConversation(part -> functionResponse); + session.startAudioConversation(); + session.stopAudioConversation(); + session.stopReceiving(); + session.close(); + } + + private void validateLiveContentResponse(LiveContentResponse response) { + // int status = response.getStatus(); + // Assert.assertEquals(status, LiveContentResponse.Status.Companion.getNORMAL()); + // Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getINTERRUPTED()); + // Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getTURN_COMPLETE()); + // TODO b/412743328 LiveContentResponse.Status inaccessible for Java users + Content data = response.getData(); + if (data != null) { + validateContent(data); + } + String text = response.getText(); + validateFunctionCalls(response.getFunctionCalls()); + } }