diff --git a/src/marqo/core/embed/embed.py b/src/marqo/core/embed/embed.py index a04b35732..bead85221 100644 --- a/src/marqo/core/embed/embed.py +++ b/src/marqo/core/embed/embed.py @@ -117,12 +117,9 @@ def embed_content( # Vectorise the queries with RequestMetricsStore.for_request().time(f"embed.vector_inference_full_pipeline"): - try: - qidx_to_vectors: Dict[Qidx, List[float]] = tensor_search.run_vectorise_pipeline( - temp_config, queries, device - ) - except s2_inference_errors.MediaDownloadError as e: - raise api_exceptions.InvalidArgError(message=str(e)) from e + qidx_to_vectors: Dict[Qidx, List[float]] = tensor_search.run_vectorise_pipeline( + temp_config, queries, device + ) embeddings: List[List[float]] = list(qidx_to_vectors.values()) diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py index 6b48b6fab..8d893e823 100644 --- a/src/marqo/tensor_search/tensor_search.py +++ b/src/marqo/tensor_search/tensor_search.py @@ -2073,12 +2073,14 @@ def run_vectorise_pipeline(config: Config, queries: List[BulkSearchQueryEntity], """Run the query vectorisation process Raise: - MediaDownloadError: If the media cannot be downloaded. This error is raised before the vectorisation process. - api_exceptions.InvalidArgError: If the vectorisation process fails. + api_exceptions.InvalidArgError: If the vectorisation process fails or if the media cannot be downloaded. """ # Prepend the prefixes to the queries if it exists (output should be of type List[BulkSearchQueryEntity]) - prefixed_queries = add_prefix_to_queries(queries) + try: + prefixed_queries = add_prefix_to_queries(queries) + except s2_inference_errors.MediaDownloadError as e: + raise api_exceptions.InvalidArgError(message=str(e)) from e # 1. Pre-process inputs ready for s2_inference.vectorise # we can still use qidx_to_job. But the jobs structure may need to be different @@ -2176,10 +2178,7 @@ def _vector_text_search( )] with RequestMetricsStore.for_request().time(f"search.vector_inference_full_pipeline"): - try: - qidx_to_vectors: Dict[Qidx, List[float]] = run_vectorise_pipeline(config, queries, device) - except s2_inference_errors.MediaDownloadError as e: - raise api_exceptions.InvalidArgError(message=str(e)) from e + qidx_to_vectors: Dict[Qidx, List[float]] = run_vectorise_pipeline(config, queries, device) vectorised_text = list(qidx_to_vectors.values())[0] marqo_query = MarqoTensorQuery( diff --git a/tests/tensor_search/integ_tests/test_embed.py b/tests/tensor_search/integ_tests/test_embed.py index d12bea6c1..5591dbd7c 100644 --- a/tests/tensor_search/integ_tests/test_embed.py +++ b/tests/tensor_search/integ_tests/test_embed.py @@ -764,7 +764,29 @@ def test_embed_private_image_proper_error_raised(self): marqo_config=self.config, index_name=index_name, embedding_request=EmbedRequest( content=test_content - ) + ), + device="cpu" ) self.assertIn("Error downloading media file", str(e.exception)) - self.assertIn("403 Client Error", str(e.exception)) \ No newline at end of file + self.assertIn("403 Client Error", str(e.exception)) + + def test_embed_invalid_image_proper_error_raised(self): + """Test that a proper 400 error is raised when trying to embed an invalid image url.""" + test_content_lists = [ + ("https://a-dummy-image-url.jpg", "a single invalid image url"), + (["https://a-dummy-image-url.jpg", "test"], + "a list of content with an invalid image url") + ] + + for index_name in [self.unstructured_default_image_index.name, self.structured_default_image_index.name]: + for test_content, msg in test_content_lists: + with self.subTest(f"{index_name} - {msg}"): + with self.assertRaises(InvalidArgError) as e: + embed_res = embed( + marqo_config=self.config, index_name=index_name, + embedding_request=EmbedRequest( + content=test_content + ), + device="cpu" + ) + self.assertIn("Error vectorising content", str(e.exception)) \ No newline at end of file diff --git a/tests/tensor_search/integ_tests/test_search_combined.py b/tests/tensor_search/integ_tests/test_search_combined.py index c042b4d09..d40b9b932 100644 --- a/tests/tensor_search/integ_tests/test_search_combined.py +++ b/tests/tensor_search/integ_tests/test_search_combined.py @@ -1014,7 +1014,7 @@ def test_lexical_search_DoesNotErrorWithEscapedQuotes(self): self.assertEqual(set(expected_ids), {hit['_id'] for hit in res['hits']}) def test_search_private_image_return_proper_error(self): - """A test to ensure that InvalidArgumentError is raised when searching for an image in a private index.""" + """A test to ensure that InvalidArgumentError is raised when searching for a private image.""" test_queries_list = [ ("https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small", "A private image"), ({"https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small": 1, "test": 1}, @@ -1029,4 +1029,21 @@ def test_search_private_image_return_proper_error(self): text=query, config=self.config, index_name=index_name.name, ) self.assertIn("Error downloading media file", str(e.exception)) - self.assertIn("403 Client Error", str(e.exception)) \ No newline at end of file + self.assertIn("403 Client Error", str(e.exception)) + + def test_search_invalid_image_url_image_return_proper_error(self): + """A test to ensure that InvalidArgumentError is raised when searching for an invalid image url.""" + test_queries_list = [ + ("https://a-dummy-image-url.jpg", "A invalid image"), + ({"https://a-dummy-image-url.jpg": 1, "test": 1}, + "A invalid image in the dictionary") + ] + + for index_name in [self.structured_default_image_index, self.unstructured_default_image_index]: + for query, msg in test_queries_list: + with self.subTest(f"{index_name} - {query}"): + with self.assertRaises(api_exceptions.InvalidArgError) as e: + tensor_search.search( + text=query, config=self.config, index_name=index_name.name, + ) + self.assertIn("Error vectorising content", str(e.exception)) \ No newline at end of file