@@ -150,3 +150,167 @@ def test_cast():
150
150
yc64 = x .astype ("complex64" )
151
151
with pytest .raises (TypeError , match = "Casting from complex to real is ambiguous" ):
152
152
yc64 .astype ("float64" )
153
+
154
+
155
+ def test_dot ():
156
+ """Test basic dot product operations."""
157
+ # Test matrix-vector dot product (with multiple-letter dim names)
158
+ x = xtensor ("x" , dims = ("aa" , "bb" ), shape = (2 , 3 ))
159
+ y = xtensor ("y" , dims = ("bb" ,), shape = (3 ,))
160
+ z = x .dot (y )
161
+ fn = xr_function ([x , y ], z )
162
+
163
+ x_test = DataArray (np .ones ((2 , 3 )), dims = ("aa" , "bb" ))
164
+ y_test = DataArray (np .ones (3 ), dims = ("bb" ,))
165
+ z_test = fn (x_test , y_test )
166
+ expected = x_test .dot (y_test )
167
+ xr_assert_allclose (z_test , expected )
168
+
169
+ # Test matrix-vector dot product with ellipsis
170
+ z = x .dot (y , dim = ...)
171
+ fn = xr_function ([x , y ], z )
172
+ z_test = fn (x_test , y_test )
173
+ expected = x_test .dot (y_test , dim = ...)
174
+ xr_assert_allclose (z_test , expected )
175
+
176
+ # Test matrix-matrix dot product
177
+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
178
+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (3 , 4 ))
179
+ z = x .dot (y )
180
+ fn = xr_function ([x , y ], z )
181
+
182
+ x_test = DataArray (np .add .outer (np .arange (2.0 ), np .arange (3.0 )), dims = ("a" , "b" ))
183
+ y_test = DataArray (np .add .outer (np .arange (3.0 ), np .arange (4.0 )), dims = ("b" , "c" ))
184
+ z_test = fn (x_test , y_test )
185
+ expected = x_test .dot (y_test )
186
+ xr_assert_allclose (z_test , expected )
187
+
188
+ # Test matrix-matrix dot product with string dim
189
+ z = x .dot (y , dim = "b" )
190
+ fn = xr_function ([x , y ], z )
191
+ z_test = fn (x_test , y_test )
192
+ expected = x_test .dot (y_test , dim = "b" )
193
+ xr_assert_allclose (z_test , expected )
194
+
195
+ # Test matrix-matrix dot product with list of dims
196
+ z = x .dot (y , dim = ["b" ])
197
+ fn = xr_function ([x , y ], z )
198
+ z_test = fn (x_test , y_test )
199
+ expected = x_test .dot (y_test , dim = ["b" ])
200
+ xr_assert_allclose (z_test , expected )
201
+
202
+ # Test matrix-matrix dot product with ellipsis
203
+ z = x .dot (y , dim = ...)
204
+ fn = xr_function ([x , y ], z )
205
+ z_test = fn (x_test , y_test )
206
+ expected = x_test .dot (y_test , dim = ...)
207
+ xr_assert_allclose (z_test , expected )
208
+
209
+ # Test a case where there are two dimensions to sum over
210
+ x = xtensor ("x" , dims = ("a" , "b" , "c" ), shape = (2 , 3 , 4 ))
211
+ y = xtensor ("y" , dims = ("b" , "c" , "d" ), shape = (3 , 4 , 5 ))
212
+ z = x .dot (y )
213
+ fn = xr_function ([x , y ], z )
214
+
215
+ x_test = DataArray (np .arange (24.0 ).reshape (2 , 3 , 4 ), dims = ("a" , "b" , "c" ))
216
+ y_test = DataArray (np .arange (60.0 ).reshape (3 , 4 , 5 ), dims = ("b" , "c" , "d" ))
217
+ z_test = fn (x_test , y_test )
218
+ expected = x_test .dot (y_test )
219
+ xr_assert_allclose (z_test , expected )
220
+
221
+ # Same but with explicit dimensions
222
+ z = x .dot (y , dim = ["b" , "c" ])
223
+ fn = xr_function ([x , y ], z )
224
+ z_test = fn (x_test , y_test )
225
+ expected = x_test .dot (y_test , dim = ["b" , "c" ])
226
+ xr_assert_allclose (z_test , expected )
227
+
228
+ # Same but with ellipses
229
+ z = x .dot (y , dim = ...)
230
+ fn = xr_function ([x , y ], z )
231
+ z_test = fn (x_test , y_test )
232
+ expected = x_test .dot (y_test , dim = ...)
233
+ xr_assert_allclose (z_test , expected )
234
+
235
+ # Dot product with sum
236
+ x_test = DataArray (np .arange (24.0 ).reshape (2 , 3 , 4 ), dims = ("a" , "b" , "c" ))
237
+ y_test = DataArray (np .arange (60.0 ).reshape (3 , 4 , 5 ), dims = ("b" , "c" , "d" ))
238
+ expected = x_test .dot (y_test , dim = ("a" , "b" , "c" ))
239
+
240
+ x = xtensor ("x" , dims = ("a" , "b" , "c" ), shape = (2 , 3 , 4 ))
241
+ y = xtensor ("y" , dims = ("b" , "c" , "d" ), shape = (3 , 4 , 5 ))
242
+ z = x .dot (y , dim = ("a" , "b" , "c" ))
243
+ fn = xr_function ([x , y ], z )
244
+ z_test = fn (x_test , y_test )
245
+ xr_assert_allclose (z_test , expected )
246
+
247
+ # Dot product with sum in the middle
248
+ x_test = DataArray (np .arange (120.0 ).reshape (2 , 3 , 4 , 5 ), dims = ("a" , "b" , "c" , "d" ))
249
+ y_test = DataArray (np .arange (360.0 ).reshape (3 , 4 , 5 , 6 ), dims = ("b" , "c" , "d" , "e" ))
250
+ expected = x_test .dot (y_test , dim = ("b" , "d" ))
251
+ x = xtensor ("x" , dims = ("a" , "b" , "c" , "d" ), shape = (2 , 3 , 4 , 5 ))
252
+ y = xtensor ("y" , dims = ("b" , "c" , "d" , "e" ), shape = (3 , 4 , 5 , 6 ))
253
+ z = x .dot (y , dim = ("b" , "d" ))
254
+ fn = xr_function ([x , y ], z )
255
+ z_test = fn (x_test , y_test )
256
+ xr_assert_allclose (z_test , expected )
257
+
258
+ # Same but with first two dims
259
+ expected = x_test .dot (y_test , dim = ["a" , "b" ])
260
+ z = x .dot (y , dim = ["a" , "b" ])
261
+ fn = xr_function ([x , y ], z )
262
+ z_test = fn (x_test , y_test )
263
+ xr_assert_allclose (z_test , expected )
264
+
265
+ # Same but with last two
266
+ expected = x_test .dot (y_test , dim = ["d" , "e" ])
267
+ z = x .dot (y , dim = ["d" , "e" ])
268
+ fn = xr_function ([x , y ], z )
269
+ z_test = fn (x_test , y_test )
270
+ xr_assert_allclose (z_test , expected )
271
+
272
+ # Same but with every other dim
273
+ expected = x_test .dot (y_test , dim = ["a" , "c" , "e" ])
274
+ z = x .dot (y , dim = ["a" , "c" , "e" ])
275
+ fn = xr_function ([x , y ], z )
276
+ z_test = fn (x_test , y_test )
277
+ xr_assert_allclose (z_test , expected )
278
+
279
+ # Test symbolic shapes
280
+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (None , 3 )) # First dimension is symbolic
281
+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (3 , None )) # Second dimension is symbolic
282
+ z = x .dot (y )
283
+ fn = xr_function ([x , y ], z )
284
+ x_test = DataArray (np .ones ((2 , 3 )), dims = ("a" , "b" ))
285
+ y_test = DataArray (np .ones ((3 , 4 )), dims = ("b" , "c" ))
286
+ z_test = fn (x_test , y_test )
287
+ expected = x_test .dot (y_test )
288
+ xr_assert_allclose (z_test , expected )
289
+
290
+
291
+ def test_dot_errors ():
292
+ # No matching dimensions
293
+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
294
+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (3 , 4 ))
295
+ with pytest .raises (ValueError , match = "Dimension e not found in either input" ):
296
+ x .dot (y , dim = "e" )
297
+
298
+ # Concrete dimension size mismatches
299
+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
300
+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (4 , 5 ))
301
+ with pytest .raises (
302
+ ValueError ,
303
+ match = "Size of dim 'b' does not match" ,
304
+ ):
305
+ x .dot (y )
306
+
307
+ # Symbolic dimension size mismatches
308
+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , None ))
309
+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (None , 5 ))
310
+ z = x .dot (y )
311
+ fn = xr_function ([x , y ], z )
312
+ x_test = DataArray (np .ones ((2 , 3 )), dims = ("a" , "b" ))
313
+ y_test = DataArray (np .ones ((4 , 5 )), dims = ("b" , "c" ))
314
+ # Doesn't fail until the rewrite
315
+ with pytest .raises (ValueError , match = "not aligned" ):
316
+ fn (x_test , y_test )
0 commit comments