Skip to content

Commit

Permalink
Add tests for invalid image url
Browse files Browse the repository at this point in the history
  • Loading branch information
wanliAlex committed Oct 31, 2024
1 parent 3f946c3 commit 9054c1b
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 17 deletions.
9 changes: 3 additions & 6 deletions src/marqo/core/embed/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,9 @@ def embed_content(

# Vectorise the queries
with RequestMetricsStore.for_request().time(f"embed.vector_inference_full_pipeline"):
try:
qidx_to_vectors: Dict[Qidx, List[float]] = tensor_search.run_vectorise_pipeline(
temp_config, queries, device
)
except s2_inference_errors.MediaDownloadError as e:
raise api_exceptions.InvalidArgError(message=str(e)) from e
qidx_to_vectors: Dict[Qidx, List[float]] = tensor_search.run_vectorise_pipeline(
temp_config, queries, device
)

embeddings: List[List[float]] = list(qidx_to_vectors.values())

Expand Down
13 changes: 6 additions & 7 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2073,12 +2073,14 @@ def run_vectorise_pipeline(config: Config, queries: List[BulkSearchQueryEntity],
"""Run the query vectorisation process
Raise:
MediaDownloadError: If the media cannot be downloaded. This error is raised before the vectorisation process.
api_exceptions.InvalidArgError: If the vectorisation process fails.
api_exceptions.InvalidArgError: If the vectorisation process fails or if the media cannot be downloaded.
"""

# Prepend the prefixes to the queries if it exists (output should be of type List[BulkSearchQueryEntity])
prefixed_queries = add_prefix_to_queries(queries)
try:
prefixed_queries = add_prefix_to_queries(queries)
except s2_inference_errors.MediaDownloadError as e:
raise api_exceptions.InvalidArgError(message=str(e)) from e

# 1. Pre-process inputs ready for s2_inference.vectorise
# we can still use qidx_to_job. But the jobs structure may need to be different
Expand Down Expand Up @@ -2176,10 +2178,7 @@ def _vector_text_search(
)]

with RequestMetricsStore.for_request().time(f"search.vector_inference_full_pipeline"):
try:
qidx_to_vectors: Dict[Qidx, List[float]] = run_vectorise_pipeline(config, queries, device)
except s2_inference_errors.MediaDownloadError as e:
raise api_exceptions.InvalidArgError(message=str(e)) from e
qidx_to_vectors: Dict[Qidx, List[float]] = run_vectorise_pipeline(config, queries, device)
vectorised_text = list(qidx_to_vectors.values())[0]

marqo_query = MarqoTensorQuery(
Expand Down
26 changes: 24 additions & 2 deletions tests/tensor_search/integ_tests/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,29 @@ def test_embed_private_image_proper_error_raised(self):
marqo_config=self.config, index_name=index_name,
embedding_request=EmbedRequest(
content=test_content
)
),
device="cpu"
)
self.assertIn("Error downloading media file", str(e.exception))
self.assertIn("403 Client Error", str(e.exception))
self.assertIn("403 Client Error", str(e.exception))

def test_embed_invalid_image_proper_error_raised(self):
"""Test that a proper 400 error is raised when trying to embed an invalid image url."""
test_content_lists = [
("https://a-dummy-image-url.jpg", "a single invalid image url"),
(["https://a-dummy-image-url.jpg", "test"],
"a list of content with an invalid image url")
]

for index_name in [self.unstructured_default_image_index.name, self.structured_default_image_index.name]:
for test_content, msg in test_content_lists:
with self.subTest(f"{index_name} - {msg}"):
with self.assertRaises(InvalidArgError) as e:
embed_res = embed(
marqo_config=self.config, index_name=index_name,
embedding_request=EmbedRequest(
content=test_content
),
device="cpu"
)
self.assertIn("Error vectorising content", str(e.exception))
21 changes: 19 additions & 2 deletions tests/tensor_search/integ_tests/test_search_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ def test_lexical_search_DoesNotErrorWithEscapedQuotes(self):
self.assertEqual(set(expected_ids), {hit['_id'] for hit in res['hits']})

def test_search_private_image_return_proper_error(self):
"""A test to ensure that InvalidArgumentError is raised when searching for an image in a private index."""
"""A test to ensure that InvalidArgumentError is raised when searching for a private image."""
test_queries_list = [
("https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small", "A private image"),
({"https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small": 1, "test": 1},
Expand All @@ -1029,4 +1029,21 @@ def test_search_private_image_return_proper_error(self):
text=query, config=self.config, index_name=index_name.name,
)
self.assertIn("Error downloading media file", str(e.exception))
self.assertIn("403 Client Error", str(e.exception))
self.assertIn("403 Client Error", str(e.exception))

def test_search_invalid_image_url_image_return_proper_error(self):
"""A test to ensure that InvalidArgumentError is raised when searching for an invalid image url."""
test_queries_list = [
("https://a-dummy-image-url.jpg", "A invalid image"),
({"https://a-dummy-image-url.jpg": 1, "test": 1},
"A invalid image in the dictionary")
]

for index_name in [self.structured_default_image_index, self.unstructured_default_image_index]:
for query, msg in test_queries_list:
with self.subTest(f"{index_name} - {query}"):
with self.assertRaises(api_exceptions.InvalidArgError) as e:
tensor_search.search(
text=query, config=self.config, index_name=index_name.name,
)
self.assertIn("Error vectorising content", str(e.exception))

0 comments on commit 9054c1b

Please sign in to comment.