Skip to content

Commit

Permalink
implement transform_bounding_boxes for random_shear (#20704)
Browse files Browse the repository at this point in the history
  • Loading branch information
shashaka authored Jan 2, 2025
1 parent 476a664 commit 5b29974
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 2 deletions.
142 changes: 140 additions & 2 deletions keras/src/layers/preprocessing/image_preprocessing/random_shear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501
clip_to_image_size,
)
from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501
convert_format,
)
from keras.src.random.seed_generator import SeedGenerator
from keras.src.utils import backend_utils


@keras_export("keras.layers.RandomShear")
Expand Down Expand Up @@ -175,7 +182,7 @@ def get_random_transformation(self, data, training=True, seed=None):
)
* invert
)
return {"shear_factor": shear_factor}
return {"shear_factor": shear_factor, "input_shape": images_shape}

def transform_images(self, images, transformation, training=True):
images = self.backend.cast(images, self.compute_dtype)
Expand Down Expand Up @@ -231,13 +238,144 @@ def _get_shear_matrix(self, shear_factors):
def transform_labels(self, labels, transformation, training=True):
return labels

def get_transformed_x_y(self, x, y, transform):
a0, a1, a2, b0, b1, b2, c0, c1 = self.backend.numpy.split(
transform, 8, axis=-1
)

k = c0 * x + c1 * y + 1
x_transformed = (a0 * x + a1 * y + a2) / k
y_transformed = (b0 * x + b1 * y + b2) / k
return x_transformed, y_transformed

def get_shifted_bbox(self, bounding_boxes, w_shift_factor, h_shift_factor):
bboxes = bounding_boxes["boxes"]
x1, x2, x3, x4 = self.backend.numpy.split(bboxes, 4, axis=-1)

w_shift_factor = self.backend.convert_to_tensor(
w_shift_factor, dtype=x1.dtype
)
h_shift_factor = self.backend.convert_to_tensor(
h_shift_factor, dtype=x1.dtype
)

if len(bboxes.shape) == 3:
w_shift_factor = self.backend.numpy.expand_dims(w_shift_factor, -1)
h_shift_factor = self.backend.numpy.expand_dims(h_shift_factor, -1)

bounding_boxes["boxes"] = self.backend.numpy.concatenate(
[
x1 - w_shift_factor,
x2 - h_shift_factor,
x3 - w_shift_factor,
x4 - h_shift_factor,
],
axis=-1,
)
return bounding_boxes

def transform_bounding_boxes(
self,
bounding_boxes,
transformation,
training=True,
):
raise NotImplementedError
def _get_height_width(transformation):
if self.data_format == "channels_first":
height_axis = -2
width_axis = -1
else:
height_axis = -3
width_axis = -2
input_height, input_width = (
transformation["input_shape"][height_axis],
transformation["input_shape"][width_axis],
)
return input_height, input_width

if training:
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

input_height, input_width = _get_height_width(transformation)

bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="rel_xyxy",
height=input_height,
width=input_width,
dtype=self.compute_dtype,
)

bounding_boxes = self._shear_bboxes(bounding_boxes, transformation)

bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=input_height,
width=input_width,
bounding_box_format="rel_xyxy",
)

bounding_boxes = convert_format(
bounding_boxes,
source="rel_xyxy",
target=self.bounding_box_format,
height=input_height,
width=input_width,
dtype=self.compute_dtype,
)

self.backend.reset()

return bounding_boxes

def _shear_bboxes(self, bounding_boxes, transformation):
shear_factor = self.backend.cast(
transformation["shear_factor"], dtype=self.compute_dtype
)
shear_x_amount, shear_y_amount = self.backend.numpy.split(
shear_factor, 2, axis=-1
)

x1, y1, x2, y2 = self.backend.numpy.split(
bounding_boxes["boxes"], 4, axis=-1
)
x1 = self.backend.numpy.squeeze(x1, axis=-1)
y1 = self.backend.numpy.squeeze(y1, axis=-1)
x2 = self.backend.numpy.squeeze(x2, axis=-1)
y2 = self.backend.numpy.squeeze(y2, axis=-1)

if shear_x_amount is not None:
x1_top = x1 - (shear_x_amount * y1)
x1_bottom = x1 - (shear_x_amount * y2)
x1 = self.backend.numpy.where(shear_x_amount < 0, x1_top, x1_bottom)

x2_top = x2 - (shear_x_amount * y1)
x2_bottom = x2 - (shear_x_amount * y2)
x2 = self.backend.numpy.where(shear_x_amount < 0, x2_bottom, x2_top)

