Skip to content

Commit

Permalink
Merge pull request #101 from AssemblyAI/niels/close-properly
Browse files Browse the repository at this point in the history
Terminate streaming session properly
  • Loading branch information
Swimburger authored Apr 18, 2024
2 parents 478309b + eecc79a commit b44142a
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 15 deletions.
21 changes: 15 additions & 6 deletions sample-app/src/main/java/sample/App.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@
import com.assemblyai.api.resources.lemur.requests.LemurTaskParams;
import com.assemblyai.api.resources.lemur.types.LemurTaskResponse;
import com.assemblyai.api.resources.realtime.types.AudioEncoding;
import com.assemblyai.api.resources.realtime.types.SessionInformation;
import com.assemblyai.api.resources.transcripts.requests.*;
import com.assemblyai.api.resources.transcripts.types.*;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;

public final class App {

public static void main(String... args) throws IOException, InterruptedException {
public static void main(String... args) throws IOException, InterruptedException, ExecutionException {
AssemblyAI client = AssemblyAI.builder()
.apiKey(System.getenv("ASSEMBLYAI_API_KEY"))
.build();
Expand Down Expand Up @@ -86,18 +89,24 @@ public static void main(String... args) throws IOException, InterruptedException
TranscriptList transcripts = client.transcripts().list();
System.out.println("List transcript. " + transcripts);

RealtimeTranscriber realtimeTranscriber = RealtimeTranscriber.builder()
try (RealtimeTranscriber realtimeTranscriber = RealtimeTranscriber.builder()
.apiKey(System.getenv("ASSEMBLYAI_API_KEY"))
.encoding(AudioEncoding.PCM_S16LE)
.onSessionBegins(System.out::println)
.onPartialTranscript(System.out::println)
.onFinalTranscript(System.out::println)
.onError((err) -> System.out.println(err.getMessage()))
.onClose((code, reason) -> System.out.printf("%s: %s", code, reason))
.build();
realtimeTranscriber.connect();
streamFile("sample-app/src/main/resources/gore-short.wav", realtimeTranscriber);
realtimeTranscriber.close();
.onSessionInformation(System.out::println)
.build()) {
realtimeTranscriber.connect();
streamFile("sample-app/src/main/resources/gore-short.wav", realtimeTranscriber);
Future<SessionInformation> closeFuture = realtimeTranscriber.closeWithSessionTermination();
SessionInformation info = closeFuture.get();
// Force exit is necessary for some reason.
// The program will end after a while, but not immediately as it should.
System.exit(0);
}
}

public static void streamFile(String filePath, RealtimeTranscriber realtimeTranscriber) {
Expand Down
100 changes: 91 additions & 9 deletions src/main/java/com/assemblyai/api/RealtimeTranscriber.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import java.util.List;
import java.util.Optional;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

Expand Down Expand Up @@ -40,7 +42,11 @@ public final class RealtimeTranscriber implements AutoCloseable {
private final Consumer<Throwable> onError;
private final BiConsumer<Integer, String> onClose;
private final RealtimeMessageVisitor realtimeMessageVisitor;
private final Consumer<SessionInformation> onSessionInformation;
private WebSocket webSocket;
private SessionInformation sessionInformation;
private CompletableFuture<SessionInformation> sessionTerminatedFuture;
private boolean isConnected;

private RealtimeTranscriber(
String apiKey,
Expand All @@ -55,6 +61,7 @@ private RealtimeTranscriber(
Consumer<FinalTranscript> onFinalTranscript,
Consumer<RealtimeTranscript> onTranscript,
Consumer<Throwable> onError,
Consumer<SessionInformation> onSessionInformation,
BiConsumer<Integer, String> onClose) {
this.apiKey = apiKey;
this.token = token;
Expand All @@ -68,6 +75,7 @@ private RealtimeTranscriber(
this.onFinalTranscript = onFinalTranscript;
this.onTranscript = onTranscript;
this.onError = onError;
this.onSessionInformation = onSessionInformation;
this.onClose = onClose;
this.realtimeMessageVisitor = new RealtimeMessageVisitor();
}
Expand All @@ -83,6 +91,10 @@ public void connect() {
if (disablePartialTranscripts) {
url += "&disable_partial_transcripts=true";
}

// always set so it can be return from closeWithSessionTermination
url += "&enable_extra_session_information=true";

if (wordBoost.isPresent() && !wordBoost.get().isEmpty()) {
try {
url += "&word_boost=" + ObjectMappers.JSON_MAPPER.writeValueAsString(wordBoost.get());
Expand Down Expand Up @@ -144,15 +156,33 @@ public void configureEndUtteranceSilenceThreshold(int threshold) {
));
}

public Future<SessionInformation> closeWithSessionTermination() {
this.sessionTerminatedFuture = new CompletableFuture<SessionInformation>();
this.webSocket.send("{\"terminate_session\":true}");
sessionTerminatedFuture.whenComplete((sessionInformation1, throwable) -> this.closeSocket());
return this.sessionTerminatedFuture;
}

/**
* Closes the websocket connection.
* Closes the websocket connection immediately, without waiting for session termination.
* Use closeWithSessionTermination() if possible.
*
* @see #closeWithSessionTermination
* Terminate the session, wait for session termination, and then close the connection.
*/
@Override
public void close() {
boolean closed = this.webSocket.close(1000, "Shutting down");
if (!closed) {
this.webSocket.cancel();
if (isConnected) {
this.webSocket.send("{\"terminate_session\":true}");
}
this.closeSocket();
}

private void closeSocket() {
if(webSocket == null) return;
this.webSocket.close(1000, "Shutting down");
this.webSocket.cancel();
this.webSocket = null;
}

public static RealtimeTranscriber.Builder builder() {
Expand All @@ -174,6 +204,7 @@ public static final class Builder {
private Consumer<RealtimeTranscript> onTranscript;
private Consumer<Throwable> onError;
private BiConsumer<Integer, String> onClose;
private Consumer<SessionInformation> onSessionInformation;

/**
* Sets the AssemblyAI API key used to authenticate the RealtimeTranscriber
Expand Down Expand Up @@ -323,6 +354,19 @@ public RealtimeTranscriber.Builder onError(Consumer<Throwable> onError) {
return this;
}

/**
* Sets onSessionInformation
*
* @param onSessionInformation an event handler for the session information event.
* This message is sent at the end of the session, before the SessionTerminated message.
* Defaults to a noop.
* @return this
*/
public RealtimeTranscriber.Builder onSessionInformation(Consumer<SessionInformation> onSessionInformation) {
this.onSessionInformation = onSessionInformation;
return this;
}

/**
* Sets onClose
*
Expand Down Expand Up @@ -351,6 +395,7 @@ public RealtimeTranscriber build() {
onFinalTranscript,
onTranscript,
onError,
onSessionInformation,
onClose);
}
}
Expand All @@ -364,6 +409,7 @@ public Listener(Consumer<Response> onOpen) {

@Override
public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) {
isConnected = true;
if (onOpen != null) {
onOpen.accept(response);
}
Expand All @@ -372,12 +418,29 @@ public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) {
@Override
public void onMessage(@NotNull WebSocket webSocket, @NotNull String text) {
try {
RealtimeMessage realtimeMessage = ObjectMappers.JSON_MAPPER.readValue(text, RealtimeMessage.class);
try {
realtimeMessage.visit(realtimeMessageVisitor);
} catch (IllegalStateException ignored) {
// when a new message is added to the API, this should not throw an exception
RealtimeBaseMessage baseMessage = ObjectMappers.parseOrThrow(text, RealtimeBaseMessage.class);
MessageType messageType = baseMessage.getMessageType();
if (messageType == MessageType.SESSION_BEGINS) {
realtimeMessageVisitor.visit(
ObjectMappers.JSON_MAPPER.readValue(text, SessionBegins.class)
);
} else if (messageType == MessageType.PARTIAL_TRANSCRIPT) {
realtimeMessageVisitor.visit(
ObjectMappers.JSON_MAPPER.readValue(text, PartialTranscript.class)
);
} else if (messageType == MessageType.FINAL_TRANSCRIPT) {
realtimeMessageVisitor.visit(
ObjectMappers.JSON_MAPPER.readValue(text, FinalTranscript.class)
);
} else if (messageType == MessageType.SESSION_INFORMATION) {
realtimeMessageVisitor.visit(
ObjectMappers.JSON_MAPPER.readValue(text, SessionInformation.class)
);
} else if (messageType == MessageType.SESSION_TERMINATED) {
realtimeMessageVisitor.visit((SessionTerminated) null);
}
// Intentionally don't throw an exception for unknown message type.
// New message types shouldn't cause this to break.
} catch (JsonProcessingException e) {
if (onError == null) return;
onError.accept(e);
Expand All @@ -386,6 +449,7 @@ public void onMessage(@NotNull WebSocket webSocket, @NotNull String text) {

@Override
public void onFailure(@NotNull WebSocket webSocket, @NotNull Throwable t, @Nullable Response response) {
isConnected = false;
if (onError == null) return;
onError.accept(t);
}
Expand All @@ -399,6 +463,12 @@ public void onClosing(@NotNull WebSocket webSocket, int code, String reason) {
onClose.accept(code, reason);
super.onClosing(webSocket, code, reason);
}

@Override
public void onClosed(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
isConnected = false;
super.onClosed(webSocket, code, reason);
}
}

private final class RealtimeMessageVisitor implements RealtimeMessage.Visitor<Void> {
Expand All @@ -423,8 +493,20 @@ public Void visit(FinalTranscript value) {
return null;
}

@Override
public Void visit(SessionInformation value) {
sessionInformation = value;
if (onSessionInformation == null) return null;
onSessionInformation.accept(value);
return null;
}


@Override
public Void visit(SessionTerminated value) {
if (sessionTerminatedFuture != null) {
sessionTerminatedFuture.complete(sessionInformation);
}
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -41,6 +42,8 @@ public <T> T visit(Visitor<T> visitor) {
return visitor.visit((SessionTerminated) this.value);
} else if (this.type == 4) {
return visitor.visit((RealtimeError) this.value);
} else if (this.type == 5) {
return visitor.visit((SessionInformation) this.value);
}
throw new IllegalStateException("Failed to visit value. This should never happen.");
}
Expand Down Expand Up @@ -85,6 +88,10 @@ public static RealtimeMessage of(RealtimeError value) {
return new RealtimeMessage(value, 4);
}

public static RealtimeMessage of(SessionInformation value) {
return new RealtimeMessage(value, 5);
}

public interface Visitor<T> {
T visit(SessionBegins value);

Expand All @@ -95,6 +102,8 @@ public interface Visitor<T> {
T visit(SessionTerminated value);

T visit(RealtimeError value);

T visit(SessionInformation value);
}

static final class Deserializer extends StdDeserializer<RealtimeMessage> {
Expand Down

0 comments on commit b44142a

Please sign in to comment.