Skip to content

Commit

Permalink
Download SpringAI project as zip file (#37)
Browse files Browse the repository at this point in the history
* Build SpringAI template and provide downloaded zip
* Documentation
  • Loading branch information
corradodebari authored Nov 7, 2024
1 parent df24730 commit a140c05
Show file tree
Hide file tree
Showing 11 changed files with 1,450 additions and 2 deletions.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,9 @@ sonar-project.properties
__pycache__/
*.py[cod]
*$py.class
env.sh
spring_ai/src/main/resources/data/sandbox_settings.json
spring_ai/target/**
spring_ai/create_user.sql
spring_ai/drop.sql
start.sh
spring_ai/env.sh
245 changes: 244 additions & 1 deletion app/src/modules/st_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

import json
import copy
import io
import zipfile
import tempfile
import shutil
import os

# Streamlit
import streamlit as st
Expand All @@ -18,8 +23,13 @@
import modules.help as custom_help
from modules.chatbot import ChatCmd

logger = logging_config.logging.getLogger("modules.st_common")

from langchain_huggingface import HuggingFaceEndpointEmbeddings
from langchain_openai import OpenAIEmbeddings
from langchain_ollama import OllamaEmbeddings
from langchain_cohere import CohereEmbeddings

logger = logging_config.logging.getLogger("modules.st_common")

def clear_initialized():
"""Reset the initialization of the ChatBot"""
Expand Down Expand Up @@ -90,6 +100,7 @@ def reset_rag():

# Set the RAG Prompt
set_prompt()
state.show_download_spring_ai = False


def initialize_rag():
Expand Down Expand Up @@ -238,6 +249,223 @@ def initialize_chatbot(ll_model):
return cmd


def get_yaml_obaas(session_state_json,provider):

########## OBAAS VERSION ##########
# eureka:
# instance:
# hostname: ${spring.application.name}
# preferIpAddress: true
# client:
# service-url:
# defaultZone: ${eureka.service-url}
# fetch-registry: true
# register-with-eureka: true
# enabled: true
vector_table,_= utilities.get_vs_table(session_state_json["rag_params"]["model"],session_state_json["rag_params"]["chunk_size"],session_state_json["rag_params"]["chunk_overlap"],session_state_json["rag_params"]["distance_metric"])
instr_context=session_state_json["lm_instr_config"][session_state_json["lm_instr_prompt"]]["prompt"]

yaml_base=f"""
server:
servlet:
context-path: /v1
spring:
datasource:
url: ${{spring.datasource.url}}
username: ${{spring.datasource.username}}
password: ${{spring.datasource.password}}
ai:
vectorstore:
oracle:
distance-type: {session_state_json["rag_params"]["distance_metric"]}
remove-existing-vector-store-table: True
initialize-schema: True
index-type: None
"""

yaml_base_aims=f"""
aims:
context_instr: \"{instr_context}\"
vectortable:
name: {vector_table}
rag_params:
search_type: Similarity
top_k: {session_state_json.get("rag_params", {}).get("top_k",4)}
"""
openai_yaml=f"""
openai:
base-url: \"{session_state_json["ll_model_config"][session_state_json["ll_model"]]["url"]}\"
api-key:
chat:
options:
temperature: {session_state_json["ll_model_config"][session_state_json["ll_model"]]["temperature"][0]}
frequencyPenalty: {session_state_json["ll_model_config"][session_state_json["ll_model"]]["frequency_penalty"][0]}
presencePenalty: {session_state_json["ll_model_config"][session_state_json["ll_model"]]["presence_penalty"][0]}
maxTokens: {session_state_json["ll_model_config"][session_state_json["ll_model"]]["max_tokens"][0]}
topP: {session_state_json["ll_model_config"][session_state_json["ll_model"]]["top_p"][0]}
model: {session_state_json["ll_model"]}
embedding:
options:
model: {session_state_json["rag_params"]["model"]}
"""

ollama_yaml=f"""
ollama:
base-url: "http://ollama.ollama.svc.cluster.local:11434"
chat:
options:
top-p: {session_state_json["ll_model_config"][session_state_json["ll_model"]]["top_p"][0]}
presence-penalty: {session_state_json["ll_model_config"][session_state_json["ll_model"]]["presence_penalty"][0]}
frequency-penalty: {session_state_json["ll_model_config"][session_state_json["ll_model"]]["frequency_penalty"][0]}
temperature: {session_state_json["ll_model_config"][session_state_json["ll_model"]]["temperature"][0]}
num-predict: {session_state_json["ll_model_config"][session_state_json["ll_model"]]["max_tokens"][0]}
model: \"{session_state_json["ll_model"]}\"
embedding:
options:
model: \"{session_state_json["rag_params"]["model"]}\"
"""

if provider=="ollama":
yaml = yaml_base + ollama_yaml + yaml_base_aims
else:
yaml = yaml_base + openai_yaml + yaml_base_aims

return yaml

def get_yaml_env(session_state_json,provider):

OLLAMA_MODEL="llama3.1"

instr_context=session_state_json["lm_instr_config"][session_state_json["lm_instr_prompt"]]["prompt"]
vector_table,_= utilities.get_vs_table(session_state_json["rag_params"]["model"],session_state_json["rag_params"]["chunk_size"],session_state_json["rag_params"]["chunk_overlap"],session_state_json["rag_params"]["distance_metric"])

logger.info("ll_model selected: %s",session_state_json["ll_model"])
logger.info(session_state_json["ll_model"] != OLLAMA_MODEL)

if session_state_json["ll_model"] != OLLAMA_MODEL:
env_vars_LLM= f"""
export OPENAI_CHAT_MODEL={session_state_json["ll_model"]}
export OPENAI_EMBEDDING_MODEL={session_state_json["rag_params"]["model"]}
export OPENAI_URL=\"{session_state_json["ll_model_config"][session_state_json["ll_model"]]["url"]}\"
export OP_TEMPERATURE={session_state_json["ll_model_config"][session_state_json["ll_model"]]["temperature"][0]}
export OP_FREQUENCY_PENALTY={session_state_json["ll_model_config"][session_state_json["ll_model"]]["frequency_penalty"][0]}
export OP_PRESENCE_PENALTY={session_state_json["ll_model_config"][session_state_json["ll_model"]]["presence_penalty"][0]}
export OP_MAX_TOKENS={session_state_json["ll_model_config"][session_state_json["ll_model"]]["max_tokens"][0]}
export OP_TOP_P={session_state_json["ll_model_config"][session_state_json["ll_model"]]["top_p"][0]}
export OL_TEMPERATURE=
export OL_FREQUENCY_PENALTY=
export OL_PRESENCE_PENALTY=
export OL_MAX_TOKENS=
export OL_TOP_P=
export OLLAMA_CHAT_MODEL=\"\"
export OLLAMA_EMBEDDING_MODEL=\"\"
"""

else:
env_vars_LLM= f"""
export OPENAI_CHAT_MODEL=\"\"
export OPENAI_EMBEDDING_MODEL=\"\"
export OPENAI_URL=\"\"
export OLLAMA_CHAT_MODEL=\"{session_state_json["ll_model"]}\"
export OLLAMA_EMBEDDING_MODEL=\"{session_state_json["rag_params"]["model"]}\"
export OL_TEMPERATURE={session_state_json["ll_model_config"][session_state_json["ll_model"]]["temperature"][0]}
export OL_FREQUENCY_PENALTY={session_state_json["ll_model_config"][session_state_json["ll_model"]]["frequency_penalty"][0]}
export OL_PRESENCE_PENALTY={session_state_json["ll_model_config"][session_state_json["ll_model"]]["presence_penalty"][0]}
export OL_MAX_TOKENS={session_state_json["ll_model_config"][session_state_json["ll_model"]]["max_tokens"][0]}
export OL_TOP_P={session_state_json["ll_model_config"][session_state_json["ll_model"]]["top_p"][0]}
export OP_TEMPERATURE=
export OP_FREQUENCY_PENALTY=
export OP_PRESENCE_PENALTY=
export OP_MAX_TOKENS=
export OP_TOP_P=
"""

env_vars= f"""
export SPRING_AI_OPENAI_API_KEY=$OPENAI_API_KEY
export DB_DSN=\"jdbc:oracle:thin:@{session_state_json["db_config"]["dsn"]}\"
export DB_USERNAME=\"{session_state_json["db_config"]["user"]}\"
export DB_PASSWORD=\"{session_state_json["db_config"]["password"]}\"
export DISTANCE_TYPE={session_state_json["rag_params"]["distance_metric"]}
export OLLAMA_BASE_URL=\"{session_state_json["ll_model_config"][OLLAMA_MODEL]["url"]}\"
export CONTEXT_INSTR=\"{instr_context}\"
export TOP_K={session_state_json.get("rag_params", {}).get("top_k",4)}
export VECTOR_STORE={vector_table}
export PROVIDER={provider}
mvn spring-boot:run -P {provider}
"""
logger.info(env_vars_LLM+env_vars)

return env_vars_LLM+env_vars


def create_zip(state_dict_filt, provider):
# Source directory that you want to copy
toCopy=["mvnw","mvnw.cmd","pom.xml","README.md"]
source_dir_root = '../../spring_ai'
source_dir = '../../spring_ai/src'

logger.info(f"Local dir : {os.getcwd()}")
# Using TemporaryDirectory
with tempfile.TemporaryDirectory() as temp_dir:
destination_dir= os.path.join(temp_dir, "spring_ai")

shutil.copytree(source_dir, os.path.join(temp_dir, "spring_ai/src"))
for item in toCopy:
shutil.copy(os.path.join(source_dir_root,item), os.path.join(temp_dir, "spring_ai"))
logger.info(f"Temporary directory created and copied: {temp_dir}")

zip_buffer = io.BytesIO()

with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
for foldername, subfolders, filenames in os.walk(destination_dir):
for filename in filenames:
file_path = os.path.join(foldername, filename)

arcname = os.path.relpath(file_path, destination_dir) # Make the path relative
zip_file.write(file_path, arcname)
env_content = get_yaml_env(state_dict_filt, provider)
obaas_yaml = get_yaml_obaas(state_dict_filt,provider)
zip_file.writestr('env.sh', env_content)
zip_file.writestr('src/main/resources/application-obaas.yml', obaas_yaml)
zip_buffer.seek(0)
return zip_buffer

# Check if the conf is full ollama or openai, currently supported for springai export
def check_hybrid_conf(session_state_json):

embedding_models = meta.embedding_models()
chat_models = meta.ll_models()

embModel = embedding_models.get(session_state_json["rag_params"].get("model"))
chatModel = chat_models.get(session_state_json["ll_model"])
logger.info("Model: %s",session_state_json["ll_model"])
logger.info("Embedding Model embModel: %s",embModel)
logger.info("Chat Model: %s",chatModel)
if (chatModel is not None) and (embModel is not None):

logger.info("Embeddings to string: %s", str(embModel["api"]))

if ("OpenAI" in str(embModel["api"])) and (chatModel["api"]== "OpenAI"):
logger.info("RETURN openai")
logger.info("chatModel[api]: %s",chatModel["api"])
logger.info("Embedding Model embModel[api]: %s",embModel["api"])
return "openai"
else:
if ("Ollama" in str(embModel["api"])) and (chatModel["api"]== "ChatOllama"):
logger.info("RETURN ollama")
logger.info("chatModel[api]: %s",chatModel["api"])
logger.info("Embedding Model embModel[api]: %s",embModel["api"])
return "ollama"
else:
return "hybrid"
else:
return "hybrid"



###################################
# Language Model Sidebar
###################################
Expand Down Expand Up @@ -582,10 +810,25 @@ def empty_key(obj):
state_dict_filt = empty_key(state_dict_filt)
session_state_json = json.dumps(state_dict_filt, indent=4)
# Only allow exporting settings if tools/admin is enabled


if not state.disable_tools and not state.disable_admin:
st.sidebar.download_button(
label="Download Settings",
data=session_state_json,
file_name="sandbox_settings.json",
use_container_width=True,
)


state.provider= check_hybrid_conf(state_dict_filt)
logger.info("Provider type: %s", state.provider)

if not state.disable_tools and not state.disable_admin and state.provider != "hybrid":
st.sidebar.download_button(
label="Download SpringAI",
data=create_zip(state_dict_filt, state.provider), # Generate zip on the fly
file_name="spring_ai.zip", # Zip file name
mime="application/zip", # Mime type for zip file
use_container_width=True,
)
Loading

0 comments on commit a140c05

Please sign in to comment.