diff --git a/tests/unit/test_match.py b/tests/unit/test_match.py index 8ab70b00..93b4591a 100644 --- a/tests/unit/test_match.py +++ b/tests/unit/test_match.py @@ -102,6 +102,44 @@ def test_match_exclude_dataset(): assert len(res["results"]) == 0, res +def test_match_include_dataset(): + # When querying Putin + query = {"queries": {"vv": EXAMPLE}} + # Using only datasets that do not include Putin + params = { + "algorithm": "name-based", + "include_dataset": ["ae_local_terrorists", "mx_governors"], + } + resp = client.post("/match/default", json=query, params=params) + # We should get a succesful response + assert resp.status_code == 200, resp.text + data = resp.json() + res = data["responses"]["vv"] + # And we should get no matches + assert len(res["results"]) == 0, res + # When using a dataset that includes Putin + params = { + "algorithm": "name-based", + "include_dataset": ["eu_fsf", "ae_local_terrorists"], + } + resp = client.post("/match/default", json=query, params=params) + data = resp.json() + res = data["responses"]["vv"] + # And we should get matches + assert len(res["results"]) > 0, res + # When we exclude the eu_fsf dataset + params = { + "algorithm": "name-based", + "include_dataset": ["eu_fsf", "mx_governors", "ae_local_terrorists"], + "exclude_dataset": "eu_fsf", + } + # We should get no matches + resp = client.post("/match/default", json=query, params=params) + data = resp.json() + res = data["responses"]["vv"] + assert len(res["results"]) == 0, res + + def test_filter_topic(): query = {"queries": {"vv": EXAMPLE}} params = {"algorithm": "name-based", "topics": "crime.cyber"} diff --git a/tests/unit/test_search.py b/tests/unit/test_search.py index 09a65219..dfcbe012 100644 --- a/tests/unit/test_search.py +++ b/tests/unit/test_search.py @@ -68,6 +68,29 @@ def test_search_filter_exclude_dataset(): assert new_total == 0 +def test_search_filter_include_dataset(): + res = client.get("/search/default?q=vladimir putin") + assert res.status_code == 200, res + total = res.json()["total"]["value"] + assert total > 0, total + # When we include a dataset that does not contain Putin or is not available + # in the collection we should get no results + res = client.get("/search/default?q=vladimir putin&include_dataset=mx_senators") + assert res.status_code == 200, res + new_total = res.json()["total"]["value"] + assert new_total == 0 + # When we include a dataset that contains Putin we should get results + res = client.get("/search/default?q=vladimir putin&include_dataset=eu_fsf") + new_total = res.json()["total"]["value"] + assert new_total > 0 + # When using both include and exclude, the exclude should take precedence + res = client.get( + "/search/default?q=vladimir putin&include_dataset=eu_fsf&exclude_dataset=eu_fsf" + ) + new_total = res.json()["total"]["value"] + assert new_total == 0 + + def test_search_filter_changed_since(): ts = datetime.utcnow() + timedelta(days=1) tx = ts.isoformat(sep="T", timespec="minutes") diff --git a/yente/routers/match.py b/yente/routers/match.py index d0776510..dd4d5934 100644 --- a/yente/routers/match.py +++ b/yente/routers/match.py @@ -47,6 +47,9 @@ async def match( title="Lower bound of score for results to be returned at all", ), algorithm: str = Query(settings.DEFAULT_ALGORITHM, title=ALGO_HELP), + include_dataset: List[str] = Query( + [], title="Only include the given datasets in results" + ), exclude_schema: List[str] = Query( [], title="Remove the given types of entities from results" ), @@ -144,6 +147,7 @@ async def match( entity, filters=filters, fuzzy=fuzzy, + include_dataset=include_dataset, exclude_schema=exclude_schema, exclude_dataset=exclude_dataset, changed_since=changed_since, diff --git a/yente/routers/search.py b/yente/routers/search.py index 31e10c31..f6a51698 100644 --- a/yente/routers/search.py +++ b/yente/routers/search.py @@ -53,6 +53,9 @@ async def search( schema: str = Query( settings.BASE_SCHEMA, title="Types of entities that can match the search" ), + include_dataset: List[str] = Query( + [], title="Only include the given datasets in results" + ), exclude_schema: List[str] = Query( [], title="Remove the given types of entities from results" ), @@ -109,6 +112,7 @@ async def search( filters=filters, fuzzy=fuzzy, simple=simple, + include_dataset=include_dataset, exclude_schema=exclude_schema, exclude_dataset=exclude_dataset, changed_since=changed_since, diff --git a/yente/search/queries.py b/yente/search/queries.py index 4c140902..5aef04f2 100644 --- a/yente/search/queries.py +++ b/yente/search/queries.py @@ -21,13 +21,19 @@ def filter_query( dataset: Optional[Dataset] = None, schema: Optional[Schema] = None, filters: FilterDict = {}, + include_dataset: List[str] = [], exclude_schema: List[str] = [], exclude_dataset: List[str] = [], changed_since: Optional[str] = None, ) -> Clause: filterqs: List[Clause] = [] if dataset is not None: - ds = [d for d in dataset.dataset_names if d not in exclude_dataset] + ds = [ + d + for d in dataset.dataset_names + if (len(include_dataset) == 0 or d in include_dataset) + and d not in exclude_dataset + ] filterqs.append({"terms": {"datasets": ds}}) if schema is not None: schemata = schema.matchable_schemata @@ -76,7 +82,7 @@ def names_query(entity: EntityProxy, fuzzy: bool = True) -> List[Clause]: term = {NAME_KEY_FIELD: {"value": key, "boost": 4.0}} shoulds.append({"term": term}) for token in set(index_name_parts(names)): - term = {NAME_PART_FIELD: {"value": token, 'boost': 1.0}} + term = {NAME_PART_FIELD: {"value": token, "boost": 1.0}} shoulds.append({"term": term}) for phoneme in set(phonetic_names(names)): term = {NAME_PHONETIC_FIELD: {"value": phoneme, "boost": 0.8}} @@ -89,6 +95,7 @@ def entity_query( entity: EntityProxy, filters: FilterDict = {}, fuzzy: bool = True, + include_dataset: List[str] = [], exclude_schema: List[str] = [], exclude_dataset: List[str] = [], changed_since: Optional[str] = None, @@ -110,6 +117,7 @@ def entity_query( filters=filters, dataset=dataset, schema=entity.schema, + include_dataset=include_dataset, exclude_schema=exclude_schema, exclude_dataset=exclude_dataset, changed_since=changed_since, @@ -123,6 +131,7 @@ def text_query( filters: FilterDict = {}, fuzzy: bool = False, simple: bool = False, + include_dataset: List[str] = [], exclude_schema: List[str] = [], exclude_dataset: List[str] = [], changed_since: Optional[str] = None, @@ -156,6 +165,7 @@ def text_query( dataset=dataset, schema=schema, filters=filters, + include_dataset=include_dataset, exclude_schema=exclude_schema, exclude_dataset=exclude_dataset, changed_since=changed_since,