From e78fcb676a936aa255f8258417dd3ac9e1cd2ef7 Mon Sep 17 00:00:00 2001 From: Maximilian Berr Date: Tue, 20 Aug 2019 06:14:28 +0300 Subject: [PATCH] Upgrade get_dataset.tokenize() to multiprocessing ability get_dataset.tokenize() is to slow on a single CPU. Therefore it is upgraded to multiprocessing by implementing the multiprocessing target function worker_tokenize(args_list). Additionally a multiprocessing debug logger mp_logger was added together with logger.debug() and mp_logger.debug() message to track progress in the python console. --- utils.py | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/utils.py b/utils.py index 6061889..ebad34e 100644 --- a/utils.py +++ b/utils.py @@ -3,10 +3,10 @@ # LICENSE file in the root directory of this source tree. import json import logging +import multiprocessing as mp import os import tarfile import tempfile - import torch from pytorch_pretrained_bert import cached_path @@ -15,6 +15,11 @@ HF_FINETUNED_MODEL = "https://s3.amazonaws.com/models.huggingface.co/transfer-learning-chatbot/finetuned_chatbot_gpt.tar.gz" logger = logging.getLogger(__file__) +logger.setLevel(level=logging.DEBUG) +mp.log_to_stderr(level=logging.DEBUG) +mp_logger = mp.get_logger() +mp_logger.setLevel(level=logging.DEBUG) + def download_pretrained_model(): """ Download and extract finetuned model from S3 """ @@ -27,6 +32,31 @@ def download_pretrained_model(): return tempdir +def worker_tokenize(args_list): + """Target function for multiprocessing text encoding. All input args are included in a list as workaround + for worker_tokenize() calling itself recursively with constant tokenizer as one argument. + + IMPORTANT: This function has to be implemented globally (outside of get_dataset()) to avoid + multiprocessing error 'AttributeError: Can't pickle local object 'get_dataset..worker_tokenize''. + + Args: + args_list: [obj, tokenizer] as workaround for recursive self-calling of function within itself.""" + obj = args_list[0] + tokenizer = args_list[1] + if isinstance(obj, str): + return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj)) + if isinstance(obj, dict): + worker_tokenize._dict_key_calls += 1 + mp_logger.debug( + 'Encoding {}. obj.key() = {}, obj.items().__len__() = {}'.format(worker_tokenize._dict_key_calls, + obj.keys(), obj.items().__len__())) + return dict((n, worker_tokenize([o, tokenizer])) for n, o in obj.items()) + return list(worker_tokenize([o, tokenizer]) for o in obj) + + +worker_tokenize._dict_key_calls = 0 + + def get_dataset(tokenizer, dataset_path, dataset_cache=None): """ Get PERSONACHAT from S3 """ dataset_path = dataset_path or PERSONACHAT_URL @@ -39,19 +69,42 @@ def get_dataset(tokenizer, dataset_path, dataset_cache=None): personachat_file = cached_path(dataset_path) with open(personachat_file, "r", encoding="utf-8") as f: dataset = json.loads(f.read()) - logger.info("Tokenize and encode the dataset") + def tokenize(obj): if isinstance(obj, str): return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj)) if isinstance(obj, dict): + tokenize.dict_key_calls += 1 + logger.debug( + 'Encoding {}. obj.keys() = {}, obj.items().__len__() = {}'.format(tokenize.dict_key_calls, + obj.keys(), + obj.items().__len__())) return dict((n, tokenize(o)) for n, o in obj.items()) - return list(tokenize(o) for o in obj) + min_samples_for_multiprocessing = 100 + if obj.__len__() > min_samples_for_multiprocessing: + logger.debug(' Encoding VERY LONG list of obj.__len__() = {}'.format(obj.__len__())) + logger.debug(' Encoding list with with multiprocessing...') + """functools.partial does not work becuase tokenizer has to be handed recusively together with obj to + worker_tokenize again. As a workaround of not knowing how to handle splash-operator for possible + dict-output and **kwargs input, the list_args is implemented.""" + with mp.Pool(processes=mp.cpu_count() - 1) as pool: + results = pool.map(func=worker_tokenize, + iterable=[[o, tokenizer] for o in obj]) + return results + else: + logger.debug(' Encoding list of obj.__len__() = {}'.format(obj.__len__())) + return list(tokenize(o) for o in obj) + + tokenize.dict_key_calls = 0 + dataset = tokenize(dataset) + # dataset = tokenize(dataset) if dataset_cache: torch.save(dataset, dataset_cache) return dataset + def get_dataset_personalities(tokenizer, dataset_path, dataset_cache=None): """ Get personalities from PERSONACHAT """ dataset_path = dataset_path or PERSONACHAT_URL @@ -66,14 +119,16 @@ def get_dataset_personalities(tokenizer, dataset_path, dataset_cache=None): personachat = json.loads(f.read()) logger.info("Tokenize and encode the dataset") + def tokenize(obj): if isinstance(obj, str): return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj)) if isinstance(obj, dict): return dict((n, tokenize(o)) for n, o in obj.items()) return list(tokenize(o) for o in obj) + personachat = tokenize(personachat) - torch.save(personachat, dataset_cache) + # torch.save(personachat, dataset_cache) logger.info("Filter personalities") personalities = [] @@ -84,6 +139,7 @@ def tokenize(obj): logger.info("Gathered {} personalities".format(len(personalities))) return personalities + class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs)