Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiDiscrete action spaces #146

Open
tkelestemur opened this issue Jun 8, 2021 · 10 comments
Open

MultiDiscrete action spaces #146

tkelestemur opened this issue Jun 8, 2021 · 10 comments
Assignees

Comments

@tkelestemur
Copy link
Contributor

I have a custom environment with a MultiDiscrete action space. The MultiDiscrete action space allows controlling an agent with n-dimensional discrete action spaces.

In my environment, I have 4 dimensions where each dimension has 11 actions. I'm trying to use A2C with a Softmax policy. Below is the implementation of the policy and value networks. The output of the policy gives me [N, 4, 11] tensor where N is the batch size. The softmax is applied to the last dimension of this tensor so basically, I have 4 action distributions. I thought this would work but I'm getting the following error:

Do I need to make changes to the A2C or am I doing something wrong?

  File "train_rl.py", line 90, in <module>
    train()
  File "train_rl.py", line 80, in train
    experiments.train_agent_batch(
  File "/home/tarik/venvs/tacto/lib/python3.8/site-packages/pfrl/experiments/train_agent_batch.py", line 86, in train_agent_batch
    agent.batch_observe(obss, rs, dones, resets)
  File "/home/tarik/venvs/tacto/lib/python3.8/site-packages/pfrl/agents/a2c.py", line 224, in batch_observe
    self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset)
  File "/home/tarik/venvs/tacto/lib/python3.8/site-packages/pfrl/agents/a2c.py", line 288, in _batch_observe_train
    self.update()
  File "/home/tarik/venvs/tacto/lib/python3.8/site-packages/pfrl/agents/a2c.py", line 183, in update
    action_log_probs = action_log_probs.reshape(
RuntimeError: shape '[5, 2]' is invalid for input of size 40
policy = torch.nn.Sequential(
    torch.nn.Linear(44, 128),
    torch.nn.Tanh(),
    torch.nn.Linear(128, 128),
    torch.nn.Tanh(),
    torch.nn.Linear(128, 44),
    torch.nn.Unflatten(1, (4, 11)),
    SoftmaxCategoricalHead()
)

value = torch.nn.Sequential(
    torch.nn.Linear(44, 128),
    torch.nn.Tanh(),
    torch.nn.Linear(128, 128),
    torch.nn.Tanh(),
    torch.nn.Linear(128, 1),
)

model = pfrl.nn.Branched(policy, value)
@muupan muupan self-assigned this Jun 9, 2021
@tkelestemur
Copy link
Contributor Author

After some digging into the a2c code, I realized that the log probabilities of the policy need to have the shape of (update_steps, num_processes) so that it can probably be multiplied with the advantages. As a quick workaround, we can basically sum the log probabilities across the dimensions of the action space by changing this line to action_log_probs = pout.log_prob(actions).sum(dim=1) as explained in this paper.

This should fix the A2C but a general approach for supporting multi discrete action spaces should be considered.

@muupan
Copy link
Member

muupan commented Jun 9, 2021

I guess you can wrap the output of SoftmaxCategoricalHead with torch.distributions.Independent so that your resulting distribution's batch_shape is (N,) and event_shape is (4,).

@xylee95
Copy link

xylee95 commented Jun 11, 2021

I'm also running into a similar issue with my environment, and even before going into the update rule, I'm facing the problem of having a multi-discrete action space with different number of actions along each dimension. For example, dimension 1 has 5 actions, dimension 2 has 3 actions and dimension 3 has 10 actions.

How would I code up the final layer of the policy in that case? In the issue above, the author could nicely unflatten the tensor into uniform shapes along each dimension but I'm not aware of any way to do it for multi-discrete action spaces with different dimensions.

Also, please let me know if you would rather have me open a new issue for this topic. Thanks!

@muupan
Copy link
Member

muupan commented Jun 12, 2021

I'm facing the problem of having a multi-discrete action space with different number of actions along each dimension. For example, dimension 1 has 5 actions, dimension 2 has 3 actions and dimension 3 has 10 actions.

I think this requires a new subclass of torch.distributions.Distrubution that models a joint distribution of multiple categorical distributions of different sizes.

@xylee95
Copy link

xylee95 commented Jun 12, 2021

I'm facing the problem of having a multi-discrete action space with different number of actions along each dimension. For example, dimension 1 has 5 actions, dimension 2 has 3 actions and dimension 3 has 10 actions.

I think this requires a new subclass of torch.distributions.Distrubution that models a joint distribution of multiple categorical distributions of different sizes.

Yes, I think you're right. I've managed to get something simple that works to model individual Categorical torch distributions before combining them. Thanks a lot, although please do consider including agents that support MultiDiscrete action spaces in the future. I think it would be really helpful.

@muupan
Copy link
Member

muupan commented Jun 12, 2021

A perhaps easier but less clean workaround is to model it as a joint distribution of same-sized categorical distributions using Independent and Categorical but set logits for unused categories to very low values so that they are never sampled.

@tkelestemur
Copy link
Contributor Author

@muupan Thanks, wrapping the output of Categorical with Independent worked fine for the same size multi-action spaces.

@xylee95, Can you share how did you manage to get it working with different sizes of action spaces?

I'm currently writing a class that is based on the MultiCategoricalDistrubtion from stable_baseline3 and hopefully open a PR soon.

@xylee95
Copy link

xylee95 commented Jun 15, 2021

@tkelestemur Yes, that is exactly what I did. I wrote a class based on the MultiCategoricalDistribution from stable_baseline3 and changed some of the function names to fit the log_prob calls in the agent. It works fine but I've only tested it with PPO so far and not other agents. If you need more details, I'll be happy to share

@tkelestemur
Copy link
Contributor Author

@xylee95 can you share your implementation? I've tried to write a subclass of torch.distributions.Distrubution but didn't have much success.

@xylee95
Copy link

xylee95 commented Jun 29, 2021

@tkelestemur This is my implementation. It is almost a copy and paste of stable_baseline3 code and I did not write it as a sub-class of torch.distributions.Distribution, instead I created a new class which returns a list of torch distributions. It would definitely be much cleaner if written as a subclass of torch.distributions.Distribution

class MultiCategoricalDistribution():

	def __init__(self, action_dims):
		"""Initialization
		"""
		super(MultiCategoricalDistribution, self).__init__()
		self.action_dims = action_dims

	def proba_distribution_net(self, latent_dim):
		"""
		Create the layer that represents the distribution. 
		It will be the logits (flattened) of the MultiCategorical distribution.
		You can then get probabilities using a softmax on each sub-space.
		"""
		action_logits = nn.Linear(latent_dim, sum(self.action_dims))
		return action_logits

	def proba_distribution(self, action_logits):
		"""Create a list of categorical distribution for each dimension
		"""
		self.distribution = [torch.distributions.Categorical(logits=split) for split in torch.split(action_logits, tuple(self.action_dims), dim=1)]
		return self

	def log_prob(self, actions):
		"""Extract each discrete action and compute log prob for their respective distributions
		"""
		return torch.stack(
			[dist.log_prob(action) for dist, action in zip(self.distribution, torch.unbind(actions, dim=1))], dim=1
		).sum(dim=1)

	def entropy(self):
		"""Computes sum of entropy of individual caterogical dist
		"""
		return torch.stack([dist.entropy() for dist in self.distribution], dim=1).sum(dim=1)

	def sample(self):
		"""Samples actions from each individual categorical dist
		"""
		return torch.stack([dist.sample() for dist in self.distribution], dim=1)

	def mode(self):
		"""Computes mode of each categorical dist.
		"""
		return torch.stack([torch.argmax(dist.probs, dim=1) for dist in self.distribution], dim=1)

	def get_actions(self, deterministic=False):
		"""Return actions according to the probability distribution.  
		"""
		if deterministic:
			return self.mode()
		return self.sample()

	def actions_from_params(self, action_logits, deterministic=False):
		"""Update the proba distribution
		"""
		self.proba_distribution(action_logits)
		return self.get_actions(deterministic=deterministic)

	def log_prob_from_params(self, action_logits):
		"""Compute log-probability of actions
		"""
		actions = self.actions_from_params(action_logits)
		log_prob = self.log_prob(actions)
		return actions, log_prob````

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants