Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[BlenderBot2] Fix ObservationEchoRetriever (#4428)
Browse files Browse the repository at this point in the history
* add new counter

* bb2 wizintgold doc

* comment

* typo
  • Loading branch information
Jing authored Mar 17, 2022
1 parent d6773a0 commit f773ef7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion parlai/agents/fid/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def get_retrieved_knowledge(self, message: Message):
for doc_idx in range(n_docs_in_message):
doc_content = message[consts.RETRIEVED_DOCS][doc_idx]
for sel_sentc in selected_sentences:
if sel_sentc in doc_content:
if sel_sentc in doc_content and doc_idx not in already_added_doc_idx:
retrieved_docs.append(
self._extract_doc_from_message(message, doc_idx)
)
Expand Down
18 changes: 17 additions & 1 deletion parlai/agents/rag/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,10 +1305,17 @@ def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared: TShared = None
self.n_docs = opt['n_docs']
self._query_ids = dict()
self._saved_docs = dict()
self._largest_seen_idx = -1
super().__init__(opt, dictionary, shared=shared)

def add_retrieve_doc(self, query: str, retrieved_docs: List[Document]):
new_idx = len(self._query_ids)
self._largest_seen_idx += 1
new_idx = self._largest_seen_idx
if new_idx in self._query_ids.values() or new_idx in self._saved_docs:
raise RuntimeError(
"Nonunique new_idx created in add_retrieve_doc in ObservationEchoRetriever \n"
"this might return the same set of docs for two distinct queries"
)
self._query_ids[query] = new_idx
self._saved_docs[new_idx] = retrieved_docs or [
BLANK_DOC for _ in range(self.n_docs)
Expand All @@ -1320,6 +1327,11 @@ def tokenize_query(self, query: str) -> List[int]:
def get_delimiter(self) -> str:
return self._delimiter

def _clear_mapping(self):
self._query_ids = dict()
self._saved_docs = dict()
self._largest_seen_idx = -1

def retrieve_and_score(
self, query: torch.LongTensor
) -> Tuple[List[List[Document]], torch.Tensor]:
Expand All @@ -1336,6 +1348,10 @@ def retrieve_and_score(
retrieved_doc_scores = retrieved_doc_scores.repeat(batch_size, 1).to(
query.device
)

# empty the 2 mappings after each retrieval
self._clear_mapping()

return retrieved_docs, retrieved_doc_scores


Expand Down
4 changes: 1 addition & 3 deletions projects/blenderbot2/agents/blenderbot2.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,4 @@ def __init__(self, opt: Opt, shared: TShared = None):
class BlenderBot2WizIntGoldDocRetrieverFiDAgent(
WizIntGoldDocRetrieverFiDAgent, BlenderBot2FidAgent
):
def _set_query_vec(self, observation: Message) -> Message:
self.show_observation_to_echo_retriever(observation)
super()._set_query_vec(observation)
pass

0 comments on commit f773ef7

Please sign in to comment.