Skip to content

Commit

Permalink
get_mode for mixture distribution (#1415)
Browse files Browse the repository at this point in the history
* get_mode for mixture distribution

* Add comments

* Address comments
  • Loading branch information
Haichao-Zhang authored Nov 21, 2022
1 parent 306c874 commit 0f8d0ec
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
12 changes: 12 additions & 0 deletions alf/networks/projection_networks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,10 @@ def test_mixture_of_gaussian_1d_action(self):
self.assertEqual(dist.batch_shape, (7, ))
x = dist.sample()
self.assertEqual((7, 1), x.shape)
mode = dist_utils.get_mode(dist)
self.assertEqual((7, 1), mode.shape)
mode = dist_utils.get_rmode(dist)
self.assertEqual((7, 1), mode.shape)

def test_mixture_of_stable_normal_2d_action(self):
input_spec = TensorSpec((10, ), torch.float32)
Expand All @@ -431,6 +435,10 @@ def test_mixture_of_stable_normal_2d_action(self):
self.assertEqual(dist.batch_shape, (7, ))
x = dist.sample()
self.assertEqual((7, 2), x.shape)
mode = dist_utils.get_mode(dist)
self.assertEqual((7, 2), mode.shape)
mode = dist_utils.get_rmode(dist)
self.assertEqual((7, 2), mode.shape)

def test_mixture_of_beta_2d_action(self):
input_spec = TensorSpec((10, ), torch.float32)
Expand All @@ -450,6 +458,10 @@ def test_mixture_of_beta_2d_action(self):
self.assertEqual(dist.batch_shape, (7, ))
x = dist.sample()
self.assertEqual((7, 2), x.shape)
mode = dist_utils.get_mode(dist)
self.assertEqual((7, 2), mode.shape)
mode = dist_utils.get_rmode(dist)
self.assertEqual((7, 2), mode.shape)


if __name__ == "__main__":
Expand Down
16 changes: 16 additions & 0 deletions alf/utils/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,15 @@ def get_mode(dist):
torch.argmax(dist.logits, -1), num_classes=dist.logits.shape[-1])
elif isinstance(dist, td.normal.Normal):
mode = dist.mean
elif isinstance(dist, td.MixtureSameFamily):
# Note that this just computes an approximate mode. We use an approximate
# approach to compute the mode, by using the mode of the component
# distribution that has the highest component probability.
# [B]
ind = get_mode(dist.mixture_distribution)
# [B, num_component, d]
component_mode = get_mode(dist.component_distribution)
mode = component_mode[torch.arange(component_mode.shape[0]), ind]
elif isinstance(dist, StableCauchy):
mode = dist.loc
elif isinstance(dist, td.Independent):
Expand Down Expand Up @@ -1196,6 +1205,13 @@ def get_rmode(dist):
"""
if isinstance(dist, td.normal.Normal):
mode = dist.mean
elif isinstance(dist, td.MixtureSameFamily):
# note that for the mixture distribution, there is no gradient back-propagation
# [B]
ind = get_mode(dist.mixture_distribution)
# [B, num_component, d]
component_mode = get_rmode(dist.component_distribution)
mode = component_mode[torch.arange(component_mode.shape[0]), ind]
elif isinstance(dist, StableCauchy):
mode = dist.loc
elif isinstance(dist, Beta) or isinstance(dist, TruncatedDistribution):
Expand Down

0 comments on commit 0f8d0ec

Please sign in to comment.