Skip to content

Commit

Permalink
feat: azure llm and embedding models (#592)
Browse files Browse the repository at this point in the history
  • Loading branch information
Icemap authored Jan 15, 2025
1 parent f70022e commit 8cff7e5
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 0 deletions.
17 changes: 17 additions & 0 deletions backend/app/rag/embeddings/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class EmbeddingProvider(str, enum.Enum):
GITEEAI = "giteeai"
LOCAL = "local"
OPENAI_LIKE = "openai_like"
AZURE_OPENAI = "azure_openai"


class EmbeddingProviderOption(BaseModel):
Expand Down Expand Up @@ -123,6 +124,22 @@ class EmbeddingProviderOption(BaseModel):
credentials_type="str",
default_credentials="****",
),
EmbeddingProviderOption(
provider=EmbeddingProvider.AZURE_OPENAI,
provider_display_name="Azure OpenAI",
provider_description="Azure OpenAI is a cloud-based AI service that provides a suite of AI models and tools for developers to build intelligent applications.",
provider_url="https://azure.microsoft.com/en-us/products/ai-services/openai-service",
default_embedding_model="text-embedding-3-small",
embedding_model_description="Before using this option, you need to deploy an Azure OpenAI API and model, see https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource.",
default_config={
"azure_endpoint": "https://<your-resource-name>.openai.azure.com/",
"api_version": "<your-api-version>"
},
credentials_display_name="Azure OpenAI API Key",
credentials_description="The API key of Azure OpenAI",
credentials_type="str",
default_credentials="****",
),
EmbeddingProviderOption(
provider=EmbeddingProvider.LOCAL,
provider_display_name="Local Embedding",
Expand Down
8 changes: 8 additions & 0 deletions backend/app/rag/embeddings/resolver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Optional

from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from sqlmodel import Session

from llama_index.core.base.embeddings.base import BaseEmbedding
Expand Down Expand Up @@ -65,6 +67,12 @@ def get_embed_model(
api_key=credentials,
**config,
)
case EmbeddingProvider.AZURE_OPENAI:
return AzureOpenAIEmbedding(
model=model,
api_key=credentials,
**config,
)
case EmbeddingProvider.OPENAI_LIKE:
return OpenAILikeEmbedding(
model=model,
Expand Down
19 changes: 19 additions & 0 deletions backend/app/rag/llms/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class LLMProvider(str, enum.Enum):
BEDROCK = "bedrock"
OLLAMA = "ollama"
GITEEAI = "giteeai"
AZURE_OPENAI = "azure_openai"


class LLMProviderOption(BaseModel):
Expand Down Expand Up @@ -147,4 +148,22 @@ class LLMProviderOption(BaseModel):
"aws_region_name": "us-west-2",
},
),
LLMProviderOption(
provider=LLMProvider.AZURE_OPENAI,
provider_display_name="Azure OpenAI",
provider_description="Azure OpenAI is a cloud-based AI service that provides access to OpenAI's advanced language models.",
provider_url="https://azure.microsoft.com/en-us/products/ai-services/openai-service",
default_llm_model="gpt-4o",
llm_model_description="",
config_description="Refer to this document https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart to have more information about the Azure OpenAI API.",
default_config={
"azure_endpoint": "https://<your-resource-name>.openai.azure.com/",
"api_version": "<your-api-version>",
"engine": "<your-deployment-name>",
},
credentials_display_name="Azure OpenAI API Key",
credentials_description="The API key of Azure OpenAI",
credentials_type="str",
default_credentials="****",
),
]
7 changes: 7 additions & 0 deletions backend/app/rag/llms/resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Optional
from llama_index.core.llms.llm import LLM
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_like import OpenAILike
from llama_index.llms.gemini import Gemini
Expand Down Expand Up @@ -86,6 +87,12 @@ def get_llm(
api_key=credentials,
**config,
)
case LLMProvider.AZURE_OPENAI:
return AzureOpenAI(
model=model,
api_key=credentials,
**config,
)
case _:
raise ValueError(f"Got unknown LLM provider: {provider}")

Expand Down
11 changes: 11 additions & 0 deletions backend/app/utils/dspy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import hashlib
from typing import Any, Literal

from llama_index.llms.azure_openai import AzureOpenAI

import dspy
import requests
from dsp.modules.lm import LM
Expand Down Expand Up @@ -83,6 +85,15 @@ def get_dspy_lm_by_llama_llm(llama_llm: BaseLLM) -> dspy.LM:
max_tokens=llama_llm.context_window,
num_ctx=llama_llm.context_window,
)
elif type(llama_llm) is AzureOpenAI:
return dspy.AzureOpenAI(
model=llama_llm.model,
max_tokens=llama_llm.max_tokens or 4096,
api_key=llama_llm.api_key,
api_base=enforce_trailing_slash(llama_llm.azure_endpoint),
api_version=llama_llm.api_version,
deployment_id=llama_llm.engine,
)
else:
raise ValueError(f"Got unknown LLM provider: {llama_llm.__class__.__name__}")

