Skip to content

Commit

Permalink
security improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
placerda committed Dec 17, 2024
1 parent 95b4804 commit 9901228
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 365 deletions.
2 changes: 0 additions & 2 deletions local.settings.json.template
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@
"DB_USERNAME":"",
"DB_TOP_K":"3",
"DB_MAX_TOKENS":"1000",
"SECURITY_HUB_ENDPOINT": "",
"SECURITY_HUB_CHECK": "false",
"SECURITY_HUB_HATE_THRESHHOLD": "0",
"SECURITY_HUB_SELFHARM_THRESHHOLD": "0",
"SECURITY_HUB_SEXUAL_THRESHHOLD": "0",
Expand Down
13 changes: 7 additions & 6 deletions orc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ async def main(req: func.HttpRequest) -> func.HttpResponse:
req_body = req.get_json()
conversation_id = req_body.get('conversation_id')
question = req_body.get('question')
client_principal_id = req_body.get('client_principal_id')
client_principal_name = req_body.get('client_principal_name')
if not client_principal_id or client_principal_id == '':
client_principal_id = '00000000-0000-0000-0000-000000000000'
client_principal_name = 'anonymous'

# Get client principal information
client_principal_id = req_body.get('client_principal_id', '00000000-0000-0000-0000-000000000000')
client_principal_name = req_body.get('client_principal_name', 'anonymous')
client_group_names = req_body.get('client_group_names', '')
client_principal = {
'id': client_principal_id,
'name': client_principal_name
'name': client_principal_name,
'group_names': client_group_names
}

if question:
Expand Down
4 changes: 2 additions & 2 deletions orc/code_orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
if SECURITY_HUB_CHECK:
SECURITY_HUB_THRESHOLDS=[get_possitive_int_or_default(os.environ.get("SECURITY_HUB_HATE_THRESHHOLD"), 0),get_possitive_int_or_default(os.environ.get("SECURITY_HUB_SELFHARM_THRESHHOLD"), 0),get_possitive_int_or_default(os.environ.get("SECURITY_HUB_SEXUAL_THRESHHOLD"), 0),get_possitive_int_or_default(os.environ.get("SECURITY_HUB_VIOLENCE_THRESHHOLD"), 0)]

async def get_answer(history,client_principal_id):
async def get_answer(history, security_ids):

#############################
# INITIALIZATION
Expand Down Expand Up @@ -220,7 +220,7 @@ async def get_answer(history,client_principal_id):
#run search retrieval function
retrievalPlugin= await retrievalPluginTask
if(SEARCH_RETRIEVAL):
search_function_result = await kernel.invoke(retrievalPlugin["VectorIndexRetrieval"], sk.KernelArguments(input=search_query,apim_key=apim_key,client_principal_id=client_principal_id))
search_function_result = await kernel.invoke(retrievalPlugin["VectorIndexRetrieval"], sk.KernelArguments(input=search_query,apim_key=apim_key,security_ids=security_ids))
formatted_sources = search_function_result.value[:100].replace('\n', ' ')
escaped_sources = escape_xml_characters(search_function_result.value)
search_sources=escaped_sources
Expand Down
24 changes: 18 additions & 6 deletions orc/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from azure.cosmos.aio import CosmosClient
from datetime import datetime
from shared.util import format_answer
from azure.identity.aio import DefaultAzureCredential
from azure.identity.aio import ManagedIdentityCredential, AzureCliCredential, ChainedTokenCredential
import orc.code_orchestration as code_orchestration

# logging level
Expand All @@ -28,9 +28,19 @@
ANSWER_FORMAT = "html" # html, markdown, none

async def get_credentials():
async with DefaultAzureCredential() as credential:
async with ChainedTokenCredential(
ManagedIdentityCredential(),
AzureCliCredential()
) as credential:
return credential

def generate_security_ids(client_principal):
security_ids = 'anonymous'
if client_principal is not None:
group_names = client_principal['group_names']
security_ids = f"{client_principal['id']}" + (f",{group_names}" if group_names else "")
return security_ids

async def run(conversation_id, ask, client_principal):

start_time = time.time()
Expand All @@ -47,7 +57,10 @@ async def run(conversation_id, ask, client_principal):
# get conversation

#credential = get_credentials()
async with DefaultAzureCredential() as credential:
async with ChainedTokenCredential(
ManagedIdentityCredential(),
AzureCliCredential()
) as credential:
async with CosmosClient(AZURE_DB_URI, credential=credential) as db_client:
db = db_client.get_database_client(database=AZURE_DB_NAME)
container = db.get_container_client('conversations')
Expand All @@ -68,10 +81,9 @@ async def run(conversation_id, ask, client_principal):

# 2) get answer and sources

client_principal_id = client_principal['id']
# get rag answer and sources
logging.info(f"[orchestrator] executing RAG retrieval using code orchestration")
answer_dict = await code_orchestration.get_answer(history,client_principal_id)
security_ids = generate_security_ids(client_principal)
answer_dict = await code_orchestration.get_answer(history, security_ids)

# 3) update and save conversation (containing history and conversation data)

Expand Down
28 changes: 20 additions & 8 deletions orc/plugins/Retrieval/native_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing_extensions import Annotated
from azure.cognitiveservices.search.customsearch import CustomSearchClient
from msrest.authentication import CognitiveServicesCredentials
from azure.identity.aio import DefaultAzureCredential
from azure.identity.aio import ManagedIdentityCredential, AzureCliCredential, ChainedTokenCredential
import aiohttp
import asyncio

