diff --git a/parlai/agents/fid/fid.py b/parlai/agents/fid/fid.py index 5ae75c0a2e6..28ef2638360 100644 --- a/parlai/agents/fid/fid.py +++ b/parlai/agents/fid/fid.py @@ -409,7 +409,7 @@ def get_retrieved_knowledge(self, message: Message): for doc_idx in range(n_docs_in_message): doc_content = message[consts.RETRIEVED_DOCS][doc_idx] for sel_sentc in selected_sentences: - if sel_sentc in doc_content: + if sel_sentc in doc_content and doc_idx not in already_added_doc_idx: retrieved_docs.append( self._extract_doc_from_message(message, doc_idx) ) diff --git a/parlai/agents/rag/retrievers.py b/parlai/agents/rag/retrievers.py index a6407af0ced..6592d7a9134 100644 --- a/parlai/agents/rag/retrievers.py +++ b/parlai/agents/rag/retrievers.py @@ -1305,10 +1305,17 @@ def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared: TShared = None self.n_docs = opt['n_docs'] self._query_ids = dict() self._saved_docs = dict() + self._largest_seen_idx = -1 super().__init__(opt, dictionary, shared=shared) def add_retrieve_doc(self, query: str, retrieved_docs: List[Document]): - new_idx = len(self._query_ids) + self._largest_seen_idx += 1 + new_idx = self._largest_seen_idx + if new_idx in self._query_ids.values() or new_idx in self._saved_docs: + raise RuntimeError( + "Nonunique new_idx created in add_retrieve_doc in ObservationEchoRetriever \n" + "this might return the same set of docs for two distinct queries" + ) self._query_ids[query] = new_idx self._saved_docs[new_idx] = retrieved_docs or [ BLANK_DOC for _ in range(self.n_docs) @@ -1320,6 +1327,11 @@ def tokenize_query(self, query: str) -> List[int]: def get_delimiter(self) -> str: return self._delimiter + def _clear_mapping(self): + self._query_ids = dict() + self._saved_docs = dict() + self._largest_seen_idx = -1 + def retrieve_and_score( self, query: torch.LongTensor ) -> Tuple[List[List[Document]], torch.Tensor]: @@ -1336,6 +1348,10 @@ def retrieve_and_score( retrieved_doc_scores = retrieved_doc_scores.repeat(batch_size, 1).to( query.device ) + + # empty the 2 mappings after each retrieval + self._clear_mapping() + return retrieved_docs, retrieved_doc_scores diff --git a/projects/blenderbot2/agents/blenderbot2.py b/projects/blenderbot2/agents/blenderbot2.py index a9d897f24af..0b691e5c714 100644 --- a/projects/blenderbot2/agents/blenderbot2.py +++ b/projects/blenderbot2/agents/blenderbot2.py @@ -924,6 +924,4 @@ def __init__(self, opt: Opt, shared: TShared = None): class BlenderBot2WizIntGoldDocRetrieverFiDAgent( WizIntGoldDocRetrieverFiDAgent, BlenderBot2FidAgent ): - def _set_query_vec(self, observation: Message) -> Message: - self.show_observation_to_echo_retriever(observation) - super()._set_query_vec(observation) + pass