Skip to content

Commit

Permalink
fix saving engine names
Browse files Browse the repository at this point in the history
  • Loading branch information
generall committed Mar 16, 2024
1 parent f112b6b commit 5f8ebee
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
15 changes: 13 additions & 2 deletions engine/base_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class BaseClient:
def __init__(
self,
name: str, # name of the experiment
engine: str, # name of the engine
configurator: BaseConfigurator,
uploader: BaseUploader,
searchers: List[BaseSearcher],
Expand All @@ -28,6 +29,7 @@ def __init__(
self.configurator = configurator
self.uploader = uploader
self.searchers = searchers
self.engine = engine

def save_search_results(
self, dataset_name: str, results: dict, search_id: int, search_params: dict
Expand All @@ -40,7 +42,15 @@ def save_search_results(
result_path = RESULTS_DIR / experiments_file
with open(result_path, "w") as out:
out.write(
json.dumps({"params": search_params, "results": results}, indent=2)
json.dumps({
"params": {
"dataset": dataset_name,
"experiment": self.name,
"engine": self.engine,
**search_params
},
"results": results
}, indent=2)
)
return result_path

Expand All @@ -53,7 +63,8 @@ def save_upload_results(
with open(RESULTS_DIR / experiments_file, "w") as out:
upload_stats = {
"params": {
"engine": self.name,
"experiment": self.name,
"engine": self.engine,
"dataset": dataset_name,
**upload_params
},
Expand Down
3 changes: 3 additions & 0 deletions engine/clients/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@
class ClientFactory(ABC):
def __init__(self, host):
self.host = host
self.engine = None

def _create_configurator(self, experiment) -> BaseConfigurator:
self.engine = experiment["engine"]
engine_configurator_class = ENGINE_CONFIGURATORS[experiment["engine"]]
engine_configurator = engine_configurator_class(
self.host,
Expand Down Expand Up @@ -103,6 +105,7 @@ def _create_searchers(self, experiment) -> List[BaseSearcher]:
def build_client(self, experiment):
return BaseClient(
name=experiment["name"],
engine=experiment["engine"],
configurator=self._create_configurator(experiment),
uploader=self._create_uploader(experiment),
searchers=self._create_searchers(experiment),
Expand Down

0 comments on commit 5f8ebee

Please sign in to comment.