diff --git a/induce_retrieve_pipeline/embed_full_personas.py b/induce_retrieve_pipeline/embed_full_personas.py index 8102459..2992675 100644 --- a/induce_retrieve_pipeline/embed_full_personas.py +++ b/induce_retrieve_pipeline/embed_full_personas.py @@ -10,26 +10,32 @@ "routine_habit": "here is what i regularly or consistently do", "goal_plan": "here is what i will do or achieve in the future", "experience": "here is what i did in the past", - "relationship": "related to other people or social groups" + "relationship": "related to other people or social groups", } + def full_persona_triplets_to_sentences(full_persona_json): full_persona_sentences = {} for dialogue_id, personas in full_persona_json.items(): full_persona_sentences[dialogue_id] = {} for persona_id, persona_triplets in personas.items(): sentences = [] - for t in persona_triplets: - sentences.append(f"{t[0]}. {map_relation[t[1]]}: {t[2]}") - full_persona_sentences[dialogue_id][persona_id] = sentences + for key, item in persona_triplets.items(): + if key == 'triple': + for t in item: + if t[1] in map_relation.keys(): + sentences.append( + f"{t[0]}. {map_relation[t[1]]}: {t[2]}") + full_persona_sentences[dialogue_id][persona_id] = sentences return full_persona_sentences + def embed_full_personas( full_persona_sentences, gpu_id=None, embedding_dim=768, model_id='all-mpnet-base-v2' - ): +): model = SentenceTransformer(model_id) for dialogue_id, personas in tqdm.tqdm(full_persona_sentences.items()): dialogue_embeddings = {} @@ -37,7 +43,7 @@ def embed_full_personas( if len(persona_sentences) > 0: embeddings = model.encode( persona_sentences, convert_to_tensor=True, device=f'cuda:{gpu_id}' - ) + ) if embedding_dim < embeddings.size(1): embeddings = embeddings[:, :embedding_dim] embeddings = embeddings.cpu().numpy() @@ -47,6 +53,7 @@ def embed_full_personas( full_persona_sentences[dialogue_id].update(dialogue_embeddings) return full_persona_sentences + def _clean_relations_in_triplets(full_persona_json): for dialogue_id, personas in full_persona_json.items(): for persona_id, persona_triplets in personas.items(): @@ -60,14 +67,17 @@ def _clean_relations_in_triplets(full_persona_json): def main(): print('Full persona triplets to sentences...') full_personas_json = json.load(open(FULL_PERSONAS_PATH, 'r')) + print('Cleaning relations...') full_personas_json = _clean_relations_in_triplets(full_personas_json) - full_persona_sentences = full_persona_triplets_to_sentences(full_personas_json) + full_persona_sentences = full_persona_triplets_to_sentences( + full_personas_json) pkl.dump( full_persona_sentences, open(os.path.join(SAVE_DIR, 'full_persona_sentences.pkl'), 'wb') ) print('Computing embeddings...') - full_persona_embeddings = embed_full_personas(full_persona_sentences, gpu_id=GPU) + full_persona_embeddings = embed_full_personas( + full_persona_sentences, gpu_id=GPU) pkl.dump( full_persona_embeddings, open(os.path.join(SAVE_DIR, 'full_persona_embeddings.pkl'), 'wb') @@ -80,4 +90,3 @@ def main(): SAVE_DIR = 'pickled_stuff' os.makedirs(SAVE_DIR, exist_ok=True) main() -