Skip to content

Commit

Permalink
Implement AI assistant for journaling to facilitate reflection (#23)
Browse files Browse the repository at this point in the history
Resolved: #9
  • Loading branch information
zionhann authored Aug 4, 2024
1 parent c1a826c commit 665b087
Show file tree
Hide file tree
Showing 40 changed files with 1,191 additions and 54 deletions.
468 changes: 468 additions & 0 deletions api-docs.yml

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ dependencies {
implementation("com.fasterxml.jackson.module:jackson-module-kotlin")
implementation("org.jetbrains.kotlin:kotlin-reflect")
implementation("org.springframework.session:spring-session-data-redis")
implementation("com.squareup.retrofit2:retrofit:2.11.0")
implementation("com.squareup.retrofit2:converter-gson:2.11.0")
implementation("com.squareup.okhttp3:logging-interceptor:4.12.0")
developmentOnly("org.springframework.boot:spring-boot-devtools")
runtimeOnly("com.h2database:h2")
runtimeOnly("com.mysql:mysql-connector-j")
Expand Down
36 changes: 36 additions & 0 deletions src/main/kotlin/com/likelionhgu/stepper/chat/ChatController.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.likelionhgu.stepper.chat

import com.likelionhgu.stepper.chat.response.ChatHistoryResponseWrapper
import com.likelionhgu.stepper.chat.response.ChatResponse
import com.likelionhgu.stepper.chat.response.ChatSummaryResponse
import com.likelionhgu.stepper.goal.GoalService
import org.springframework.http.HttpStatus
import org.springframework.http.ResponseEntity
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.PathVariable
import org.springframework.web.bind.annotation.PostMapping
import org.springframework.web.bind.annotation.RestController

@RestController
class ChatController(
private val goalService: GoalService,
private val chatService: ChatService
) {
@PostMapping("/v1/goals/{goalId}/chats")
fun createChat(@PathVariable goalId: Long): ResponseEntity<ChatResponse> {
val goal = goalService.goalInfo(goalId)
val responseBody = chatService.initChat(goal)
return ResponseEntity.status(HttpStatus.CREATED).body(responseBody)
}

@GetMapping("/v1/chats/{chatId}/history")
fun getChatHistory(@PathVariable chatId: String): ChatHistoryResponseWrapper {
return chatService.chatHistoryOf(chatId).let(ChatHistoryResponseWrapper.Companion::of)
}

@PostMapping("/v1/chats/{chatId}/summary")
fun getChatSummary(@PathVariable chatId: String): ResponseEntity<ChatSummaryResponse> {
val responseBody = chatService.generateSummaryOf(chatId)
return ResponseEntity.status(HttpStatus.CREATED).body(responseBody)
}
}
14 changes: 14 additions & 0 deletions src/main/kotlin/com/likelionhgu/stepper/chat/ChatRole.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.likelionhgu.stepper.chat

import com.likelionhgu.stepper.exception.ChatRoleNotSupportedException

enum class ChatRole(
val alias: String
) {
USER("user"), CHATBOT("assistant");

companion object {
fun of(role: String): ChatRole = ChatRole.entries.find { it.alias == role }
?: throw ChatRoleNotSupportedException("Role $role is not supported")
}
}
140 changes: 140 additions & 0 deletions src/main/kotlin/com/likelionhgu/stepper/chat/ChatService.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package com.likelionhgu.stepper.chat

import com.likelionhgu.stepper.chat.response.ChatHistoryResponseWrapper
import com.likelionhgu.stepper.chat.response.ChatHistoryResponseWrapper.ChatHistoryResponse
import com.likelionhgu.stepper.chat.response.ChatResponse
import com.likelionhgu.stepper.chat.response.ChatSummaryResponse
import com.likelionhgu.stepper.exception.FailedAssistantException
import com.likelionhgu.stepper.exception.FailedCompletionException
import com.likelionhgu.stepper.exception.FailedMessageException
import com.likelionhgu.stepper.exception.FailedRunException
import com.likelionhgu.stepper.exception.FailedThreadException
import com.likelionhgu.stepper.goal.Goal
import com.likelionhgu.stepper.openai.OpenAiProperties
import com.likelionhgu.stepper.openai.assistant.AssistantService
import com.likelionhgu.stepper.openai.assistant.message.MessageResponseWrapper
import com.likelionhgu.stepper.openai.assistant.request.AssistantRequest
import com.likelionhgu.stepper.openai.assistant.run.RunRequest
import com.likelionhgu.stepper.openai.assistant.thread.ThreadCreationRequest
import com.likelionhgu.stepper.openai.completion.CompletionService
import com.likelionhgu.stepper.openai.completion.request.CompletionRequest
import com.likelionhgu.stepper.websocket.MessagePayload
import org.slf4j.LoggerFactory
import org.springframework.data.redis.core.StringRedisTemplate
import org.springframework.stereotype.Service
import retrofit2.Call
import retrofit2.Response

@Service
class ChatService(
private val assistantService: AssistantService,
private val completionService: CompletionService,
private val openAiProperties: OpenAiProperties,
private val redisTemplate: StringRedisTemplate
) {

/**
* Initialize a chat with the goal.
*
* @param goal The goal to initialize the chat with.
* @return The thread ID of the chat.
*/
fun initChat(goal: Goal): ChatResponse {
val contents = openAiProperties.assistant.welcomeMessages
val requestBody = ThreadCreationRequest.withDefault(contents, goal.title)

return assistantService.createThread(requestBody).resolve()
?.let(ChatResponse.Companion::of)
?: throw FailedThreadException("Failed to create a thread")
}

/**
* Retrieve chat history of the given chatId.
*
* The chat history consists of messages between the user and the assistant.
*
* @param chatId The chatId to retrieve chat history.
* @return The chat history of the given chatId.
*/
fun chatHistoryOf(chatId: String): MessageResponseWrapper {
return assistantService.listMessagesOf(chatId).resolve()
?: throw FailedThreadException("Failed to get chat history for thread $chatId")
}

fun generateQuestion(chatId: String, message: MessagePayload): ChatHistoryResponse {
addMessageToThread(chatId, message)
runAssistantOn(chatId).also { runId ->
waitUntilRunComplete(chatId, runId)
}
return assistantService.listMessagesOf(chatId).resolve()
?.let(ChatHistoryResponseWrapper.Companion::firstOf)
?: throw FailedMessageException("Failed to retrieve messages of thread $chatId")
}

private fun addMessageToThread(chatId: String, message: MessagePayload) {
logger.info("Adding message to thread $chatId")
assistantService.createMessageOf(chatId, message.toSimpleMessage()).resolve()
?: throw FailedMessageException("Failed to add message to thread $chatId")
}

private fun runAssistantOn(chatId: String): String {
logger.info("Running assistant on thread $chatId")
val run = assistantService.createRunOf(chatId, RunRequest(assistant())).resolve()
?: throw FailedRunException("Failed to run assistant on thread $chatId")
return run.id
}

private fun waitUntilRunComplete(chatId: String, runId: String) {
do {
Thread.sleep(1_000)
val run = assistantService.retrieveRunOf(chatId, runId).resolve()
?: throw FailedRunException("Failed to retrieve run $runId of thread $chatId")
} while (run.status != "completed")
}

private fun assistant(assistantName: String = DEFAULT_ASSISTANT_NAME): String {
val assistantKey = ASSISTANT_REDIS_KEY_PREFIX + assistantName
return redisTemplate.opsForValue().get(assistantKey)
?: fetchAssistantOf(assistantName)
?: createAssistant(assistantName)
}

private fun fetchAssistantOf(assistantName: String): String? {
val res = assistantService.listAssistants().resolve()
?: throw FailedAssistantException("Failed to fetch assistants")

return res.data.find { it.name == assistantName }?.id
}

private fun createAssistant(assistantName: String): String {
with(openAiProperties.assistant) {
val requestBody = AssistantRequest(modelType.id, instructions, assistantName)
return assistantService.createAssistant(requestBody).resolve()
?.let {
redisTemplate.opsForValue().set(ASSISTANT_REDIS_KEY_PREFIX + assistantName, it.id)
it.id
} ?: throw FailedAssistantException("Failed to create assistant")
}
}

fun generateSummaryOf(chatId: String): ChatSummaryResponse {
val chatHistory = chatHistoryOf(chatId).toSimpleMessage()

with(openAiProperties.completion) {
val requestBody = CompletionRequest.of(modelType.id, instructions, chatHistory)
return completionService.createChatCompletion(requestBody).resolve()
?.let(ChatSummaryResponse.Companion::of)
?: throw FailedCompletionException("Failed to create completion for chat $chatId")
}
}

companion object {
private const val ASSISTANT_REDIS_KEY_PREFIX = "openai:assistant:"
private const val DEFAULT_ASSISTANT_NAME = "default"
private val logger = LoggerFactory.getLogger(ChatService::class.java)
}
}

private fun <T> Call<T>.resolve(): T? {
return execute().run(Response<T>::body)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.likelionhgu.stepper.chat.response

import com.likelionhgu.stepper.chat.ChatRole
import com.likelionhgu.stepper.openai.assistant.message.MessageResponseWrapper

data class ChatHistoryResponseWrapper(val messages: List<ChatHistoryResponse>) {
data class ChatHistoryResponse(val messageId: String, val role: ChatRole, val content: String)

companion object {
fun of(response: MessageResponseWrapper): ChatHistoryResponseWrapper {
val messages = response.data.map { message ->
ChatHistoryResponse(
messageId = message.id,
role = ChatRole.of(message.role),
content = message.content.first().text.value
)
}.reversed()
return ChatHistoryResponseWrapper(messages)
}

fun firstOf(response: MessageResponseWrapper): ChatHistoryResponse {
with(response.data.first()) {
return ChatHistoryResponse(
messageId = id,
role = ChatRole.of(role),
content = content.first().text.value
)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.likelionhgu.stepper.chat.response

import com.likelionhgu.stepper.openai.assistant.thread.ThreadResponse

data class ChatResponse(val chatId: String) {

companion object {
fun of(response: ThreadResponse): ChatResponse {
return ChatResponse(response.id)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.likelionhgu.stepper.chat.response

import com.likelionhgu.stepper.openai.completion.response.CompletionResponse

data class ChatSummaryResponse(val content: String) {

companion object {
fun of(response: CompletionResponse): ChatSummaryResponse {
val content = response.choices.first().message.content
return ChatSummaryResponse(content)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package com.likelionhgu.stepper.exception

class ChatRoleNotSupportedException(message: String) : RuntimeException(message) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.likelionhgu.stepper.exception

class FailedAssistantException(message: String) : RuntimeException(message) {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.likelionhgu.stepper.exception

class FailedCompletionException(message: String) : RuntimeException(message) {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.likelionhgu.stepper.exception

class FailedMessageException(message: String) : RuntimeException(message) {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.likelionhgu.stepper.exception

class FailedRunException(message: String) : RuntimeException(message) {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package com.likelionhgu.stepper.exception

class FailedThreadException(message: String) : RuntimeException(message) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package com.likelionhgu.stepper.exception

class ModelNotSupportedException(message: String) : RuntimeException(message) {
}
4 changes: 2 additions & 2 deletions src/main/kotlin/com/likelionhgu/stepper/goal/GoalService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class GoalService(
* @return Goal The `Goal` entity corresponding to the provided ID.
* @throws GoalNotFoundException if no goal is found with the provided ID.
*/
fun goalInfo(goalId: String): Goal {
return goalRepository.findById(goalId.toLong()).getOrNull()
fun goalInfo(goalId: Long): Goal {
return goalRepository.findById(goalId).getOrNull()
?: throw GoalNotFoundException("The goal with the id \"$goalId\" does not exist")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class JournalController(
@PostMapping("/v1/goals/{goalId}/journals")
fun writeJournal(
@AuthenticationPrincipal user: CommonOAuth2Attribute,
@PathVariable goalId: String,
@PathVariable goalId: Long,
@Valid @RequestBody journalRequest: JournalRequest
): ResponseEntity<Unit> {
val member = memberService.memberInfo(user.oauth2UserId)
Expand All @@ -39,7 +39,7 @@ class JournalController(

@GetMapping("/v1/goals/{goalId}/journals")
fun displayJournals(
@PathVariable goalId: String,
@PathVariable goalId: Long,
@RequestParam(required = false) sort: JournalSortType = JournalSortType.NEWEST,
@RequestParam(required = false) q: String?
): ResponseEntity<JournalResponseWrapper> {
Expand Down
17 changes: 17 additions & 0 deletions src/main/kotlin/com/likelionhgu/stepper/openai/ModelType.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.likelionhgu.stepper.openai

import com.likelionhgu.stepper.exception.ModelNotSupportedException

enum class ModelType(val id: String) {
GPT_4O("gpt-4o"),
GPT_4O_MINI("gpt-4o-mini");

companion object {
fun of(model: String): ModelType {
return entries.find { supportedModel ->
val modelId = model.takeIf(String::isNotBlank) ?: GPT_4O_MINI.id
supportedModel.id == modelId
} ?: throw ModelNotSupportedException("Model $model is not supported")
}
}
}
Loading

0 comments on commit 665b087

Please sign in to comment.