Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add device argument for multi-backends access & Ascend NPU support #5321

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions detectron2/engine/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from detectron2.solver import build_lr_scheduler, build_optimizer
from detectron2.utils import comm
from detectron2.utils.collect_env import collect_env_info
from detectron2.utils.comm import _TORCH_NPU_AVAILABLE
from detectron2.utils.env import seed_all_rng
from detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
from detectron2.utils.file_io import PathManager
Expand Down Expand Up @@ -114,11 +115,21 @@ def default_argument_parser(epilog=None):
"See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
)
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
parser.add_argument(
"--num-accelerators", type=int, default=1, help="number of accelerators *per machine*"
)
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
parser.add_argument(
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
)
# NOTE (cmq): mainly for _distributed_worker in ./detectron2/engine/launch.py,
# which is called before cfg loaded; overwrite the device in config, if given.
parser.add_argument(
"--device",
type=str,
default="cuda",
help="the accelerator, e.g., 'cuda' for Nvidia gpu, 'npu' for Ascend npu",
)

# PyTorch still may leave orphan processes in multi-gpu training.
# Therefore we use a deterministic way to obtain port,
Expand Down Expand Up @@ -182,6 +193,9 @@ def default_setup(cfg, args):
Args:
cfg (CfgNode or omegaconf.DictConfig): the full config to be used
args (argparse.NameSpace): the command line arguments to be logged
NOTE (cmq):
put it before cfg.freeze() as it may modify the cfg.MODEL.DEVICE according
to "--device" argument
"""
output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir")
if comm.is_main_process() and output_dir:
Expand All @@ -203,6 +217,29 @@ def default_setup(cfg, args):
)
)

if hasattr(args, "device"):
# set model device as args.device
if cfg.MODEL.DEVICE != args.device:
logger.warning(
"""Switching cfg.MODEL.DEVICE {} to args.device {} now.
Re-specify --device if you deny to this switch""".format(
cfg.MODEL.DEVICE, args.device
)
)
cfg.MODEL.DEVICE = args.device
if "npu" in cfg.MODEL.DEVICE and not _TORCH_NPU_AVAILABLE:
logger.error(
"torch-npu not found, install torch-npu with pip when setting device to {}".format(
cfg.MODEL.DEVICE
)
)
elif "cuda" in cfg.MODEL.DEVICE and not torch.cuda.is_available():
logger.error(
"Cuda not found, ensure set up cuda env when setting device to {}".format(
cfg.MODEL.DEVICE
)
)

if comm.is_main_process() and output_dir:
# Note: some of our scripts may expect the existence of
# config.yaml in output directory
Expand Down Expand Up @@ -381,7 +418,7 @@ def __init__(self, cfg):

model = create_ddp_model(model, broadcast_buffers=False)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
model, data_loader, optimizer
model, data_loader, optimizer, device=cfg.MODEL.DEVICE
)

self.scheduler = self.build_lr_scheduler(cfg, optimizer)
Expand Down
97 changes: 44 additions & 53 deletions detectron2/engine/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,17 @@
import torch.multiprocessing as mp

from detectron2.utils import comm
from detectron2.utils.comm import _TORCH_NPU_AVAILABLE, _find_free_port

__all__ = ["DEFAULT_TIMEOUT", "launch"]

DEFAULT_TIMEOUT = timedelta(minutes=30)


def _find_free_port():
import socket

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port


def launch(
main_func,
# Should be num_processes_per_machine, but kept for compatibility.
num_gpus_per_machine,
num_accelerators_per_machine,
num_machines=1,
machine_rank=0,
dist_url=None,
Expand All @@ -37,11 +26,11 @@ def launch(
"""
Launch multi-process or distributed training.
This function must be called on all machines involved in the training.
It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
It will spawn child processes (defined by ``num_accelerators_per_machine``) on each machine.

