diff --git a/official/vision/utils/object_detection/visualization_utils.py b/official/vision/utils/object_detection/visualization_utils.py index 0b0b9016fc5..53143d5f914 100644 --- a/official/vision/utils/object_detection/visualization_utils.py +++ b/official/vision/utils/object_detection/visualization_utils.py @@ -438,6 +438,11 @@ def _denormalize_images(images: tf.Tensor) -> tf.Tensor: ) return tf.cast(images, dtype=tf.uint8) + if images.shape[3] > 3: + images = images[:, :, :, 0:3] + elif images.shape[3] == 1: + images = tf.image.grayscale_to_rgb(images) + images = tf.nest.map_structure( tf.identity, tf.map_fn(