if shear_y_amount is not None:
y1_left = y1 - (shear_y_amount * x1)
y1_right = y1 - (shear_y_amount * x2)
y1 = self.backend.numpy.where(shear_y_amount > 0, y1_right, y1_left)

y2_left = y2 - (shear_y_amount * x1)
y2_right = y2 - (shear_y_amount * x2)
y2 = self.backend.numpy.where(shear_y_amount > 0, y2_left, y2_right)

boxes = self.backend.numpy.concatenate(
[
self.backend.numpy.expand_dims(x1, axis=-1),
self.backend.numpy.expand_dims(y1, axis=-1),
self.backend.numpy.expand_dims(x2, axis=-1),
self.backend.numpy.expand_dims(y2, axis=-1),
],
axis=-1,
)
bounding_boxes["boxes"] = boxes

return bounding_boxes

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import numpy as np
import pytest
from absl.testing import parameterized
from tensorflow import data as tf_data

import keras
from keras.src import backend
from keras.src import layers
from keras.src import testing
from keras.src.utils import backend_utils


class RandomShearTest(testing.TestCase):
Expand Down Expand Up @@ -74,3 +76,127 @@ def test_tf_data_compatibility(self):
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

@parameterized.named_parameters(
(
"with_x_shift",
[[1.0, 0.0]],
[[[0.0, 1.0, 3.2, 3.0], [1.2, 4.0, 4.8, 6.0]]],
),
(
"with_y_shift",
[[0.0, 1.0]],
[[[2.0, 0.0, 4.0, 0.5], [6.0, 0.0, 8.0, 0.0]]],
),
(
"with_xy_shift",
[[1.0, 1.0]],
[[[0.0, 0.0, 3.2, 3.5], [1.2, 0.0, 4.8, 4.5]]],
),
)
def test_random_shear_bounding_boxes(self, translation, expected_boxes):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
image_shape = (10, 8, 3)
else:
image_shape = (3, 10, 8)
input_image = np.random.random(image_shape)
bounding_boxes = {
"boxes": np.array(
[
[2, 1, 4, 3],
[6, 4, 8, 6],
]
),
"labels": np.array([[1, 2]]),
}
input_data = {"images": input_image, "bounding_boxes": bounding_boxes}
layer = layers.RandomShear(
x_factor=0.5,
y_factor=0.5,
data_format=data_format,
seed=42,
bounding_box_format="xyxy",
)

transformation = {
"shear_factor": backend_utils.convert_tf_tensor(
np.array(translation)
),
"input_shape": image_shape,
}
output = layer.transform_bounding_boxes(
input_data["bounding_boxes"],
transformation=transformation,
training=True,
)

self.assertAllClose(output["boxes"], expected_boxes)

@parameterized.named_parameters(
(
"with_x_shift",
[[1.0, 0.0]],
[[[0.0, 1.0, 3.2, 3.0], [1.2, 4.0, 4.8, 6.0]]],
),
(
"with_y_shift",
[[0.0, 1.0]],
[[[2.0, 0.0, 4.0, 0.5], [6.0, 0.0, 8.0, 0.0]]],
),
(
"with_xy_shift",
[[1.0, 1.0]],
[[[0.0, 0.0, 3.2, 3.5], [1.2, 0.0, 4.8, 4.5]]],
),
)
def test_random_shear_tf_data_bounding_boxes(
self, translation, expected_boxes
):
data_format = backend.config.image_data_format()
if backend.config.image_data_format() == "channels_last":
image_shape = (1, 10, 8, 3)
else:
image_shape = (1, 3, 10, 8)
input_image = np.random.random(image_shape)
bounding_boxes = {
"boxes": np.array(
[
[
[2, 1, 4, 3],
[6, 4, 8, 6],
]
]
),
"labels": np.array([[1, 2]]),
}

input_data = {"images": input_image, "bounding_boxes": bounding_boxes}

ds = tf_data.Dataset.from_tensor_slices(input_data)
layer = layers.RandomShear(
x_factor=0.5,
y_factor=0.5,
data_format=data_format,
seed=42,
bounding_box_format="xyxy",
)

transformation = {
"shear_factor": backend_utils.convert_tf_tensor(
np.array(translation)
),
"input_shape": image_shape,
}

ds = ds.map(
lambda x: layer.transform_bounding_boxes(
x["bounding_boxes"],
transformation=transformation,
training=True,
)
)

output = next(iter(ds))
expected_boxes = np.array(expected_boxes)
self.assertAllClose(output["boxes"], expected_boxes)

0 comments on commit 5b29974

Please sign in to comment.