Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update embed_full_personas.py for missing keys in map_relation #2

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions induce_retrieve_pipeline/embed_full_personas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,40 @@
"routine_habit": "here is what i regularly or consistently do",
"goal_plan": "here is what i will do or achieve in the future",
"experience": "here is what i did in the past",
"relationship": "related to other people or social groups"
"relationship": "related to other people or social groups",
}


def full_persona_triplets_to_sentences(full_persona_json):
full_persona_sentences = {}
for dialogue_id, personas in full_persona_json.items():
full_persona_sentences[dialogue_id] = {}
for persona_id, persona_triplets in personas.items():
sentences = []
for t in persona_triplets:
sentences.append(f"{t[0]}. {map_relation[t[1]]}: {t[2]}")
full_persona_sentences[dialogue_id][persona_id] = sentences
for key, item in persona_triplets.items():
if key == 'triple':
for t in item:
if t[1] in map_relation.keys():
sentences.append(
f"{t[0]}. {map_relation[t[1]]}: {t[2]}")
full_persona_sentences[dialogue_id][persona_id] = sentences
return full_persona_sentences


def embed_full_personas(
full_persona_sentences,
gpu_id=None,
embedding_dim=768,
model_id='all-mpnet-base-v2'
):
):
model = SentenceTransformer(model_id)
for dialogue_id, personas in tqdm.tqdm(full_persona_sentences.items()):
dialogue_embeddings = {}
for persona_id, persona_sentences in personas.items():
if len(persona_sentences) > 0:
embeddings = model.encode(
persona_sentences, convert_to_tensor=True, device=f'cuda:{gpu_id}'
)
)
if embedding_dim < embeddings.size(1):
embeddings = embeddings[:, :embedding_dim]
embeddings = embeddings.cpu().numpy()
Expand All @@ -47,6 +53,7 @@ def embed_full_personas(
full_persona_sentences[dialogue_id].update(dialogue_embeddings)
return full_persona_sentences


def _clean_relations_in_triplets(full_persona_json):
for dialogue_id, personas in full_persona_json.items():
for persona_id, persona_triplets in personas.items():
Expand All @@ -60,14 +67,17 @@ def _clean_relations_in_triplets(full_persona_json):
def main():
print('Full persona triplets to sentences...')
full_personas_json = json.load(open(FULL_PERSONAS_PATH, 'r'))
print('Cleaning relations...')
full_personas_json = _clean_relations_in_triplets(full_personas_json)
full_persona_sentences = full_persona_triplets_to_sentences(full_personas_json)
full_persona_sentences = full_persona_triplets_to_sentences(
full_personas_json)
pkl.dump(
full_persona_sentences,
open(os.path.join(SAVE_DIR, 'full_persona_sentences.pkl'), 'wb')
)
print('Computing embeddings...')
full_persona_embeddings = embed_full_personas(full_persona_sentences, gpu_id=GPU)
full_persona_embeddings = embed_full_personas(
full_persona_sentences, gpu_id=GPU)
pkl.dump(
full_persona_embeddings,
open(os.path.join(SAVE_DIR, 'full_persona_embeddings.pkl'), 'wb')
Expand All @@ -80,4 +90,3 @@ def main():
SAVE_DIR = 'pickled_stuff'
os.makedirs(SAVE_DIR, exist_ok=True)
main()