From 5c372ddc48498a50ea1fb8f53364644f6fbefcaa Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Fri, 25 Oct 2024 15:41:29 +1100
Subject: [PATCH] Fix image download headers regresison and fix png image issue
(#1022)
This includes three bug fixes:
- Fix the image download headers regression so users can provide a media download header to download private images. This also applies to other modalities including video and audio.
- Fix a bug that languagebind models can not encode png images.
- Fix a bug that Marqo can not take more than 2 modalities in weighted queries. Marqo can now take weighted queries with a mixture of all modalities supported by the model.
---
src/marqo/api/models/add_docs_objects.py | 20 ++
src/marqo/api/models/embed_request.py | 32 ++-
src/marqo/core/embed/embed.py | 13 +-
.../embedding_models/abstract_clip_model.py | 24 +-
.../embedding_models/image_download.py | 236 ------------------
.../embedding_models/open_clip_model.py | 4 +-
src/marqo/core/inference/image_download.py | 20 +-
src/marqo/core/models/add_docs_params.py | 4 +-
src/marqo/core/search/hybrid_search.py | 7 +-
.../semi_structured_add_document_handler.py | 5 +-
.../unstructured_add_document_handler.py | 7 +-
.../core/vespa_index/add_documents_handler.py | 1 -
src/marqo/s2_inference/clip_utils.py | 52 ++--
.../languagebind/image/processing_image.py | 5 +
.../s2_inference/multimodal_model_load.py | 42 ++--
src/marqo/s2_inference/s2_inference.py | 67 +++--
src/marqo/tensor_search/add_docs.py | 149 ++++++-----
src/marqo/tensor_search/api.py | 4 +-
src/marqo/tensor_search/models/api_models.py | 25 +-
src/marqo/tensor_search/models/search.py | 43 +++-
.../streaming_media_processor.py | 40 ++-
src/marqo/tensor_search/tensor_search.py | 175 ++++++-------
src/marqo/tensor_search/web/api_utils.py | 16 +-
tests/conftest.py | 4 +-
tests/marqo_test.py | 15 ++
tests/s2_inference/test_image_downloading.py | 8 +-
tests/s2_inference/test_vectorise.py | 3 +-
.../test_add_documents_combined.py | 221 ++++++++++++++--
tests/tensor_search/integ_tests/test_embed.py | 10 +-
.../integ_tests/test_search_combined.py | 6 +-
...test_add_documents_use_existing_tensors.py | 2 +-
tests/tensor_search/test_api_utils.py | 18 +-
.../test_image_download_headers.py | 24 +-
.../tensor_search/test_modalities_download.py | 67 +++--
tests/tensor_search/test_search.py | 2 +-
35 files changed, 770 insertions(+), 601 deletions(-)
delete mode 100644 src/marqo/core/inference/embedding_models/image_download.py
diff --git a/src/marqo/api/models/add_docs_objects.py b/src/marqo/api/models/add_docs_objects.py
index 5d7a33695..ad5c2b81d 100644
--- a/src/marqo/api/models/add_docs_objects.py
+++ b/src/marqo/api/models/add_docs_objects.py
@@ -22,6 +22,7 @@ class Config:
tensorFields: Optional[List] = None
useExistingTensors: bool = False
imageDownloadHeaders: dict = Field(default_factory=dict)
+ mediaDownloadHeaders: Optional[dict] = None
modelAuth: Optional[ModelAuth] = None
mappings: Optional[dict] = None
documents: Union[Sequence[Union[dict, Any]], np.ndarray]
@@ -38,3 +39,22 @@ def validate_thread_counts(cls, values):
if media_count is not None and image_count != read_env_vars_and_defaults_ints(EnvVars.MARQO_IMAGE_DOWNLOAD_THREAD_COUNT_PER_REQUEST):
raise ValueError("Cannot set both imageDownloadThreadCount and mediaDownloadThreadCount")
return values
+
+ @root_validator(skip_on_failure=True)
+ def _validate_image_download_headers_and_media_download_headers(cls, values):
+ """Validate imageDownloadHeaders and mediaDownloadHeaders. Raise an error if both are set.
+
+ If imageDownloadHeaders is set, set mediaDownloadHeaders to it and use mediaDownloadHeaders in the
+ rest of the code.
+
+ imageDownloadHeaders is deprecated and will be removed in the future.
+ """
+ image_download_headers = values.get('imageDownloadHeaders')
+ media_download_headers = values.get('mediaDownloadHeaders')
+ if image_download_headers and media_download_headers:
+ raise ValueError("Cannot set both imageDownloadHeaders and mediaDownloadHeaders. "
+ "'imageDownloadHeaders' is deprecated and will be removed in the future. "
+ "Use mediaDownloadHeaders instead.")
+ if image_download_headers:
+ values['mediaDownloadHeaders'] = image_download_headers
+ return values
diff --git a/src/marqo/api/models/embed_request.py b/src/marqo/api/models/embed_request.py
index 9ca47422e..c1373da6d 100644
--- a/src/marqo/api/models/embed_request.py
+++ b/src/marqo/api/models/embed_request.py
@@ -6,18 +6,21 @@
import pydantic
from typing import Union, List, Dict, Optional, Any
+from pydantic import Field, root_validator
+
from marqo.tensor_search.models.private_models import ModelAuth
-from marqo.tensor_search.models.api_models import BaseMarqoModel
+from marqo.base_model import MarqoBaseModel
from marqo.core.embed.embed import EmbedContentType
-class EmbedRequest(BaseMarqoModel):
+class EmbedRequest(MarqoBaseModel):
# content can be a single query or list of queries. Queries can be a string or a dictionary.
content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]]
- image_download_headers: Optional[Dict] = None
+ image_download_headers: Optional[Dict] = Field(default=None, alias="imageDownloadHeaders")
+ mediaDownloadHeaders: Optional[Dict] = None
modelAuth: Optional[ModelAuth] = None
- content_type: Optional[EmbedContentType] = EmbedContentType.Query
+ content_type: Optional[EmbedContentType] = Field(default=EmbedContentType.Query, alias="contentType")
@pydantic.validator('content')
def validate_content(cls, value):
@@ -47,4 +50,23 @@ def validate_content(cls, value):
else:
raise ValueError("Embed content should be a string, a dictionary, or a list of strings or dictionaries")
- return value
\ No newline at end of file
+ return value
+
+ @root_validator(skip_on_failure=True)
+ def _validate_image_download_headers_and_media_download_headers(cls, values):
+ """Validate imageDownloadHeaders and mediaDownloadHeaders. Raise an error if both are set.
+
+ If imageDownloadHeaders is set, set mediaDownloadHeaders to it and use mediaDownloadHeaders in the
+ rest of the code.
+
+ imageDownloadHeaders is deprecated and will be removed in the future.
+ """
+ image_download_headers = values.get('imageDownloadHeaders')
+ media_download_headers = values.get('mediaDownloadHeaders')
+ if image_download_headers and media_download_headers:
+ raise ValueError("Cannot set both imageDownloadHeaders and mediaDownloadHeaders. "
+ "'imageDownloadHeaders' is deprecated and will be removed in the future. "
+ "Use mediaDownloadHeaders instead.")
+ if image_download_headers:
+ values['mediaDownloadHeaders'] = image_download_headers
+ return values
\ No newline at end of file
diff --git a/src/marqo/core/embed/embed.py b/src/marqo/core/embed/embed.py
index 29d6fcf54..4730ddcd5 100644
--- a/src/marqo/core/embed/embed.py
+++ b/src/marqo/core/embed/embed.py
@@ -34,11 +34,12 @@ def validate_default_device(cls, value):
return value
def embed_content(
- self, content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]],
- index_name: str, device: str = None, image_download_headers: Optional[Dict] = None,
- model_auth: Optional[ModelAuth] = None,
- content_type: Optional[EmbedContentType] = EmbedContentType.Query
- ) -> Dict:
+ self, content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]],
+ index_name: str, device: str = None,
+ media_download_headers: Optional[Dict] = None,
+ model_auth: Optional[ModelAuth] = None,
+ content_type: Optional[EmbedContentType] = EmbedContentType.Query
+ ) -> Dict:
"""
Use the index's model to embed the content
@@ -105,7 +106,7 @@ def embed_content(
BulkSearchQueryEntity(
q=content_entry,
index=marqo_index,
- image_download_headers=image_download_headers,
+ mediaDownloadHeaders=media_download_headers,
modelAuth=model_auth,
text_query_prefix=prefix
# TODO: Check if it's fine that we leave out the other parameters
diff --git a/src/marqo/core/inference/embedding_models/abstract_clip_model.py b/src/marqo/core/inference/embedding_models/abstract_clip_model.py
index 1b2a33b23..42b8c2d8c 100644
--- a/src/marqo/core/inference/embedding_models/abstract_clip_model.py
+++ b/src/marqo/core/inference/embedding_models/abstract_clip_model.py
@@ -7,7 +7,7 @@
from marqo.core.inference.image_download import (_is_image, format_and_load_CLIP_images,
format_and_load_CLIP_image)
from marqo.core.inference.embedding_models.abstract_embedding_model import AbstractEmbeddingModel
-from marqo.core.inference.embedding_models.image_download import (_is_image, format_and_load_CLIP_images,
+from marqo.core.inference.image_download import (_is_image, format_and_load_CLIP_images,
format_and_load_CLIP_image)
from marqo.s2_inference.logger import get_logger
from marqo.s2_inference.types import *
@@ -50,11 +50,11 @@ def encode_text(self, inputs: Union[str, List[str]], normalize: bool = True) ->
pass
@abstractmethod
- def encode_image(self, inputs, normalize: bool = True, image_download_headers: dict = None) -> np.ndarray:
+ def encode_image(self, inputs, normalize: bool = True, media_download_headers: dict = None) -> np.ndarray:
pass
- def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
- default: str = 'text', normalize=True, **kwargs) -> np.ndarray:
+ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]], normalize=True, **kwargs) -> np.ndarray:
+ default = "text"
infer = kwargs.pop('infer', True)
if infer and _is_image(inputs):
is_image = True
@@ -68,8 +68,8 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
if is_image:
logger.debug('image')
- image_download_headers = kwargs.get("image_download_headers", dict())
- return self.encode_image(inputs, normalize=normalize, image_download_headers=image_download_headers)
+ media_download_headers = kwargs.get("media_download_headers", dict())
+ return self.encode_image(inputs, normalize=normalize, media_download_headers=media_download_headers)
else:
logger.debug('text')
return self.encode_text(inputs, normalize=normalize)
@@ -85,27 +85,27 @@ def normalize(outputs):
return outputs.norm(dim=-1, keepdim=True)
def _preprocess_images(self, images: Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor],
- image_download_headers: Optional[Dict] = None) -> Tensor:
+ media_download_headers: Optional[Dict] = None) -> Tensor:
"""Preprocess the input image to be ready for the model.
Args:
images (Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor]): input image,
can be a str(url), a PIL image, or a tensor, or a list of them
- image_download_headers (Optional[Dict]): headers for the image download
+ media_download_headers (Optional[Dict]): headers for the image download
Return:
Tensor: the processed image tensor with shape (batch_size, channel, n_px, n_px)
"""
if self.model is None:
self.load()
- if image_download_headers is None:
- image_download_headers = dict()
+ if media_download_headers is None:
+ media_download_headers = dict()
# default to batch encoding
if isinstance(images, list):
image_input: List[Union[ImageType, Tensor]] \
- = format_and_load_CLIP_images(images, image_download_headers)
+ = format_and_load_CLIP_images(images, media_download_headers)
else:
- image_input: List[Union[ImageType, Tensor]] = [format_and_load_CLIP_image(images, image_download_headers)]
+ image_input: List[Union[ImageType, Tensor]] = [format_and_load_CLIP_image(images, media_download_headers)]
image_input_processed: Tensor = torch.stack([self.preprocess(_img).to(self.device) \
if not isinstance(_img, torch.Tensor) else _img \
diff --git a/src/marqo/core/inference/embedding_models/image_download.py b/src/marqo/core/inference/embedding_models/image_download.py
deleted file mode 100644
index 65c158e20..000000000
--- a/src/marqo/core/inference/embedding_models/image_download.py
+++ /dev/null
@@ -1,236 +0,0 @@
-import os
-from io import BytesIO
-
-import certifi
-import numpy as np
-import pycurl
-import requests
-import torch
-import validators
-from PIL import Image, UnidentifiedImageError
-from requests.utils import requote_uri
-
-from marqo import marqo_docs
-from marqo.api.exceptions import InternalError
-from marqo.s2_inference.errors import ImageDownloadError
-from marqo.s2_inference.types import *
-from marqo.tensor_search.telemetry import RequestMetrics
-
-# TODO Merge this with the one in clip_utils in the future refactoring
-
-DEFAULT_HEADERS = {'User-Agent': 'Marqobot/1.0'}
-
-
-def get_allowed_image_types():
- return {'.jpg', '.png', '.bmp', '.jpeg'}
-
-
-def _is_image(inputs: Union[str, List[Union[str, ImageType, ndarray]]]) -> bool:
- # some logic to determine if something is an image or not
- # assume the batch is the same type
- # maybe we use something like this https://github.com/ahupp/python-magic
-
- _allowed = get_allowed_image_types()
-
- # we assume the batch is this way if a list
- # otherwise apply over each element
- if isinstance(inputs, list):
-
- if len(inputs) == 0:
- raise UnidentifiedImageError("received empty list, expected at least one element.")
-
- thing = inputs[0]
- else:
- thing = inputs
-
- # if it is a string, determine if it is a local file or url
- if isinstance(thing, str):
- name, extension = os.path.splitext(thing.lower())
-
- # if it has the correct extension, asssume yes
- if extension in _allowed:
- return True
-
- # if it is a local file without extension, then raise an error
- if os.path.isfile(thing):
- # we could also read the first part of the file and infer
- raise UnidentifiedImageError(
- f"local file [{thing}] extension {extension} does not match allowed file types of {_allowed}")
- else:
- # if it is not a local file and does not have an extension
- # check if url
- if validators.url(thing):
- return True
- else:
- return False
-
- # if it is an array, then it is an image
- elif isinstance(thing, (ImageType, ndarray, Tensor)):
- return True
- else:
- raise UnidentifiedImageError(f"expected type Image or str for inputs but received type {type(thing)}")
-
-
-def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], image_download_headers: dict) -> List[
- ImageType]:
- """takes in a list of strings, arrays or urls and either loads and/or converts to PIL
- for the clip model
-
- Args:
- images (List[Union[str, np.ndarray, ImageType]]): list of file locations or arrays (can be mixed)
-
- Raises:
- TypeError: _description_
-
- Returns:
- List[ImageType]: list of PIL images
- """
- if not isinstance(images, list):
- raise TypeError(f"expected list but received {type(images)}")
-
- results = []
- for image in images:
- results.append(format_and_load_CLIP_image(image, image_download_headers))
-
- return results
-
-
-def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
- image_download_headers: dict) -> Union[ImageType, Tensor]:
- """standardizes the input to be a PIL image
-
- Args:
- image (Union[str, np.ndarray, ImageType, Tensor]): can be a local file, url, array or a tensor
-
- Raises:
- ValueError: _description_
- TypeError: _description_
-
- Returns:
- standardized the image:
- ImageType: PIL image if input is a string, an array or a PIL image
- Tensor: torch tensor if input is a torch tensor
- """
- # check for the input type
- if isinstance(image, str):
- img = load_image_from_path(image, image_download_headers)
- elif isinstance(image, np.ndarray):
- img = Image.fromarray(image.astype('uint8'), 'RGB')
- elif isinstance(image, torch.Tensor):
- img = image
- elif isinstance(image, ImageType):
- img = image
- else:
- raise UnidentifiedImageError(f"input of type {type(image)} "
- f"did not match allowed types of str, np.ndarray, ImageType, Tensor")
-
- return img
-
-
-def load_image_from_path(image_path: str, image_download_headers: dict, timeout_ms=3000,
- metrics_obj: Optional[RequestMetrics] = None) -> ImageType:
- """Loads an image into PIL from a string path that is either local or a url
-
- Args:
- image_path (str): Local or remote path to image.
- image_download_headers (dict): header for the image download
- timeout_ms (int): timeout (in milliseconds), for the whole request
- Raises:
- ValueError: If the local path is invalid, and is not a url
- UnidentifiedImageError: If the image is irretrievable or unprocessable.
-
- Returns:
- ImageType: In-memory PIL image.
- """
- if os.path.isfile(image_path):
- img = Image.open(image_path)
- elif validators.url(image_path):
- if metrics_obj is not None:
- metrics_obj.start(f"image_download.{image_path}")
- try:
- img_io: BytesIO = download_image_from_url(image_path, image_download_headers, timeout_ms)
- img = Image.open(img_io)
- except ImageDownloadError as e:
- raise UnidentifiedImageError(str(e)) from e
- finally:
- if metrics_obj is not None:
- metrics_obj.stop(f"image_download.{image_path}")
- else:
- raise UnidentifiedImageError(f"Input str of {image_path} is not a local file or a valid url. "
- f"If you are using Marqo Cloud, please note that images can only be downloaded "
- f"from a URL and local files are not supported. "
- f"If you are running Marqo in a Docker container, you will need to use a Docker "
- f"volume so that your container can access host files. "
- f"For more information, please refer to: "
- f"{marqo_docs.indexing_images()}")
-
- return img
-
-
-def download_image_from_url(image_path: str, image_download_headers: dict, timeout_ms: int = 3000) -> BytesIO:
- """Download an image from a URL and return a PIL image using pycurl.
-
- Args:
- image_path (str): URL to the image.
- image_download_headers (dict): Headers for the image download.
- timeout_ms (int): Timeout in milliseconds, for the whole request.
-
- Returns:
- buffer (BytesIO): The image as a BytesIO object.
-
- Raises:
- ImageDownloadError: If the image download fails.
- """
-
- if not isinstance(timeout_ms, int):
- raise InternalError(f"timeout must be an integer but received {timeout_ms} of type {type(timeout_ms)}")
-
- try:
- encoded_url = encode_url(image_path)
- except UnicodeEncodeError as e:
- raise ImageDownloadError(f"Marqo encountered an error when downloading the image url {image_path}. "
- f"The url could not be encoded properly. Original error: {e}")
- buffer = BytesIO()
- c = pycurl.Curl()
- c.setopt(pycurl.CAINFO, certifi.where())
- c.setopt(pycurl.URL, encoded_url)
- c.setopt(pycurl.WRITEDATA, buffer)
- c.setopt(pycurl.TIMEOUT_MS, timeout_ms)
- c.setopt(pycurl.FOLLOWLOCATION, 1)
-
- headers = DEFAULT_HEADERS.copy()
- headers.update(image_download_headers)
- c.setopt(pycurl.HTTPHEADER, [f"{k}: {v}" for k, v in headers.items()])
-
- try:
- c.perform()
- if c.getinfo(pycurl.RESPONSE_CODE) != 200:
- raise ImageDownloadError(f"image url `{image_path}` returned {c.getinfo(pycurl.RESPONSE_CODE)}")
- except pycurl.error as e:
- raise ImageDownloadError(f"Marqo encountered an error when downloading the image url {image_path}. "
- f"The original error is: {e}")
- finally:
- c.close()
- buffer.seek(0)
- return buffer
-
-
-def encode_url(url: str) -> str:
- """
- Encode a URL to a valid format with only ASCII characters and reserved characters using percent-encoding.
-
- In version 2.8, we replaced the requests library with pycurl for image downloads. Consequently, we need to implement
- the URL encoding function ourselves. This function replicates the encoding behavior of the
- 'requests.utils.requote_uri' function from the requests library.
-
- Args:
- url (str): The URL to encode.
-
- Returns:
- str: The encoded URL.
-
- Raises:
- UnicodeEncodeError: If the URL cannot be encoded properly.
-
- """
- return requests.utils.requote_uri(url)
diff --git a/src/marqo/core/inference/embedding_models/open_clip_model.py b/src/marqo/core/inference/embedding_models/open_clip_model.py
index e79cb9feb..fdc050316 100644
--- a/src/marqo/core/inference/embedding_models/open_clip_model.py
+++ b/src/marqo/core/inference/embedding_models/open_clip_model.py
@@ -247,10 +247,10 @@ def _download_from_repo(self):
return model_file_path
def encode_image(self, images: Union[str, ImageType, List[Union[str, ImageType]]],
- image_download_headers: Optional[Dict] = None,
+ media_download_headers: Optional[Dict] = None,
normalize=True) -> FloatTensor:
- self.image_input_processed: Tensor = self._preprocess_images(images, image_download_headers)
+ self.image_input_processed: Tensor = self._preprocess_images(images, media_download_headers)
with torch.no_grad():
if self.device.startswith("cuda"):
diff --git a/src/marqo/core/inference/image_download.py b/src/marqo/core/inference/image_download.py
index 65c158e20..9cebb5948 100644
--- a/src/marqo/core/inference/image_download.py
+++ b/src/marqo/core/inference/image_download.py
@@ -71,7 +71,7 @@ def _is_image(inputs: Union[str, List[Union[str, ImageType, ndarray]]]) -> bool:
raise UnidentifiedImageError(f"expected type Image or str for inputs but received type {type(thing)}")
-def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], image_download_headers: dict) -> List[
+def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], media_download_headers: dict) -> List[
ImageType]:
"""takes in a list of strings, arrays or urls and either loads and/or converts to PIL
for the clip model
@@ -90,13 +90,13 @@ def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], im
results = []
for image in images:
- results.append(format_and_load_CLIP_image(image, image_download_headers))
+ results.append(format_and_load_CLIP_image(image, media_download_headers))
return results
def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
- image_download_headers: dict) -> Union[ImageType, Tensor]:
+ media_download_headers: dict) -> Union[ImageType, Tensor]:
"""standardizes the input to be a PIL image
Args:
@@ -113,7 +113,7 @@ def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
"""
# check for the input type
if isinstance(image, str):
- img = load_image_from_path(image, image_download_headers)
+ img = load_image_from_path(image, media_download_headers)
elif isinstance(image, np.ndarray):
img = Image.fromarray(image.astype('uint8'), 'RGB')
elif isinstance(image, torch.Tensor):
@@ -127,13 +127,13 @@ def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
return img
-def load_image_from_path(image_path: str, image_download_headers: dict, timeout_ms=3000,
+def load_image_from_path(image_path: str, media_download_headers: dict, timeout_ms=3000,
metrics_obj: Optional[RequestMetrics] = None) -> ImageType:
"""Loads an image into PIL from a string path that is either local or a url
Args:
image_path (str): Local or remote path to image.
- image_download_headers (dict): header for the image download
+ media_download_headers (dict): header for the image download
timeout_ms (int): timeout (in milliseconds), for the whole request
Raises:
ValueError: If the local path is invalid, and is not a url
@@ -148,7 +148,7 @@ def load_image_from_path(image_path: str, image_download_headers: dict, timeout_
if metrics_obj is not None:
metrics_obj.start(f"image_download.{image_path}")
try:
- img_io: BytesIO = download_image_from_url(image_path, image_download_headers, timeout_ms)
+ img_io: BytesIO = download_image_from_url(image_path, media_download_headers, timeout_ms)
img = Image.open(img_io)
except ImageDownloadError as e:
raise UnidentifiedImageError(str(e)) from e
@@ -167,12 +167,12 @@ def load_image_from_path(image_path: str, image_download_headers: dict, timeout_
return img
-def download_image_from_url(image_path: str, image_download_headers: dict, timeout_ms: int = 3000) -> BytesIO:
+def download_image_from_url(image_path: str, media_download_headers: dict, timeout_ms: int = 3000) -> BytesIO:
"""Download an image from a URL and return a PIL image using pycurl.
Args:
image_path (str): URL to the image.
- image_download_headers (dict): Headers for the image download.
+ media_download_headers (dict): Headers for the image download.
timeout_ms (int): Timeout in milliseconds, for the whole request.
Returns:
@@ -199,7 +199,7 @@ def download_image_from_url(image_path: str, image_download_headers: dict, timeo
c.setopt(pycurl.FOLLOWLOCATION, 1)
headers = DEFAULT_HEADERS.copy()
- headers.update(image_download_headers)
+ headers.update(media_download_headers)
c.setopt(pycurl.HTTPHEADER, [f"{k}: {v}" for k, v in headers.items()])
try:
diff --git a/src/marqo/core/models/add_docs_params.py b/src/marqo/core/models/add_docs_params.py
index 557bf166b..66cf12185 100644
--- a/src/marqo/core/models/add_docs_params.py
+++ b/src/marqo/core/models/add_docs_params.py
@@ -31,7 +31,7 @@ class AddDocsParams(BaseModel):
device: Device used to carry out the document update, if `None` is given, it will be determined by
EnvVars.MARQO_BEST_AVAILABLE_DEVICE
image_download_thread_count: number of threads used to concurrently download images
- image_download_headers: headers to authenticate image download
+ media_download_headers: headers to authenticate media download requests
mappings: a dictionary used to handle all the object field content in the doc,
e.g., multimodal_combination field
model_auth: an object used to authorise downloading an object from a datastore
@@ -53,7 +53,7 @@ class Config:
image_download_thread_count: int = Field(default_factory=lambda: read_env_vars_and_defaults_ints(
EnvVars.MARQO_IMAGE_DOWNLOAD_THREAD_COUNT_PER_REQUEST))
media_download_thread_count: Optional[int]
- image_download_headers: dict = Field(default_factory=dict)
+ media_download_headers: Optional[dict] = None
use_existing_tensors: bool = False
mappings: Optional[dict] = None
model_auth: Optional[ModelAuth] = None
diff --git a/src/marqo/core/search/hybrid_search.py b/src/marqo/core/search/hybrid_search.py
index 3bc2e4ead..9ee8a5264 100644
--- a/src/marqo/core/search/hybrid_search.py
+++ b/src/marqo/core/search/hybrid_search.py
@@ -33,7 +33,7 @@ def search(
offset: int = 0, ef_search: Optional[int] = None, approximate: bool = True,
searchable_attributes: Iterable[str] = None, filter_string: str = None, device: str = None,
attributes_to_retrieve: Optional[List[str]] = None, boost: Optional[Dict] = None,
- image_download_headers: Optional[Dict] = None, context: Optional[SearchContext] = None,
+ media_download_headers: Optional[Dict] = None, context: Optional[SearchContext] = None,
score_modifiers: Optional[ScoreModifierLists] = None, model_auth: Optional[ModelAuth] = None,
highlights: bool = False, text_query_prefix: Optional[str] = None,
hybrid_parameters: HybridParameters = None) -> Dict:
@@ -51,7 +51,8 @@ def search(
verbose: if 0 - nothing is printed. if 1 - data is printed without vectors, if 2 - full
objects are printed out
attributes_to_retrieve: if set, only returns these fields
- image_download_headers: headers for downloading images
+ media_download_headers: headers for downloading media
+
context: a dictionary to allow custom vectors in search
score_modifiers: a dictionary to modify the score based on field values, should be None for hybrid search
model_auth: Authorisation details for downloading a model (if required)
@@ -151,7 +152,7 @@ def search(
q=query_text_vectorise, searchableAttributes=searchable_attributes, searchMethod=SearchMethod.HYBRID,
limit=result_count,
offset=offset, showHighlights=False, filter=filter_string, attributesToRetrieve=attributes_to_retrieve,
- boost=boost, image_download_headers=image_download_headers, context=context, scoreModifiers=score_modifiers,
+ boost=boost, mediaDownloadHeaders=media_download_headers, context=context, scoreModifiers=score_modifiers,
index=marqo_index, modelAuth=model_auth, text_query_prefix=text_query_prefix,
hybridParameters=hybrid_parameters
)]
diff --git a/src/marqo/core/semi_structured_vespa_index/semi_structured_add_document_handler.py b/src/marqo/core/semi_structured_vespa_index/semi_structured_add_document_handler.py
index c82ae3fc1..74ffb4073 100644
--- a/src/marqo/core/semi_structured_vespa_index/semi_structured_add_document_handler.py
+++ b/src/marqo/core/semi_structured_vespa_index/semi_structured_add_document_handler.py
@@ -41,7 +41,10 @@ def __init__(self, marqo_index: SemiStructuredMarqoIndex, add_docs_params: AddDo
def _handle_field(self, marqo_doc, field_name, field_content):
self._validate_field(field_name, field_content)
- text_field_type = self._infer_field_type(field_content)
+ text_field_type = self._infer_field_type(
+ field_content,
+ media_download_headers=self.add_docs_params.media_download_headers
+ )
content = self.tensor_fields_container.collect(marqo_doc[MARQO_DOC_ID], field_name,
field_content, text_field_type)
marqo_doc[field_name] = content
diff --git a/src/marqo/core/unstructured_vespa_index/unstructured_add_document_handler.py b/src/marqo/core/unstructured_vespa_index/unstructured_add_document_handler.py
index 31c3b300c..7915455aa 100644
--- a/src/marqo/core/unstructured_vespa_index/unstructured_add_document_handler.py
+++ b/src/marqo/core/unstructured_vespa_index/unstructured_add_document_handler.py
@@ -65,17 +65,18 @@ def _validate_doc(self, doc):
def _handle_field(self, marqo_doc, field_name, field_content):
self._validate_field(field_name, field_content)
- text_field_type = self._infer_field_type(field_content)
+ text_field_type = self._infer_field_type(field_content, self.add_docs_params.media_download_headers)
content = self.tensor_fields_container.collect(marqo_doc[MARQO_DOC_ID], field_name,
field_content, text_field_type)
marqo_doc[field_name] = content
- def _infer_field_type(self, field_content: Any) -> Optional[FieldType]:
+ def _infer_field_type(self, field_content: Any, media_download_headers: Optional[Dict] = None) \
+ -> Optional[FieldType]:
if not isinstance(field_content, str):
return None
try:
- modality = infer_modality(field_content)
+ modality = infer_modality(field_content, media_download_headers)
if not self.marqo_index.treat_urls_and_pointers_as_media and modality in [Modality.AUDIO, Modality.VIDEO]:
modality = Modality.TEXT
diff --git a/src/marqo/core/vespa_index/add_documents_handler.py b/src/marqo/core/vespa_index/add_documents_handler.py
index b181d35b7..8133abd4d 100644
--- a/src/marqo/core/vespa_index/add_documents_handler.py
+++ b/src/marqo/core/vespa_index/add_documents_handler.py
@@ -421,4 +421,3 @@ def _field_type_chunker_map(self, media_repo):
FieldType.VideoPointer: AudioVideoChunker(media_repo=media_repo),
}
return chunkers
-
diff --git a/src/marqo/s2_inference/clip_utils.py b/src/marqo/s2_inference/clip_utils.py
index 342e6d849..d1fd5c684 100644
--- a/src/marqo/s2_inference/clip_utils.py
+++ b/src/marqo/s2_inference/clip_utils.py
@@ -67,7 +67,7 @@ def _get_transform(n_px: int, image_mean: List[float] = None, image_std: List[fl
])
-def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], image_download_headers: dict) -> List[
+def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], media_download_headers: dict) -> List[
ImageType]:
"""takes in a list of strings, arrays or urls and either loads and/or converts to PIL
for the clip model
@@ -86,18 +86,18 @@ def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], im
results = []
for image in images:
- results.append(format_and_load_CLIP_image(image, image_download_headers))
+ results.append(format_and_load_CLIP_image(image, media_download_headers))
return results
-def load_image_from_path(image_path: str, image_download_headers: dict, timeout_ms=3000,
+def load_image_from_path(image_path: str, media_download_headers: dict, timeout_ms=3000,
metrics_obj: Optional[RequestMetrics] = None) -> ImageType:
"""Loads an image into PIL from a string path that is either local or a url
Args:
image_path (str): Local or remote path to image.
- image_download_headers (dict): header for the image download
+ media_download_headers (dict): header for the image download
timeout_ms (int): timeout (in milliseconds), for the whole request
Raises:
ValueError: If the local path is invalid, and is not a url
@@ -112,7 +112,7 @@ def load_image_from_path(image_path: str, image_download_headers: dict, timeout_
if metrics_obj is not None:
metrics_obj.start(f"image_download.{image_path}")
try:
- img_io: BytesIO = download_image_from_url(image_path, image_download_headers, timeout_ms)
+ img_io: BytesIO = download_image_from_url(image_path, media_download_headers, timeout_ms)
img = Image.open(img_io)
except ImageDownloadError as e:
raise UnidentifiedImageError(str(e)) from e
@@ -145,12 +145,12 @@ def validate_url(url: str) -> bool:
-def download_image_from_url(image_path: str, image_download_headers: dict, timeout_ms: int = 3000) -> BytesIO:
+def download_image_from_url(image_path: str, media_download_headers: dict, timeout_ms: int = 3000) -> BytesIO:
"""Download an image from a URL and return a PIL image using pycurl.
Args:
image_path (str): URL to the image.
- image_download_headers (dict): Headers for the image download.
+ media_download_headers (dict): Headers for the image download.
timeout_ms (int): Timeout in milliseconds, for the whole request.
Returns:
@@ -177,7 +177,9 @@ def download_image_from_url(image_path: str, image_download_headers: dict, timeo
c.setopt(pycurl.FOLLOWLOCATION, 1)
headers = DEFAULT_HEADERS.copy()
- headers.update(image_download_headers)
+ if media_download_headers is None:
+ media_download_headers = dict()
+ headers.update(media_download_headers)
c.setopt(pycurl.HTTPHEADER, [f"{k}: {v}" for k, v in headers.items()])
try:
@@ -215,7 +217,7 @@ def encode_url(url: str) -> str:
def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
- image_download_headers: dict) -> Union[ImageType, Tensor]:
+ media_download_headers: dict) -> Union[ImageType, Tensor]:
"""standardizes the input to be a PIL image
Args:
@@ -232,7 +234,7 @@ def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
"""
# check for the input type
if isinstance(image, str):
- img = load_image_from_path(image, image_download_headers)
+ img = load_image_from_path(image, media_download_headers)
elif isinstance(image, np.ndarray):
img = Image.fromarray(image.astype('uint8'), 'RGB')
elif isinstance(image, torch.Tensor):
@@ -418,27 +420,27 @@ def encode_text(self, sentence: Union[str, List[str]], normalize=True) -> FloatT
return self._convert_output(outputs)
def _preprocess_images(self, images: Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor],
- image_download_headers: Optional[Dict] = None) -> Tensor:
+ media_download_headers: Optional[Dict] = None) -> Tensor:
"""Preprocess the input image to be ready for the model.
Args:
images (Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor]): input image,
can be a str(url), a PIL image, or a tensor, or a list of them
- image_download_headers (Optional[Dict]): headers for the image download
+ media_download_headers (Optional[Dict]): headers for the image download
Return:
Tensor: the processed image tensor with shape (batch_size, channel, n_px, n_px)
"""
if self.model is None:
self.load()
- if image_download_headers is None:
- image_download_headers = dict()
+ if media_download_headers is None:
+ media_download_headers = dict()
# default to batch encoding
if isinstance(images, list):
image_input: List[Union[ImageType, Tensor]] \
- = format_and_load_CLIP_images(images, image_download_headers)
+ = format_and_load_CLIP_images(images, media_download_headers)
else:
- image_input: List[Union[ImageType, Tensor]] = [format_and_load_CLIP_image(images, image_download_headers)]
+ image_input: List[Union[ImageType, Tensor]] = [format_and_load_CLIP_image(images, media_download_headers)]
image_input_processed: Tensor = torch.stack([self.preprocess(_img).to(self.device) \
if not isinstance(_img, torch.Tensor) else _img \
@@ -446,18 +448,18 @@ def _preprocess_images(self, images: Union[str, ImageType, List[Union[str, Image
return image_input_processed
def encode_image(self, images: Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor],
- normalize=True, image_download_headers: Optional[Dict] = None) -> FloatTensor:
+ normalize=True, media_download_headers: Optional[Dict] = None) -> FloatTensor:
"""Encode the input image to a tensor representation.
Args:
images (Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor]): input image,
can be a str(url), a PIL image, or a tensor, or a list of them
normalize (bool): whether to normalize the output tensor
- image_download_headers (Optional[Dict]): headers for the image download
+ media_download_headers (Optional[Dict]): headers for the image download
Return:
FloatTensor: the encoded image tensor with shape (batch_size, embedding_dim)
"""
- self.image_input_processed: Tensor = self._preprocess_images(images, image_download_headers)
+ self.image_input_processed: Tensor = self._preprocess_images(images, media_download_headers)
with torch.no_grad():
outputs = self.model.encode_image(self.image_input_processed)
@@ -485,8 +487,8 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
if is_image:
logger.debug('image')
- image_download_headers = kwargs.get("image_download_headers", dict())
- return self.encode_image(inputs, normalize=normalize, image_download_headers=image_download_headers)
+ media_download_headers = kwargs.get("media_download_headers", dict())
+ return self.encode_image(inputs, normalize=normalize, media_download_headers=media_download_headers)
else:
logger.debug('text')
return self.encode_text(inputs, normalize=normalize)
@@ -571,16 +573,16 @@ def encode_text(self, sentence: Union[str, List[str]], normalize=True) -> FloatT
return self._convert_output(outputs)
def encode_image(self, images: Union[str, ImageType, List[Union[str, ImageType]]],
- normalize=True, image_download_headers: Optional[dict] = None) -> FloatTensor:
+ normalize=True, media_download_headers: Optional[dict] = None) -> FloatTensor:
if self.visual_model is None:
self.load()
- if image_download_headers is None:
- image_download_headers = dict()
+ if media_download_headers is None:
+ media_download_headers = dict()
# default to batch encoding
if isinstance(images, list):
- image_input = format_and_load_CLIP_images(images, image_download_headers)
+ image_input = format_and_load_CLIP_images(images, media_download_headers)
else:
image_input = [format_and_load_CLIP_image(images, {})]
diff --git a/src/marqo/s2_inference/languagebind/image/processing_image.py b/src/marqo/s2_inference/languagebind/image/processing_image.py
index 7a3d7c396..90f80b155 100644
--- a/src/marqo/s2_inference/languagebind/image/processing_image.py
+++ b/src/marqo/s2_inference/languagebind/image/processing_image.py
@@ -13,10 +13,15 @@ def make_list_of_images(x):
return x
+def _convert_to_rgb(image):
+ return image.convert("RGB")
+
+
def get_image_transform(config):
config = config.vision_config
transform = transforms.Compose(
[
+ _convert_to_rgb,
transforms.ToTensor(),
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
diff --git a/src/marqo/s2_inference/multimodal_model_load.py b/src/marqo/s2_inference/multimodal_model_load.py
index 173630c22..593c45bf5 100644
--- a/src/marqo/s2_inference/multimodal_model_load.py
+++ b/src/marqo/s2_inference/multimodal_model_load.py
@@ -11,10 +11,11 @@
from pydantic import BaseModel
from enum import Enum
from abc import ABC, abstractmethod
-from typing import List, Dict, Any, Union
+from typing import List, Dict, Any, Union, Optional
from PIL.Image import Image
import torch
from urllib.parse import quote
+from marqo.core.inference.image_download import DEFAULT_HEADERS
from marqo.s2_inference.multimodal_model_load import *
@@ -109,15 +110,15 @@ def preprocessor(self, modality):
raise ValueError("Model has not been loaded yet. Call _load_model() first.")
return self.encoder.preprocessor(modality)
- def encode(self, content, modality, **kwargs):
+ def encode(self, content, modality, media_download_headers: Optional[Dict]=None, **kwargs):
if self.encoder is None:
raise ValueError("Model has not been loaded yet. Call _load_model() first.")
- return self.encoder.encode(content, modality, **kwargs)
+ return self.encoder.encode(content, modality, media_download_headers, **kwargs)
class ModelEncoder(ABC):
@abstractmethod
- def encode(self, content, modality, **kwargs):
+ def encode(self, content, modality, media_download_headers, **kwargs):
pass
@@ -125,13 +126,20 @@ class DefaultEncoder(ModelEncoder):
def __init__(self, model):
self.model = model
- def encode(self, content, modality, **kwargs):
- return self.model.encode(content, **kwargs)
+ def encode(self, content, modality, media_download_headers, **kwargs):
+ return self.model.encode(content, modality=modality, media_download_headers=media_download_headers, **kwargs)
@contextmanager
-def fetch_content_sample(url, sample_size=10240): # 10 KB
- response = requests.get(url, stream=True)
+def fetch_content_sample(url, media_download_headers: Optional[dict] = None, sample_size=10240): # 10 KB
+ # It's ok to pass None to requests.get() for headers and it won't change the default headers
+ """Fetch a sample of the content from the URL.
+
+ Raises:
+ HTTPError: If the response status code is not 200
+ """
+ response = requests.get(url, stream=True, headers=media_download_headers)
+ response.raise_for_status()
buffer = io.BytesIO()
try:
for chunk in response.iter_content(chunk_size=min(sample_size, 8192)):
@@ -145,7 +153,7 @@ def fetch_content_sample(url, sample_size=10240): # 10 KB
response.close()
-def infer_modality(content: Union[str, List[str], bytes]) -> Modality:
+def infer_modality(content: Union[str, List[str], bytes], media_download_headers: Optional[dict] = None) -> Modality:
"""
Infer the modality of the content. Video, audio, image or text.
"""
@@ -155,7 +163,6 @@ def infer_modality(content: Union[str, List[str], bytes]) -> Modality:
# Encode the URL
encoded_url = encode_url(content)
-
extension = encoded_url.split('.')[-1].lower()
if extension in ['jpg', 'jpeg', 'png', 'gif', 'webp']:
return Modality.IMAGE
@@ -163,11 +170,10 @@ def infer_modality(content: Union[str, List[str], bytes]) -> Modality:
return Modality.VIDEO
elif extension in ['mp3', 'wav', 'ogg']:
return Modality.AUDIO
-
if validate_url(encoded_url):
# Use context manager to handle content sample
try:
- with fetch_content_sample(encoded_url) as sample:
+ with fetch_content_sample(encoded_url, media_download_headers) as sample:
mime = magic.from_buffer(sample.read(), mime=True)
if mime.startswith('image/'):
return Modality.IMAGE
@@ -249,7 +255,7 @@ def preprocessor(self, modality):
return self._preprocessors.get(modality)
- def encode(self, content, modality, normalize=True, **kwargs):
+ def encode(self, content, modality, normalize=True, media_download_headers: Optional[Dict]=None, **kwargs):
inputs = {}
if modality == Modality.TEXT:
@@ -267,7 +273,7 @@ def encode(self, content, modality, normalize=True, **kwargs):
with open(temp_filename, 'wb') as f:
f.write(content)
elif isinstance(content, str) and "http" in content:
- self._download_content(content, temp_filename)
+ self._download_content(content, temp_filename, media_download_headers)
else:
return self.encode([content], modality=Modality.TEXT)
@@ -278,7 +284,7 @@ def encode(self, content, modality, normalize=True, **kwargs):
if isinstance(content, str) and "http" in content:
suffix = ".mp4" if modality == Modality.VIDEO else ".wav"
with self._temp_file(suffix) as temp_filename:
- self._download_content(content, temp_filename)
+ self._download_content(content, temp_filename, media_download_headers)
preprocessed_content = self.preprocessor(modality)([temp_filename], return_tensors='pt')
inputs[modality.value] = to_device(preprocessed_content, self.model.device)['pixel_values']
@@ -286,7 +292,7 @@ def encode(self, content, modality, normalize=True, **kwargs):
# If media has already been preprocessed
inputs[modality.value] = to_device(content[0], self.model.device)['pixel_values']
elif isinstance(content[0], str) and 'http' in content[0]:
- return self.encode(content[0], modality=modality)
+ return self.encode(content[0], modality=modality, media_download_headers=media_download_headers)
else:
raise ValueError(f"Unsupported {modality.value} content type: {type(content)}, content: {content}")
@@ -300,11 +306,11 @@ def encode(self, content, modality, normalize=True, **kwargs):
return embeddings.cpu().numpy()
- def _download_content(self, url, filename):
+ def _download_content(self, url, filename, media_download_headers: Optional[Dict]=None):
# 3 seconds for images, 20 seconds for audio and video
timeout_ms = 3000 if filename.endswith(('.png', '.jpg', '.jpeg')) else 20000
- buffer = download_image_from_url(url, {}, timeout_ms)
+ buffer = download_image_from_url(url, media_download_headers, timeout_ms)
with open(filename, 'wb') as f:
f.write(buffer.getvalue())
diff --git a/src/marqo/s2_inference/s2_inference.py b/src/marqo/s2_inference/s2_inference.py
index 55f2d7781..19ce24c8e 100644
--- a/src/marqo/s2_inference/s2_inference.py
+++ b/src/marqo/s2_inference/s2_inference.py
@@ -45,11 +45,13 @@
cache_type=read_env_vars_and_defaults(EnvVars.MARQO_INFERENCE_CACHE_TYPE))
-def vectorise(model_name: str, content: Union[str, List[str], List[Image], List[bytes]],
- model_properties: dict = None,
- device: str = None, normalize_embeddings: bool = get_default_normalization(),
- model_auth: ModelAuth = None, enable_cache: bool = False, modality: Modality = Modality.TEXT,
- **kwargs, ) -> List[List[float]]:
+
+def vectorise(
+ model_name: str, content: Union[str, List[str], List[Image], List[bytes]],
+ model_properties: dict = None,
+ device: str = None, normalize_embeddings: bool = get_default_normalization(),
+ model_auth: ModelAuth = None, enable_cache: bool = False, modality: Modality = Modality.TEXT,
+ media_download_headers: Optional[Dict] = None, **kwargs) -> List[List[float]]:
if not device:
raise InternalError(message=f"vectorise (internal function) cannot be called without setting device!")
@@ -64,27 +66,36 @@ def vectorise(model_name: str, content: Union[str, List[str], List[Image], List[
model = _available_models[model_cache_key][AvailableModelsKey.model]
if _marqo_inference_cache.is_enabled() and enable_cache:
- return _vectorise_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs)
+ return _vectorise_with_cache(model, model_cache_key, content, normalize_embeddings, modality,
+ media_download_headers, **kwargs)
else:
- return _vectorise_without_cache(model_cache_key, content, normalize_embeddings, modality, **kwargs)
-
+ return _vectorise_without_cache(model_cache_key, content, normalize_embeddings, modality, media_download_headers,
+ **kwargs)
-def _vectorise_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs):
+def _vectorise_with_cache(model, model_cache_key, content, normalize_embeddings, modality, media_download_headers,
+ **kwargs):
if isinstance(content, str):
vectorised = _marqo_inference_cache.get(model_cache_key, content)
if vectorised is None:
- vectorised = _encode_without_cache(model_cache_key, content, normalize_embeddings, modality, **kwargs)
+ vectorised = _encode_without_cache(
+ model_cache_key, content, normalize_embeddings, modality, media_download_headers,
+ **kwargs
+ )
_marqo_inference_cache.set(model_cache_key, content, vectorised[0])
else:
vectorised = _convert_cached_embeddings_to_output(vectorised)
return vectorised
elif isinstance(content, list):
- return _vectorise_list_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs)
+ return _vectorise_list_with_cache(
+ model, model_cache_key, content, normalize_embeddings, modality,
+ media_download_headers,
+ **kwargs
+ )
else:
raise TypeError(f"Unsupported content type: {type(content).__name__}")
-
-def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs):
+def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embeddings, modality, media_download_headers,
+ **kwargs):
contents_to_vectorise = []
cached_output = []
@@ -100,8 +111,10 @@ def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embedd
contents_to_vectorise.append(content_item)
if contents_to_vectorise:
- vectorised_outputs = _encode_without_cache(model_cache_key, contents_to_vectorise, normalize_embeddings,
- modality, **kwargs)
+ vectorised_outputs = _encode_without_cache(
+ model_cache_key, contents_to_vectorise, normalize_embeddings, modality,
+ media_download_headers, **kwargs
+ )
# Cache the vectorised outputs
for content_item, vectorised_output in zip(contents_to_vectorise, vectorised_outputs):
if isinstance(content_item, str):
@@ -115,19 +128,25 @@ def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embedd
return vectorised_outputs
-def _vectorise_without_cache(model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
- normalize_embeddings: bool, modality: Modality, **kwargs) -> List[List[float]]:
- return _encode_without_cache(model_cache_key, content, normalize_embeddings, modality, **kwargs)
+def _vectorise_without_cache(
+ model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
+ normalize_embeddings: bool, modality: Modality, media_download_headers,
+ **kwargs) -> List[List[float]]:
+ return _encode_without_cache(model_cache_key, content, normalize_embeddings, modality, media_download_headers, **kwargs)
def _encode_without_cache(model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
- normalize_embeddings: bool, modality: Modality, **kwargs) -> List[List[float]]:
+ normalize_embeddings: bool, modality: Modality, media_download_headers: Optional[Dict]=None,
+ **kwargs) -> List[List[float]]:
try:
model = _available_models[model_cache_key][AvailableModelsKey.model]
encoder = get_encoder(model)
if isinstance(content, str):
- vectorised = model.encode(content, normalize=normalize_embeddings, modality=modality, **kwargs)
+ vectorised = model.encode(
+ content, normalize=normalize_embeddings, modality=modality,
+ media_download_headers=media_download_headers, **kwargs
+ )
elif isinstance(content, (torch.Tensor, torch.FloatTensor)):
vectorised = model.encode(content, normalize=normalize_embeddings, modality=modality, **kwargs)
else:
@@ -140,9 +159,10 @@ def _encode_without_cache(model_cache_key: str, content: Union[str, List[str], L
# TODO maybe the infer parameter can be replaced by modality
infer = kwargs.pop('infer', False if modality == Modality.TEXT else True)
- encoded_batch = encoder.encode(batch, modality=modality, normalize=normalize_embeddings,
- infer=infer, **kwargs)
-
+ encoded_batch = encoder.encode(
+ batch, modality=modality, normalize=normalize_embeddings,
+ media_download_headers=media_download_headers, infer = infer, **kwargs)
+
vector_batches.append(_convert_tensor_to_numpy(encoded_batch))
if not vector_batches or all(len(batch) == 0 for batch in vector_batches):
@@ -233,7 +253,6 @@ def load_multimodal_model_and_get_preprocessors(model_name: str, model_propertie
}
return model, preprocessors
- # return getattr(_available_models[model_cache_key][AvailableModelsKey.model], "preprocess", None)
def _get_max_vectorise_batch_size() -> int:
diff --git a/src/marqo/tensor_search/add_docs.py b/src/marqo/tensor_search/add_docs.py
index a8dbded3d..9906075a7 100644
--- a/src/marqo/tensor_search/add_docs.py
+++ b/src/marqo/tensor_search/add_docs.py
@@ -39,10 +39,9 @@
def threaded_download_and_preprocess_content(allocated_docs: List[dict],
media_repo: dict,
tensor_fields: List[str],
- image_download_headers: dict,
device: str = None,
media_field_types_mapping: Optional[Dict[str, FieldType]] = None,
- download_headers: Optional[Dict] = None, # Optional for now
+ media_download_headers: Optional[Dict] = None,
metric_obj: Optional[RequestMetrics] = None,
preprocessors: Optional[Dict[str, Compose]] = None,
marqo_index_type: Optional[IndexType] = None,
@@ -59,7 +58,7 @@ def threaded_download_and_preprocess_content(allocated_docs: List[dict],
image_repo: dictionary that will be mutated by this thread. It will add PIL images
as values and the URLs as keys
tensor_fields: A tuple of tensor_fields. Images will be downloaded for these fields only.
- image_download_headers: A dict of headers for image download. Can be used
+ media_download_headers: A dict of headers for image download. Can be used
to authenticate image downloads
force_download: If True, skip the _is_image check and download the fields as images.
Side Effects:
@@ -93,7 +92,7 @@ def threaded_download_and_preprocess_content(allocated_docs: List[dict],
continue
if isinstance(doc[field], str) or force_download:
try:
- inferred_modality = infer_modality(doc[field])
+ inferred_modality = infer_modality(doc[field], media_download_headers)
except MediaDownloadError as e:
if is_structured_index and media_field_types_mapping[field] == FieldType.ImagePointer:
# Continue processing for structured indexes with image fields
@@ -118,7 +117,7 @@ def threaded_download_and_preprocess_content(allocated_docs: List[dict],
continue
try:
- media_repo[doc[field]] = clip_utils.load_image_from_path(doc[field], image_download_headers,
+ media_repo[doc[field]] = clip_utils.load_image_from_path(doc[field], media_download_headers,
timeout_ms=int(
TIMEOUT_SECONDS * 1000),
metrics_obj=metric_obj)
@@ -166,9 +165,12 @@ def threaded_download_and_preprocess_content(allocated_docs: List[dict],
continue
try:
- processed_chunks = download_and_chunk_media(doc[field], device, download_headers, inferred_modality,
- marqo_index_type, marqo_index_model, preprocessors,
- audio_preprocessing, video_preprocessing)
+ processed_chunks = download_and_chunk_media(
+ url=doc[field], device=device, modality=inferred_modality,
+ marqo_index_type=marqo_index_type, marqo_index_model=marqo_index_model,
+ preprocessors=preprocessors, audio_preprocessing=audio_preprocessing,
+ video_preprocessing=video_preprocessing, media_download_headers=media_download_headers
+ )
media_repo[doc[field]] = processed_chunks
except (ffmpeg.Error, S2InferenceError) as e:
logger.error(f"Error processing {inferred_modality} file: {str(e)}")
@@ -188,7 +190,7 @@ def threaded_download_and_preprocess_content(allocated_docs: List[dict],
try:
media_repo[sub_field] = clip_utils.load_image_from_path(
sub_field,
- image_download_headers,
+ media_download_headers,
timeout=TIMEOUT_SECONDS,
metrics_obj=metric_obj
)
@@ -198,13 +200,17 @@ def threaded_download_and_preprocess_content(allocated_docs: List[dict],
continue
-def download_and_chunk_media(url: str, device: str, headers: dict, modality: Modality, marqo_index_type: IndexType, marqo_index_model: Model,
+def download_and_chunk_media(url: str, device: str, modality: Modality, marqo_index_type: IndexType, marqo_index_model: Model,
preprocessors: Preprocessors, audio_preprocessing: AudioPreProcessing = None,
- video_preprocessing: VideoPreProcessing = None) -> List[Dict[str, torch.Tensor]]:
+ video_preprocessing: VideoPreProcessing = None,
+ media_download_headers: Optional[Dict] = None) -> List[Dict[str, torch.Tensor]]:
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB in bytes
- processor = StreamingMediaProcessor(url, device, headers, modality, marqo_index_type, marqo_index_model, preprocessors,
- audio_preprocessing, video_preprocessing)
+ processor = StreamingMediaProcessor(
+ url=url, device=device, modality=modality, marqo_index_type=marqo_index_type, marqo_index_model=marqo_index_model,
+ preprocessors=preprocessors, audio_preprocessing=audio_preprocessing, video_preprocessing=video_preprocessing,
+ media_download_headers=media_download_headers
+ )
if processor.total_size > MAX_FILE_SIZE:
raise ValueError(
@@ -222,24 +228,24 @@ def download_and_preprocess_multimedia_content(
) -> ContextManager[dict]:
thread_count = _determine_thread_count(marqo_index, add_docs_params)
- media_repo = process_batch(docs=docs,
- thread_count=thread_count,
- tensor_fields=list(media_field_types_mapping.keys()),
- media_field_types_mapping=media_field_types_mapping,
- image_download_headers=add_docs_params.image_download_headers,
- download_headers=None, # TODO verify if this is used
- marqo_index_type=marqo_index.type,
- device=add_docs_params.device,
- marqo_index_model=marqo_index.model,
- model_name=marqo_index.model.name,
- model_properties=marqo_index.model.properties,
- normalize_embeddings=marqo_index.normalize_embeddings,
- model_auth=add_docs_params.model_auth,
- patch_method_exists=marqo_index.image_preprocessing.patch_method is not None,
- audio_preprocessing=marqo_index.audio_preprocessing,
- video_preprocessing=marqo_index.video_preprocessing,
- force_download=False, # TODO verify if this is used
- )
+ media_repo = process_batch(
+ docs=docs,
+ thread_count=thread_count,
+ tensor_fields=list(media_field_types_mapping.keys()),
+ media_field_types_mapping=media_field_types_mapping,
+ media_download_headers=add_docs_params.media_download_headers,
+ marqo_index_type=marqo_index.type,
+ device=add_docs_params.device,
+ marqo_index_model=marqo_index.model,
+ model_name=marqo_index.model.name,
+ model_properties=marqo_index.model.properties,
+ normalize_embeddings=marqo_index.normalize_embeddings,
+ model_auth=add_docs_params.model_auth,
+ patch_method_exists=marqo_index.image_preprocessing.patch_method is not None,
+ audio_preprocessing=marqo_index.audio_preprocessing,
+ video_preprocessing=marqo_index.video_preprocessing,
+ force_download=False, # TODO verify if this is used
+ )
try:
yield media_repo
@@ -289,11 +295,10 @@ def _determine_thread_count(marqo_index: MarqoIndex, add_docs_params: AddDocsPar
@contextmanager
def download_and_preprocess_content(docs: List[dict], thread_count: int, tensor_fields: List[str],
- image_download_headers: dict,
model_name: str,
normalize_embeddings: bool,
media_field_types_mapping: Optional[Dict[str, FieldType]],
- download_headers: Optional[Dict] = None, # Optional for now
+ media_download_headers: Optional[Dict] = None,
model_properties: Optional[Dict] = None,
model_auth: Optional[ModelAuth] = None,
device: Optional[str] = None,
@@ -305,11 +310,24 @@ def download_and_preprocess_content(docs: List[dict], thread_count: int, tensor_
force_download: bool = False
) -> ContextManager[dict]:
media_repo = {} # for image/video/audio
- media_repo = process_batch(docs, thread_count, tensor_fields, image_download_headers,
- model_name, normalize_embeddings, force_download,
- media_field_types_mapping, download_headers, model_properties, model_auth,
- device, patch_method_exists, marqo_index_type, marqo_index_model,
- audio_preprocessing, video_preprocessing)
+ media_repo = process_batch(
+ docs = docs,
+ thread_count = thread_count,
+ tensor_fields = tensor_fields,
+ model_name = model_name,
+ normalize_embeddings = normalize_embeddings,
+ force_download = force_download,
+ media_field_types_mapping = media_field_types_mapping,
+ media_download_headers = media_download_headers,
+ model_properties = model_properties,
+ model_auth = model_auth,
+ device = device,
+ patch_method_exists = patch_method_exists,
+ marqo_index_type = marqo_index_type,
+ marqo_index_model = marqo_index_model,
+ audio_preprocessing = audio_preprocessing,
+ video_preprocessing = video_preprocessing
+ )
try:
yield media_repo
@@ -322,14 +340,18 @@ def download_and_preprocess_content(docs: List[dict], thread_count: int, tensor_
pass
-def process_batch(docs: List[dict], thread_count: int, tensor_fields: List[str],
- image_download_headers: dict, model_name: str, normalize_embeddings: bool,
- force_download: bool, media_field_types_mapping: Optional[Dict[str, FieldType]],
- download_headers: Optional[Dict], model_properties: Optional[Dict],
- model_auth: Optional[ModelAuth], device: Optional[str],
- patch_method_exists: bool, marqo_index_type: Optional[IndexType], marqo_index_model: Optional[Model],
- audio_preprocessing: Optional[AudioPreProcessing] = None,
- video_preprocessing: Optional[VideoPreProcessing] = None) -> dict:
+def process_batch(
+ docs: List[dict], thread_count: int, tensor_fields: List[str],
+ model_name: str, normalize_embeddings: bool,
+ force_download: bool, media_field_types_mapping: Optional[Dict[str, FieldType]],
+ model_properties: Optional[Dict],
+ model_auth: Optional[ModelAuth], device: Optional[str],
+ patch_method_exists: bool, marqo_index_type: Optional[IndexType], marqo_index_model: Optional[Model],
+ media_download_headers: Optional[Dict] = None,
+ audio_preprocessing: Optional[AudioPreProcessing] = None,
+ video_preprocessing: Optional[VideoPreProcessing] = None
+) -> dict:
+
docs_per_thread = math.ceil(len(docs) / thread_count)
copied = copy.deepcopy(docs)
@@ -349,25 +371,26 @@ def process_batch(docs: List[dict], thread_count: int, tensor_fields: List[str],
# Consider replacing below with:
# thread_allocated_docs = [copied[i: i + docs_per_thread] for i in range(0, len(copied), docs_per_thread)]
thread_allocated_docs = [copied[i: i + docs_per_thread] for i in range(len(copied))[::docs_per_thread]]
- download_headers = download_headers if download_headers else {}
with ThreadPoolExecutor(max_workers=len(thread_allocated_docs)) as executor:
- futures = [executor.submit(threaded_download_and_preprocess_content,
- allocation,
- media_repo,
- tensor_fields,
- image_download_headers,
- device,
- media_field_types_mapping,
- download_headers,
- m[i],
- preprocessors,
- marqo_index_type,
- marqo_index_model,
- audio_preprocessing,
- video_preprocessing,
- force_download)
- for i, allocation in enumerate(thread_allocated_docs)]
+ futures = [
+ executor.submit(
+ threaded_download_and_preprocess_content,
+ allocation,
+ media_repo,
+ tensor_fields,
+ device,
+ media_field_types_mapping,
+ media_download_headers,
+ m[i],
+ preprocessors,
+ marqo_index_type,
+ marqo_index_model,
+ audio_preprocessing,
+ video_preprocessing,
+ force_download)
+ for i, allocation in enumerate(thread_allocated_docs)
+ ]
# Unhandled exceptions will be raised here.
# We only raise the first exception if there are multiple exceptions
diff --git a/src/marqo/tensor_search/api.py b/src/marqo/tensor_search/api.py
index 556ae4ed6..8386b8382 100644
--- a/src/marqo/tensor_search/api.py
+++ b/src/marqo/tensor_search/api.py
@@ -277,7 +277,7 @@ def search(search_query: SearchQuery, index_name: str, device: str = Depends(api
reranker=search_query.reRanker,
filter=search_query.filter, device=device,
attributes_to_retrieve=search_query.attributesToRetrieve, boost=search_query.boost,
- image_download_headers=search_query.image_download_headers,
+ media_download_headers = search_query.mediaDownloadHeaders,
context=search_query.context,
score_modifiers=search_query.scoreModifiers,
model_auth=search_query.modelAuth,
@@ -334,7 +334,7 @@ def embed(embedding_request: EmbedRequest, index_name: str, device: str = Depend
return marqo_config.embed.embed_content(
content=embedding_request.content,
index_name=index_name, device=device,
- image_download_headers=embedding_request.image_download_headers,
+ media_download_headers=embedding_request.mediaDownloadHeaders,
model_auth=embedding_request.modelAuth,
content_type=embedding_request.content_type
)
diff --git a/src/marqo/tensor_search/models/api_models.py b/src/marqo/tensor_search/models/api_models.py
index 0c853e557..3f4bccd97 100644
--- a/src/marqo/tensor_search/models/api_models.py
+++ b/src/marqo/tensor_search/models/api_models.py
@@ -7,7 +7,7 @@
from typing import Union, List, Dict, Optional
import pydantic
-from pydantic import BaseModel, root_validator, validator
+from pydantic import BaseModel, root_validator, validator, Field
from marqo.base_model import ImmutableStrictBaseModel
from marqo.core.models.hybrid_parameters import HybridParameters
@@ -47,7 +47,8 @@ class SearchQuery(BaseMarqoModel):
filter: str = None
attributesToRetrieve: Union[None, List[str]] = None
boost: Optional[Dict] = None
- image_download_headers: Optional[Dict] = None
+ imageDownloadHeaders: Optional[Dict] = Field(default_factory=None, alias="image_download_headers")
+ mediaDownloadHeaders: Optional[Dict] = None
context: Optional[SearchContext] = None
scoreModifiers: Optional[ScoreModifierLists] = None
modelAuth: Optional[ModelAuth] = None
@@ -68,6 +69,26 @@ def _preprocess_search_method(cls, value):
else:
return value
+ @root_validator(skip_on_failure=True)
+ def _validate_image_download_headers_and_media_download_headers(cls, values):
+ """Validate imageDownloadHeaders and mediaDownloadHeaders. Raise an error if both are set.
+
+ If imageDownloadHeaders is set, set mediaDownloadHeaders to it and use mediaDownloadHeaders in the
+ rest of the code.
+
+ imageDownloadHeaders is deprecated and will be removed in the future.
+ """
+ image_download_headers = values.get('imageDownloadHeaders')
+ media_download_headers = values.get('mediaDownloadHeaders')
+ if image_download_headers and media_download_headers:
+ raise ValueError("Cannot set both imageDownloadHeaders(image_download_headers) and mediaDownloadHeaders. "
+ "'imageDownloadHeaders'(image_download_headers) is deprecated and will be removed in the future. "
+ "Use mediaDownloadHeaders instead.")
+ if image_download_headers:
+ values['mediaDownloadHeaders'] = image_download_headers
+ return values
+
+
@root_validator(pre=False, skip_on_failure=True)
def validate_query_and_context(cls, values):
"""Validate that one of query and context are present for tensor/hybrid search, or just the query for lexical search.
diff --git a/src/marqo/tensor_search/models/search.py b/src/marqo/tensor_search/models/search.py
index 4d8c5a74c..0a7101eb8 100644
--- a/src/marqo/tensor_search/models/search.py
+++ b/src/marqo/tensor_search/models/search.py
@@ -4,7 +4,9 @@
from typing import Any, Union, List, Dict, Optional, NewType, Literal
from marqo.api.exceptions import InvalidArgError
+from marqo.core.models import MarqoQuery
from marqo.tensor_search.models.private_models import ModelAuth
+from marqo.s2_inference.multimodal_model_load import Modality
Qidx = NewType('Qidx', int) # Indicates the position of a search query in a bulk search request
JHash = NewType('JHash', int) # hash of a VectoriseJob. Used for quick access of VectorisedJobs
@@ -26,25 +28,25 @@ class VectorisedJobs(BaseModel):
content: List[Union[str, List[str]]]
device: str
normalize_embeddings: bool
- image_download_headers: Optional[Dict]
- content_type: Literal['text', 'media']
+ media_download_headers: Optional[Dict]
model_auth: Optional[ModelAuth]
+ modality: Modality
def __hash__(self):
return self.groupby_key() + hash(json.dumps(self.content, sort_keys=True))
def groupby_key(self) -> JHash:
return VectorisedJobs.get_groupby_key(self.model_name, self.model_properties, self.device,
- self.normalize_embeddings, self.content_type,
- self.image_download_headers)
+ self.normalize_embeddings, self.modality,
+ self.media_download_headers)
@staticmethod
def get_groupby_key(model_name: str, model_properties: Dict[str, Any], device: str,
- normalize_embeddings: bool, content_type: str, image_download_headers: Optional[Dict]) -> JHash:
+ normalize_embeddings: bool, modality: str, media_download_headers: Optional[Dict]) -> JHash:
return JHash(hash(model_name) + hash(json.dumps(model_properties, sort_keys=True))
+ hash(device) + hash(normalize_embeddings)
- + hash(content_type)
- + hash(json.dumps(image_download_headers, sort_keys=True))
+ + hash(modality)
+ + hash(json.dumps(media_download_headers, sort_keys=True))
)
def add_content(self, content: List[Union[str, List[str]]]) -> VectorisedJobPointer:
@@ -75,4 +77,29 @@ def __init__(self, **data):
def check_vector_length(cls, v):
if not (1 <= len(v) <= 64):
raise InvalidArgError('The number of tensors must be between 1 and 64')
- return v
\ No newline at end of file
+ return v
+
+
+class QueryContent(BaseModel):
+ content: str
+ modality: Modality
+
+
+class QueryContentCollector(BaseModel):
+ queries: List[QueryContent]
+ @property
+ def text_queries(self) -> List[QueryContent]:
+ return [q for q in self.queries if q.modality == Modality.TEXT]
+
+ @property
+ def image_queries(self) -> List[QueryContent]:
+ return [q for q in self.queries if q.modality == Modality.IMAGE]
+
+ @property
+ def video_queries(self) -> List[QueryContent]:
+ return [q for q in self.queries if q.modality == Modality.VIDEO]
+
+ @property
+ def audio_queries(self) -> List[QueryContent]:
+ return [q for q in self.queries if q.modality == Modality.AUDIO]
+
\ No newline at end of file
diff --git a/src/marqo/tensor_search/streaming_media_processor.py b/src/marqo/tensor_search/streaming_media_processor.py
index 72b75de3c..56d285637 100644
--- a/src/marqo/tensor_search/streaming_media_processor.py
+++ b/src/marqo/tensor_search/streaming_media_processor.py
@@ -15,15 +15,15 @@
from marqo.core.models.marqo_index import *
from marqo.s2_inference.multimodal_model_load import Modality
from marqo.tensor_search.models.preprocessors_model import Preprocessors
+from marqo.core.exceptions import InternalError
class StreamingMediaProcessor:
- def __init__(self, url: str, device: str, headers: Dict[str, str], modality: Modality, marqo_index_type: IndexType,
+ def __init__(self, url: str, device: str, modality: Modality, marqo_index_type: IndexType,
marqo_index_model: Model, preprocessors: Preprocessors, audio_preprocessing: AudioPreProcessing = None,
- video_preprocessing: VideoPreProcessing = None):
+ video_preprocessing: VideoPreProcessing = None, media_download_headers: Optional[Dict[str, str]]= None):
self.url = url
self.device = device
- self.headers = headers
self.modality = modality
self.marqo_index_type = marqo_index_type
self.marqo_index_model = marqo_index_model
@@ -31,6 +31,8 @@ def __init__(self, url: str, device: str, headers: Dict[str, str], modality: Mod
self.video_preprocessing = video_preprocessing
self.preprocessors = preprocessors
self.preprocessor = self.preprocessors[modality]
+ self.media_download_headers = self._convert_headers_to_cli_format(media_download_headers)
+
self.total_size, self.duration = self._fetch_file_metadata()
self._set_split_parameters(modality)
@@ -56,6 +58,25 @@ def _log_initialization_details(self):
# print(f"from StreamingMediaProcessor, self.duration: {self.duration}")
pass
+ def _convert_headers_to_cli_format(self, raw_media_download_headers: Optional[Dict] = None) -> str:
+ """
+ A helper function to convert the media download headers into a format that can be passed to ffmpeg in
+ subprocess calls.
+
+ Examples:
+ If the headers are {"key1": "value1", "key2": "value2"}, the function will return a string
+ "key1: value1\r\nkey2: value2"
+
+ Returns:
+ str: The headers in the required format. An empty string if no headers or None are provided.
+ """
+ if raw_media_download_headers is None or raw_media_download_headers == {}:
+ return ""
+ elif not isinstance(raw_media_download_headers, dict):
+ raise InternalError("media_download_headers should be a dictionary")
+ return "\r\n".join([f"{key}: {value}" for key, value in raw_media_download_headers.items()])
+
+
def _fetch_file_metadata(self):
start_time = time.time()
@@ -64,9 +85,12 @@ def _fetch_file_metadata(self):
'v': 'error',
'show_entries': 'format=size,duration',
'of': 'json',
- 'probesize': '256K' # Probe only the first 256KB
+ 'probesize': '256K', # Probe only the first 256KB
}
+ if self.media_download_headers:
+ probe_options['headers'] = self.media_download_headers
+
probe = ffmpeg.probe(self.url, **probe_options)
size = int(probe['format'].get('size', 0))
@@ -105,7 +129,13 @@ def process_media(self) -> List[Dict[str, torch.Tensor]]:
try:
# Use ffmpeg-python to process the chunk
- stream = ffmpeg.input(self.url, ss=chunk_start, t=chunk_end - chunk_start)
+ if self.media_download_headers:
+ stream = ffmpeg.input(
+ self.url, ss=chunk_start, t=chunk_end - chunk_start,
+ headers=self.media_download_headers
+ )
+ else:
+ stream = ffmpeg.input(self.url, ss=chunk_start, t=chunk_end - chunk_start)
if self.modality == Modality.VIDEO:
stream = ffmpeg.output(stream, output_file, vcodec='libx264', acodec='aac', **{'f': 'mp4'})
diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py
index a15646254..f040f899b 100644
--- a/src/marqo/tensor_search/tensor_search.py
+++ b/src/marqo/tensor_search/tensor_search.py
@@ -90,7 +90,7 @@
from marqo.tensor_search.models.delete_docs_objects import MqDeleteDocsRequest
from marqo.tensor_search.models.private_models import ModelAuth
from marqo.tensor_search.models.search import Qidx, JHash, SearchContext, VectorisedJobs, VectorisedJobPointer, \
- SearchContextTensor
+ SearchContextTensor, QueryContentCollector, QueryContent
from marqo.tensor_search.telemetry import RequestMetricsStore
from marqo.tensor_search.tensor_search_logging import get_logger
from marqo.vespa.exceptions import VespaStatusError
@@ -186,7 +186,7 @@ def _add_documents_unstructured(config: Config, add_docs_params: AddDocsParams,
docs=docs,
thread_count=media_download_thread_count,
tensor_fields=tensor_fields_and_multimodal_subfields,
- image_download_headers=add_docs_params.image_download_headers,
+ media_download_headers=add_docs_params.media_download_headers,
model_name=marqo_index.model.name,
normalize_embeddings=marqo_index.normalize_embeddings,
media_field_types_mapping=None,
@@ -709,7 +709,7 @@ def _add_documents_structured(config: Config, add_docs_params: AddDocsParams, ma
docs=docs,
thread_count=media_download_thread_count,
tensor_fields=media_fields,
- image_download_headers=add_docs_params.image_download_headers,
+ media_download_headers=add_docs_params.media_download_headers,
# add non image download headers in the future
model_name=marqo_index.model.name,
normalize_embeddings=marqo_index.normalize_embeddings,
@@ -1465,7 +1465,7 @@ def search(config: Config, index_name: str, text: Optional[Union[str, dict, Cust
reranker: Union[str, Dict] = None, filter: Optional[str] = None,
attributes_to_retrieve: Optional[List[str]] = None,
device: str = None, boost: Optional[Dict] = None,
- image_download_headers: Optional[Dict] = None,
+ media_download_headers: Optional[Dict] = None,
context: Optional[SearchContext] = None,
score_modifiers: Optional[ScoreModifierLists] = None,
model_auth: Optional[ModelAuth] = None,
@@ -1493,7 +1493,7 @@ def search(config: Config, index_name: str, text: Optional[Union[str, dict, Cust
device: May be none, we calculate default device here
num_highlights: number of highlights to return for each doc
boost: boosters to re-weight the scores of individual fields
- image_download_headers: headers for downloading images
+ media_download_headers: headers to use when downloading media
context: a dictionary to allow custom vectors in search, for tensor search only
score_modifiers: a dictionary to modify the score based on field values, for tensor search only
model_auth: Authorisation details for downloading a model (if required)
@@ -1583,7 +1583,7 @@ def search(config: Config, index_name: str, text: Optional[Union[str, dict, Cust
ef_search=ef_search, approximate=approximate, searchable_attributes=searchable_attributes,
filter_string=filter, device=selected_device, attributes_to_retrieve=attributes_to_retrieve,
boost=boost,
- image_download_headers=image_download_headers, context=context, score_modifiers=score_modifiers,
+ media_download_headers=media_download_headers, context=context, score_modifiers=score_modifiers,
model_auth=model_auth, highlights=highlights, text_query_prefix=text_query_prefix
)
elif search_method.upper() == SearchMethod.HYBRID:
@@ -1594,7 +1594,7 @@ def search(config: Config, index_name: str, text: Optional[Union[str, dict, Cust
ef_search=ef_search, approximate=approximate, searchable_attributes=searchable_attributes,
filter_string=filter, device=selected_device, attributes_to_retrieve=attributes_to_retrieve,
boost=boost,
- image_download_headers=image_download_headers, context=context, score_modifiers=score_modifiers,
+ media_download_headers=media_download_headers, context=context, score_modifiers=score_modifiers,
model_auth=model_auth, highlights=highlights, text_query_prefix=text_query_prefix,
hybrid_parameters=hybrid_parameters
)
@@ -1735,37 +1735,39 @@ def _lexical_search(
return gathered_docs
-def construct_vector_input_batches(query: Optional[Union[str, Dict]], index_info: MarqoIndex) \
- -> Tuple[List[str], List[str]]:
+def construct_vector_input_batches(query: Optional[Union[str, Dict]], media_download_headers: Optional[Dict] = None) \
+ -> QueryContentCollector:
"""Splits images from text in a single query (either a query string, or dict of weighted strings).
Args:
query: a string query, or a dict of weighted strings.
- index_info: used to determine whether URLs should be treated as images
+ media_download_headers: headers to use when downloading media
Returns:
- A tuple of string batches. The first is text content the second is image content.
+ A SearchQueryCollector object with the text and media content separated.
"""
# TODO - infer this from model
- treat_urls_as_media = True
-
+ query_content_list = []
if isinstance(query, str):
- if treat_urls_as_media and validate_url(query):
- return [], [query, ]
- else:
- return [query, ], []
+ query_content_list.append(
+ QueryContent(
+ content=query,
+ modality=infer_modality(query, media_download_headers=media_download_headers)
+ )
+ )
elif isinstance(query, dict): # is dict:
- ordered_queries = list(query.items())
- if treat_urls_as_media:
- text_queries = [k for k, _ in ordered_queries if not _is_image(k)]
- image_queries = [k for k, _ in ordered_queries if _is_image(k)]
- return text_queries, image_queries
- else:
- return [k for k, _ in ordered_queries], []
+ for query, weights in query.items():
+ query_content_list.append(
+ QueryContent(
+ content=query,
+ modality=infer_modality(query, media_download_headers=media_download_headers)
+ )
+ )
elif query is None:
- return [], []
+ pass
else:
raise ValueError(f"Incorrect type for query: {type(query).__name__}")
+ return QueryContentCollector(queries = query_content_list)
def gather_documents_from_response(response: QueryResult, marqo_index: MarqoIndex, highlights: bool,
@@ -1800,7 +1802,7 @@ def unstructured_index_attributes_to_retrieve(marqo_doc: Dict[str, Any], attribu
def assign_query_to_vector_job(
q: BulkSearchQueryEntity, jobs: Dict[JHash, VectorisedJobs],
- grouped_content: Tuple[List[str], List[str], List[str], List[str]],
+ grouped_content: QueryContentCollector,
index_info: MarqoIndex, device: str) -> List[VectorisedJobPointer]:
"""
For a individual query, assign its content (to be vectorised) to a vector job. If none exist with the correct
@@ -1819,34 +1821,39 @@ def assign_query_to_vector_job(
Returns:
A list of pointers to the location in a vector job that will have its vectorised content.
"""
- if len(grouped_content) != 2:
- raise RuntimeError(
- "assign_query_to_vector_job() expects param `grouped_content` with 2 elems. Instead received"
- f" `grouped_content` with {len(grouped_content)} elems")
ptrs = []
- for i, content in enumerate(grouped_content):
- content_type = ['text', 'media'][i]
- vector_job = VectorisedJobs(
- model_name=index_info.model.name,
- model_properties=index_info.model.get_properties(),
- content=content,
- device=device,
- normalize_embeddings=index_info.normalize_embeddings,
- image_download_headers=q.image_download_headers,
- content_type=content_type,
- model_auth=q.modelAuth
- )
- # If exists, add content to vector job. Otherwise create new
- if jobs.get(vector_job.groupby_key()) is not None:
- j = jobs.get(vector_job.groupby_key())
- ptrs.append(j.add_content(content))
- else:
- jobs[vector_job.groupby_key()] = vector_job
- ptrs.append(VectorisedJobPointer(
- job_hash=vector_job.groupby_key(),
- start_idx=0,
- end_idx=len(vector_job.content)
- ))
+ content_lists_by_modality = [
+ grouped_content.text_queries,
+ grouped_content.image_queries,
+ grouped_content.audio_queries,
+ grouped_content.video_queries,
+ ]
+
+ for i, list_of_queries_by_modalities in enumerate(content_lists_by_modality):
+ if len(list_of_queries_by_modalities) > 0:
+ content: List[str] = [query.content for query in list_of_queries_by_modalities]
+ modality: Modality = list_of_queries_by_modalities[0].modality
+ vector_job = VectorisedJobs(
+ model_name=index_info.model.name,
+ model_properties=index_info.model.get_properties(),
+ content=content,
+ device=device,
+ normalize_embeddings=index_info.normalize_embeddings,
+ media_download_headers=q.mediaDownloadHeaders,
+ model_auth=q.modelAuth,
+ modality = modality
+ )
+ # If exists, add content to vector job. Otherwise create new
+ if jobs.get(vector_job.groupby_key()) is not None:
+ j = jobs.get(vector_job.groupby_key())
+ ptrs.append(j.add_content(content))
+ else:
+ jobs[vector_job.groupby_key()] = vector_job
+ ptrs.append(VectorisedJobPointer(
+ job_hash=vector_job.groupby_key(),
+ start_idx=0,
+ end_idx=len(vector_job.content)
+ ))
return ptrs
@@ -1865,9 +1872,8 @@ def create_vector_jobs(queries: List[BulkSearchQueryEntity], config: Config, dev
qidx_to_job: Dict[Qidx, List[VectorisedJobPointer]] = dict()
jobs: Dict[JHash, VectorisedJobs] = {}
for i, q in enumerate(queries):
- q = queries[i]
# split images, from text:
- to_be_vectorised: Tuple[List[str], List[str]] = construct_vector_input_batches(q.q, q.index)
+ to_be_vectorised: QueryContentCollector = construct_vector_input_batches(q.q, q.mediaDownloadHeaders)
qidx_to_job[i] = assign_query_to_vector_job(q, jobs, to_be_vectorised, q.index, device)
return qidx_to_job, jobs
@@ -1882,12 +1888,15 @@ def vectorise_jobs(jobs: List[VectorisedJobs]) -> Dict[JHash, Dict[str, List[flo
# TODO: Handle exception for single job, and allow others to run.
try:
if v.content:
- modality = infer_modality(v.content[0] if isinstance(v.content, list) else v.content)
+ modality = infer_modality(
+ v.content[0] if isinstance(v.content, list) else v.content,
+ media_download_headers=v.media_download_headers
+ )
vectors = s2_inference.vectorise(
model_name=v.model_name, model_properties=v.model_properties,
content=v.content, device=v.device,
normalize_embeddings=v.normalize_embeddings,
- image_download_headers=v.image_download_headers,
+ media_download_headers=v.media_download_headers,
model_auth=v.model_auth,
enable_cache=True,
modality=modality
@@ -1940,11 +1949,12 @@ def get_query_vectors_from_jobs(
if ordered_queries:
# multiple queries. We have to weight and combine them:
vectorised_ordered_queries = [
- (get_content_vector(
+ (
+ get_content_vector(
possible_jobs=qidx_to_job[qidx],
- jobs=jobs,
job_to_vectors=job_to_vectors,
- content=content),
+ content=content
+ ),
weight,
content
) for content, weight in ordered_queries
@@ -1975,7 +1985,6 @@ def get_query_vectors_from_jobs(
# result[qidx] = vectors[0]
result[qidx] = get_content_vector(
possible_jobs=qidx_to_job.get(qidx, []),
- jobs=jobs,
job_to_vectors=job_to_vectors,
content=q.q
)
@@ -1984,14 +1993,16 @@ def get_query_vectors_from_jobs(
return result
-def get_content_vector(possible_jobs: List[VectorisedJobPointer], job_to_vectors: Dict[JHash, Dict[str, List[float]]],
- jobs: Dict[JHash, VectorisedJobs], content: str) -> List[float]:
+def get_content_vector(
+ possible_jobs: List[VectorisedJobPointer],
+ job_to_vectors: Dict[JHash, Dict[str, List[float]]],
+ content: str
+) -> List[float]:
"""finds the vector associated with a piece of content
Args:
possible_jobs: The jobs where the target vector may reside
- treat_urls_as_media: an index_parameter that indicates whether content should be treated as image, audio, video
- if it has a URL structure
+ job_to_vectors: The mapping of job to vectors
content: The content to search
Returns:
@@ -1999,15 +2010,10 @@ def get_content_vector(possible_jobs: List[VectorisedJobPointer], job_to_vectors
Raises runtime error if is not found
"""
- content_type = 'text' if infer_modality(content) == Modality.TEXT else 'media'
-
not_found_error = RuntimeError(f"get_content_vector(): could not find corresponding vector for content `{content}`")
for vec_job_pointer in possible_jobs:
- if jobs[vec_job_pointer.job_hash].content_type == content_type:
- try:
- return job_to_vectors[vec_job_pointer.job_hash][content]
- except KeyError:
- raise not_found_error
+ if content in job_to_vectors[vec_job_pointer.job_hash]:
+ return job_to_vectors[vec_job_pointer.job_hash][content]
raise not_found_error
@@ -2019,19 +2025,20 @@ def add_prefix_to_queries(queries: List[BulkSearchQueryEntity]) -> List[BulkSear
if q.q is None:
prefixed_q = q.q
elif isinstance(q.q, str):
- if _is_image(q.q):
- prefixed_q = q.q
- else:
+ modality = infer_modality(q.q, q.mediaDownloadHeaders)
+ if modality == Modality.TEXT:
prefixed_q = f"{text_query_prefix}{q.q}"
+ else:
+ prefixed_q = q.q
else: # q.q is dict
prefixed_q = {}
for key, value in q.q.items():
# Apply prefix if key is not an image or if index does not treat URLs and pointers as images
- if _is_image(key):
- prefixed_q[key] = value
- else:
+ modality = infer_modality(key, q.mediaDownloadHeaders)
+ if modality == Modality.TEXT:
prefixed_q[f"{text_query_prefix}{key}"] = value
-
+ else:
+ prefixed_q[key] = value
new_query_object = BulkSearchQueryEntity(
q=prefixed_q,
searchableAttributes=q.searchableAttributes,
@@ -2042,7 +2049,7 @@ def add_prefix_to_queries(queries: List[BulkSearchQueryEntity]) -> List[BulkSear
filter=q.filter,
attributesToRetrieve=q.attributesToRetrieve,
boost=q.boost,
- image_download_headers=q.image_download_headers,
+ mediaDownloadHeaders=q.mediaDownloadHeaders,
context=q.context,
scoreModifiers=q.scoreModifiers,
index=q.index,
@@ -2087,7 +2094,7 @@ def _vector_text_search(
ef_search: Optional[int] = None, approximate: bool = True,
searchable_attributes: Iterable[str] = None, filter_string: str = None, device: str = None,
attributes_to_retrieve: Optional[List[str]] = None, boost: Optional[Dict] = None,
- image_download_headers: Optional[Dict] = None, context: Optional[SearchContext] = None,
+ media_download_headers: Optional[Dict] = None, context: Optional[SearchContext] = None,
score_modifiers: Optional[ScoreModifierLists] = None, model_auth: Optional[ModelAuth] = None,
highlights: bool = False, text_query_prefix: Optional[str] = None) -> Dict:
"""
@@ -2104,7 +2111,7 @@ def _vector_text_search(
verbose: if 0 - nothing is printed. if 1 - data is printed without vectors, if 2 - full
objects are printed out
attributes_to_retrieve: if set, only returns these fields
- image_download_headers: headers for downloading images
+ media_download_headers: headers for downloading media
context: a dictionary to allow custom vectors in search
score_modifiers: a dictionary to modify the score based on field values, for tensor search only
model_auth: Authorisation details for downloading a model (if required)
@@ -2153,7 +2160,7 @@ def _vector_text_search(
queries = [BulkSearchQueryEntity(
q=query, searchableAttributes=searchable_attributes, searchMethod=SearchMethod.TENSOR, limit=result_count,
offset=offset, showHighlights=False, filter=filter_string, attributesToRetrieve=attributes_to_retrieve,
- boost=boost, image_download_headers=image_download_headers, context=context, scoreModifiers=score_modifiers,
+ boost=boost, mediaDownloadHeaders=media_download_headers, context=context, scoreModifiers=score_modifiers,
index=marqo_index, modelAuth=model_auth, text_query_prefix=text_query_prefix
)]
@@ -2365,7 +2372,7 @@ def vectorise_multimodal_combination_field_unstructured(field: str,
model_name=marqo_index.model.name,
model_properties=marqo_index.model.properties, content=prefixed_text_content_to_vectorise,
device=device, normalize_embeddings=normalize_embeddings,
- infer=False, model_auth=model_auth, modality=Modality.TEXT
+ infer=True, model_auth=model_auth, modality=Modality.TEXT
)
vectors_list.extend(text_vectors)
diff --git a/src/marqo/tensor_search/web/api_utils.py b/src/marqo/tensor_search/web/api_utils.py
index 0c2ab4d4e..1db7cea68 100644
--- a/src/marqo/tensor_search/web/api_utils.py
+++ b/src/marqo/tensor_search/web/api_utils.py
@@ -50,27 +50,27 @@ def translate_api_device(device: Optional[str]) -> Optional[str]:
f"Acceptable device types: {acceptable_devices}")
-def decode_image_download_headers(image_download_headers: Optional[str] = None) -> dict:
+def decode_media_download_headers(media_download_headers: Optional[str] = None) -> dict:
"""Decodes an image download header string into a Python dict
Args:
- image_download_headers: JSON-serialised, URL encoded header dictionary
+ media_download_headers: JSON-serialised, URL encoded header dictionary
Returns:
- image_download_headers as a dict
+ media_download_headers as a dict
Raises:
InvalidArgError if there is trouble parsing the dictionary
"""
- if not image_download_headers:
+ if not media_download_headers:
return dict()
else:
try:
- as_str = urllib.parse.unquote_plus(image_download_headers)
+ as_str = urllib.parse.unquote_plus(media_download_headers)
as_dict = json.loads(as_str)
return as_dict
except json.JSONDecodeError as e:
- raise InvalidArgError(f"Error parsing image_download_headers. Message: {e}")
+ raise InvalidArgError(f"Error parsing media_download_headers. Message: {e}")
def decode_query_string_model_auth(model_auth: Optional[str] = None) -> Optional[ModelAuth]:
@@ -130,14 +130,14 @@ def add_docs_params_orchestrator(index_name: str, body: Union[AddDocsBodyParams,
tensor_fields = body.tensorFields
use_existing_tensors = body.useExistingTensors
model_auth = body.modelAuth
- image_download_headers = body.imageDownloadHeaders
+ media_download_headers = body.mediaDownloadHeaders
image_download_thread_count = body.imageDownloadThreadCount
text_chunk_prefix = body.textChunkPrefix
return AddDocsParams(
index_name=index_name, docs=docs,
device=device, tensor_fields=tensor_fields,
- use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers,
+ use_existing_tensors=use_existing_tensors, media_download_headers=media_download_headers,
image_download_thread_count=image_download_thread_count,
mappings=mappings, model_auth=model_auth, text_chunk_prefix=text_chunk_prefix,
batch_vectorisation_mode=body.batchVectorisationMode,
diff --git a/tests/conftest.py b/tests/conftest.py
index 36d1b9617..93d52e8ed 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -18,9 +18,9 @@ def pytest_collection_modifyitems(config, items):
skip_cpu_only = pytest.mark.skip(reason="skip in --largemodel mode when cpu_only is present")
if config.getoption("--largemodel"):
- # --largemodel given in cli: do not skip largemodel tests, skip cpu_only tests
+ # --largemodel given in cli: only run tests that have largemodel marker
for item in items:
- if "cpu_only" in item.keywords:
+ if "largemodel" not in item.keywords:
item.add_marker(skip_cpu_only)
else:
for item in items:
diff --git a/tests/marqo_test.py b/tests/marqo_test.py
index 8d66c2a86..25edff9d7 100644
--- a/tests/marqo_test.py
+++ b/tests/marqo_test.py
@@ -36,6 +36,21 @@ class TestImageUrls(str, Enum):
HIPPO_STATUE = 'https://raw.githubusercontent.com/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_statue_small.png'
+class TestAudioUrls(str, Enum):
+ __test__ = False
+ AUDIO1 = "https://marqo-ecs-50-audio-test-dataset.s3.us-east-1.amazonaws.com/audios/1-100032-A-0.wav"
+ AUDIO2 = "https://marqo-ecs-50-audio-test-dataset.s3.us-east-1.amazonaws.com/audios/1-115545-C-48.wav"
+ AUDIO3 = "https://marqo-ecs-50-audio-test-dataset.s3.us-east-1.amazonaws.com/audios/1-119125-A-45.wav"
+
+
+class TestVideoUrls(str, Enum):
+ __test__ = False
+ VIDEO1 = "https://marqo-k400-video-test-dataset.s3.us-east-1.amazonaws.com/videos/--_S9IDQPLg_000135_000145.mp4"
+ VIDEO2 = "https://marqo-k400-video-test-dataset.s3.us-east-1.amazonaws.com/videos/---QUuC4vJs_000084_000094.mp4"
+ VIDEO3 = "https://marqo-k400-video-test-dataset.s3.us-east-1.amazonaws.com/videos/--mI_-gaZLk_000018_000028.mp4"
+
+
+
class MarqoTestCase(unittest.TestCase):
indexes = []
diff --git a/tests/s2_inference/test_image_downloading.py b/tests/s2_inference/test_image_downloading.py
index 89f88200f..29a214024 100644
--- a/tests/s2_inference/test_image_downloading.py
+++ b/tests/s2_inference/test_image_downloading.py
@@ -53,12 +53,12 @@ def test_download_image_from_url_handleDifferentUrlsCorrectly(self):
for url, expected, msg in self.test_cases:
with self.subTest(url=url, expected=expected, msg=msg):
with self.assertRaises(ImageDownloadError) as cm:
- download_image_from_url(image_path=url + ".jpg", image_download_headers={})
+ download_image_from_url(image_path=url + ".jpg", media_download_headers={})
def test_download_image_from_url_handlesUrlRequiringUserAgentHeader(self):
url_requiring_user_agent_header = "https://docs.marqo.ai/2.0.0/Examples/marqo.jpg"
try:
- download_image_from_url(image_path=url_requiring_user_agent_header, image_download_headers={})
+ download_image_from_url(image_path=url_requiring_user_agent_header, media_download_headers={})
except Exception as e:
self.fail(f"Exception was raised when downloading {url_requiring_user_agent_header}: {e}")
@@ -77,7 +77,7 @@ def test_download_image_from_url_mergesDefaultHeadersWithCustomHeaders(self, moc
for (headers, expected_headers, msg) in test_cases:
with self.subTest(headers=headers, expected_headers=expected_headers, msg=msg):
- download_image_from_url('http://example.com/image.jpg', image_download_headers=headers)
+ download_image_from_url('http://example.com/image.jpg', media_download_headers=headers)
mock_curl_instance.setopt.assert_called_with(pycurl.HTTPHEADER, expected_headers)
def test_download_image_from_url_handlesRedirection(self):
@@ -88,5 +88,5 @@ def test_download_image_from_url_handlesRedirection(self):
])
with MockHttpServer(app).run_in_thread() as base_url:
- result = download_image_from_url(f'{base_url}/missing_image.jpg', image_download_headers={})
+ result = download_image_from_url(f'{base_url}/missing_image.jpg', media_download_headers={})
self.assertEqual(result.getvalue(), image_content)
diff --git a/tests/s2_inference/test_vectorise.py b/tests/s2_inference/test_vectorise.py
index 6e51446b0..5ccd1bde4 100644
--- a/tests/s2_inference/test_vectorise.py
+++ b/tests/s2_inference/test_vectorise.py
@@ -240,7 +240,8 @@ def test_vectorise_single_content_item(self):
result = s2_inference.vectorise(model_name='mock_model', content=single_content,
model_properties=self.mock_model_props, device="cpu")
- self.mock_model.encode.assert_called_once_with(single_content, normalize=True, modality=Modality.TEXT)
+ self.mock_model.encode.assert_called_once_with(single_content, normalize=True, modality=Modality.TEXT,
+ media_download_headers=None)
self.assertIsInstance(result, list)
self.assertEqual(len(result), 1)
diff --git a/tests/tensor_search/integ_tests/test_add_documents_combined.py b/tests/tensor_search/integ_tests/test_add_documents_combined.py
index 21c3b90b4..64df57273 100644
--- a/tests/tensor_search/integ_tests/test_add_documents_combined.py
+++ b/tests/tensor_search/integ_tests/test_add_documents_combined.py
@@ -1,24 +1,18 @@
import os
import unittest.mock
+import unittest.mock
import uuid
from unittest import mock
from unittest.mock import patch
import PIL
-import numpy as np
-
import numpy as np
import pytest
-
-
-import PIL
import requests
import torch
-from more_itertools import flatten
from torch import Tensor
-import unittest.mock
-
+from marqo.core.models.add_docs_params import AddDocsParams, BatchVectorisationMode
from marqo.core.models.marqo_get_documents_by_id_response import MarqoGetDocumentsByIdsResponse
from marqo.core.models.marqo_index import *
from marqo.core.models.marqo_index_request import FieldRequest
@@ -27,10 +21,7 @@
from marqo.tensor_search import add_docs
from marqo.tensor_search import streaming_media_processor
from marqo.tensor_search import tensor_search
-from marqo.core.models.add_docs_params import AddDocsParams, BatchVectorisationMode
-from tests.marqo_test import MarqoTestCase, TestImageUrls
-from marqo.s2_inference.multimodal_model_load import infer_modality
-from marqo.tensor_search import streaming_media_processor
+from tests.marqo_test import MarqoTestCase, TestImageUrls, TestAudioUrls, TestVideoUrls
class TestAddDocumentsCombined(MarqoTestCase):
@@ -577,7 +568,7 @@ def test_imageDownloadWithoutPreprocessor(self):
allocated_docs=[test_doc],
media_repo=media_repo,
tensor_fields=['field_1', 'field_2'],
- image_download_headers={},
+ media_download_headers={},
marqo_index_type=IndexType.Unstructured,
marqo_index_model=Model(name="test", properties={}),
)
@@ -597,7 +588,7 @@ def test_imageDownloadWithPreprocessor(self):
allocated_docs=[test_doc],
media_repo=media_repo,
tensor_fields=['field_1', 'field_2'],
- image_download_headers={},
+ media_download_headers={},
preprocessors={'image': lambda x: torch.randn(3, 224, 224)},
device='cpu',
marqo_index_type=IndexType.Unstructured,
@@ -619,7 +610,7 @@ def run():
{"Title": "frog", "Desc": "blah"}, {"Title": "Dog", "Loc": "https://google.com/my_dog.png"}],
media_repo=media_repo,
tensor_fields=['Title', 'Desc', 'Loc'],
- image_download_headers={},
+ media_download_headers={},
marqo_index_type=IndexType.Unstructured,
marqo_index_model=Model(name="test", properties={}),
)
@@ -708,7 +699,7 @@ def test_threaded_download_images_non_tensor_field(self):
allocated_docs=docs,
media_repo=media_repo,
tensor_fields=['field_1', 'field_2'],
- image_download_headers={},
+ media_download_headers={},
marqo_index_type=IndexType.Unstructured,
marqo_index_model=Model(name="test", properties={}),
)
@@ -760,7 +751,7 @@ def test_download_images_non_tensor_field(self):
docs=docs,
thread_count=20,
tensor_fields=['field_1', 'field_2'],
- image_download_headers={},
+ media_download_headers={},
model_name="ViT-B/32",
normalize_embeddings=True,
model_properties=model_properties,
@@ -842,13 +833,13 @@ def test_process_media_chunk_calculation(self, mock_temp_dir, mock_ffmpeg):
processor = streaming_media_processor.StreamingMediaProcessor(
url='http://example.com/video.mp4',
device='cpu',
- headers={},
modality=streaming_media_processor.Modality.VIDEO,
marqo_index_type=IndexType.Unstructured,
marqo_index_model=Model(name="test", properties={}),
audio_preprocessing=unittest.mock.Mock(),
video_preprocessing=unittest.mock.Mock(),
- preprocessors={'video': unittest.mock.Mock()}
+ preprocessors={'video': unittest.mock.Mock()},
+ media_download_headers={},
)
# Set arbitrary values
@@ -1092,4 +1083,194 @@ def test_textIndexEmbeddingsUnnormalized(self):
embeddings = get_res['results'][0]['_tensor_facets'][0]['_embedding']
norm = np.linalg.norm(np.array(embeddings))
- self.assertTrue(norm - 1.0 > 1e-5, f"Embedding norm is {norm}")
\ No newline at end of file
+ self.assertTrue(norm - 1.0 > 1e-5, f"Embedding norm is {norm}")
+
+ def test_add_private_images_proper_error_returned(self):
+ """Test to ensure that private images can not be downloaded and an appropriate error is returned"""
+ test_indexes = [self.structured_marqo_index_name, self.unstructured_marqo_index_name]
+ documents = [
+ {
+ "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png",
+ "text_field_1": "A private image with a png extension",
+ "_id": "1"
+ },
+ {
+ "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small",
+ "text_field_1": "A private image without an extension",
+ "_id": "2"
+ }
+ ]
+ for index_name in test_indexes:
+ tensor_fields = ["multimodal_field", "my_combination_field"] if (
+ index_name == self.unstructured_marqo_index_name) else None
+ mappings = {
+ "multimodal_field":
+ {
+ "type": "multimodal_combination",
+ "weights": {"image_field_1": 1.0, "text_field_1": 1.0}
+ }
+ }
+ with self.subTest(index_name):
+ res = tensor_search.add_documents(
+ self.config,
+ add_docs_params=AddDocsParams(
+ docs=documents,
+ index_name=index_name,
+ tensor_fields=tensor_fields,
+ mappings=mappings
+ )
+ )
+ self.assertTrue(res.errors)
+ items = res.items
+ self.assertEqual(2, len(items))
+ for item in items:
+ self.assertEqual(400, item.status)
+ self.assertIn("403", item.message)
+
+ def test_add_private_images_success(self):
+ """Test to ensure that private images can be downloaded with proper headers"""
+ # test_indexes = [self.structured_marqo_index_name, self.unstructured_marqo_index_name]
+ test_indexes = [self.unstructured_marqo_index_name, ]
+ documents = [
+ {
+ "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png",
+ "text_field_1": "A private image with a png extension",
+ "_id": "1"
+ },
+ {
+ "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small",
+ "text_field_1": "A private image without an extension",
+ "_id": "2"
+ }
+ ]
+ for index_name in test_indexes:
+ tensor_fields = ["image_field_1", "multimodal_field"] if (
+ index_name == self.unstructured_marqo_index_name) else None
+ mappings = {
+ "multimodal_field":
+ {
+ "type": "multimodal_combination",
+ "weights": {"image_field_1": 1.0, "text_field_1": 1.0}
+ }
+ }
+ with self.subTest(index_name):
+ res = tensor_search.add_documents(
+ self.config,
+ add_docs_params=AddDocsParams(
+ docs=documents,
+ index_name=index_name,
+ tensor_fields=tensor_fields,
+ media_download_headers={"marqo_media_header": "media_header_test_key"},
+ mappings=mappings
+ )
+ )
+ self.assertFalse(res.errors)
+
+
+@pytest.mark.largemodel
+class TestLanguageBindModelAddDocumentCombined(MarqoTestCase):
+ """A class to test the add_documents with the LanguageBind model."""
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ super().setUpClass()
+
+ structured_language_bind_index = cls.structured_marqo_index_request(
+ name="structured_image_index" + str(uuid.uuid4()).replace('-', ''),
+ fields=[
+ FieldRequest(name="text_field_1", type=FieldType.Text,
+ features=[FieldFeature.Filter, FieldFeature.LexicalSearch]),
+ FieldRequest(name="image_field_1", type=FieldType.ImagePointer),
+ FieldRequest(name="audio_field_1", type=FieldType.AudioPointer),
+ FieldRequest(name="video_field_1", type=FieldType.VideoPointer),
+ FieldRequest(
+ name="multimodal_field",
+ type=FieldType.MultimodalCombination,
+ dependent_fields={
+ "image_field_1": 1.0,
+ "text_field_1": 1.0,
+ "audio_field_1": 1.0,
+ "video_field_1": 1.0,
+ }
+ )
+ ],
+ model=Model(name="LanguageBind/Video_V1.5_FT_Audio_FT_Image"),
+ tensor_fields=["text_field_1", "image_field_1", "audio_field_1", "video_field_1", "multimodal_field"],
+ )
+
+ unstructured_language_bind_index = cls.unstructured_marqo_index_request(
+ name="unstructured_image_index" + str(uuid.uuid4()).replace('-', ''),
+ model=Model(name="LanguageBind/Video_V1.5_FT_Audio_FT_Image"),
+ treat_urls_and_pointers_as_images=True,
+ treat_urls_and_pointers_as_media=True
+ )
+
+ cls.indexes = cls.create_indexes([structured_language_bind_index, unstructured_language_bind_index])
+
+ cls.structured_language_bind_index_name = structured_language_bind_index.name
+ cls.unstructured_language_bind_index_name = unstructured_language_bind_index.name
+
+ s2_inference.clear_loaded_models()
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ super().tearDownClass()
+ s2_inference.clear_loaded_models()
+
+ def test_language_bind_model_can_add_all_media_modalities(self):
+ """Test to ensure that the LanguageBind model can add all media types to the index"""
+ documents = [
+ {
+ "text_field_1": "This is a test text",
+ "image_field_1": TestImageUrls.IMAGE1.value,
+ "audio_field_1": TestAudioUrls.AUDIO1.value,
+ "video_field_1": TestVideoUrls.VIDEO1.value,
+ "_id": "1"
+ }
+ ]
+ for index_name in [self.structured_language_bind_index_name, self.unstructured_language_bind_index_name]:
+ tensor_fields = ["text_field_1", "image_field_1", "audio_field_1", "video_field_1", "multimodal_field"] \
+ if index_name == self.unstructured_language_bind_index_name else None
+ with self.subTest(index_name):
+ res = tensor_search.add_documents(
+ self.config,
+ add_docs_params=AddDocsParams(
+ docs=documents,
+ index_name=index_name,
+ tensor_fields=tensor_fields
+ )
+ )
+ self.assertFalse(res.errors)
+
+ def test_language_bind_model_can_add_all_private_media_modalities(self):
+ documents = [
+ { # With extensions
+ "text_field_1": "This is a test text",
+ "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png",
+ "audio_field_1": "https://d2k91vq0avo7lq.cloudfront.net/bark.wav",
+ "video_field_1": "https://d2k91vq0avo7lq.cloudfront.net/congress.mp4",
+ "_id": "1"
+ },
+ {
+ # No extensions
+ "text_field_1": "This is a test text",
+ "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small",
+ "audio_field_1": "https://d2k91vq0avo7lq.cloudfront.net/bark",
+ "video_field_1": "https://d2k91vq0avo7lq.cloudfront.net/congress",
+ "_id": "2"
+ }
+ ]
+ for index_name in [self.structured_language_bind_index_name, self.unstructured_language_bind_index_name]:
+ tensor_fields = ["text_field_1", "image_field_1", "audio_field_1", "video_field_1", "multimodal_field"] \
+ if index_name == self.unstructured_language_bind_index_name else None
+ with self.subTest(index_name):
+ res = tensor_search.add_documents(
+ self.config,
+ add_docs_params=AddDocsParams(
+ docs=documents,
+ index_name=index_name,
+ tensor_fields=tensor_fields,
+ media_download_headers={"marqo_media_header": "media_header_test_key"}
+ )
+ )
+ self.assertFalse(res.errors)
\ No newline at end of file
diff --git a/tests/tensor_search/integ_tests/test_embed.py b/tests/tensor_search/integ_tests/test_embed.py
index a971ced5d..1e393ad69 100644
--- a/tests/tensor_search/integ_tests/test_embed.py
+++ b/tests/tensor_search/integ_tests/test_embed.py
@@ -523,9 +523,9 @@ def run():
self.assertEqual(embed_res["content"], [image_url])
self.assertTrue(np.allclose(embed_res["embeddings"][0], search_query_embedding))
- def test_embed_with_image_download_headers_and_model_auth(self):
+ def test_embed_with_media_download_headers_and_model_auth(self):
"""
- Ensure that vectorise is called with the correct image_download_headers and model_auth
+ Ensure that vectorise is called with the correct media_download_headers and model_auth
when using the embed endpoint.
"""
for index in [self.unstructured_default_image_index, self.structured_default_image_index]:
@@ -537,7 +537,7 @@ def pass_through_vectorise(*arg, **kwargs):
via mock
Set image download headers and model auth to None so there's no error out.
"""
- kwargs["image_download_headers"] = None
+ kwargs["media_download_headers"] = None
kwargs["model_auth"] = None
return vectorise(*arg, **kwargs)
@@ -549,7 +549,7 @@ def run():
marqo_config=self.config, index_name=index.name,
embedding_request=EmbedRequest(
content=[image_url],
- image_download_headers={"Authorization": "my secret key"},
+ mediaDownloadHeaders={"Authorization": "my secret key"},
modelAuth=ModelAuth(s3=S3Auth(
aws_access_key_id='12345',
aws_secret_access_key='this-is-a-secret'))
@@ -564,7 +564,7 @@ def run():
self.assertEqual(len(call_args), 1)
vectorise_kwargs = call_args[0].kwargs
- self.assertEqual(vectorise_kwargs["image_download_headers"], {"Authorization": "my secret key"})
+ self.assertEqual(vectorise_kwargs["media_download_headers"], {"Authorization": "my secret key"})
self.assertEqual(vectorise_kwargs["model_auth"], ModelAuth(s3=S3Auth(
aws_access_key_id='12345',
aws_secret_access_key='this-is-a-secret')))
diff --git a/tests/tensor_search/integ_tests/test_search_combined.py b/tests/tensor_search/integ_tests/test_search_combined.py
index 14d38ab19..514a92e99 100644
--- a/tests/tensor_search/integ_tests/test_search_combined.py
+++ b/tests/tensor_search/integ_tests/test_search_combined.py
@@ -204,7 +204,7 @@ def test_search_video(self):
documents = [
{"video_field_1": "https://marqo-k400-video-test-dataset.s3.amazonaws.com/videos/---QUuC4vJs_000084_000094.mp4", "_id": "1"},
# Replace the audio link with something marqo-hosted
- {"audio_field_1": "https://marqo-ecs-50-audio-test-dataset.s3.amazonaws.com/audios/marqo-audio-test.mp3", "_id": "2"},
+ {"audio_field_1": "https://marqo-ecs-50-audio-test-dataset.s3.amazonaws.com/audios/marqo-audio-test.mp3", "_id": "2"},
{"image_field_1": TestImageUrls.HIPPO_REALISTIC_LARGE.value, "_id": "3"},
# {"image_field_1": TestImageUrls.HIPPO_REALISTIC.value, "_id": "5"}, # png image with palette is not supported
{"text_field_1": "hello there padawan. Today you will begin your training to be a Jedi", "_id": "4"},
@@ -239,7 +239,7 @@ def test_search_audio(self):
documents = [
{"video_field_1": "https://marqo-k400-video-test-dataset.s3.amazonaws.com/videos/---QUuC4vJs_000084_000094.mp4", "_id": "1"},
# Replace the audio link with something marqo-hosted
- {"audio_field_1": "https://marqo-ecs-50-audio-test-dataset.s3.amazonaws.com/audios/marqo-audio-test.mp3", "_id": "2"},
+ {"audio_field_1": "https://marqo-ecs-50-audio-test-dataset.s3.amazonaws.com/audios/marqo-audio-test.mp3", "_id": "2"},
{"image_field_1": TestImageUrls.HIPPO_REALISTIC_LARGE.value, "_id": "3"},
# {"image_field_1": TestImageUrls.HIPPO_REALISTIC.value, "_id": "5"}, # png file with palette is not supported
{"text_field_1": "hello there padawan. Today you will begin your training to be a Jedi", "_id": "4"},
@@ -262,7 +262,7 @@ def test_search_audio(self):
index_name=index.name,
text="https://marqo-ecs-50-audio-test-dataset.s3.amazonaws.com/audios/marqo-audio-test.mp3"
)
-
+
# Assertions
self.assertEqual(len(results['hits']), 3) # 3 documents should be returned (limit=3)
self.assertEqual(results['hits'][0]['_id'], "2") # The audio document should be the top result
diff --git a/tests/tensor_search/test_add_documents_use_existing_tensors.py b/tests/tensor_search/test_add_documents_use_existing_tensors.py
index cd9ea8e88..b1febcfc3 100644
--- a/tests/tensor_search/test_add_documents_use_existing_tensors.py
+++ b/tests/tensor_search/test_add_documents_use_existing_tensors.py
@@ -829,7 +829,7 @@ def run():
vectorised_content = [call_kwargs['content'] for call_args, call_kwargs
in mock_vectorise.call_args_list]
- artefact_pil_image = load_image_from_path(artefact_hippo_img, image_download_headers={})
+ artefact_pil_image = load_image_from_path(artefact_hippo_img, media_download_headers={})
expected_to_be_vectorised = [
["this is the updated 1st sentence.", "This is my second"],
["this is a brand new sentence.", "Yes it is"],
diff --git a/tests/tensor_search/test_api_utils.py b/tests/tensor_search/test_api_utils.py
index 437d81654..acb040651 100644
--- a/tests/tensor_search/test_api_utils.py
+++ b/tests/tensor_search/test_api_utils.py
@@ -98,13 +98,13 @@ def test_add_docs_params_orchestrator(self):
# Query parameters should be parsed as default values
non_tensor_fields = []
use_existing_tensors = False
- image_download_headers = dict()
+ media_download_headers = dict()
model_auth = None
mappings = dict()
# Call the function with the arguments
result = add_docs_params_orchestrator(index_name, body, device, non_tensor_fields, mappings,
- model_auth, image_download_headers, use_existing_tensors)
+ model_auth, media_download_headers, use_existing_tensors)
# Assert that the result is as expected
assert isinstance(result, AddDocsParams)
@@ -114,7 +114,7 @@ def test_add_docs_params_orchestrator(self):
assert result.non_tensor_fields == ["field1"]
assert result.use_existing_tensors == True
assert result.docs == [{"test": "doc"}]
- assert result.image_download_headers == {"header1": "value1"}
+ assert result.media_download_headers == {"header1": "value1"}
def test_add_docs_params_orchestrator_deprecated_query_parameters(self):
# Set up the arguments for the function
@@ -126,14 +126,14 @@ def test_add_docs_params_orchestrator_deprecated_query_parameters(self):
device = "test-device"
non_tensor_fields = ["field1"]
use_existing_tensors = True
- image_download_headers = {"header1": "value1"}
+ media_download_headers = {"header1": "value1"}
model_auth = model_auth
mappings = {"map1": "value1"}
auto_refresh = True
# Call the function with the arguments
result = add_docs_params_orchestrator(index_name, body, device, auto_refresh, non_tensor_fields, mappings,
- model_auth, image_download_headers, use_existing_tensors)
+ model_auth, media_download_headers, use_existing_tensors)
# Assert that the result is as expected
assert isinstance(result, AddDocsParams)
@@ -143,7 +143,7 @@ def test_add_docs_params_orchestrator_deprecated_query_parameters(self):
assert result.non_tensor_fields == ["field1"]
assert result.use_existing_tensors == True
assert result.docs == [{"test": "doc"}]
- assert result.image_download_headers == {"header1": "value1"}
+ assert result.media_download_headers == {"header1": "value1"}
def test_add_docs_params_orchestrator_error(self):
# Test the case where the function should raise an error due to invalid input
@@ -155,7 +155,7 @@ def test_add_docs_params_orchestrator_error(self):
device = "test-device"
non_tensor_fields = ["field1"]
use_existing_tensors = True
- image_download_headers = {"header1": "value1"}
+ media_download_headers = {"header1": "value1"}
model_auth = model_auth
mappings = {"map1": "value1"}
auto_refresh = True
@@ -163,7 +163,7 @@ def test_add_docs_params_orchestrator_error(self):
# Use pytest.raises to check for the error
try:
_ = add_docs_params_orchestrator(index_name, body, device, auto_refresh, non_tensor_fields, mappings,
- model_auth, image_download_headers, use_existing_tensors)
+ model_auth, media_download_headers, use_existing_tensors)
except InternalError as e:
self.assertIn("Unexpected request body type", str(e))
@@ -181,7 +181,7 @@ def test_add_docs_params_orchestrator_deprecated_query_parameters_error(self):
mappings={"map1": "value1"})
params = {"non_tensor_fields": ["what"], "use_existing_tensors": True,
- "image_download_headers": {"header2": "value2"}, "model_auth": model_auth,
+ "media_download_headers": {"header2": "value2"}, "model_auth": model_auth,
"mappings": {"map2": "value2"}}
for param, value in params.items():
diff --git a/tests/tensor_search/test_image_download_headers.py b/tests/tensor_search/test_image_download_headers.py
index ea692be9e..04c0ef0a7 100644
--- a/tests/tensor_search/test_image_download_headers.py
+++ b/tests/tensor_search/test_image_download_headers.py
@@ -62,11 +62,11 @@ def test_img_download_search(self):
tensor_search.create_vector_index(
config=self.config, index_name=self.index_name_1, index_settings=self.image_index_settings()
)
- image_download_headers = {"Authorization": "some secret key blah"}
+ media_download_headers = {"Authorization": "some secret key blah"}
self.add_documents(config=self.config, add_docs_params=AddDocsParams(
index_name=self.index_name_1, docs=[
{"_id": "1", "image": self.real_img_url}],
- auto_refresh=True, image_download_headers=image_download_headers, device="cpu"))
+ auto_refresh=True, media_download_headers=media_download_headers, device="cpu"))
def pass_through_requests_get(url, *args, **kwargs):
return requests_get(url, *args, **kwargs)
@@ -80,11 +80,11 @@ def pass_through_requests_get(url, *args, **kwargs):
# Perform a vector search
search_res = tensor_search._vector_text_search(
config=self.config, index_name=self.index_name_1,
- result_count=1, query=self.real_img_url, image_download_headers=image_download_headers, device="cpu"
+ result_count=1, query=self.real_img_url, media_download_headers=media_download_headers, device="cpu"
)
# Check if the image URL was called at least once with the correct headers
image_url_called = any(
- call_args[0] == self.real_img_url and call_kwargs.get('headers', None) == image_download_headers
+ call_args[0] == self.real_img_url and call_kwargs.get('headers', None) == media_download_headers
for call_args, call_kwargs in mock_get.call_args_list
)
assert image_url_called, "Image URL not called with the correct headers"
@@ -102,18 +102,18 @@ def pass_through_load_image_from_path(*arg, **kwargs):
@unittest.mock.patch("marqo.s2_inference.clip_utils.load_image_from_path", mock_load_image_from_path)
def run():
- image_download_headers = {"Authorization": "some secret key blah"}
+ media_download_headers = {"Authorization": "some secret key blah"}
# Add a document with an image URL
self.add_documents(config=self.config, add_docs_params=AddDocsParams(
index_name=self.index_name_1, docs=[
{"_id": "1", "image": self.real_img_url}
- ], auto_refresh=True, image_download_headers=image_download_headers, device="cpu"
+ ], auto_refresh=True, media_download_headers=media_download_headers, device="cpu"
))
# Check if load_image_from_path was called with the correct headers
assert len(mock_load_image_from_path.call_args_list) == 1
call_args, call_kwargs = mock_load_image_from_path.call_args_list[0]
- assert image_download_headers in call_args
+ assert media_download_headers in call_args
return True
assert run() is True
@@ -123,14 +123,14 @@ def test_img_download_bulk_search(self):
tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1,
index_settings=self.image_index_settings())
test_image_url = self.real_img_url
- image_download_headers = {"Authorization": "some secret key blah"}
+ media_download_headers = {"Authorization": "some secret key blah"}
def pass_through_load_image_from_path(*args, **kwargs):
return load_image_from_path(*args, **kwargs)
def pass_through_requests_get(url, *args, **kwargs):
if url == test_image_url:
- assert kwargs.get('headers', None) == image_download_headers
+ assert kwargs.get('headers', None) == media_download_headers
return requests_get(url, *args, **kwargs)
# Mock the load_image_from_path function
@@ -144,7 +144,7 @@ def pass_through_requests_get(url, *args, **kwargs):
"_id": "1",
"image": test_image_url,
}],
- auto_refresh=True, image_download_headers=image_download_headers, device="cpu"))
+ auto_refresh=True, media_download_headers=media_download_headers, device="cpu"))
# Set up the mock GET
mock_get = unittest.mock.MagicMock()
@@ -155,13 +155,13 @@ def pass_through_requests_get(url, *args, **kwargs):
bulk_search_query = BulkSearchQuery(queries=[{
"index": self.index_name_1,
"q": self.real_img_url,
- "image_download_headers": image_download_headers
+ "media_download_headers": media_download_headers
}])
resp = tensor_search.bulk_search(marqo_config=self.config, query=bulk_search_query)
# Check if the image URL was called at least once with the correct headers
image_url_called = any(
- call_args[0] == test_image_url and call_kwargs.get('headers', None) == image_download_headers
+ call_args[0] == test_image_url and call_kwargs.get('headers', None) == media_download_headers
for call_args, call_kwargs in mock_get.call_args_list
)
assert image_url_called, "Image URL not called with the correct headers"
diff --git a/tests/tensor_search/test_modalities_download.py b/tests/tensor_search/test_modalities_download.py
index 55142a5cf..0335d2a48 100644
--- a/tests/tensor_search/test_modalities_download.py
+++ b/tests/tensor_search/test_modalities_download.py
@@ -1,17 +1,21 @@
import unittest
from unittest.mock import Mock, patch, MagicMock
-from PIL import UnidentifiedImageError
+
+import ffmpeg
+import pytest
import torch
-from marqo.s2_inference.errors import UnsupportedModalityError, S2InferenceError
-from marqo.tensor_search.add_docs import threaded_download_and_preprocess_content
+from PIL import UnidentifiedImageError
+
from marqo.core.models.marqo_index import IndexType, MarqoIndex, FieldType
-from marqo.s2_inference.s2_inference import Modality
-from marqo.s2_inference.models.model_type import ModelType
-from marqo.tensor_search.telemetry import RequestMetricsStore, RequestMetrics
from marqo.s2_inference.errors import MediaDownloadError
-import ffmpeg
+from marqo.s2_inference.errors import UnsupportedModalityError, S2InferenceError
+from marqo.s2_inference.models.model_type import ModelType
+from marqo.s2_inference.s2_inference import Modality
+from marqo.tensor_search.add_docs import threaded_download_and_preprocess_content
+from marqo.tensor_search.telemetry import RequestMetrics
+@pytest.mark.unittest
class TestThreadedDownloadAndPreprocess(unittest.TestCase):
def setUp(self):
@@ -62,7 +66,7 @@ def test_image_unstructured_index(self, mock_infer_modality, mock_load_image):
tensor_fields = ["field1"]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
@@ -83,7 +87,7 @@ def test_image_structured_index(self, mock_infer_modality, mock_load_image):
media_field_types_mapping = {"field1": FieldType.ImagePointer}
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
media_field_types_mapping=media_field_types_mapping
@@ -106,7 +110,7 @@ def test_video_unstructured_index(self, mock_infer_modality, mock_download_and_c
tensor_fields = ["field1"]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
@@ -130,7 +134,7 @@ def test_audio_structured_index(self, mock_infer_modality, mock_download_and_chu
media_field_types_mapping = {"field1": FieldType.AudioPointer}
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
media_field_types_mapping=media_field_types_mapping
@@ -148,7 +152,7 @@ def test_unsupported_modality(self, mock_infer_modality):
tensor_fields = ["field1"]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
@@ -167,7 +171,7 @@ def test_image_load_error(self, mock_infer_modality, mock_load_image):
tensor_fields = ["field1"]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
@@ -189,7 +193,7 @@ def test_video_processing_error(self, mock_infer_modality, mock_download_and_chu
tensor_fields = ["field1"]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
@@ -217,7 +221,7 @@ def test_video_and_audio_unstructured_index(self, mock_infer_modality, mock_down
# Call the function
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
@@ -230,13 +234,30 @@ def test_video_and_audio_unstructured_index(self, mock_infer_modality, mock_down
# Verify that download_and_chunk_media was called twice
self.assertEqual(mock_download_and_chunk.call_count, 2)
+ print(mock_download_and_chunk.call_args_list)
# Verify the calls to download_and_chunk_media
mock_download_and_chunk.assert_any_call(
- self.mock_video_url, "cpu", None, Modality.VIDEO, self.mock_marqo_index.type, self.mock_marqo_index.model, None, None, None
+ url=self.mock_video_url,
+ device='cpu',
+ modality= Modality.VIDEO,
+ marqo_index_type = self.mock_marqo_index.type,
+ marqo_index_model = self.mock_marqo_index.model,
+ preprocessors = None,
+ audio_preprocessing = None,
+ video_preprocessing = None,
+ media_download_headers = {}
)
mock_download_and_chunk.assert_any_call(
- self.mock_audio_url, "cpu", None, Modality.AUDIO, self.mock_marqo_index.type, self.mock_marqo_index.model, None, None, None
+ url=self.mock_video_url,
+ device='cpu',
+ modality= Modality.VIDEO,
+ marqo_index_type = self.mock_marqo_index.type,
+ marqo_index_model = self.mock_marqo_index.model,
+ preprocessors = None,
+ audio_preprocessing = None,
+ video_preprocessing = None,
+ media_download_headers = {}
)
@patch("marqo.tensor_search.add_docs.download_and_chunk_media")
@@ -261,7 +282,7 @@ def test_mismatched_media_fields(self, mock_infer_modality, mock_download_and_ch
]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
media_field_types_mapping=media_field_types_mapping
@@ -291,7 +312,7 @@ def test_invalid_media_fields(self, mock_infer_modality):
mock_infer_modality.side_effect = [Modality.TEXT, Modality.TEXT]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
media_field_types_mapping=media_field_types_mapping
@@ -321,7 +342,7 @@ def test_ffmpeg_error_handling(self, mock_infer_modality, mock_download_and_chun
mock_download_and_chunk.side_effect = ffmpeg.Error("FFmpeg processing error", stdout=b"", stderr=b"")
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
media_field_types_mapping=media_field_types_mapping
@@ -347,7 +368,7 @@ def test_valid_image_processing(self, mock_infer_modality, mock_load_image):
media_field_types_mapping = {"image_field": FieldType.ImagePointer}
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
media_field_types_mapping=media_field_types_mapping
@@ -365,7 +386,7 @@ def test_media_download_error(self, mock_infer_modality):
tensor_fields = ["field1"]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
@@ -392,7 +413,7 @@ def test_audio_with_video_only_model(self, mock_infer_modality, mock_download_an
# Call the function
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
diff --git a/tests/tensor_search/test_search.py b/tests/tensor_search/test_search.py
index c44848c12..0a0fbdc15 100644
--- a/tests/tensor_search/test_search.py
+++ b/tests/tensor_search/test_search.py
@@ -1136,7 +1136,7 @@ def run() -> typing.List[float]:
weighted_vectors = []
for q, weight in multi_query.items():
vec = vectorise(model_name="ViT-B/16", content=[q, ],
- image_download_headers=None, normalize_embeddings=True,
+ media_download_headers=None, normalize_embeddings=True,
device="cpu")[0]
weighted_vectors.append(np.asarray(vec) * weight)