@@ -62,7 +62,7 @@ def test_min_n_shape(self):
6262 tensor_spec .shape [:] = minimum .shape
6363 bounds = tensor_spec_utils .bounds (tensor_spec )
6464 np .testing .assert_array_equal (minimum , bounds .min )
65- np . testing . assert_array_equal ( np . full ( minimum . shape , 2 ** 32 - 1 ) , bounds .max )
65+ self . assertEqual ( 2 ** 32 - 1 , bounds .max )
6666
6767 def test_max_n_shape (self ):
6868 maximum = np .array ([[1 , 2 ], [3 , 4 ]])
@@ -72,7 +72,7 @@ def test_max_n_shape(self):
7272 tensor_spec .max .uint32s .array [:] = maximum .flatten ().data .tolist ()
7373 tensor_spec .shape [:] = maximum .shape
7474 bounds = tensor_spec_utils .bounds (tensor_spec )
75- np . testing . assert_array_equal ( np . full ( maximum . shape , 0 ) , bounds .min )
75+ self . assertEqual ( 0 , bounds .min )
7676 np .testing .assert_array_equal (maximum , bounds .max )
7777
7878 def test_invalid_min_shape (self ):
@@ -134,34 +134,33 @@ def test_invalid_min_var_shape(self):
134134 'can only have scalar ranges.' ):
135135 tensor_spec_utils .bounds (tensor_spec )
136136
137- def test_min_broadcast (self ):
137+ def test_min_scalar_doesnt_broadcast (self ):
138138 tensor_spec = dm_env_rpc_pb2 .TensorSpec ()
139139 tensor_spec .dtype = dm_env_rpc_pb2 .DataType .UINT32
140140 tensor_spec .min .uint32s .array [:] = [1 ]
141141 tensor_spec .shape [:] = (2 , 2 )
142142 bounds = tensor_spec_utils .bounds (tensor_spec )
143- np .testing .assert_array_equal (np .full (tensor_spec .shape , 1 ), bounds .min )
144- np .testing .assert_array_equal (
145- np .full (tensor_spec .shape , 2 ** 32 - 1 ), bounds .max )
143+ self .assertEqual (1 , bounds .min )
144+ self .assertEqual (2 ** 32 - 1 , bounds .max )
146145
147- def test_max_broadcast (self ):
146+ def test_max_scalar_doesnt_broadcast (self ):
148147 tensor_spec = dm_env_rpc_pb2 .TensorSpec ()
149148 tensor_spec .dtype = dm_env_rpc_pb2 .DataType .UINT32
150149 tensor_spec .max .uint32s .array [:] = [1 ]
151150 tensor_spec .shape [:] = (2 , 2 )
152151 bounds = tensor_spec_utils .bounds (tensor_spec )
153- np . testing . assert_array_equal ( np . full ( tensor_spec . shape , 0 ) , bounds .min )
154- np . testing . assert_array_equal ( np . full ( tensor_spec . shape , 1 ) , bounds .max )
152+ self . assertEqual ( 0 , bounds .min )
153+ self . assertEqual ( 1 , bounds .max )
155154
156- def test_min_max_broadcast (self ):
155+ def test_min_max_scalars_dont_broadcast (self ):
157156 tensor_spec = dm_env_rpc_pb2 .TensorSpec ()
158157 tensor_spec .dtype = dm_env_rpc_pb2 .DataType .UINT32
159158 tensor_spec .min .uint32s .array [:] = [1 ]
160159 tensor_spec .max .uint32s .array [:] = [2 ]
161160 tensor_spec .shape [:] = (4 ,)
162161 bounds = tensor_spec_utils .bounds (tensor_spec )
163- np . testing . assert_array_equal ( np . full ( tensor_spec . shape , 1 ) , bounds .min )
164- np . testing . assert_array_equal ( np . full ( tensor_spec . shape , 2 ) , bounds .max )
162+ self . assertEqual ( 1 , bounds .min )
163+ self . assertEqual ( 2 , bounds .max )
165164
166165 def test_min_mismatches_type_raises_error (self ):
167166 tensor_spec = dm_env_rpc_pb2 .TensorSpec ()
0 commit comments