Skip to content

Commit

Permalink
map rag model to device in eval rag
Browse files Browse the repository at this point in the history
  • Loading branch information
sagorbrur committed Nov 23, 2023
1 parent e6c3d29 commit a847640
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions dalm/eval/eval_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ def evaluate_rag(
)
# peft config and wrapping
rag_model.attach_pre_trained_peft_layers(retriever_peft_model_path, generator_peft_model_path, device)
# mapping rag retriever and generator model to device
if retriever_peft_model_path is None:
rag_model.retriever_model.eval().to(device)
if generator_peft_model_path is None:
rag_model.generator_model.eval().to(device)
unique_passage_dataset, passage_embeddings_array = get_passage_embeddings(
processed_datasets,
passage_column_name,
Expand Down

0 comments on commit a847640

Please sign in to comment.