Expand Down Expand Up @@ -50,8 +50,6 @@
AZURE_SEARCH_FILENAME_COLUMN = os.environ.get("AZURE_SEARCH_FILENAME_COLUMN") or "filepath"
AZURE_SEARCH_TITLE_COLUMN = os.environ.get("AZURE_SEARCH_TITLE_COLUMN") or "title"
AZURE_SEARCH_URL_COLUMN = os.environ.get("AZURE_SEARCH_URL_COLUMN") or "url"
AZURE_SEARCH_TRIMMING = os.environ.get("AZURE_SEARCH_TRIMMING") or "false"
AZURE_SEARCH_TRIMMING = True if AZURE_SEARCH_TRIMMING == "true" else False

# Bing Search Integration Settings
BING_SEARCH_TOP_K = os.environ.get("BING_SEARCH_TOP_K") or "3"
Expand Down Expand Up @@ -105,13 +103,21 @@ async def VectorIndexRetrieval(
self,
input: Annotated[str, "The user question"],
apim_key: Annotated[str, "The key to access the apim endpoint"],
client_principal_id: Annotated[str, "The user client principal id"]
# client_principal_id: Annotated[str, "The user client principal id"]
security_ids: Annotated[str, "Comma separated list string with user security ids"]
) -> Annotated[str, "the output is a string with the search results"]:
search_results = []
search_query = input
search_filter = f"security_id/any(g:search.in(g,'{client_principal_id}'))"
# search_filter = f"security_id/any(g:search.in(g,'{client_principal_id}'))"
search_filter = (
f"metadata_security_id/any(g:search.in(g, '{security_ids}')) "
f"or not metadata_security_id/any()"
)
try:
async with DefaultAzureCredential() as credential:
async with ChainedTokenCredential(
ManagedIdentityCredential(),
AzureCliCredential()
) as credential:
start_time = time.time()
logging.info(f"[sk_retrieval] generating question embeddings. search query: {search_query}")
embeddings_query = await generate_embeddings(search_query,apim_key=apim_key)
Expand Down Expand Up @@ -147,8 +153,14 @@ async def VectorIndexRetrieval(
body["queryType"] = "semantic"
body["semanticConfiguration"] = AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG

if AZURE_SEARCH_TRIMMING:
body["filter"] = search_filter
body["filter"] = search_filter

logging.debug(f"[ai_search] search filter: {search_filter}")

headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {azureSearchKey}'
}

if APIM_ENABLED:
headers = {
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ tiktoken==0.5.2
tenacity==8.2.3
azure-cognitiveservices-search-customsearch==0.3.1
beautifulsoup4==4.12.3
pydantic==2.3.0
pydantic==2.3.0
nbstripout==0.8.1
nbconvert==7.16.4
10 changes: 5 additions & 5 deletions shared/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import urllib.parse
from azure.cosmos.aio import CosmosClient as AsyncCosmosClient
from azure.keyvault.secrets.aio import SecretClient as AsyncSecretClient
from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential
from azure.identity.aio import ManagedIdentityCredential, AzureCliCredential, ChainedTokenCredential
from tenacity import retry, wait_random_exponential, stop_after_attempt
import semantic_kernel as sk
from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion
Expand Down Expand Up @@ -59,7 +59,7 @@
async def get_secret(secretName):
keyVaultName = os.environ["AZURE_KEY_VAULT_NAME"]
KVUri = f"https://{keyVaultName}.vault.azure.net"
async with AsyncDefaultAzureCredential() as credential:
async with ChainedTokenCredential( ManagedIdentityCredential(), AzureCliCredential()) as credential:
async with AsyncSecretClient(vault_url=KVUri, credential=credential) as client:
retrieved_secret = await client.get_secret(secretName)
value = retrieved_secret.value
Expand Down Expand Up @@ -365,7 +365,7 @@ async def get_aoai_config(model):
}
else:
resource = await get_next_resource(model)
async with AsyncDefaultAzureCredential() as credential:
async with ChainedTokenCredential( ManagedIdentityCredential(), AzureCliCredential()) as credential:
token = await credential.get_token("https://cognitiveservices.azure.com/.default")

if model in ('gpt-35-turbo', 'gpt-35-turbo-16k', 'gpt-4', 'gpt-4-32k','gpt-4o'):
Expand Down Expand Up @@ -394,7 +394,7 @@ async def get_next_resource(model):
return resources[0]
else:
start_time = time.time()
async with AsyncDefaultAzureCredential() as credential:
async with ChainedTokenCredential( ManagedIdentityCredential(), AzureCliCredential()) as credential:
async with AsyncCosmosClient(AZURE_DB_URI, credential) as db_client:
db = db_client.get_database_client(database=AZURE_DB_NAME)
container = db.get_container_client('models')
Expand Down Expand Up @@ -430,7 +430,7 @@ async def get_next_resource(model):

async def get_blocked_list():
blocked_list = []
async with AsyncDefaultAzureCredential() as credential:
async with ChainedTokenCredential( ManagedIdentityCredential(), AzureCliCredential()) as credential:
async with AsyncCosmosClient(AZURE_DB_URI, credential) as db_client:
db = db_client.get_database_client(database=AZURE_DB_NAME)
container = db.get_container_client('guardrails')
Expand Down
Loading

0 comments on commit 9901228

Please sign in to comment.