diff --git a/sdk/python/foundation-models/system/inference/image-text-embeddings/text-to-image-retrieval.ipynb b/sdk/python/foundation-models/system/inference/image-text-embeddings/text-to-image-retrieval.ipynb new file mode 100644 index 00000000000..47c21e35f87 --- /dev/null +++ b/sdk/python/foundation-models/system/inference/image-text-embeddings/text-to-image-retrieval.ipynb @@ -0,0 +1,598 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Text-to-Image Retrieval using Online Endpoints and Indexes in Azure AI Search\n", + "\n", + "This example shows how to perform text-to-image search with a Azure AI Search Index and a deployed `embeddings` type model.\n", + "\n", + "### Task\n", + "The text-to-image retrieval task is to select from a collection of images those that are semantically related to a text query.\n", + " \n", + "### Model\n", + "Models that can perform the `embeddings` task are tagged with `embeddings`. We will use the `OpenAI-CLIP-Image-Text-Embeddings-vit-base-patch32` model in this notebook. If you don't find a model that suits your scenario or domain, you can discover and [import models from HuggingFace hub](../../import/import_model_into_registry.ipynb) and then use them for inference. \n", + "\n", + "### Inference data\n", + "We will use the [fridgeObjects](https://cvbp-secondary.z19.web.core.windows.net/datasets/image_classification/fridgeObjects.zip) dataset.\n", + "\n", + "\n", + "### Outline\n", + "1. Setup pre-requisites\n", + "2. Prepare data for inference\n", + "3. Deploy the model to an online endpoint real time inference\n", + "4. Create a search service and index\n", + "5. Populate the index with image embeddings\n", + "6. Query the index with text embeddings and visualize results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Setup pre-requisites\n", + "* Install dependencies\n", + "* Connect to AzureML Workspace. Learn more at [set up SDK authentication](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication?tabs=sdk). Replace ``, `` and `` below.\n", + "* Connect to `azureml` system registry" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from azure.ai.ml import MLClient\n", + "from azure.identity import (\n", + " DefaultAzureCredential,\n", + " InteractiveBrowserCredential,\n", + ")\n", + "import time\n", + "\n", + "try:\n", + " credential = DefaultAzureCredential()\n", + " credential.get_token(\"https://management.azure.com/.default\")\n", + "except Exception as ex:\n", + " credential = InteractiveBrowserCredential()\n", + "\n", + "try:\n", + " workspace_ml_client = MLClient.from_config(credential)\n", + " subscription_id = workspace_ml_client.subscription_id\n", + " resource_group = workspace_ml_client.resource_group_name\n", + " workspace_name = workspace_ml_client.workspace_name\n", + "except Exception as ex:\n", + " print(ex)\n", + " # Enter details of your AML workspace\n", + " subscription_id = \"\"\n", + " resource_group = \"\"\n", + " workspace_name = \"\"\n", + "workspace_ml_client = MLClient(\n", + " credential, subscription_id, resource_group, workspace_name\n", + ")\n", + "\n", + "# The models are available in the AzureML system registry, \"azureml\"\n", + "registry_ml_client = MLClient(\n", + " credential,\n", + " subscription_id,\n", + " resource_group,\n", + " registry_name=\"azureml\",\n", + ")\n", + "# Generating a unique timestamp that can be used for names and versions that need to be unique\n", + "timestamp = str(int(time.time()))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Prepare data for inference\n", + "\n", + "We will use the [fridgeObjects](https://cvbp-secondary.z19.web.core.windows.net/datasets/image_classification/fridgeObjects.zip) dataset for multi-class classification task. The fridge object dataset is stored in a directory. There are four different folders inside:\n", + "- /water_bottle\n", + "- /milk_bottle\n", + "- /carton\n", + "- /can\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import urllib\n", + "from zipfile import ZipFile\n", + "\n", + "# Change to a different location if you prefer\n", + "dataset_parent_dir = \"./data\"\n", + "\n", + "# create data folder if it doesnt exist.\n", + "os.makedirs(dataset_parent_dir, exist_ok=True)\n", + "\n", + "# download data\n", + "download_url = \"https://cvbp-secondary.z19.web.core.windows.net/datasets/image_classification/fridgeObjects.zip\"\n", + "\n", + "# Extract current dataset name from dataset url\n", + "dataset_name = os.path.split(download_url)[-1].split(\".\")[0]\n", + "# Get dataset path for later use\n", + "dataset_dir = os.path.join(dataset_parent_dir, dataset_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the data zip file path\n", + "data_file = os.path.join(dataset_parent_dir, f\"{dataset_name}.zip\")\n", + "\n", + "# Download the dataset\n", + "urllib.request.urlretrieve(download_url, filename=data_file)\n", + "\n", + "# extract files\n", + "with ZipFile(data_file, \"r\") as zip:\n", + " print(\"extracting files...\")\n", + " zip.extractall(path=dataset_parent_dir)\n", + " print(\"done\")\n", + "# delete zip file\n", + "os.remove(data_file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Image\n", + "\n", + "sample_image = os.path.join(dataset_dir, \"milk_bottle\", \"99.jpg\")\n", + "Image(filename=sample_image)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Deploy the model to an online endpoint for real time inference\n", + "Online endpoints give a durable REST API that can be used to integrate with applications that need to use the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"OpenAI-CLIP-Image-Text-Embeddings-vit-base-patch32\"\n", + "foundation_model = registry_ml_client.models.get(name=model_name, label=\"latest\")\n", + "print(\n", + " f\"\\n\\nUsing model name: {foundation_model.name}, version: {foundation_model.version}, id: {foundation_model.id} for inferencing\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "from azure.ai.ml.entities import (\n", + " ManagedOnlineEndpoint,\n", + " ManagedOnlineDeployment,\n", + ")\n", + "\n", + "# Endpoint names need to be unique in a region, hence using timestamp to create unique endpoint name\n", + "timestamp = int(time.time())\n", + "online_endpoint_name = \"clip-embeddings-\" + str(timestamp)\n", + "# Create an online endpoint\n", + "endpoint = ManagedOnlineEndpoint(\n", + " name=online_endpoint_name,\n", + " description=\"Online endpoint for \"\n", + " + foundation_model.name\n", + " + \", for image-text-embeddings task\",\n", + " auth_mode=\"key\",\n", + ")\n", + "workspace_ml_client.begin_create_or_update(endpoint).wait()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from azure.ai.ml.entities import OnlineRequestSettings, ProbeSettings\n", + "\n", + "deployment_name = \"embeddings-mlflow-deploy\"\n", + "\n", + "# Create a deployment\n", + "demo_deployment = ManagedOnlineDeployment(\n", + " name=deployment_name,\n", + " endpoint_name=online_endpoint_name,\n", + " model=foundation_model.id,\n", + " instance_type=\"Standard_NC6s_v3\", # Use GPU instance type like Standard_DS3v2 for lower cost but slower inference\n", + " instance_count=1,\n", + " request_settings=OnlineRequestSettings(\n", + " max_concurrent_requests_per_instance=1,\n", + " request_timeout_ms=90000,\n", + " max_queue_wait_ms=500,\n", + " ),\n", + " liveness_probe=ProbeSettings(\n", + " failure_threshold=49,\n", + " success_threshold=1,\n", + " timeout=299,\n", + " period=180,\n", + " initial_delay=180,\n", + " ),\n", + " readiness_probe=ProbeSettings(\n", + " failure_threshold=10,\n", + " success_threshold=1,\n", + " timeout=10,\n", + " period=10,\n", + " initial_delay=10,\n", + " ),\n", + ")\n", + "workspace_ml_client.online_deployments.begin_create_or_update(demo_deployment).wait()\n", + "endpoint.traffic = {deployment_name: 100}\n", + "workspace_ml_client.begin_create_or_update(endpoint).result()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Create a search service and index" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Follow instructions [here](https://learn.microsoft.com/en-us/azure/search/search-create-service-portal) to create a search service using the Azure Portal. Then, run the code below to create a search index." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "SEARCH_SERVICE_NAME = \"\"\n", + "SERVICE_ADMIN_KEY = \"\"\n", + "\n", + "INDEX_NAME = \"fridge-objects-index\"\n", + "API_VERSION = \"2023-07-01-Preview\"\n", + "CREATE_INDEX_REQUEST_URL = \"https://{search_service_name}.search.windows.net/indexes?api-version={api_version}\".format(\n", + " search_service_name=SEARCH_SERVICE_NAME, api_version=API_VERSION\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "\n", + "create_request = {\n", + " \"name\": INDEX_NAME,\n", + " \"fields\": [\n", + " {\n", + " \"name\": \"id\",\n", + " \"type\": \"Edm.String\",\n", + " \"key\": True,\n", + " \"searchable\": True,\n", + " \"retrievable\": True,\n", + " \"filterable\": True,\n", + " },\n", + " {\n", + " \"name\": \"filename\",\n", + " \"type\": \"Edm.String\",\n", + " \"searchable\": True,\n", + " \"filterable\": True,\n", + " \"sortable\": True,\n", + " \"retrievable\": True,\n", + " },\n", + " {\n", + " \"name\": \"imageEmbeddings\",\n", + " \"type\": \"Collection(Edm.Single)\",\n", + " \"searchable\": True,\n", + " \"retrievable\": True,\n", + " \"dimensions\": 512,\n", + " \"vectorSearchConfiguration\": \"my-vector-config\",\n", + " },\n", + " ],\n", + " \"vectorSearch\": {\n", + " \"algorithmConfigurations\": [\n", + " {\n", + " \"name\": \"my-vector-config\",\n", + " \"kind\": \"hnsw\",\n", + " \"hnswParameters\": {\n", + " \"m\": 4,\n", + " \"efConstruction\": 400,\n", + " \"efSearch\": 500,\n", + " \"metric\": \"cosine\",\n", + " },\n", + " }\n", + " ]\n", + " },\n", + "}\n", + "response = requests.post(\n", + " CREATE_INDEX_REQUEST_URL,\n", + " json=create_request,\n", + " headers={\"api-key\": SERVICE_ADMIN_KEY},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5. Populate the index with image embeddings\n", + "\n", + "Submit requests with image data to the online endpoint to get image embeddings. Add the image embeddings to the search index." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import base64\n", + "\n", + "_REQUEST_FILE_NAME = \"request.json\"\n", + "\n", + "\n", + "def read_image(image_path):\n", + " with open(image_path, \"rb\") as f:\n", + " return f.read()\n", + "\n", + "\n", + "def make_request_images(image_path):\n", + " request_json = {\n", + " \"input_data\": {\n", + " \"columns\": [\"image\", \"text\"],\n", + " \"data\": [[base64.encodebytes(read_image(image_path)).decode(\"utf-8\"), \"\"]],\n", + " }\n", + " }\n", + "\n", + " with open(_REQUEST_FILE_NAME, \"wt\") as f:\n", + " json.dump(request_json, f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ADD_DATA_REQUEST_URL = \"https://{search_service_name}.search.windows.net/indexes/{index_name}/docs/index?api-version={api_version}\".format(\n", + " search_service_name=SEARCH_SERVICE_NAME,\n", + " index_name=INDEX_NAME,\n", + " api_version=API_VERSION,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm.auto import tqdm\n", + "\n", + "image_paths = [\n", + " os.path.join(dp, f)\n", + " for dp, dn, filenames in os.walk(dataset_dir)\n", + " for f in filenames\n", + " if os.path.splitext(f)[1] == \".jpg\"\n", + "]\n", + "\n", + "for idx, image_path in enumerate(tqdm(image_paths)):\n", + " ID = idx\n", + " FILENAME = image_path\n", + " MAX_RETRIES = 3\n", + "\n", + " # get embedding from endpoint\n", + " embedding_request = make_request_images(image_path)\n", + "\n", + " response = None\n", + " request_failed = False\n", + " IMAGE_EMBEDDING = None\n", + " for r in range(MAX_RETRIES):\n", + " try:\n", + " response = workspace_ml_client.online_endpoints.invoke(\n", + " endpoint_name=online_endpoint_name,\n", + " deployment_name=deployment_name,\n", + " request_file=_REQUEST_FILE_NAME,\n", + " )\n", + " response = json.loads(response)\n", + " IMAGE_EMBEDDING = response[0][\"image_features\"]\n", + " break\n", + " except Exception as e:\n", + " print(f\"Unable to get embeddings for image {FILENAME}: {e}\")\n", + " print(response)\n", + " if r == MAX_RETRIES - 1:\n", + " print(f\"attempt {r} failed, reached retry limit\")\n", + " request_failed = True\n", + " else:\n", + " print(f\"attempt {r} failed, retrying\")\n", + "\n", + " # add embedding to index\n", + " if IMAGE_EMBEDDING:\n", + " add_data_request = {\n", + " \"value\": [\n", + " {\n", + " \"id\": str(ID),\n", + " \"filename\": FILENAME,\n", + " \"imageEmbeddings\": IMAGE_EMBEDDING,\n", + " \"@search.action\": \"upload\",\n", + " }\n", + " ]\n", + " }\n", + " response = requests.post(\n", + " ADD_DATA_REQUEST_URL,\n", + " json=add_data_request,\n", + " headers={\"api-key\": SERVICE_ADMIN_KEY},\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6. Query the index with text embeddings and visualize results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "TEXT_QUERY = \"a photo of a milk bottle\"\n", + "K = 5 # number of results to retrieve" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 6.1 Get the text embeddings for the query using the online endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def make_request_text(text_sample):\n", + " request_json = {\n", + " \"input_data\": {\n", + " \"columns\": [\"image\", \"text\"],\n", + " \"data\": [[\"\", text_sample]],\n", + " }\n", + " }\n", + "\n", + " with open(_REQUEST_FILE_NAME, \"wt\") as f:\n", + " json.dump(request_json, f)\n", + "\n", + "\n", + "make_request_text(TEXT_QUERY)\n", + "response = workspace_ml_client.online_endpoints.invoke(\n", + " endpoint_name=online_endpoint_name,\n", + " deployment_name=deployment_name,\n", + " request_file=_REQUEST_FILE_NAME,\n", + ")\n", + "response = json.loads(response)\n", + "QUERY_TEXT_EMBEDDING = response[0][\"text_features\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 6.2 Send the text embeddings as a query to the search index" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "QUERY_REQUEST_URL = \"https://{search_service_name}.search.windows.net/indexes/{index_name}/docs/search?api-version={api_version}\".format(\n", + " search_service_name=SEARCH_SERVICE_NAME,\n", + " index_name=INDEX_NAME,\n", + " api_version=API_VERSION,\n", + ")\n", + "\n", + "\n", + "search_request = {\n", + " \"vectors\": [{\"value\": QUERY_TEXT_EMBEDDING, \"fields\": \"imageEmbeddings\", \"k\": K}],\n", + " \"select\": \"filename\",\n", + "}\n", + "\n", + "\n", + "response = requests.post(\n", + " QUERY_REQUEST_URL, json=search_request, headers={\"api-key\": SERVICE_ADMIN_KEY}\n", + ")\n", + "neighbors = json.loads(response.text)[\"value\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 6.3 Visualize Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "from PIL import Image\n", + "\n", + "K1, K2 = 3, 4\n", + "\n", + "\n", + "def make_pil_image(image_path):\n", + " pil_image = Image.open(image_path)\n", + " return pil_image\n", + "\n", + "\n", + "_, axes = plt.subplots(nrows=K1 + 1, ncols=K2, figsize=(64, 64))\n", + "for i in range(K1 + 1):\n", + " for j in range(K2):\n", + " axes[i, j].axis(\"off\")\n", + "\n", + "i, j = 0, 0\n", + "\n", + "for neighbor in neighbors:\n", + " pil_image = make_pil_image(neighbor[\"filename\"])\n", + " axes[i, j].imshow(np.asarray(pil_image), aspect=\"auto\")\n", + " axes[i, j].text(1, 1, \"{:.4f}\".format(neighbor[\"@search.score\"]), fontsize=32)\n", + "\n", + " j += 1\n", + " if j == K2:\n", + " i += 1\n", + " j = 0" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "rc_133", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}