From 1a65ad04bc2c05aad048a1dc9a3d6791aa2286e4 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Thu, 16 Jul 2020 15:15:52 -0400 Subject: [PATCH] Distributed Evaluation (#2775) * Distributed eval. * More work on distributed_eval. * Fix mp_eval. Support dumping parallel logs. * Self feeding change should not have slipped in there. * Update docstrings. * Typos. --- parlai/scripts/distributed_eval.py | 51 +++++++++ parlai/scripts/distributed_train.py | 49 +-------- parlai/scripts/eval_model.py | 32 +++++- parlai/scripts/multiprocessing_eval.py | 81 ++++++++++++++ parlai/scripts/multiprocessing_train.py | 61 +---------- parlai/utils/distributed.py | 135 +++++++++++++++++++++++- parlai/utils/logging.py | 10 +- 7 files changed, 302 insertions(+), 117 deletions(-) create mode 100644 parlai/scripts/distributed_eval.py create mode 100644 parlai/scripts/multiprocessing_eval.py diff --git a/parlai/scripts/distributed_eval.py b/parlai/scripts/distributed_eval.py new file mode 100644 index 00000000000..c43d9d9a54d --- /dev/null +++ b/parlai/scripts/distributed_eval.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Distributed evaluation script. NOT MEANT TO BE CALLED DIRECTLY BY USER. + +This script is meant to be in conjunction with +`SLURM `, which provides environmental variables +describing the environment. + +An example sbatch script is below, for a 2-host, 8-GPU setup (16 total gpus): + +.. code-block:: bash\n\n + + #!/bin/sh + #SBATCH --job-name=distributed_example + #SBATCH --output=/path/to/savepoint/stdout.%j + #SBATCH --error=/path/to/savepoint/stderr.%j + #SBATCH --partition=priority + #SBATCH --nodes=2 + #SBATCH --time=0:10:00 + #SBATCH --signal=SIGINT + #SBATCH --gres=gpu:8 + #SBATCH --ntasks-per-node=8 + #SBATCH --mem=64G + #SBATCH --cpus-per-task=10 + srun python -u -m parlai.scripts.distributed_eval \ + -m seq2seq -t convai2 --dict-file /path/to/dict-file +""" + +import os + +import parlai.scripts.eval_model as eval_model +import parlai.utils.distributed as distributed_utils + + +def main(): + parser = eval_model.setup_args() + parser.add_distributed_training_args() + parser.add_argument('--port', type=int, default=61337, help='TCP port number') + opt = parser.parse_args(print_args=(os.environ['SLURM_PROCID'] == '0')) + + with distributed_utils.slurm_distributed_context(opt) as opt: + return eval_model.eval_model(opt) + + +if __name__ == '__main__': + main() diff --git a/parlai/scripts/distributed_train.py b/parlai/scripts/distributed_train.py index a573c9a6040..d8868be5cb6 100644 --- a/parlai/scripts/distributed_train.py +++ b/parlai/scripts/distributed_train.py @@ -31,14 +31,9 @@ -m seq2seq -t convai2 --dict-file /path/to/dict-file """ -import os -import socket -import subprocess - import parlai.scripts.train_model as single_train -import parlai.utils.logging as logging -from parlai.scripts.multiprocessing_train import multiprocess_train from parlai.scripts.script import ParlaiScript +import parlai.utils.distributed as distributed_utils def setup_args(): @@ -48,52 +43,14 @@ def setup_args(): return parser -def dist_train(opt, node_list): - # We can determine the init method automatically for Slurm. - try: - # Figure out the main host, and which rank we are. - hostnames = subprocess.check_output( - ['scontrol', 'show', 'hostnames', node_list] - ) - main_host = hostnames.split()[0].decode('utf-8') - distributed_rank = int(os.environ['SLURM_PROCID']) - if opt.get('model_parallel'): - # -1 signals to multiprocessing_train to use all GPUs available. - # (A value of None signals to multiprocessing_train to use the GPU - # corresponding to the rank. - device_id = -1 - else: - device_id = int(os.environ['SLURM_LOCALID']) - port = opt['port'] - logging.info( - f'Initializing host {socket.gethostname()} as rank {distributed_rank}, ' - f'main is {main_host}' - ) - # Begin distributed training - multiprocess_train(distributed_rank, opt, port, 0, device_id, main_host) - except subprocess.CalledProcessError as e: - # scontrol failed - raise e - except FileNotFoundError: - # Slurm is not installed - raise RuntimeError('SLURM does not appear to be installed.') - - class DistributedTrain(ParlaiScript): @classmethod def setup_args(cls): return setup_args() def run(self): - # double check we're using SLURM - node_list = os.environ.get('SLURM_JOB_NODELIST') - if node_list is None: - raise RuntimeError( - 'Does not appear to be in a SLURM environment. ' - 'You should not call this script directly; ' - 'see launch_distributed.py' - ) - return dist_train(self.opt, node_list) + with distributed_utils.slurm_distributed_context(self.opt) as opt: + return single_train.TrainLoop(opt).train_model() if __name__ == '__main__': diff --git a/parlai/scripts/eval_model.py b/parlai/scripts/eval_model.py index 800186a1ccf..36733d0354c 100644 --- a/parlai/scripts/eval_model.py +++ b/parlai/scripts/eval_model.py @@ -20,7 +20,11 @@ from parlai.core.params import ParlaiParser, print_announcements from parlai.core.agents import create_agent from parlai.core.logs import TensorboardLogger -from parlai.core.metrics import aggregate_named_reports, Metric +from parlai.core.metrics import ( + aggregate_named_reports, + aggregate_unnamed_reports, + Metric, +) from parlai.core.worlds import create_task from parlai.utils.misc import TimeLogger, nice_report from parlai.utils.world_logging import WorldLogger @@ -30,6 +34,13 @@ import json import random +from parlai.utils.distributed import ( + is_primary_worker, + all_gather_list, + is_distributed, + get_rank, +) + def setup_args(parser=None): if parser is None: @@ -85,6 +96,8 @@ def setup_args(parser=None): def _save_eval_stats(opt, report): + if not is_primary_worker: + return report_fname = opt['report_filename'] if report_fname == '': return @@ -122,6 +135,10 @@ def _eval_single_world(opt, agent, task): # max number of examples to evaluate max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf') cnt = 0 + total_cnt = world.num_examples() + + if is_distributed(): + logging.warn('Progress bar is approximate in distributed mode.') while not world.epoch_done() and cnt < max_cnt: cnt += opt.get('batchsize', 1) @@ -134,18 +151,22 @@ def _eval_single_world(opt, agent, task): if log_time.time() > log_every_n_secs: report = world.report() text, report = log_time.log( - report.get('exs', 0), min(max_cnt, world.num_examples()), report + report.get('exs', 0), min(max_cnt, total_cnt), report ) logging.info(text) - report = world.report() + report = aggregate_unnamed_reports(all_gather_list(world.report())) world.reset() if world_logger is not None: # dump world acts to file world_logger.reset() # add final acts to logs base_outfile = opt['report_filename'].split('.')[0] - outfile = base_outfile + f'_{task}_replies.jsonl' + if is_distributed(): + rank = get_rank() + outfile = base_outfile + f'_{task}_{rank}_replies.jsonl' + else: + outfile = base_outfile + f'_{task}_replies.jsonl' world_logger.write(outfile, world, file_format=opt['save_format']) return report @@ -195,6 +216,7 @@ def eval_model(opt, print_parser=None): logging.info( f'Finished evaluating tasks {tasks} using datatype {opt.get("datatype")}' ) + print(nice_report(report)) _save_eval_stats(opt, report) return report @@ -206,7 +228,7 @@ def setup_args(cls): return setup_args() def run(self): - return eval_model(self.opt) + return eval_model(self.opt, print_parser=self.parser) if __name__ == '__main__': diff --git a/parlai/scripts/multiprocessing_eval.py b/parlai/scripts/multiprocessing_eval.py new file mode 100644 index 00000000000..43d5ab52330 --- /dev/null +++ b/parlai/scripts/multiprocessing_eval.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +""" +Main launch script for single-host, multi-GPU evaluation. + +This is a drop-in replacement for eval_model.py. This script will launch N +subprocess, each which runs the full eval loop independently. + +Uses torch.nn.parallel.DistributedDataParallel for its main uses. Agents must +specifically implement the wrapper of DistributedDataParallel, but all +TorchRankerAgents and TorchGeneratorAgents support this. +""" + +import torch +import random +import os +import signal +import parlai.utils.distributed as distributed_utils +import parlai.scripts.eval_model as eval_model + + +def multiprocess_eval( + rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost' +): + """ + Run a multiprocessing evaluation. + + Invoked by launch_and_eval, not instantiated directly. + """ + with distributed_utils.distributed_context( + rank, opt, port, rank_offset, gpu, hostname + ) as opt: + return eval_model.eval_model(opt) + + +def launch_and_eval(opt, port): + """ + Perform a fork() to many processes. + """ + # Launch multiple subprocesses + spawncontext = torch.multiprocessing.spawn( + multiprocess_eval, + # need to give rank offset as 1 to cover the fact that the main + # process is rank 0, but that spawn() doesn't let you control rank + (opt, port, 1), + nprocs=opt['distributed_world_size'] - 1, # main proc will also run loop + join=False, + ) + + try: + retval = multiprocess_eval(0, opt, port) + spawncontext.join() + return retval + except KeyboardInterrupt: + # tell the subprocesses to stop too + for p in spawncontext.processes: + if p.is_alive(): + os.kill(p.pid, signal.SIGINT) + raise + + +def setup_args(): + parser = eval_model.setup_args() + parser.add_distributed_training_args() + parser.set_defaults(distributed_world_size=torch.cuda.device_count()) + return parser + + +def main(): + opt = setup_args().parse_args() + port = random.randint(32000, 48000) + return launch_and_eval(opt, port) + + +if __name__ == '__main__': + main() diff --git a/parlai/scripts/multiprocessing_train.py b/parlai/scripts/multiprocessing_train.py index a7dbf10f391..192b56dbfb8 100644 --- a/parlai/scripts/multiprocessing_train.py +++ b/parlai/scripts/multiprocessing_train.py @@ -18,74 +18,19 @@ import torch import random -import copy import os import signal -import torch.distributed as dist import parlai.scripts.train_model as single_train import parlai.utils.distributed as distributed_utils -import parlai.utils.logging as logging from parlai.scripts.script import ParlaiScript def multiprocess_train( rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost' ): - """ - Subprocess which initializes distributed training, and begins training. - - This should be launched n times for n GPUs; this is handled either in main - or via srun. - - :param int rank: This process's rank - 1. (Starts at -1 ... n - 2). See comments. - :param opt: command line options - :param int port: A TCP port to use. This will need to be changed to run - multiple distributed training setups on the same machine. - :param int gpu: Which GPU to use. Defaults to using rank and local devices, - but must be manually specified when using many-hosts. - :param str hostname: Hostname of the main server. - """ - # Set per-host options - opt = copy.deepcopy(opt) - # we need to manually adjust the rank differently in multiprocessing - # and distributed train - rank = rank + rank_offset - opt['rank'] = rank - if gpu is None: - # default assumption is local GPUs - gpu = rank % torch.cuda.device_count() - opt['gpu'] = gpu - # make sure we don't just use whatever GPU was saved in the model file - if 'override' not in opt: - opt['override'] = {} - opt['override']['gpu'] = gpu - - # Suppress output of workers except the main host. - if opt.get('verbose') or rank != 0: - print_prefix = 'rank:{:3d} |'.format(rank) - else: - print_prefix = None - suppress_output = not opt.get('verbose') and rank != 0 - - with distributed_utils.override_print(suppress_output, print_prefix): - # perform distributed setup, ensuring all hosts are ready - if opt['gpu'] != -1: - torch.cuda.set_device(opt['gpu']) - dist.init_process_group( - backend="nccl", - init_method="tcp://{}:{}".format(hostname, port), - world_size=opt['distributed_world_size'], - rank=rank, - ) - logging.info("Distributed group initialized") - - # manual_seed can be a noop without this - torch.cuda.init() - # make sure all parameters will be in sync - torch.manual_seed(42) - # force a sync so that no one gets ahead, and all are seeded together - distributed_utils.sync_object(None) - + with distributed_utils.distributed_context( + rank, opt, port, rank_offset, gpu, hostname + ) as opt: # Run the actual training return single_train.TrainLoop(opt).train() diff --git a/parlai/utils/distributed.py b/parlai/utils/distributed.py index 407061bc360..54f0fcd55e2 100644 --- a/parlai/utils/distributed.py +++ b/parlai/utils/distributed.py @@ -13,8 +13,12 @@ """ import builtins +import copy +import os import pickle import contextlib +import subprocess +import socket import parlai.utils.logging as logging try: @@ -72,7 +76,7 @@ def num_workers(): def is_primary_worker(): """ - Determine if we are the primary (master) worker. + Determine if we are the primary (rank 0) worker. Returns False if we are a secondary worker. Returns True if we are either (1) not in distributed mode (2) or are the primary (rank 0) worker. @@ -95,8 +99,10 @@ def get_rank(): @contextlib.contextmanager def override_print(suppress=False, prefix=None): """ - Context manager to override the print to suppress or modify output. Recommended - usage is to call this with suppress=True for all non-primary workers, or call with a + Context manager to override the print to suppress or modify output. + + Recommended usage is to call this with suppress=True for all non-primary + workers, or call with a prefix of rank on all workers. >>> with override_print(prefix="rank{}".format(rank)): @@ -285,3 +291,126 @@ def sync_parameters(model: torch.nn.Module) -> bool: ) return True + + +@contextlib.contextmanager +def distributed_context( + rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost' +): + """ + A context which wraps initialization of a distributed/multiprocessing run. + + Every process in the distributed run should launch with this. In true + distributed setting you may wish to use slurm_distributed_context instead. + + :param int rank: + This process's rank, less rank_offset. + :param int rank_offset: + Used as an offset of rank. Used between multiprocessing vs true distributed, + and a hack around torch.multiprocessing.spawn being only used for the + non-primary workers. + :param opt: + command line options + :param int port: + A TCP port to use. This will need to be changed to run multiple + distributed training setups on the same machine. + :param int gpu: + Which GPU to use. Defaults to using rank and local devices, but must be + manually specified when using many-hosts. + :param str hostname: + Hostname of the main server. + """ + # Set per-host options + opt = copy.deepcopy(opt) + # we need to manually adjust the rank differently in multiprocessing + # and distributed train + rank = rank + rank_offset + opt['rank'] = rank + if gpu is None: + # default assumption is local GPUs + gpu = rank % torch.cuda.device_count() + opt['gpu'] = gpu + # make sure we don't just use whatever GPU was saved in the model file + if 'override' not in opt: + opt['override'] = {} + opt['override']['gpu'] = gpu + + # Suppress output of workers except the main host. + if opt.get('verbose') or rank != 0: + print_prefix = 'rank:{:3d} |'.format(rank) + else: + print_prefix = None + suppress_output = not opt.get('verbose') and rank != 0 + + with override_print(suppress_output, print_prefix): + # perform distributed setup, ensuring all hosts are ready + if opt['gpu'] != -1: + torch.cuda.set_device(opt['gpu']) + dist.init_process_group( + backend="nccl", + init_method="tcp://{}:{}".format(hostname, port), + world_size=opt['distributed_world_size'], + rank=rank, + ) + logging.info("Distributed group initialized") + + # manual_seed can be a noop without this + torch.cuda.init() + # make sure all parameters will be in sync + torch.manual_seed(42) + # force a sync so that no one gets ahead, and all are seeded together + sync_object(None) + + yield opt + + +@contextlib.contextmanager +def slurm_distributed_context(opt): + """ + Initialize a distributed context, using the SLURM environment. + + Does some work to read the environment to find a list of participating nodes + and the main node. + + :param opt: + Command line options. + """ + # We can determine the init method automatically for Slurm. + # double check we're using SLURM + node_list = os.environ.get('SLURM_JOB_NODELIST') + if node_list is None: + raise RuntimeError( + 'Does not appear to be in a SLURM environment. ' + 'You should not call this script directly; see launch_distributed.py' + ) + + try: + # Figure out the main host, and which rank we are. + hostnames = subprocess.check_output( + ['scontrol', 'show', 'hostnames', node_list] + ) + main_host = hostnames.split()[0].decode('utf-8') + distributed_rank = int(os.environ['SLURM_PROCID']) + if opt.get('model_parallel'): + # -1 signals to multiprocessing_train to use all GPUs available. + # (A value of None signals to multiprocessing_train to use the GPU + # corresponding to the rank. + device_id = -1 + else: + device_id = int(os.environ['SLURM_LOCALID']) + port = opt['port'] + logging.info( + f'Initializing host {socket.gethostname()} as rank {distributed_rank}, ' + f'main is {main_host}' + ) + # Begin distributed training + with distributed_context( + distributed_rank, opt, port, 0, device_id, main_host + ) as opt: + yield opt + except subprocess.CalledProcessError as e: + # scontrol failed + raise e + except FileNotFoundError: + # Slurm is not installed + raise RuntimeError('SLURM does not appear to be installed.') diff --git a/parlai/utils/logging.py b/parlai/utils/logging.py index dddb98cd7bf..74009db24c3 100644 --- a/parlai/utils/logging.py +++ b/parlai/utils/logging.py @@ -105,15 +105,15 @@ def mute(self): """ Stop logging to stdout. """ - prev_level = self.streamHandler.level - self.streamHandler.level = 9999 - return prev_level + self.prev_level = self.streamHandler.level + self.streamHandler.level = ERROR + return self.prev_level - def unmute(self, level): + def unmute(self): """ Resume logging to stdout. """ - self.streamHandler.level = level + self.streamHandler.level = self.prev_level # -----------------------------------