-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Distributed Training in alf #913
Comments
@breakds FYI, two reference papers I came across a while ago (RL scenario): https://openreview.net/pdf?id=H1gX8C4YPr Although they are proposed for multi-machine training, our multi-gpu single-machine case is a special and simpler case. Or refer to Pytorch official multi-gpu support (general DL scenario): |
Thanks @hnyu for the references! |
DD-PPO, the first paper's main idea is about early stopping the slow simulation during rollout with a batched environment (potentially distributed over different machines in a cluster), and try to use the full experience from some of the environment and partial experience from the early-stopped ones during the training in each iteration. I think we can borrow the ideas in the near future. As the first step, I will look into how pytorch's
in each of the training iteration. |
Currently I am hitting two problems with
|
I think multi-gpu only makes sense for a large mini-batch with intensive computation. What is your setup? |
Yep I think that is what happened. I was testing the class Network(nn.Module):
def __init__(self, input_size, output_size):
super(Network, self).__init__()
self.fc1 = nn.Linear(input_size, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 384)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(384, 64)
self.relu3 = nn.ReLU()
self.fc4 = nn.Linear(64, output_size)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
h = self.fc1(x)
h = self.relu1(h)
h = self.fc2(h)
h = self.relu2(h)
h = self.fc3(h)
h = self.relu3(h)
h = self.fc4(h)
h = self.sigmoid(h)
return h I realized that this is probably too small because even if a batch of I am now trying to fix the issue in No.2 so that I can do experiment on an actual network that is used in |
Our expected scenario for multi-gpu is image inputs with a large batch size. So you could try dummy image inputs instead. Besides running time, also another scenario is to split sgd memory consumption into multiple cards, if one card is not enough. |
That makes a lot of sense. Thanks for the suggestions and clarification! |
I was using Traceback (most recent call last):
File "/nix/store/4s0h5aawbap3xhldxhcijvl26751qrjr-python3-3.8.9/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/nix/store/4s0h5aawbap3xhldxhcijvl26751qrjr-python3-3.8.9/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/breakds/projects/alf/alf/bin/experiment/dp_network_experiment.py", line 42, in <module>
action_distribution, actor_state = actor_network(observation, state=())
File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
return self.gather(outputs, self.output_device)
File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 180, in gather
return gather(outputs, output_device, dim=self.dim)
File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 76, in gather
res = gather_map(outputs)
File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 71, in gather_map
return type(out)(map(gather_map, zip(*outputs)))
File "/nix/store/1nhxgafz45v9sivabxw0aqr0dvpyw1nc-python3-3.8.9-env/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 71, in gather_map
return type(out)(map(gather_map, zip(*outputs)))
TypeError: 'Categorical' object is not iterable With some investigation, I realized that it fails because
And this problem happens at the last step of "Gather". I will use a slightly modified network to continue experiment to work around this. However, the final solution should make multi-GPU as transparent as possible so that it is convenient to use. Directly applying |
After slightly modifying the import torch
import torch.nn as nn
import alf
from alf.networks import ActorDistributionNetwork
from alf.tensor_specs import BoundedTensorSpec
import functools
import time
if __name__ == '__main__':
alf.set_default_device('cuda')
CONV_LAYER_PARAMS = ((32, 8, 4), (64, 4, 2), (64, 3, 1))
actor_network_cls = functools.partial(
ActorDistributionNetwork,
fc_layer_params=(512, ),
conv_layer_params=CONV_LAYER_PARAMS)
actor_network = nn.DataParallel(actor_network_cls(
input_tensor_spec=BoundedTensorSpec(
shape=(4, 150, 150), dtype=torch.float32, minimum=0., maximum=1.),
action_spec=BoundedTensorSpec(
shape=(), dtype=torch.int64, minimum=0, maximum=3)))
start_time = time.time()
for i in range(1000):
observation = torch.rand(640, 4, 150, 150)
action_distribution, actor_state = actor_network(observation, state=())
print(f'{time.time() - start_time} seconds elapsed') I can see the load being distributed to 2 cards (as well as the memory being distributed). However, compared to running the same piece of code on single 3080 without
This almost rendered |
The inefficiency of DataParallel seems unreasonable. There must be something wrong going on. |
Or maybe this is by design, I can try to look into where the time is being spent. |
According to https://pytorch.org/tutorials/intermediate/ddp_tutorial.html, DataParallel might be even slower than DistributedDataParallel |
Yep, I can see that GIL issue makes sense. @hnyu and I chatted about this today, and I agree with Haonan that we might want to adjust our goal and go for a slightly more complicated (i.e. might require structural update) customized solution. We can chat more about this tomorrow. |
|
This is part of the effort to address #913. A sub-task requires extract the worker logic to be out of the class (for some reason it will prevent `multiprocessing` to work correctly). Without such change the `multiprocessing.Process` will just be stuck on `start()`.
This is part of the effort to unblock #913. Two reasons for this change 1. `worker` definitely does not rely on `ProcessEnvironment` at all, and therefore it is cleaner to make it independent of `ProcessEnvironment`. 2. If it stays as a member method of `ProcessEnvironment`, `multiprocess.Process` will get stuck on `start()` if the parent process is also a `multiprocess.Process`, for unknown reason though (tried investigation but haven't figured out).
This is part of the effort to address #913. A sub-task requires extract the worker logic to be out of the class (for some reason it will prevent `multiprocessing` to work correctly). Without such change the `multiprocessing.Process` will just be stuck on `start()`.
This is part of the effort to unblock #913. Two reasons for this change 1. `worker` definitely does not rely on `ProcessEnvironment` at all, and therefore it is cleaner to make it independent of `ProcessEnvironment`. 2. If it stays as a member method of `ProcessEnvironment`, `multiprocess.Process` will get stuck on `start()` if the parent process is also a `multiprocess.Process`, for unknown reason though (tried investigation but haven't figured out).
This is part of the effort to address #913. A sub-task requires extract the worker logic to be out of the class (for some reason it will prevent `multiprocessing` to work correctly). Without such change the `multiprocessing.Process` will just be stuck on `start()`.
This is part of the effort to unblock #913. Two reasons for this change 1. `worker` definitely does not rely on `ProcessEnvironment` at all, and therefore it is cleaner to make it independent of `ProcessEnvironment`. 2. If it stays as a member method of `ProcessEnvironment`, `multiprocess.Process` will get stuck on `start()` if the parent process is also a `multiprocess.Process`, for unknown reason though (tried investigation but haven't figured out).
This is part of the effort to address #913. A sub-task requires extract the worker logic to be out of the class (for some reason it will prevent `multiprocessing` to work correctly). Without such change the `multiprocessing.Process` will just be stuck on `start()`.
This is part of the effort to unblock #913. Two reasons for this change 1. `worker` definitely does not rely on `ProcessEnvironment` at all, and therefore it is cleaner to make it independent of `ProcessEnvironment`. 2. If it stays as a member method of `ProcessEnvironment`, `multiprocess.Process` will get stuck on `start()` if the parent process is also a `multiprocess.Process`, for unknown reason though (tried investigation but haven't figured out).
This is part of the effort to address #913. A sub-task requires extract the worker logic to be out of the class (for some reason it will prevent `multiprocessing` to work correctly). Without such change the `multiprocessing.Process` will just be stuck on `start()`.
This is part of the effort to unblock #913. Two reasons for this change 1. `worker` definitely does not rely on `ProcessEnvironment` at all, and therefore it is cleaner to make it independent of `ProcessEnvironment`. 2. If it stays as a member method of `ProcessEnvironment`, `multiprocess.Process` will get stuck on `start()` if the parent process is also a `multiprocess.Process`, for unknown reason though (tried investigation but haven't figured out).
This is part of the effort to unblock #913. Two reasons for this change 1. `worker` definitely does not rely on `ProcessEnvironment` at all, and therefore it is cleaner to make it independent of `ProcessEnvironment`. 2. If it stays as a member method of `ProcessEnvironment`, `multiprocess.Process` will get stuck on `start()` if the parent process is also a `multiprocess.Process`, for unknown reason though (tried investigation but haven't figured out).
Turns out that my hypothesis on gpu 0 and gpu 1 run sequentially proved to be wrong. I made 2 mistakes in my toy example with
Good news is that after the fix the toy example start_time = time.time()
for i in range(2500):
observation = torch.rand(batch_size, 4, 84, 84, device=rank)
action_distribution, actor_state = actor_network(observation, state=())
action = action_distribution.sample()
reward = torch.rand(batch_size, device=rank)
loss = - torch.mean(action_distribution.log_prob(action) * reward)
loss.backward()
if i % 100 == 0:
print(f'iteration {i} - {time.time() - start_time} seconds elapsed on device {rank}')
print(f'{time.time() - start_time} seconds elapsed on device {rank}') proves that when
for batch size = 1024 per batch. Bad news is that this does not lead to why running |
I think for this toy example, you can also try more complex CNN architectures like (ResnetEncodingNetwork) to see the gain. |
Turns out the problem is still one of the two processes throws an exception, but that exception is not observed. About "the exception is not observed"Actually I researched (experimented) on how exception in nested sub processes works yesterday, and thought I had solved this problem. Sadly there are still certain cases such exception just raised silently. Normally I would expect it to show in terminal because I explicitly catch it and print it in the offending process. About the exception itselfThe exception itself looks like this:
After going DDP, we would like to have newly created tensors to be placed on a process-dependent default rank (a.k.a. device id). I have patched quite a few places with Update
|
Acknowledged. Will try it later. Thanks! |
More updates, with some other small problems fixed, I was now able to train with 2 GPU under DDP wrapper:
The synchronization is there too. The trained result cannot be played yet (which is expected), will take a closer look on the checkpoints. Meanwhile, I will start to think about a cleaner implementation. |
Outline for plan of next steps, after discussion on 2021/07/23: TODO Productionize DDP over ALF [3/11]
|
UpdateI found that after training for 10 minutes, they starts to behave as "not synchronized". I think I have some misunderstanding of how DDP works. |
Had a discussion with @emailweixu while reading the DDP code, and we figured out why the above approach (directly wrapping How DDP worksThe below steps demonstrates how DDP work in one iteration, assuming w = DDP(m) is
This explains why wrapping The next idea to try is to wrap over anything that produces |
Now with a ddp wrapper applied to the class RLAlgorithm:
def activate_ddp(self, rank: int):
self.__dict__['_unroll_performer'] = DDP(UnrollPerformer(self), device_ids=[rank]) Note that Verified that both GPUs are being utilized, and they are synchronized:
|
The remaining problem is that when turning on DDP, the time consumed for each training iteration is significantly increased. On the same physical machine, single process non-DDP, each iteration took around Preliminary investigation found that the |
I was able to further hunt down the cause. The major contributor of the more than 10x time consumption increase comes from
In particular, I think this is because obs = alf.nest.find_field(experience, "observation") has a different dimension. Comparing dual-GPU DDP version and single GPU single process version, the shape of
Apparently some of the transformation was not applied in dual GPU version, which is supposed to downsample the observation from |
To me this is more like a bug when obtaining input tensors. usually we don't have a "downsampling" transformer from ALF. The env is directly responsible for resizing images. So probably you are using two different envs/wrappers. And the image channels is usually 3, or with FrameStacker, 3n. |
Thanks for the help, Haonan. I am slowly digging into that. Let me check the environment. |
With some more debugging, I found that the problem is due to "failing to apply In alf.config(
'suite_gym.load',
gym_env_wrappers=[gym_wrappers.DMAtariPreprocessing],
# Default max episode steps for all games
#
# Per DQN paper setting, 18000 frames assuming frameskip = 4
max_episode_steps=4500) With python debugger, I can see that
So this is likely some configuration loading problem. I will need to read more on this to understand what's happening here. |
And after a few hours of poking around and investigation, I finally find the problem why configuration is not respected. I'll summary my discovery here:
Note that there are 2 hierarchies of sub processes here. In order for subprocess to inherit In single GPU setup, this is not a problem because there is only 1 hierarchy of sub processes and the top level process will start the environment processes with the default start method, which is However, in the dual GPU/dual process setup, in order for DDP to work, the 2 The solution is simple once we figured out above: ctx = multiprocessing.get_context('fork')
self._process = ctx.Process(...) |
With the above problem fixed, the training for I haven't started working on the shenanigans of checkpoint/summary/metrics so the curves might look a bit messy, but it looks similar to the performance of a single GPU, within similar time. (Note that I am running 32 environments for each process in dual process dual GPU setup). |
By turning on
So actually DDP version is indeed faster, but not by a large margin. It uses about |
However, I discovered another problem - when DDP is on, even though the log files are generated, nothing is written to them. Will need to look at this as well. |
On-policy algorithm can now enjoy DDP. The next step is to add full support for off-policy as well. |
…zonRobotics#938) This is part of the effort to address HorizonRobotics#913. A sub-task requires extract the worker logic to be out of the class (for some reason it will prevent `multiprocessing` to work correctly). Without such change the `multiprocessing.Process` will just be stuck on `start()`.
…orizonRobotics#939) This is part of the effort to unblock HorizonRobotics#913. Two reasons for this change 1. `worker` definitely does not rely on `ProcessEnvironment` at all, and therefore it is cleaner to make it independent of `ProcessEnvironment`. 2. If it stays as a member method of `ProcessEnvironment`, `multiprocess.Process` will get stuck on `start()` if the parent process is also a `multiprocess.Process`, for unknown reason though (tried investigation but haven't figured out).
…nd multi GPU training (HorizonRobotics#913) (HorizonRobotics#944) * [REFACTOR] train.py to consolidate common logic for both single GPU and multi GPU training * Address Wei's comments * Address Haonan's comments * Specify authoritative url and port as well * Remove unused Optional typing
…tics#951) * Add UnrollPerformer as the module being wrapped by DistributedDataParallel * Enable DDP for on policy RLTrainer
As discussed with @emailweixu, it would be nice to have alf support multi-GPU training. The goals are
The text was updated successfully, but these errors were encountered: