Skip to content

Commit 144d871

Browse files
committed
Change tensor_spec_utils.bounds to not broadcast scalar bounds.
PiperOrigin-RevId: 371299720 Change-Id: I166d2a8d42e0544eec0a4e425181d391e3fe7410
1 parent 81c55e4 commit 144d871

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

dm_env_rpc/v1/tensor_spec_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,13 @@ def _get_value(min_max_value, shape, default):
6565
value = which and getattr(min_max_value, which)
6666

6767
if value is None:
68-
min_max = np.broadcast_to(default, shape) if shape else default
68+
min_max = default
6969
elif which in _SCALAR_VALUE_TYPES:
70-
min_max = np.broadcast_to(value, shape) if shape else value
70+
min_max = value
7171
else:
7272
unpacked = tensor_utils.unpack_proto(min_max_value)
73-
min_max = tensor_utils.reshape_array(unpacked, shape)
73+
min_max = tensor_utils.reshape_array(
74+
unpacked, shape) if len(unpacked) > 1 else unpacked[0]
7475

7576
if (shape is not None
7677
and np.any(np.array(shape) < 0)

dm_env_rpc/v1/tensor_spec_utils_test.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_min_n_shape(self):
6262
tensor_spec.shape[:] = minimum.shape
6363
bounds = tensor_spec_utils.bounds(tensor_spec)
6464
np.testing.assert_array_equal(minimum, bounds.min)
65-
np.testing.assert_array_equal(np.full(minimum.shape, 2**32 - 1), bounds.max)
65+
self.assertEqual(2**32 - 1, bounds.max)
6666

6767
def test_max_n_shape(self):
6868
maximum = np.array([[1, 2], [3, 4]])
@@ -72,7 +72,7 @@ def test_max_n_shape(self):
7272
tensor_spec.max.uint32s.array[:] = maximum.flatten().data.tolist()
7373
tensor_spec.shape[:] = maximum.shape
7474
bounds = tensor_spec_utils.bounds(tensor_spec)
75-
np.testing.assert_array_equal(np.full(maximum.shape, 0), bounds.min)
75+
self.assertEqual(0, bounds.min)
7676
np.testing.assert_array_equal(maximum, bounds.max)
7777

7878
def test_invalid_min_shape(self):
@@ -134,34 +134,33 @@ def test_invalid_min_var_shape(self):
134134
'can only have scalar ranges.'):
135135
tensor_spec_utils.bounds(tensor_spec)
136136

137-
def test_min_broadcast(self):
137+
def test_min_scalar_doesnt_broadcast(self):
138138
tensor_spec = dm_env_rpc_pb2.TensorSpec()
139139
tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
140140
tensor_spec.min.uint32s.array[:] = [1]
141141
tensor_spec.shape[:] = (2, 2)
142142
bounds = tensor_spec_utils.bounds(tensor_spec)
143-
np.testing.assert_array_equal(np.full(tensor_spec.shape, 1), bounds.min)
144-
np.testing.assert_array_equal(
145-
np.full(tensor_spec.shape, 2**32 - 1), bounds.max)
143+
self.assertEqual(1, bounds.min)
144+
self.assertEqual(2**32 - 1, bounds.max)
146145

147-
def test_max_broadcast(self):
146+
def test_max_scalar_doesnt_broadcast(self):
148147
tensor_spec = dm_env_rpc_pb2.TensorSpec()
149148
tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
150149
tensor_spec.max.uint32s.array[:] = [1]
151150
tensor_spec.shape[:] = (2, 2)
152151
bounds = tensor_spec_utils.bounds(tensor_spec)
153-
np.testing.assert_array_equal(np.full(tensor_spec.shape, 0), bounds.min)
154-
np.testing.assert_array_equal(np.full(tensor_spec.shape, 1), bounds.max)
152+
self.assertEqual(0, bounds.min)
153+
self.assertEqual(1, bounds.max)
155154

156-
def test_min_max_broadcast(self):
155+
def test_min_max_scalars_dont_broadcast(self):
157156
tensor_spec = dm_env_rpc_pb2.TensorSpec()
158157
tensor_spec.dtype = dm_env_rpc_pb2.DataType.UINT32
159158
tensor_spec.min.uint32s.array[:] = [1]
160159
tensor_spec.max.uint32s.array[:] = [2]
161160
tensor_spec.shape[:] = (4,)
162161
bounds = tensor_spec_utils.bounds(tensor_spec)
163-
np.testing.assert_array_equal(np.full(tensor_spec.shape, 1), bounds.min)
164-
np.testing.assert_array_equal(np.full(tensor_spec.shape, 2), bounds.max)
162+
self.assertEqual(1, bounds.min)
163+
self.assertEqual(2, bounds.max)
165164

166165
def test_min_mismatches_type_raises_error(self):
167166
tensor_spec = dm_env_rpc_pb2.TensorSpec()

0 commit comments

Comments
 (0)