Expand Down
3 changes: 3 additions & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ dependencies = [
"llama-index-postprocessor-xinference-rerank>=0.2.0",
"llama-index-postprocessor-bedrock-rerank>=0.3.0",
"llama-index-llms-vertex>=0.4.2",
"socksio>=1.0.0",
"llama-index-llms-azure-openai>=0.3.0",
"llama-index-embeddings-azure-openai>=0.3.0",
]
readme = "README.md"
requires-python = ">= 3.8"
Expand Down
27 changes: 27 additions & 0 deletions backend/requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ argon2-cffi-bindings==21.2.0
asyncmy==0.2.9
attrs==23.2.0
# via aiohttp
azure-core==1.32.0
# via azure-identity
azure-identity==1.19.0
# via llama-index-llms-azure-openai
backoff==2.2.1
# via dspy-ai
# via langfuse
Expand Down Expand Up @@ -103,6 +107,8 @@ colorama==0.4.6
colorlog==6.8.2
# via optuna
cryptography==42.0.8
# via azure-identity
# via msal
# via pyjwt
dataclasses-json==0.6.7
# via langchain-community
Expand Down Expand Up @@ -249,6 +255,7 @@ httpx==0.27.0
# via langsmith
# via llama-cloud
# via llama-index-core
# via llama-index-llms-azure-openai
# via ollama
# via openai
httpx-oauth==0.14.1
Expand Down Expand Up @@ -326,13 +333,15 @@ llama-index-core==0.12.10.post1
# via llama-index
# via llama-index-agent-openai
# via llama-index-cli
# via llama-index-embeddings-azure-openai
# via llama-index-embeddings-bedrock
# via llama-index-embeddings-cohere
# via llama-index-embeddings-jinaai
# via llama-index-embeddings-ollama
# via llama-index-embeddings-openai
# via llama-index-indices-managed-llama-cloud
# via llama-index-llms-anthropic
# via llama-index-llms-azure-openai
# via llama-index-llms-bedrock
# via llama-index-llms-gemini
# via llama-index-llms-ollama
Expand All @@ -349,24 +358,29 @@ llama-index-core==0.12.10.post1
# via llama-index-readers-file
# via llama-index-readers-llama-parse
# via llama-parse
llama-index-embeddings-azure-openai==0.3.0
llama-index-embeddings-bedrock==0.4.0
llama-index-embeddings-cohere==0.4.0
llama-index-embeddings-jinaai==0.4.0
llama-index-embeddings-ollama==0.5.0
llama-index-embeddings-openai==0.3.1
# via llama-index
# via llama-index-cli
# via llama-index-embeddings-azure-openai
llama-index-indices-managed-llama-cloud==0.6.3
# via llama-index
llama-index-llms-anthropic==0.6.3
# via llama-index-llms-bedrock
llama-index-llms-azure-openai==0.3.0
# via llama-index-embeddings-azure-openai
llama-index-llms-bedrock==0.3.3
llama-index-llms-gemini==0.4.2
llama-index-llms-ollama==0.5.0
llama-index-llms-openai==0.3.13
# via llama-index
# via llama-index-agent-openai
# via llama-index-cli
# via llama-index-llms-azure-openai
# via llama-index-llms-openai-like
# via llama-index-multi-modal-llms-openai
# via llama-index-program-openai
Expand Down Expand Up @@ -407,6 +421,11 @@ marshmallow==3.21.3
# via dataclasses-json
mdurl==0.1.2
# via markdown-it-py
msal==1.31.1
# via azure-identity
# via msal-extensions
msal-extensions==1.2.0
# via azure-identity
multidict==6.0.5
# via aiohttp
# via yarl
Expand Down Expand Up @@ -498,6 +517,7 @@ pluggy==1.5.0
# via pytest
portalocker==2.10.1
# via deepeval
# via msal-extensions
pre-commit==4.0.1
prometheus-client==0.20.0
# via flower
Expand Down Expand Up @@ -565,6 +585,7 @@ pygments==2.18.0
# via rich
pyjwt==2.8.0
# via fastapi-users
# via msal
pymysql==1.1.1
pyparsing==3.1.2
# via httplib2
Expand Down Expand Up @@ -617,6 +638,7 @@ regex==2024.5.15
# via tiktoken
# via transformers
requests==2.32.3
# via azure-core
# via cohere
# via datasets
# via deepeval
Expand All @@ -630,6 +652,7 @@ requests==2.32.3
# via langchain-community
# via langsmith
# via llama-index-core
# via msal
# via requests-toolbelt
# via tiktoken
# via transformers
Expand All @@ -653,13 +676,15 @@ shapely==2.0.6
shellingham==1.5.4
# via typer
six==1.16.0
# via azure-core
# via markdownify
# via python-dateutil
sniffio==1.3.1
# via anthropic
# via anyio
# via httpx
# via openai
socksio==1.0.0
soupsieve==2.5
# via beautifulsoup4
sqlalchemy==2.0.30
Expand Down Expand Up @@ -716,6 +741,8 @@ types-requests==2.32.0.20240712
typing-extensions==4.12.2
# via alembic
# via anthropic
# via azure-core
# via azure-identity
# via cohere
# via fastapi
# via fastapi-pagination
Expand Down
Loading

0 comments on commit 8cff7e5

Please sign in to comment.