Skip to content

Commit

Permalink
align names to rage2e params
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Epstein committed Sep 21, 2023
1 parent f938906 commit 865cff3
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 18 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ To run retriever only eval
(make sure you have the checkpoints in the project root)

```bash
python dalm/eval/eval_retriever_only.py --dataset_path qa_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints
python dalm/eval/eval_retriever_only.py --dataset_path qa_pairs_test.csv --retriever_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints
```

For the e2e eval

```bash
python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_model_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path rag_e2e_checkpoints/retriever --generator_peft_model_path rag_e2e_checkpoints/generator
python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path rag_e2e_checkpoints/retriever --generator_peft_model_path rag_e2e_checkpoints/generator
```


Expand Down
4 changes: 2 additions & 2 deletions dalm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def train_rag_e2e(

@cli.command()
def train_retriever_only(
model_name_or_path: Annotated[
retriever_name_or_path: Annotated[
str, typer.Argument(help="Path to the model or identifier from huggingface.co/models.", show_default=False)
],
dataset_path: Annotated[
Expand Down Expand Up @@ -238,7 +238,7 @@ def train_retriever_only(
"""End-to-end train an in-domain model, including the retriever and generator"""
train_retriever(
dataset_or_path=dataset_path,
model_name_or_path=model_name_or_path,
retriever_name_or_path=retriever_name_or_path,
dataset_passage_col_name=dataset_passage_col_name,
dataset_query_col_name=dataset_query_col_name,
query_max_len=query_max_len,
Expand Down
4 changes: 2 additions & 2 deletions dalm/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ To run retriever only eval
(make sure you have the checkpoints in the project root)

```bash
python dalm/eval/eval_retriever_only.py --dataset_path qa_paits_test.csv --retriever_model_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints
python dalm/eval/eval_retriever_only.py --dataset_path qa_paits_test.csv --retriever_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints
```

For the e2e eval

```bash
python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_model_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path retriever_only_checkpoints --generator_peft_model_path generator_only_checkpoints
python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_name_or_path "BAAI/bge-large-en" --generator_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path retriever_only_checkpoints --generator_peft_model_path generator_only_checkpoints
```
6 changes: 3 additions & 3 deletions dalm/eval/eval_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def parse_args() -> Namespace:
),
)
parser.add_argument(
"--retriever_model_name_or_path",
"--retriever_name_or_path",
type=str,
help="Path to pretrained retriever model or model identifier from huggingface.co/models.",
required=True,
)
parser.add_argument(
"--generator_model_name_or_path",
"--generator_name_or_path",
type=str,
help="Path to pretrained generator model or model identifier from huggingface.co/models.",
required=True,
Expand Down Expand Up @@ -141,7 +141,7 @@ def main() -> None:

# rag retriver and the generator (don't load new peft layers no need)
rag_model = AutoModelForRagE2E(
args.retriever_model_name_or_path, args.generator_model_name_or_path, get_peft=False, use_bnb=False
args.retriever_name_or_path, args.generator_name_or_path, get_peft=False, use_bnb=False
)

# load the test dataset
Expand Down
4 changes: 2 additions & 2 deletions dalm/eval/eval_retriever_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def parse_args() -> Namespace:
),
)
parser.add_argument(
"--retriever_model_name_or_path",
"--retriever_name_or_path",
type=str,
help="Path to pretrained retriever model or model identifier from huggingface.co/models.",
required=True,
Expand Down Expand Up @@ -104,7 +104,7 @@ def main() -> None:
SELECTED_TORCH_DTYPE: Final[torch.dtype] = torch.float16 if args.torch_dtype == "float16" else torch.bfloat16

# rag retriver and the generator (don't load new peft layers no need)
retriever_model = AutoModelForSentenceEmbedding(args.retriever_model_name_or_path, get_peft=False, use_bnb=False)
retriever_model = AutoModelForSentenceEmbedding(args.retriever_name_or_path, get_peft=False, use_bnb=False)

# load the test dataset
test_dataset = (
Expand Down
10 changes: 5 additions & 5 deletions dalm/training/retriever_only/train_retriever_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def parse_args() -> Namespace:
),
)
parser.add_argument(
"--model_name_or_path",
"--retriever_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
Expand Down Expand Up @@ -163,7 +163,7 @@ def parse_args() -> Namespace:


def train_retriever(
model_name_or_path: str,
retriever_name_or_path: str,
dataset_or_path: str | Dataset,
dataset_passage_col_name: str = "Abstract",
dataset_query_col_name: str = "Question",
Expand Down Expand Up @@ -220,7 +220,7 @@ def train_retriever(
os.makedirs(output_dir, exist_ok=True)
accelerator.wait_for_everyone()

model = AutoModelForSentenceEmbedding(model_name_or_path, use_bnb=True, get_peft=use_peft)
model = AutoModelForSentenceEmbedding(retriever_name_or_path, use_bnb=True, get_peft=use_peft)
tokenizer = model.tokenizer

# dataset download and preprocessing
Expand Down Expand Up @@ -417,7 +417,7 @@ def main() -> None:
args = parse_args()
train_retriever(
dataset_or_path=args.dataset_path,
model_name_or_path=args.model_name_or_path,
retriever_name_or_path=args.retriever_name_or_path,
dataset_passage_col_name=args.dataset_passage_col_name,
dataset_query_col_name=args.dataset_query_col_name,
query_max_len=args.query_max_len,
Expand Down Expand Up @@ -449,5 +449,5 @@ def main() -> None:


# python contrastive_train/peft_lora_constrastive_learning.py --dataset_path "xxxx.csv" \
# --model_name_or_path "BAAI/bge-small-en" --output_dir "./retriever_only_checkpoints" --use_peft \
# --retriever_name_or_path "BAAI/bge-small-en" --output_dir "./retriever_only_checkpoints" --use_peft \
# --with_tracking --report_to all --per_device_train_batch_size 30
4 changes: 2 additions & 2 deletions experiments/llama-index-10k/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ dalm train-rag-e2e \

And eval
```
python ../../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path rag_e2e_checkpoints_bgsmall/retriever --embed_dim 384
python ../../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path rag_e2e_checkpoints_bgsmall/retriever --embed_dim 384
*************
Retriever results:
Expand All @@ -80,7 +80,7 @@ dalm train-retriever-only "BAAI/bge-small-en" "qa-outputs/question_answer_pairs.

and eval
```
python ../../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints_bgsmall/ --embed_dim 384
python ../../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints_bgsmall/ --embed_dim 384
*************
Retriever results:
Expand Down

0 comments on commit 865cff3

Please sign in to comment.