Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 60 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from datetime import datetime
import json
import logging
import multiprocessing as mp
import os
import tarfile
import tempfile
import socket

import torch

from pytorch_transformers import cached_path
Expand All @@ -17,6 +17,11 @@
HF_FINETUNED_MODEL = "https://s3.amazonaws.com/models.huggingface.co/transfer-learning-chatbot/gpt_personachat_cache.tar.gz"

logger = logging.getLogger(__file__)
logger.setLevel(level=logging.DEBUG)
mp.log_to_stderr(level=logging.DEBUG)
mp_logger = mp.get_logger()
mp_logger.setLevel(level=logging.DEBUG)


def download_pretrained_model():
""" Download and extract finetuned model from S3 """
Expand All @@ -29,6 +34,31 @@ def download_pretrained_model():
return tempdir


def worker_tokenize(args_list):
"""Target function for multiprocessing text encoding. All input args are included in a list as workaround
for worker_tokenize() calling itself recursively with constant tokenizer as one argument.

IMPORTANT: This function has to be implemented globally (outside of get_dataset()) to avoid
multiprocessing error 'AttributeError: Can't pickle local object 'get_dataset.<locals>.worker_tokenize''.

Args:
args_list: [obj, tokenizer] as workaround for recursive self-calling of function within itself."""
obj = args_list[0]
tokenizer = args_list[1]
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
worker_tokenize._dict_key_calls += 1
mp_logger.debug(
'Encoding {}. obj.key() = {}, obj.items().__len__() = {}'.format(worker_tokenize._dict_key_calls,
obj.keys(), obj.items().__len__()))
return dict((n, worker_tokenize([o, tokenizer])) for n, o in obj.items())
return list(worker_tokenize([o, tokenizer]) for o in obj)


worker_tokenize._dict_key_calls = 0


def get_dataset(tokenizer, dataset_path, dataset_cache=None):
""" Get PERSONACHAT from S3 """
dataset_path = dataset_path or PERSONACHAT_URL
Expand All @@ -41,19 +71,42 @@ def get_dataset(tokenizer, dataset_path, dataset_cache=None):
personachat_file = cached_path(dataset_path)
with open(personachat_file, "r", encoding="utf-8") as f:
dataset = json.loads(f.read())

logger.info("Tokenize and encode the dataset")

def tokenize(obj):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
tokenize.dict_key_calls += 1
logger.debug(
'Encoding {}. obj.keys() = {}, obj.items().__len__() = {}'.format(tokenize.dict_key_calls,
obj.keys(),
obj.items().__len__()))
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
min_samples_for_multiprocessing = 100
if obj.__len__() > min_samples_for_multiprocessing:
logger.debug(' Encoding VERY LONG list of obj.__len__() = {}'.format(obj.__len__()))
logger.debug(' Encoding list with with multiprocessing...')
"""functools.partial does not work becuase tokenizer has to be handed recusively together with obj to
worker_tokenize again. As a workaround of not knowing how to handle splash-operator for possible
dict-output and **kwargs input, the list_args is implemented."""
with mp.Pool(processes=mp.cpu_count() - 1) as pool:
results = pool.map(func=worker_tokenize,
iterable=[[o, tokenizer] for o in obj])
return results
else:
logger.debug(' Encoding list of obj.__len__() = {}'.format(obj.__len__()))
return list(tokenize(o) for o in obj)

tokenize.dict_key_calls = 0

dataset = tokenize(dataset)
# dataset = tokenize(dataset)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absolutely!

Suggested change
# dataset = tokenize(dataset)

if dataset_cache:
torch.save(dataset, dataset_cache)
return dataset


def get_dataset_personalities(tokenizer, dataset_path, dataset_cache=None):
""" Get personalities from PERSONACHAT """
dataset_path = dataset_path or PERSONACHAT_URL
Expand All @@ -68,14 +121,16 @@ def get_dataset_personalities(tokenizer, dataset_path, dataset_cache=None):
personachat = json.loads(f.read())

logger.info("Tokenize and encode the dataset")

def tokenize(obj):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)

personachat = tokenize(personachat)
torch.save(personachat, dataset_cache)
# torch.save(personachat, dataset_cache)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course!

Suggested change
# torch.save(personachat, dataset_cache)
torch.save(personachat, dataset_cache)


logger.info("Filter personalities")
personalities = []
Expand All @@ -86,6 +141,7 @@ def tokenize(obj):
logger.info("Gathered {} personalities".format(len(personalities)))
return personalities


class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
Expand Down