Skip to content

Commit

Permalink
Improve the performance of processing large search response (#1091)
Browse files Browse the repository at this point in the history
  • Loading branch information
papa99do authored Jan 17, 2025
1 parent f29fe81 commit 8980c62
Show file tree
Hide file tree
Showing 18 changed files with 561 additions and 188 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/largemodel_unit_test_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,14 @@ jobs:
uses: actions/checkout@v3
with:
repository: marqo-ai/marqo-base
ref: 'releases/2.13'
path: marqo-base

- name: Install dependencies
run: |
pip install -r marqo-base/requirements/amd64-gpu-requirements.txt
pip install -r marqo-base/requirements.txt
# override base requirements with marqo requirements, if needed:
pip install -r marqo/requirements.txt
pip install -r marqo/requirements.dev.txt
pip install pytest==7.4.0
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/unit_test_200gb_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ jobs:
uses: actions/checkout@v3
with:
repository: marqo-ai/marqo-base
ref: 'releases/2.13'
path: marqo-base

- name: Install dependencies
run: |
pip install -r marqo-base/requirements/amd64-gpu-requirements.txt
pip install -r marqo-base/requirements.txt
# override base requirements with marqo requirements, if needed:
pip install -r marqo/requirements.txt
pip install -r marqo/requirements.dev.txt
- name: Build Vespa
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ readerwriterlock==1.0.9
kazoo==2.10.0
pycurl==7.45.3
huggingface-hub==0.25.0
jinja2==3.1.4
jinja2==3.1.4
orjson==3.10.14
23 changes: 19 additions & 4 deletions src/marqo/core/inference/tensor_fields_container.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hashlib
import json
from abc import ABC, abstractmethod
from typing import List, Dict, Set, Optional, Any, Generator, Tuple, cast, TypeVar
from typing import List, Dict, Set, Optional, Any, Generator, Tuple, cast, TypeVar, Callable, Union

import numpy as np
from PIL.Image import Image
Expand All @@ -10,7 +10,7 @@

from marqo.core import constants
from marqo.core.constants import MARQO_DOC_ID
from marqo.core.exceptions import AddDocumentsError, ModelError
from marqo.core.exceptions import AddDocumentsError, ModelError, InternalError
from marqo.core.models.marqo_index import FieldType, TextPreProcessing, ImagePreProcessing
from marqo.s2_inference import errors as s2_inference_errors
from marqo.s2_inference import s2_inference
Expand Down Expand Up @@ -493,7 +493,20 @@ def populate_tensor_from_existing_doc(self, existing_marqo_doc: Dict[str, Any],
tensor_content.populate_chunks_and_embeddings(existing_tensor[constants.MARQO_DOC_CHUNKS],
existing_tensor[constants.MARQO_DOC_EMBEDDINGS])

def collect(self, doc_id: str, field_name: str, field_content: Any, field_type: Optional[FieldType]) -> Any:
def collect(self, doc_id: str, field_name: str, field_content: Any,
infer_field_type: Callable[[str, Any], FieldType]) -> Any:
"""
Collect tensor field content from the document if it is a tensor field.
Args:
doc_id: document id
field_name: name of the field
field_content: content of the field
infer_field_type: A callable that takes the field content and field name and returns the field type, or
a FieldType enum value
Returns:
The field content
"""
if field_name not in self._tensor_fields and field_name not in self._multimodal_sub_field_reverse_map:
# not tensor fields, no need to collect
return field_content
Expand All @@ -511,6 +524,8 @@ def collect(self, doc_id: str, field_name: str, field_content: Any, field_type:
f'Invalid type {type(field_content)} for tensor field {field_name}'
)

field_type = infer_field_type(field_name, field_content)

self._add_tensor_field_content(
doc_id, field_name, TensorFieldContent(
field_content=field_content,
Expand Down Expand Up @@ -556,4 +571,4 @@ def collect_multi_modal_fields(self, doc_id: str, normalize_embeddings: bool):
is_multimodal_subfield=False,
normalize_embeddings=normalize_embeddings
))
yield field_name, weights
yield field_name, weights
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,8 @@ def __init__(self, marqo_index: SemiStructuredMarqoIndex, add_docs_params: AddDo
self.field_count_config = field_count_config

def _handle_field(self, marqo_doc, field_name, field_content):
self._validate_field(field_name, field_content)
text_field_type = self._infer_field_type(
field_content,
media_download_headers=self.add_docs_params.media_download_headers
)
content = self.tensor_fields_container.collect(marqo_doc[MARQO_DOC_ID], field_name,
field_content, text_field_type)
marqo_doc[field_name] = content

if isinstance(content, str):
super()._handle_field(marqo_doc, field_name, field_content)
if isinstance(marqo_doc[field_name], str):
self._add_lexical_field_to_index(field_name)

def _to_vespa_doc(self, doc: Dict[str, Any]) -> VespaDocument:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def from_vespa_document(cls, document: Dict, marqo_index: SemiStructuredMarqoInd
text_fields[field_name] = fields[field_name]

return cls(id=document[cls._VESPA_DOC_ID],
fixed_fields=SemiStructuredVespaDocumentFields(**fields),
fixed_fields=SemiStructuredVespaDocumentFields.construct(**fields),
tensor_fields=tensor_fields,
text_fields=text_fields,
raw_tensor_score=cls.extract_field(fields, common.VESPA_DOC_HYBRID_RAW_TENSOR_SCORE, None),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,15 @@ def _validate_add_docs_params(self, add_docs_params: AddDocsParams, marqo_index:

def _handle_field(self, marqo_doc, field_name, field_content):
self._validate_field(field_name, field_content)
field_type = self.marqo_index.field_map[field_name].type
content = self.tensor_fields_container.collect(marqo_doc[MARQO_DOC_ID], field_name, field_content, field_type)
content = self.tensor_fields_container.collect(
marqo_doc[MARQO_DOC_ID], field_name, field_content,
self._infer_field_type
)
marqo_doc[field_name] = content

def _infer_field_type(self, field_name:str, field_content: Any) -> FieldType:
return self.marqo_index.field_map[field_name].type

def _validate_field(self, field_name: str, field_content: Any) -> None:
try:
# TODO extract the validation logic somewhere else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@
from marqo.core.vespa_index.add_documents_handler import AddDocumentsHandler, AddDocumentsError
from marqo.s2_inference.errors import MediaDownloadError
from marqo.s2_inference.multimodal_model_load import infer_modality, Modality

from marqo.vespa.models import VespaDocument
from marqo.vespa.models.get_document_response import Document
from marqo.vespa.vespa_client import VespaClient

# TODO deps to tensor_search needs to be removed
from marqo.tensor_search.constants import ALLOWED_UNSTRUCTURED_FIELD_TYPES
from marqo.tensor_search.validation import validate_custom_vector, \
validate_map_numeric_field
from marqo.vespa.models import VespaDocument
from marqo.vespa.models.get_document_response import Document
from marqo.vespa.vespa_client import VespaClient


class UnstructuredAddDocumentsHandler(AddDocumentsHandler):
Expand Down Expand Up @@ -67,28 +65,52 @@ def _validate_doc(self, doc):

def _handle_field(self, marqo_doc, field_name, field_content):
self._validate_field(field_name, field_content)
text_field_type = self._infer_field_type(field_content, self.add_docs_params.media_download_headers)
content = self.tensor_fields_container.collect(marqo_doc[MARQO_DOC_ID], field_name,
field_content, text_field_type)
content = self.tensor_fields_container.collect(
marqo_doc[MARQO_DOC_ID], field_name, field_content,
self._infer_field_type
)
marqo_doc[field_name] = content

def _infer_field_type(self, field_content: Any, media_download_headers: Optional[Dict] = None) \
-> Optional[FieldType]:
if not isinstance(field_content, str):
return None

try:
modality = infer_modality(field_content, media_download_headers)

if not self.marqo_index.treat_urls_and_pointers_as_media and modality in [Modality.AUDIO, Modality.VIDEO]:
modality = Modality.TEXT

if not self.marqo_index.treat_urls_and_pointers_as_images and modality == Modality.IMAGE:
def _infer_field_type(self, field_name:str, field_content: Any) -> FieldType:
"""Infer the field type based on the field content. This is used for both unstructured and semi-structured
indexes.
We should only infer the field type if the field content is a string.
We only infer the field type if the index is configured to treat URLs and pointers as images or media.
treatUrlsAndPointersAsMedia is a new parameter introduced in Marqo 2.12 to support the new modalities
of video and audio. Here is how it interacts with treatUrlsAndPointersAsImages:
Both False: All content is processed as text only.
treatUrlsAndPointersAsImages True, treatUrlsAndPointersAsMedia False:
Processes URLs and pointers as images
Does not process other media types (video, audio)
treatUrlsAndPointersAsImages False, treatUrlsAndPointersAsMedia True:
Invalid state since this is a conflict.
Both True:
Processes URLs and pointers as various media types (images, videos, audio)
The values of treatUrlsAndPointersAsMedia and treatUrlsAndPointersAsImages are validated in the MarqoIndex class
so we do not need to validate them here.
Args:
field_content: The content of the field.
Returns:
The inferred field type.
Raises:
AddDocumentsError: If the modality of the media content cannot be inferred.
"""
if (self.marqo_index.treat_urls_and_pointers_as_images is True or
self.marqo_index.treat_urls_and_pointers_as_media is True):
try:
modality = infer_modality(field_content, self.add_docs_params.media_download_headers)
except MediaDownloadError as err:
raise AddDocumentsError(err.message) from err
if ((self.marqo_index.treat_urls_and_pointers_as_media is False) and modality in
[Modality.AUDIO, Modality.VIDEO]):
modality = Modality.TEXT

return MODALITY_FIELD_TYPE_MAP[modality]
except MediaDownloadError as err:
raise AddDocumentsError(err.message) from err
else:
return FieldType.Text

def _validate_field(self, field_name: str, field_content: Any) -> None:
try:
Expand Down
7 changes: 7 additions & 0 deletions src/marqo/core/vespa_index/add_documents_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ def _handle_field(self, marqo_doc, field_name, field_content) -> None:
"""
pass

@abstractmethod
def _infer_field_type(self, field_name: str, field_content: Any) -> FieldType:
"""
This method infers the field type of a field based on the field name and content.
"""
pass

@abstractmethod
def _handle_multi_modal_fields(self, marqo_doc: Dict[str, Any]) -> None:
"""
Expand Down
5 changes: 3 additions & 2 deletions src/marqo/tensor_search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi import Depends, FastAPI, Request
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, ORJSONResponse
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY

from marqo import config, marqo_docs
Expand Down Expand Up @@ -267,7 +267,7 @@ def create_index(index_name: str, settings: IndexSettings, marqo_config: config.
def search(search_query: SearchQuery, index_name: str, device: str = Depends(api_validation.validate_device),
marqo_config: config.Config = Depends(get_config)):
with RequestMetricsStore.for_request().time(f"POST /indexes/{index_name}/search"):
return tensor_search.search(
result = tensor_search.search(
config=marqo_config, text=search_query.q,
index_name=index_name, highlights=search_query.showHighlights,
searchable_attributes=search_query.searchableAttributes,
Expand All @@ -284,6 +284,7 @@ def search(search_query: SearchQuery, index_name: str, device: str = Depends(api
text_query_prefix=search_query.textQueryPrefix,
hybrid_parameters=search_query.hybridParameters
)
return ORJSONResponse(result)


@app.post("/indexes/{index_name}/recommend")
Expand Down
43 changes: 27 additions & 16 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@
"""
import copy
import json
import traceback
import typing
import uuid
import os
from collections import defaultdict
from contextlib import ExitStack
from timeit import default_timer as timer
from typing import List, Optional, Union, Iterable, Sequence, Dict, Any, Tuple
from typing import List, Optional, Union, Iterable, Sequence, Dict, Any, Tuple, Set

import numpy as np
import psutil
Expand Down Expand Up @@ -1775,16 +1774,26 @@ def gather_documents_from_response(response: QueryResult, marqo_index: MarqoInde
"""
Convert a VespaQueryResponse to a Marqo search response
"""

if (marqo_index.type in [IndexType.Unstructured, IndexType.SemiStructured] and
attributes_to_retrieve is not None):
# Unstructured index and Semi-structured index stores fixed fields (numeric, boolean, string arrays, etc.) in
# combined field. It needs to select attributes after converting vespa doc to marqo doc if
# attributes_to_retrieve is specified
metadata_fields_to_retrieve = {"_id", "_score", "_highlights"}
attributes_to_retrieve_set = set(attributes_to_retrieve).union(metadata_fields_to_retrieve)
else:
# If this set is None, we will return the marqo_doc as is.
attributes_to_retrieve_set = None

vespa_index = vespa_index_factory(marqo_index)
hits = []
for doc in response.hits:
marqo_doc = vespa_index.to_marqo_document(doc.dict(), return_highlights=highlights)
marqo_doc = vespa_index.to_marqo_document(dict(doc), return_highlights=highlights)
marqo_doc['_score'] = doc.relevance

if (marqo_index.type in [IndexType.Unstructured, IndexType.SemiStructured] and
attributes_to_retrieve is not None):
# For an unstructured index, we do the attributes_to_retrieve after search
marqo_doc = unstructured_index_attributes_to_retrieve(marqo_doc, attributes_to_retrieve)
if attributes_to_retrieve_set is not None:
marqo_doc = select_attributes(marqo_doc, attributes_to_retrieve_set)

# Delete chunk data
if constants.MARQO_DOC_TENSORS in marqo_doc:
Expand All @@ -1794,16 +1803,18 @@ def gather_documents_from_response(response: QueryResult, marqo_index: MarqoInde
return {'hits': hits}


def unstructured_index_attributes_to_retrieve(marqo_doc: Dict[str, Any], attributes_to_retrieve: List[str]) -> Dict[
str, Any]:
# attributes_to_retrieve should already be validated at the start of search
attributes_to_retrieve = list(set(attributes_to_retrieve).union({"_id", "_score", "_highlights"}))
return {k: v for k, v in marqo_doc.items() if k in attributes_to_retrieve or
# Please note that numeric map fields are flattened for unstructured or semi-structured indexes.
# Therefore, when filtering on attributes_to_retrieve, we need to also include flattened map fields
# with the specified attributes as prefixes. We keep this behaviour only for compatibility reasons.
any([k.startswith(attribute + ".") for attribute in attributes_to_retrieve])}
def select_attributes(marqo_doc: Dict[str, Any], attributes_to_retrieve_set: Set[str]) -> Dict[str, Any]:
"""
Unstructured index and Semi-structured index retrieve all fixed fields (numeric, boolean, string arrays, etc.)
from Vespa when attributes_to_retrieve is specified. After converting the Vespa doc to Marqo doc, it needs to
filter out attributes not in the attributes_to_retrieve list.
Please note that numeric map fields are flattened for unstructured or semi-structured indexes.
Therefore, when filtering on attributes_to_retrieve, we need to also include flattened map fields
with the specified attributes as prefixes. We keep this behaviour only for compatibility reasons.
"""
return {k: v for k, v in marqo_doc.items() if k in attributes_to_retrieve_set or
'.' in k and k.split('.', maxsplit=1)[0] in attributes_to_retrieve_set}

def assign_query_to_vector_job(
q: BulkSearchQueryEntity, jobs: Dict[JHash, VectorisedJobs],
Expand Down
2 changes: 1 addition & 1 deletion src/marqo/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.13.4"
__version__ = "2.13.5"

def get_version() -> str:
return f"{__version__}"
3 changes: 2 additions & 1 deletion src/marqo/vespa/vespa_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import httpcore
import httpx
import orjson

import marqo.logging
import marqo.vespa.concurrency as conc
Expand Down Expand Up @@ -245,7 +246,7 @@ def query(self, yql: str, hits: int = 10, ranking: str = None, model_restrict: s

self._query_raise_for_status(resp)

return QueryResult(**resp.json())
return QueryResult(**orjson.loads(resp.text))

def feed_document(self, document: VespaDocument, schema: str, timeout: int = 60) -> FeedDocumentResponse:
"""
Expand Down
Loading

0 comments on commit 8980c62

Please sign in to comment.