diff --git a/backend/example.env b/backend/example.env index f747a94e8..cfc235910 100644 --- a/backend/example.env +++ b/backend/example.env @@ -45,3 +45,7 @@ YOUTUBE_TRANSCRIPT_PROXY="https://user:pass@domain:port" EFFECTIVE_SEARCH_RATIO=5 GRAPH_CLEANUP_MODEL="openai_gpt_4o" CHUNKS_TO_BE_PROCESSED="50" +BEDROCK_EMBEDDING_MODEL="model_name,aws_access_key,aws_secret_key,region_name" #model_name="amazon.titan-embed-text-v1" +LLM_MODEL_CONFIG_bedrock_nova_micro_v1="model_name,aws_access_key,aws_secret_key,region_name" #model_name="amazon.nova-micro-v1:0" +LLM_MODEL_CONFIG_bedrock_nova_lite_v1="model_name,aws_access_key,aws_secret_key,region_name" #model_name="amazon.nova-lite-v1:0" +LLM_MODEL_CONFIG_bedrock_nova_pro_v1="model_name,aws_access_key,aws_secret_key,region_name" #model_name="amazon.nova-pro-v1:0" \ No newline at end of file diff --git a/backend/src/llm.py b/backend/src/llm.py index 0a7f74b08..5b667c07b 100644 --- a/backend/src/llm.py +++ b/backend/src/llm.py @@ -89,7 +89,7 @@ def get_llm(model: str): ) llm = ChatBedrock( - client=bedrock_client, model_id=model_name, model_kwargs=dict(temperature=0) + client=bedrock_client,region_name=region_name, model_id=model_name, model_kwargs=dict(temperature=0) ) elif "ollama" in model: diff --git a/backend/src/shared/common_fn.py b/backend/src/shared/common_fn.py index 986687e25..6b70475c4 100644 --- a/backend/src/shared/common_fn.py +++ b/backend/src/shared/common_fn.py @@ -11,7 +11,8 @@ import os from pathlib import Path from urllib.parse import urlparse - +import boto3 +from langchain_community.embeddings import BedrockEmbeddings def check_url_source(source_type, yt_url:str=None, wiki_query:str=None): language='' @@ -77,6 +78,10 @@ def load_embedding_model(embedding_model_name: str): ) dimension = 768 logging.info(f"Embedding: Using Vertex AI Embeddings , Dimension:{dimension}") + elif embedding_model_name == "titan": + embeddings = get_bedrock_embeddings() + dimension = 1536 + logging.info(f"Embedding: Using bedrock titan Embeddings , Dimension:{dimension}") else: embeddings = HuggingFaceEmbeddings( model_name="all-MiniLM-L6-v2"#, cache_folder="/embedding_model" @@ -134,4 +139,38 @@ def last_url_segment(url): parsed_url = urlparse(url) path = parsed_url.path.strip("/") # Remove leading and trailing slashes last_url_segment = path.split("/")[-1] if path else parsed_url.netloc.split(".")[0] - return last_url_segment \ No newline at end of file + return last_url_segment + +def get_bedrock_embeddings(): + """ + Creates and returns a BedrockEmbeddings object using the specified model name. + Args: + model (str): The name of the model to use for embeddings. + Returns: + BedrockEmbeddings: An instance of the BedrockEmbeddings class. + """ + try: + env_value = os.getenv("BEDROCK_EMBEDDING_MODEL") + if not env_value: + raise ValueError("Environment variable 'BEDROCK_EMBEDDING_MODEL' is not set.") + try: + model_name, aws_access_key, aws_secret_key, region_name = env_value.split(",") + except ValueError: + raise ValueError( + "Environment variable 'BEDROCK_EMBEDDING_MODEL' is improperly formatted. " + "Expected format: 'model_name,aws_access_key,aws_secret_key,region_name'." + ) + bedrock_client = boto3.client( + service_name="bedrock-runtime", + region_name=region_name.strip(), + aws_access_key_id=aws_access_key.strip(), + aws_secret_access_key=aws_secret_key.strip(), + ) + bedrock_embeddings = BedrockEmbeddings( + model_id=model_name.strip(), + client=bedrock_client + ) + return bedrock_embeddings + except Exception as e: + print(f"An unexpected error occurred: {e}") + raise \ No newline at end of file