Args:
main_func: a function that will be called by `main_func(*args)`
num_gpus_per_machine (int): number of processes per machine. When
num_accelerators_per_machine (int): number of processes per machine. When
using GPUs, this should be the number of GPUs.
num_machines (int): the total number of machines
machine_rank (int): the rank of this machine
Expand All @@ -51,56 +40,58 @@ def launch(
timeout (timedelta): timeout of the distributed workers
args (tuple): arguments passed to main_func
"""
world_size = num_machines * num_gpus_per_machine
if world_size > 1:
# https://github.com/pytorch/pytorch/pull/14391
# TODO prctl in spawned processes

if dist_url == "auto":
assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
port = _find_free_port()
dist_url = f"tcp://127.0.0.1:{port}"
if num_machines > 1 and dist_url.startswith("file://"):
logger = logging.getLogger(__name__)
logger.warning(
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
)

mp.start_processes(
_distributed_worker,
nprocs=num_gpus_per_machine,
args=(
main_func,
world_size,
num_gpus_per_machine,
machine_rank,
dist_url,
args,
timeout,
),
daemon=False,
world_size = num_machines * num_accelerators_per_machine
# https://github.com/pytorch/pytorch/pull/14391
# TODO prctl in spawned processes

if dist_url == "auto":
assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
port = _find_free_port()
dist_url = f"tcp://127.0.0.1:{port}"
if num_machines > 1 and dist_url.startswith("file://"):
logger = logging.getLogger(__name__)
logger.warning(
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
)
else:
main_func(*args)

mp.start_processes(
_distributed_worker,
nprocs=num_accelerators_per_machine,
args=(
main_func,
world_size,
num_accelerators_per_machine,
machine_rank,
dist_url,
args,
timeout,
),
daemon=False,
)


def _distributed_worker(
local_rank,
main_func,
world_size,
num_gpus_per_machine,
num_accelerators_per_machine,
machine_rank,
dist_url,
args,
timeout=DEFAULT_TIMEOUT,
):
has_gpu = torch.cuda.is_available()
if has_gpu:
assert num_gpus_per_machine <= torch.cuda.device_count()
global_rank = machine_rank * num_gpus_per_machine + local_rank
device = args[0].device
dist_backend = "gloo"
if "cuda" in device and torch.cuda.is_available():
assert num_accelerators_per_machine <= torch.cuda.device_count()
dist_backend = "nccl"
elif "npu" in device and _TORCH_NPU_AVAILABLE:
assert num_accelerators_per_machine <= torch.npu.device_count()
dist_backend = "hccl"
global_rank = machine_rank * num_accelerators_per_machine + local_rank
try:
dist.init_process_group(
backend="NCCL" if has_gpu else "GLOO",
backend=dist_backend,
init_method=dist_url,
world_size=world_size,
rank=global_rank,
Expand All @@ -112,8 +103,8 @@ def _distributed_worker(
raise e

# Setup the local process group.
comm.create_local_process_group(num_gpus_per_machine)
if has_gpu:
comm.create_local_process_group(num_accelerators_per_machine)
if "cuda" in device and torch.cuda.is_available():
torch.cuda.set_device(local_rank)

# synchronize is needed here to prevent a possible timeout after calling init_process_group
Expand Down
11 changes: 10 additions & 1 deletion detectron2/engine/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def __init__(
gather_metric_period=1,
zero_grad_before_forward=False,
async_write_metrics=False,
device: torch.device = torch.device("cuda"),
):
"""
Args:
Expand Down Expand Up @@ -451,6 +452,7 @@ def __init__(
precision: torch.dtype = torch.float16,
log_grad_scaler: bool = False,
async_write_metrics=False,
device: torch.device = torch.device("cuda"),
):
"""
Args:
Expand All @@ -475,15 +477,22 @@ def __init__(
self.grad_scaler = grad_scaler
self.precision = precision
self.log_grad_scaler = log_grad_scaler
self.device = device

def run_step(self):
"""
Implement the AMP training logic.
"""
assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
assert (
torch.cuda.is_available() or torch.npu.is_available()
), "[AMPTrainer] CUDA/Ascend NPU is required for AMP training!"

from torch.cuda.amp import autocast

if "npu" in self.device:
from torch.npu.amp import autocast

start = time.perf_counter()
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
Expand Down
3 changes: 2 additions & 1 deletion detectron2/model_zoo/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import CfgNode, LazyConfig, get_cfg, instantiate
from detectron2.modeling import build_model
from detectron2.utils.comm import _TORCH_NPU_AVAILABLE


class _ModelZooUrls:
Expand Down Expand Up @@ -196,7 +197,7 @@ def get(config_path, trained: bool = False, device: Optional[str] = None):
model = model_zoo.get("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml", trained=True)
"""
cfg = get_config(config_path, trained)
if device is None and not torch.cuda.is_available():
if device is None and not torch.cuda.is_available() and not _TORCH_NPU_AVAILABLE:
device = "cpu"
if device is not None and isinstance(cfg, CfgNode):
cfg.MODEL.DEVICE = device
Expand Down
7 changes: 7 additions & 0 deletions detectron2/modeling/meta_arch/build.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import torch

from detectron2.utils.comm import _TORCH_NPU_AVAILABLE
from detectron2.utils.logger import _log_api_usage
from detectron2.utils.registry import Registry

Expand All @@ -18,6 +19,12 @@ def build_model(cfg):
Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
Note that it does not load any weights from ``cfg``.
"""
# TODO (cmq): The support of dynamic ops in torch-npu is limited.
# Not supported kernel size [h=32, w=64] in Conv2DBackprop dynamic ops,
# revert me after supported
if "npu" in cfg.MODEL.DEVICE and _TORCH_NPU_AVAILABLE:
torch.npu.set_compile_mode(jit_compile=True)

meta_arch = cfg.MODEL.META_ARCHITECTURE
model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
model.to(torch.device(cfg.MODEL.DEVICE))
Expand Down
38 changes: 34 additions & 4 deletions detectron2/utils/collect_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from collections import defaultdict
import PIL
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision
from tabulate import tabulate

from detectron2.utils.comm import _TORCH_NPU_AVAILABLE, _find_free_port

__all__ = ["collect_env_info"]


Expand Down Expand Up @@ -219,21 +223,32 @@ def collect_env_info():
def test_nccl_ops():
num_gpu = torch.cuda.device_count()
if os.access("/tmp", os.W_OK):
import torch.multiprocessing as mp

dist_url = "file:///tmp/nccl_tmp_file"
print("Testing NCCL connectivity ... this should not hang.")
mp.spawn(_test_nccl_worker, nprocs=num_gpu, args=(num_gpu, dist_url), daemon=False)
print("NCCL succeeded.")


def _test_nccl_worker(rank, num_gpu, dist_url):
import torch.distributed as dist

dist.init_process_group(backend="NCCL", init_method=dist_url, rank=rank, world_size=num_gpu)
dist.barrier(device_ids=[rank])


def test_hccl_ops():
num_npu = torch.npu.device_count()

port = _find_free_port()
dist_url = f"tcp://127.0.0.1:{port}"
print("Testing HCCL connectivity ... this should not hang.")
mp.spawn(_test_hccl_worker, nprocs=num_npu, args=(num_npu, dist_url), daemon=False)
print("HCCL succeeded.")


def _test_hccl_worker(rank, num_npu, dist_url):
dist.init_process_group(backend="hccl", init_method=dist_url, rank=rank, world_size=num_npu)
dist.barrier(device_ids=[rank])


def main() -> None:
global x
try:
Expand All @@ -258,6 +273,21 @@ def main() -> None:
if num_gpu > 1:
test_nccl_ops()

if _TORCH_NPU_AVAILABLE:
num_npu = torch.npu.device_count()
for k in range(num_npu):
device = f"npu:{k}"
try:
x = torch.tensor([1, 2.0], dtype=torch.float32)
x = x.to(device)
except Exception as e:
print(
f"Unable to copy tensor to device={device}: {e}. "
"Your Ascend NPU environment is broken."
)
if num_npu > 1:
test_hccl_ops()


if __name__ == "__main__":
main() # pragma: no cover
Loading