From bdbcbaa89ff0a4a807c3cfb72f5f9842fc611e34 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Thu, 30 Jan 2025 22:46:43 +0530 Subject: [PATCH] Update beta sampling code in augment.py The function `_sample_from_beta(alpha, beta, shape)` in `MixupAndCutmix` class, is not having the same functionality as `numpy.random.beta`. So `tfm.vision.augment.MixupAndCutmix._sample_from_beta(0.2, 0.2, tf.shape( tf.range(10000))).numpy()` is also deviating as well. So suggesting the fix keeping `alpha=alpha, beta=1.0` in `_sample_from_beta`. The reproduced [gist](https://colab.sandbox.google.com/gist/LakshmiKalaKadali/06533824610d6e85ea4aa3c6399819e6/tf_model_13490.ipynb#scrollTo=zSlE-3YDjL91) also attached. This PR closes [#13490](https://github.com/tensorflow/models/issues/13490) Thank You --- official/vision/ops/augment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/official/vision/ops/augment.py b/official/vision/ops/augment.py index f632c2232fb..ea00506ca42 100644 --- a/official/vision/ops/augment.py +++ b/official/vision/ops/augment.py @@ -2697,8 +2697,8 @@ def distort(self, images: tf.Tensor, @staticmethod def _sample_from_beta(alpha, beta, shape): - sample_alpha = tf.random.gamma(shape, 1., beta=alpha) - sample_beta = tf.random.gamma(shape, 1., beta=beta) + sample_alpha = tf.random.gamma(shape, alpha, beta=1.0) + sample_beta = tf.random.gamma(shape, alpha, beta=1.0) return sample_alpha / (sample_alpha + sample_beta) def _cutmix(self, images: tf.Tensor,