Skip to content

Commit

Permalink
refactor: Nested search params in ES config (#120)
Browse files Browse the repository at this point in the history
* refactor: Nested search params in ES config

Co-authored-by: filipe oliveira <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: Remove extra config vars

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: Remove extra comment

* feat: Add keyword, text, and float index types in ES

---------

Co-authored-by: filipe oliveira <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 15, 2024
1 parent 2ffe5e2 commit 5f2121e
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 70 deletions.
28 changes: 24 additions & 4 deletions engine/clients/elasticsearch/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
ELASTIC_PORT = 9200
ELASTIC_INDEX = "bench"
ELASTIC_USER = "elastic"
ELASTIC_PASSWORD = "passwd"
import os

from elasticsearch import Elasticsearch

ELASTIC_PORT = int(os.getenv("ELASTIC_PORT", 9200))
ELASTIC_INDEX = os.getenv("ELASTIC_INDEX", "bench")
ELASTIC_USER = os.getenv("ELASTIC_USER", "elastic")
ELASTIC_PASSWORD = os.getenv("ELASTIC_PASSWORD", "passwd")


def get_es_client(host, connection_params):
init_params = {
"verify_certs": False,
"retry_on_timeout": True,
"ssl_show_warn": False,
**connection_params,
}
client = Elasticsearch(
f"http://{host}:{ELASTIC_PORT}",
basic_auth=(ELASTIC_USER, ELASTIC_PASSWORD),
**init_params,
)
assert client.ping()
return client
34 changes: 8 additions & 26 deletions engine/clients/elasticsearch/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@
from engine.base_client import IncompatibilityError
from engine.base_client.configure import BaseConfigurator
from engine.base_client.distances import Distance
from engine.clients.elasticsearch.config import (
ELASTIC_INDEX,
ELASTIC_PASSWORD,
ELASTIC_PORT,
ELASTIC_USER,
)
from engine.clients.elasticsearch.config import ELASTIC_INDEX, get_es_client


class ElasticConfigurator(BaseConfigurator):
Expand All @@ -20,24 +15,15 @@ class ElasticConfigurator(BaseConfigurator):
}
INDEX_TYPE_MAPPING = {
"int": "long",
"keyword": "keyword",
"text": "text",
"float": "double",
"geo": "geo_point",
}

def __init__(self, host, collection_params: dict, connection_params: dict):
super().__init__(host, collection_params, connection_params)
init_params = {
**{
"verify_certs": False,
"request_timeout": 90,
"retry_on_timeout": True,
},
**connection_params,
}
self.client = Elasticsearch(
f"http://{host}:{ELASTIC_PORT}",
basic_auth=(ELASTIC_USER, ELASTIC_PASSWORD),
**init_params,
)
self.client = get_es_client(host, connection_params)

def clean(self):
try:
Expand All @@ -60,7 +46,7 @@ def recreate(self, dataset: Dataset, collection_params):
"index": {
"number_of_shards": 1,
"number_of_replicas": 0,
"refresh_interval": -1,
"refresh_interval": -1, # no refresh is required because we index all the data at once
}
},
mappings={
Expand All @@ -72,12 +58,8 @@ def recreate(self, dataset: Dataset, collection_params):
"index": True,
"similarity": self.DISTANCE_MAPPING[dataset.config.distance],
"index_options": {
**{
"type": "hnsw",
"m": 16,
"ef_construction": 100,
},
**collection_params.get("index_options"),
"type": "hnsw",
**collection_params["index_options"],
},
},
**self._prepare_fields_config(dataset),
Expand Down
23 changes: 3 additions & 20 deletions engine/clients/elasticsearch/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
from elasticsearch import Elasticsearch

from engine.base_client.search import BaseSearcher
from engine.clients.elasticsearch.config import (
ELASTIC_INDEX,
ELASTIC_PASSWORD,
ELASTIC_PORT,
ELASTIC_USER,
)
from engine.clients.elasticsearch.config import ELASTIC_INDEX, get_es_client
from engine.clients.elasticsearch.parser import ElasticConditionParser


Expand All @@ -29,20 +24,8 @@ def get_mp_start_method(cls):
return "forkserver" if "forkserver" in mp.get_all_start_methods() else "spawn"

@classmethod
def init_client(cls, host, distance, connection_params: dict, search_params: dict):
init_params = {
**{
"verify_certs": False,
"request_timeout": 90,
"retry_on_timeout": True,
},
**connection_params,
}
cls.client: Elasticsearch = Elasticsearch(
f"http://{host}:{ELASTIC_PORT}",
basic_auth=(ELASTIC_USER, ELASTIC_PASSWORD),
**init_params,
)
def init_client(cls, host, _distance, connection_params: dict, search_params: dict):
cls.client = get_es_client(host, connection_params)
cls.search_params = search_params

@classmethod
Expand Down
23 changes: 3 additions & 20 deletions engine/clients/elasticsearch/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
from elasticsearch import Elasticsearch

from engine.base_client.upload import BaseUploader
from engine.clients.elasticsearch.config import (
ELASTIC_INDEX,
ELASTIC_PASSWORD,
ELASTIC_PORT,
ELASTIC_USER,
)
from engine.clients.elasticsearch.config import ELASTIC_INDEX, get_es_client


class ClosableElastic(Elasticsearch):
Expand All @@ -27,20 +22,8 @@ def get_mp_start_method(cls):
return "forkserver" if "forkserver" in mp.get_all_start_methods() else "spawn"

@classmethod
def init_client(cls, host, distance, connection_params, upload_params):
init_params = {
**{
"verify_certs": False,
"request_timeout": 90,
"retry_on_timeout": True,
},
**connection_params,
}
cls.client = Elasticsearch(
f"http://{host}:{ELASTIC_PORT}",
basic_auth=(ELASTIC_USER, ELASTIC_PASSWORD),
**init_params,
)
def init_client(cls, host, _distance, connection_params, upload_params):
cls.client = get_es_client(host, connection_params)
cls.upload_params = upload_params

@classmethod
Expand Down

0 comments on commit 5f2121e

Please sign in to comment.