Skip to content

Commit

Permalink
Fix image download headers regresison and fix png image issue (#1022)
Browse files Browse the repository at this point in the history
This includes three bug fixes:
- Fix the image download headers regression so users can provide a media download header to download private images. This also applies to other modalities including video and audio.
- Fix a bug that languagebind models can not encode png images.
- Fix a bug that Marqo can not take more than 2 modalities in weighted queries. Marqo can now take weighted queries with a mixture of all modalities supported by the model.
  • Loading branch information
wanliAlex authored Oct 25, 2024
1 parent 58efc24 commit 5c372dd
Show file tree
Hide file tree
Showing 35 changed files with 770 additions and 601 deletions.
20 changes: 20 additions & 0 deletions src/marqo/api/models/add_docs_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Config:
tensorFields: Optional[List] = None
useExistingTensors: bool = False
imageDownloadHeaders: dict = Field(default_factory=dict)
mediaDownloadHeaders: Optional[dict] = None
modelAuth: Optional[ModelAuth] = None
mappings: Optional[dict] = None
documents: Union[Sequence[Union[dict, Any]], np.ndarray]
Expand All @@ -38,3 +39,22 @@ def validate_thread_counts(cls, values):
if media_count is not None and image_count != read_env_vars_and_defaults_ints(EnvVars.MARQO_IMAGE_DOWNLOAD_THREAD_COUNT_PER_REQUEST):
raise ValueError("Cannot set both imageDownloadThreadCount and mediaDownloadThreadCount")
return values

@root_validator(skip_on_failure=True)
def _validate_image_download_headers_and_media_download_headers(cls, values):
"""Validate imageDownloadHeaders and mediaDownloadHeaders. Raise an error if both are set.
If imageDownloadHeaders is set, set mediaDownloadHeaders to it and use mediaDownloadHeaders in the
rest of the code.
imageDownloadHeaders is deprecated and will be removed in the future.
"""
image_download_headers = values.get('imageDownloadHeaders')
media_download_headers = values.get('mediaDownloadHeaders')
if image_download_headers and media_download_headers:
raise ValueError("Cannot set both imageDownloadHeaders and mediaDownloadHeaders. "
"'imageDownloadHeaders' is deprecated and will be removed in the future. "
"Use mediaDownloadHeaders instead.")
if image_download_headers:
values['mediaDownloadHeaders'] = image_download_headers
return values
32 changes: 27 additions & 5 deletions src/marqo/api/models/embed_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@
import pydantic
from typing import Union, List, Dict, Optional, Any

from pydantic import Field, root_validator

from marqo.tensor_search.models.private_models import ModelAuth
from marqo.tensor_search.models.api_models import BaseMarqoModel
from marqo.base_model import MarqoBaseModel
from marqo.core.embed.embed import EmbedContentType



class EmbedRequest(BaseMarqoModel):
class EmbedRequest(MarqoBaseModel):
# content can be a single query or list of queries. Queries can be a string or a dictionary.
content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]]
image_download_headers: Optional[Dict] = None
image_download_headers: Optional[Dict] = Field(default=None, alias="imageDownloadHeaders")
mediaDownloadHeaders: Optional[Dict] = None
modelAuth: Optional[ModelAuth] = None
content_type: Optional[EmbedContentType] = EmbedContentType.Query
content_type: Optional[EmbedContentType] = Field(default=EmbedContentType.Query, alias="contentType")

@pydantic.validator('content')
def validate_content(cls, value):
Expand Down Expand Up @@ -47,4 +50,23 @@ def validate_content(cls, value):
else:
raise ValueError("Embed content should be a string, a dictionary, or a list of strings or dictionaries")

return value
return value

@root_validator(skip_on_failure=True)
def _validate_image_download_headers_and_media_download_headers(cls, values):
"""Validate imageDownloadHeaders and mediaDownloadHeaders. Raise an error if both are set.
If imageDownloadHeaders is set, set mediaDownloadHeaders to it and use mediaDownloadHeaders in the
rest of the code.
imageDownloadHeaders is deprecated and will be removed in the future.
"""
image_download_headers = values.get('imageDownloadHeaders')
media_download_headers = values.get('mediaDownloadHeaders')
if image_download_headers and media_download_headers:
raise ValueError("Cannot set both imageDownloadHeaders and mediaDownloadHeaders. "
"'imageDownloadHeaders' is deprecated and will be removed in the future. "
"Use mediaDownloadHeaders instead.")
if image_download_headers:
values['mediaDownloadHeaders'] = image_download_headers
return values
13 changes: 7 additions & 6 deletions src/marqo/core/embed/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ def validate_default_device(cls, value):
return value

