Skip to content

Commit

Permalink
Merge pull request #35 from BiomedSciAI/fix_base_model_override
Browse files Browse the repository at this point in the history
Reposition base model override from model config
  • Loading branch information
yoavkt authored Aug 18, 2024
2 parents 36f15c8 + 71a4eec commit d291ff7
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions scripts/run_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,6 @@ def main(
description_builder = load_class(**model_dict["descriptor"])
else:
description_builder = None
if "base_model" in model_dict:
base_model = load_class(**model_dict["base_model"])
else:
base_model = LogisticRegression(max_iter=5000, n_jobs=-1)
if scoring_type == "category":
scoring = (
"roc_auc_ovr_weighted",
Expand All @@ -195,6 +191,7 @@ def main(
"f1_weighted",
)
cv = 5
base_model = LogisticRegression(max_iter=5000, n_jobs=-1)
elif scoring_type == "regression":
scoring = (
"r2",
Expand Down Expand Up @@ -236,6 +233,8 @@ def main(
post_processing = model_dict["post_processing"]
else:
post_processing = "average"
if "base_model" in model_dict:
base_model = load_class(**model_dict["base_model"])
encoder = load_class(**model_dict["encoder"])
if ";" in task_name:
sub_task = task_name.split(";")[1]
Expand Down

0 comments on commit d291ff7

Please sign in to comment.