Skip to content

Commit

Permalink
add genai lib, mock gemini class with it
Browse files Browse the repository at this point in the history
  • Loading branch information
pld committed Dec 22, 2024
1 parent 2132410 commit c5f62a5
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 159 deletions.
2 changes: 2 additions & 0 deletions android/gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ fhir-sdk-engine = "1.1.0-preview2-SNAPSHOT"
fhir-sdk-knowledge = "0.1.0-alpha03-preview5-rc2-SNAPSHOT"
fhir-sdk-workflow = "0.1.0-alpha04-preview10-rc1-SNAPSHOT"
fragment-ktx = "1.8.3"
generativeai = "0.9.0"
glide = "4.16.0"
googleCloudSpeech = "2.5.2"
gradle = "8.3.2"
Expand Down Expand Up @@ -129,6 +130,7 @@ fhir-sdk-common = { group = "org.smartregister", name = "common", version.ref =
foundation = { group = "androidx.compose.foundation", name = "foundation", version.ref = "compose-ui" }
fragment-ktx = { group = "androidx.fragment", name = "fragment-ktx", version.ref = "fragment-ktx" }
fragment-testing = { group = "androidx.fragment", name = "fragment-testing", version.ref = "fragment-ktx" }
generativeai = { module = "com.google.ai.client.generativeai:generativeai", version.ref = "generativeai" }
glide = { group = "com.github.bumptech.glide", name = "glide", version.ref = "glide" }
gms-play-services-location = { group = "com.google.android.gms", name = "play-services-location", version.ref = "playServicesLocation" }
google-cloud-speech = { module = "com.google.cloud:google-cloud-speech", version.ref = "googleCloudSpeech" }
Expand Down
3 changes: 3 additions & 0 deletions android/quest/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,10 @@ dependencies {
implementation(libs.androidx.fragment.compose)
implementation(libs.bundles.cameraX)
implementation(libs.log4j)

// AI dependencies
implementation(libs.google.cloud.speech)
implementation(libs.generativeai)

// Annotation processors
kapt(libs.hilt.compiler)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,48 @@
/*
* Copyright 2021-2024 Ona Systems, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.smartregister.fhircore.quest.ui.speechtoform

import okhttp3.*
import okhttp3.MediaType.Companion.toMediaType
import org.json.JSONObject
import java.io.IOException
import com.google.ai.client.generativeai.GenerativeModel
import com.google.ai.client.generativeai.type.BlockThreshold
import com.google.ai.client.generativeai.type.HarmCategory
import com.google.ai.client.generativeai.type.SafetySetting
import com.google.ai.client.generativeai.type.generationConfig

class GeminiClient(private val apiKey: String) {
private val client = OkHttpClient()
private val baseUrl = "https://generativeai.googleapis.com/v1/models/gemini-1.5-flash:generateContent"

fun generateContent(prompt: String): String {
val requestBody = JSONObject().put("prompt", prompt).toString()

val request = Request.Builder()
.url(baseUrl)
.addHeader("Authorization", "Bearer $apiKey")
.post(RequestBody.create("application/json".toMediaType(), requestBody))
.build()

client.newCall(request).execute().use { response ->
if (!response.isSuccessful) throw IOException("Unexpected code $response")

val jsonResponse = JSONObject(response.body?.string() ?: "")
return jsonResponse.getString("text")
}
}
// model usage
// https://developer.android.com/ai/google-ai-client-sdk
private val model =
GenerativeModel(
model = "gemini-1.5-flash-001",
// todo actually add the API key
apiKey = "BuildConfig.apikey",
generationConfig =
generationConfig {
temperature = 0.15f
topK = 32
topP = 1f
maxOutputTokens = 4096
},
safetySettings =
listOf(
SafetySetting(HarmCategory.HARASSMENT, BlockThreshold.MEDIUM_AND_ABOVE),
SafetySetting(HarmCategory.HATE_SPEECH, BlockThreshold.MEDIUM_AND_ABOVE),
SafetySetting(HarmCategory.SEXUALLY_EXPLICIT, BlockThreshold.MEDIUM_AND_ABOVE),
SafetySetting(HarmCategory.DANGEROUS_CONTENT, BlockThreshold.MEDIUM_AND_ABOVE),
),
)
}
Original file line number Diff line number Diff line change
@@ -1,43 +1,63 @@
/*
* Copyright 2021-2024 Ona Systems, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.smartregister.fhircore.quest.ui.speechtoform

import org.hl7.fhir.r4.model.Questionnaire
import org.hl7.fhir.r4.model.QuestionnaireResponse
import java.io.File
import java.util.logging.Logger
import org.hl7.fhir.r4.model.Questionnaire
import org.hl7.fhir.r4.model.QuestionnaireResponse

class SpeechToForm(
private val speechToText: SpeechToText,
private val textToForm: TextToForm
private val speechToText: SpeechToText,
private val textToForm: TextToForm,
) {

private val logger = Logger.getLogger(SpeechToForm::class.java.name)

/**
* Reads an audio file, transcribes it, and generates a FHIR QuestionnaireResponse.
*
* @param audioFile The input audio file to process.
* @param questionnaire The FHIR Questionnaire used to generate the response.
* @return The generated QuestionnaireResponse, or null if the process fails.
*/
fun processAudioToQuestionnaireResponse(audioFile: File, questionnaire: Questionnaire): QuestionnaireResponse? {
logger.info("Starting audio transcription process...")

// Step 1: Transcribe audio to text
val tempTextFile = speechToText.transcribeAudioToText(audioFile)
if (tempTextFile == null) {
logger.severe("Failed to transcribe audio.")
return null
}
logger.info("Transcription successful. File path: ${tempTextFile.absolutePath}")

// Step 2: Generate QuestionnaireResponse from the transcript
val questionnaireResponse = textToForm.generateQuestionnaireResponse(tempTextFile, questionnaire)
if (questionnaireResponse == null) {
logger.severe("Failed to generate QuestionnaireResponse.")
return null
}

logger.info("QuestionnaireResponse generated successfully.")
return questionnaireResponse
private val logger = Logger.getLogger(SpeechToForm::class.java.name)

/**
* Reads an audio file, transcribes it, and generates a FHIR QuestionnaireResponse.
*
* @param audioFile The input audio file to process.
* @param questionnaire The FHIR Questionnaire used to generate the response.
* @return The generated QuestionnaireResponse, or null if the process fails.
*/
fun processAudioToQuestionnaireResponse(
audioFile: File,
questionnaire: Questionnaire,
): QuestionnaireResponse? {
logger.info("Starting audio transcription process...")

// Step 1: Transcribe audio to text
val tempTextFile = speechToText.transcribeAudioToText(audioFile)
if (tempTextFile == null) {
logger.severe("Failed to transcribe audio.")
return null
}
logger.info("Transcription successful. File path: ${tempTextFile.absolutePath}")

// Step 2: Generate QuestionnaireResponse from the transcript
val questionnaireResponse =
textToForm.generateQuestionnaireResponse(tempTextFile, questionnaire)
if (questionnaireResponse == null) {
logger.severe("Failed to generate QuestionnaireResponse.")
return null
}

logger.info("QuestionnaireResponse generated successfully.")
return questionnaireResponse
}
}
Original file line number Diff line number Diff line change
@@ -1,57 +1,75 @@
/*
* Copyright 2021-2024 Ona Systems, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.smartregister.fhircore.quest.ui.speechtoform

import com.google.cloud.speech.v1.RecognitionAudio
import com.google.cloud.speech.v1.RecognitionConfig
import com.google.cloud.speech.v1.RecognitionConfig.AudioEncoding
import com.google.cloud.speech.v1.SpeechClient
import com.google.cloud.speech.v1.SpeechRecognitionResult
import com.google.cloud.speech.v1.RecognitionConfig.AudioEncoding
import java.io.File
import java.util.logging.Logger

class SpeechToText {

private val logger = Logger.getLogger(SpeechToText::class.java.name)

/**
* Transcribes an audio file to text using Google Cloud Speech-to-Text API and writes it to a
* temporary file.
*
* @param audioFile The audio file to be transcribed.
* @return The temporary file containing the transcribed text.
*/
fun transcribeAudioToText(audioFile: File): File? {
var tempFile: File? = null

SpeechClient.create().use { speechClient ->
val audioBytes = audioFile.readBytes()

// Build the recognition audio
val recognitionAudio = RecognitionAudio.newBuilder()
.setContent(com.google.protobuf.ByteString.copyFrom(audioBytes))
.build()

// Configure recognition settings
val config = RecognitionConfig.newBuilder()
.setEncoding(AudioEncoding.LINEAR16)
.setSampleRateHertz(16000)
.setLanguageCode("en-US")
.build()

// Perform transcription
val response = speechClient.recognize(config, recognitionAudio)
val transcription = response.resultsList.joinToString(" ") {
result: SpeechRecognitionResult ->
result.alternativesList[0].transcript
}

logger.info("Transcription: $transcription")

// Write transcription to a temporary file
tempFile = File.createTempFile("transcription", ".txt")
tempFile?.writeText(transcription)

logger.info("Transcription written to temporary file. ")
private val logger = Logger.getLogger(SpeechToText::class.java.name)

/**
* Transcribes an audio file to text using Google Cloud Speech-to-Text API and writes it to a
* temporary file.
*
* @param audioFile The audio file to be transcribed.
* @return The temporary file containing the transcribed text.
*/
fun transcribeAudioToText(audioFile: File): File? {
var tempFile: File? = null

SpeechClient.create().use { speechClient ->
val audioBytes = audioFile.readBytes()

// Build the recognition audio
val recognitionAudio =
RecognitionAudio.newBuilder()
.setContent(com.google.protobuf.ByteString.copyFrom(audioBytes))
.build()

// Configure recognition settings
val config =
RecognitionConfig.newBuilder()
.setEncoding(AudioEncoding.LINEAR16)
.setSampleRateHertz(16000)
.setLanguageCode("en-US")
.build()

// Perform transcription
val response = speechClient.recognize(config, recognitionAudio)
val transcription =
response.resultsList.joinToString(" ") { result: SpeechRecognitionResult ->
result.alternativesList[0].transcript
}
return tempFile

logger.info("Transcription: $transcription")

// Write transcription to a temporary file
tempFile = File.createTempFile("transcription", ".txt")
tempFile?.writeText(transcription)

logger.info("Transcription written to temporary file. ")
}
return tempFile
}
}
Loading

0 comments on commit c5f62a5

Please sign in to comment.