diff --git a/alf/networks/projection_networks_test.py b/alf/networks/projection_networks_test.py index dcf39fbb1..66e6e2b7f 100644 --- a/alf/networks/projection_networks_test.py +++ b/alf/networks/projection_networks_test.py @@ -407,6 +407,10 @@ def test_mixture_of_gaussian_1d_action(self): self.assertEqual(dist.batch_shape, (7, )) x = dist.sample() self.assertEqual((7, 1), x.shape) + mode = dist_utils.get_mode(dist) + self.assertEqual((7, 1), mode.shape) + mode = dist_utils.get_rmode(dist) + self.assertEqual((7, 1), mode.shape) def test_mixture_of_stable_normal_2d_action(self): input_spec = TensorSpec((10, ), torch.float32) @@ -431,6 +435,10 @@ def test_mixture_of_stable_normal_2d_action(self): self.assertEqual(dist.batch_shape, (7, )) x = dist.sample() self.assertEqual((7, 2), x.shape) + mode = dist_utils.get_mode(dist) + self.assertEqual((7, 2), mode.shape) + mode = dist_utils.get_rmode(dist) + self.assertEqual((7, 2), mode.shape) def test_mixture_of_beta_2d_action(self): input_spec = TensorSpec((10, ), torch.float32) @@ -450,6 +458,10 @@ def test_mixture_of_beta_2d_action(self): self.assertEqual(dist.batch_shape, (7, )) x = dist.sample() self.assertEqual((7, 2), x.shape) + mode = dist_utils.get_mode(dist) + self.assertEqual((7, 2), mode.shape) + mode = dist_utils.get_rmode(dist) + self.assertEqual((7, 2), mode.shape) if __name__ == "__main__": diff --git a/alf/utils/dist_utils.py b/alf/utils/dist_utils.py index b331129ff..3361bf571 100644 --- a/alf/utils/dist_utils.py +++ b/alf/utils/dist_utils.py @@ -1159,6 +1159,15 @@ def get_mode(dist): torch.argmax(dist.logits, -1), num_classes=dist.logits.shape[-1]) elif isinstance(dist, td.normal.Normal): mode = dist.mean + elif isinstance(dist, td.MixtureSameFamily): + # Note that this just computes an approximate mode. We use an approximate + # approach to compute the mode, by using the mode of the component + # distribution that has the highest component probability. + # [B] + ind = get_mode(dist.mixture_distribution) + # [B, num_component, d] + component_mode = get_mode(dist.component_distribution) + mode = component_mode[torch.arange(component_mode.shape[0]), ind] elif isinstance(dist, StableCauchy): mode = dist.loc elif isinstance(dist, td.Independent): @@ -1196,6 +1205,13 @@ def get_rmode(dist): """ if isinstance(dist, td.normal.Normal): mode = dist.mean + elif isinstance(dist, td.MixtureSameFamily): + # note that for the mixture distribution, there is no gradient back-propagation + # [B] + ind = get_mode(dist.mixture_distribution) + # [B, num_component, d] + component_mode = get_rmode(dist.component_distribution) + mode = component_mode[torch.arange(component_mode.shape[0]), ind] elif isinstance(dist, StableCauchy): mode = dist.loc elif isinstance(dist, Beta) or isinstance(dist, TruncatedDistribution):