Skip to content

Commit

Permalink
minor name changes
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 27, 2024
1 parent d2d5caa commit 9708a28
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
13 changes: 9 additions & 4 deletions folktexts/cli/eval_feature_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setup_arg_parser() -> ArgumentParser:
# List of command-line arguments, with type and helper string
cli_args = [
("--model", str, "[str] Model name or path to model saved on disk"),
("--task-name", str, "[str] Name of the ACS task to run the experiment on", False, DEFAULT_TASK_NAME),
("--task", str, "[str] Name of the ACS task to run the experiment on", False, DEFAULT_TASK_NAME),
("--results-dir", str, "[str] Directory under which this experiment's results will be saved", False, DEFAULT_RESULTS_DIR),
("--data-dir", str, "[str] Root folder to find datasets on", False, DEFAULT_DATA_DIR),
("--models-dir", str, "[str] Root folder to find huggingface models on", False, DEFAULT_MODELS_DIR),
Expand Down Expand Up @@ -132,8 +132,13 @@ def main():
logging.getLogger().setLevel(logging.INFO)

# Load model and tokenizer
model_folder_path = get_model_folder_path(model_name=args.model, root_dir=args.models_dir)
model, tokenizer = load_model_tokenizer(model_folder_path)
model_folder_path = Path(get_model_folder_path(model_name=args.model, root_dir=args.models_dir))
if model_folder_path.exists() and model_folder_path.is_dir():
logging.info(f"Loading model from {model_folder_path.as_posix()}")
model, tokenizer = load_model_tokenizer(model_folder_path)
else:
logging.info(f"Loading model from {Path(args.model).resolve().as_posix()}")
model, tokenizer = load_model_tokenizer(args.model)

# Set-up results directory
results_dir = Path(args.results_dir) / Path(model_folder_path).name
Expand All @@ -142,7 +147,7 @@ def main():

# Load Task and Dataset
from folktexts.acs import ACSTaskMetadata
task = ACSTaskMetadata.get_task(args.task_name)
task = ACSTaskMetadata.get_task(args.task)

from folktexts.acs import ACSDataset
dataset = ACSDataset.make_from_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def setup_arg_parser() -> argparse.ArgumentParser:
)

parser.add_argument(
"--task-name",
"--task",
type=str,
help="[string] ACS task name to run experiments on - can provide multiple!",
required=False,
Expand Down Expand Up @@ -198,7 +198,7 @@ def main():
# with `setup_arg_parser().convert_arg_line_to_args(extra_kwargs)` !!!

models = args.model or LLM_MODELS
tasks = args.task_name or ACS_TASKS
tasks = args.task or ACS_TASKS

# Load experiment from JSON file if provided
if args.experiment_json:
Expand Down

0 comments on commit 9708a28

Please sign in to comment.