def embed_content(
self, content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]],
index_name: str, device: str = None, image_download_headers: Optional[Dict] = None,
model_auth: Optional[ModelAuth] = None,
content_type: Optional[EmbedContentType] = EmbedContentType.Query
) -> Dict:
self, content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]],
index_name: str, device: str = None,
media_download_headers: Optional[Dict] = None,
model_auth: Optional[ModelAuth] = None,
content_type: Optional[EmbedContentType] = EmbedContentType.Query
) -> Dict:
"""
Use the index's model to embed the content
Expand Down Expand Up @@ -105,7 +106,7 @@ def embed_content(
BulkSearchQueryEntity(
q=content_entry,
index=marqo_index,
image_download_headers=image_download_headers,
mediaDownloadHeaders=media_download_headers,
modelAuth=model_auth,
text_query_prefix=prefix
# TODO: Check if it's fine that we leave out the other parameters
Expand Down
24 changes: 12 additions & 12 deletions src/marqo/core/inference/embedding_models/abstract_clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from marqo.core.inference.image_download import (_is_image, format_and_load_CLIP_images,
format_and_load_CLIP_image)
from marqo.core.inference.embedding_models.abstract_embedding_model import AbstractEmbeddingModel
from marqo.core.inference.embedding_models.image_download import (_is_image, format_and_load_CLIP_images,
from marqo.core.inference.image_download import (_is_image, format_and_load_CLIP_images,
format_and_load_CLIP_image)
from marqo.s2_inference.logger import get_logger
from marqo.s2_inference.types import *
Expand Down Expand Up @@ -50,11 +50,11 @@ def encode_text(self, inputs: Union[str, List[str]], normalize: bool = True) ->
pass

@abstractmethod
def encode_image(self, inputs, normalize: bool = True, image_download_headers: dict = None) -> np.ndarray:
def encode_image(self, inputs, normalize: bool = True, media_download_headers: dict = None) -> np.ndarray:
pass

def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
default: str = 'text', normalize=True, **kwargs) -> np.ndarray:
def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]], normalize=True, **kwargs) -> np.ndarray:
default = "text"
infer = kwargs.pop('infer', True)
if infer and _is_image(inputs):
is_image = True
Expand All @@ -68,8 +68,8 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],

if is_image:
logger.debug('image')
image_download_headers = kwargs.get("image_download_headers", dict())
return self.encode_image(inputs, normalize=normalize, image_download_headers=image_download_headers)
media_download_headers = kwargs.get("media_download_headers", dict())
return self.encode_image(inputs, normalize=normalize, media_download_headers=media_download_headers)
else:
logger.debug('text')
return self.encode_text(inputs, normalize=normalize)
Expand All @@ -85,27 +85,27 @@ def normalize(outputs):
return outputs.norm(dim=-1, keepdim=True)

def _preprocess_images(self, images: Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor],
image_download_headers: Optional[Dict] = None) -> Tensor:
media_download_headers: Optional[Dict] = None) -> Tensor:
"""Preprocess the input image to be ready for the model.
Args:
images (Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor]): input image,
can be a str(url), a PIL image, or a tensor, or a list of them
image_download_headers (Optional[Dict]): headers for the image download
media_download_headers (Optional[Dict]): headers for the image download
Return:
Tensor: the processed image tensor with shape (batch_size, channel, n_px, n_px)
"""
if self.model is None:
self.load()
if image_download_headers is None:
image_download_headers = dict()
if media_download_headers is None:
media_download_headers = dict()

# default to batch encoding
if isinstance(images, list):
image_input: List[Union[ImageType, Tensor]] \
= format_and_load_CLIP_images(images, image_download_headers)
= format_and_load_CLIP_images(images, media_download_headers)
else:
image_input: List[Union[ImageType, Tensor]] = [format_and_load_CLIP_image(images, image_download_headers)]
image_input: List[Union[ImageType, Tensor]] = [format_and_load_CLIP_image(images, media_download_headers)]

image_input_processed: Tensor = torch.stack([self.preprocess(_img).to(self.device) \
if not isinstance(_img, torch.Tensor) else _img \
Expand Down
Loading

0 comments on commit 5c372dd

Please sign in to comment.