@@ -151,3 +151,167 @@ def test_cast():
151
151
yc64 = x .astype ("complex64" )
152
152
with pytest .raises (TypeError , match = "Casting from complex to real is ambiguous" ):
153
153
yc64 .astype ("float64" )
154
+
155
+
156
+ def test_dot ():
157
+ """Test basic dot product operations."""
158
+ # Test matrix-vector dot product (with multiple-letter dim names)
159
+ x = xtensor ("x" , dims = ("aa" , "bb" ), shape = (2 , 3 ))
160
+ y = xtensor ("y" , dims = ("bb" ,), shape = (3 ,))
161
+ z = x .dot (y )
162
+ fn = xr_function ([x , y ], z )
163
+
164
+ x_test = DataArray (np .ones ((2 , 3 )), dims = ("aa" , "bb" ))
165
+ y_test = DataArray (np .ones (3 ), dims = ("bb" ,))
166
+ z_test = fn (x_test , y_test )
167
+ expected = x_test .dot (y_test )
168
+ xr_assert_allclose (z_test , expected )
169
+
170
+ # Test matrix-vector dot product with ellipsis
171
+ z = x .dot (y , dim = ...)
172
+ fn = xr_function ([x , y ], z )
173
+ z_test = fn (x_test , y_test )
174
+ expected = x_test .dot (y_test , dim = ...)
175
+ xr_assert_allclose (z_test , expected )
176
+
177
+ # Test matrix-matrix dot product
178
+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
179
+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (3 , 4 ))
180
+ z = x .dot (y )
181
+ fn = xr_function ([x , y ], z )
182
+
183
+ x_test = DataArray (np .add .outer (np .arange (2.0 ), np .arange (3.0 )), dims = ("a" , "b" ))
184
+ y_test = DataArray (np .add .outer (np .arange (3.0 ), np .arange (4.0 )), dims = ("b" , "c" ))
185
+ z_test = fn (x_test , y_test )
186
+ expected = x_test .dot (y_test )
187
+ xr_assert_allclose (z_test , expected )
188
+
189
+ # Test matrix-matrix dot product with string dim
190
+ z = x .dot (y , dim = "b" )
191
+ fn = xr_function ([x , y ], z )
192
+ z_test = fn (x_test , y_test )
193
+ expected = x_test .dot (y_test , dim = "b" )
194
+ xr_assert_allclose (z_test , expected )
195
+
196
+ # Test matrix-matrix dot product with list of dims
197
+ z = x .dot (y , dim = ["b" ])
198
+ fn = xr_function ([x , y ], z )
199
+ z_test = fn (x_test , y_test )
200
+ expected = x_test .dot (y_test , dim = ["b" ])
201
+ xr_assert_allclose (z_test , expected )
202
+
203
+ # Test matrix-matrix dot product with ellipsis
204
+ z = x .dot (y , dim = ...)
205
+ fn = xr_function ([x , y ], z )
206
+ z_test = fn (x_test , y_test )
207
+ expected = x_test .dot (y_test , dim = ...)
208
+ xr_assert_allclose (z_test , expected )
209
+
210
+ # Test a case where there are two dimensions to sum over
211
+ x = xtensor ("x" , dims = ("a" , "b" , "c" ), shape = (2 , 3 , 4 ))
212
+ y = xtensor ("y" , dims = ("b" , "c" , "d" ), shape = (3 , 4 , 5 ))
213
+ z = x .dot (y )
214
+ fn = xr_function ([x , y ], z )
215
+
216
+ x_test = DataArray (np .arange (24.0 ).reshape (2 , 3 , 4 ), dims = ("a" , "b" , "c" ))
217
+ y_test = DataArray (np .arange (60.0 ).reshape (3 , 4 , 5 ), dims = ("b" , "c" , "d" ))
218
+ z_test = fn (x_test , y_test )
219
+ expected = x_test .dot (y_test )
220
+ xr_assert_allclose (z_test , expected )
221
+
222
+ # Same but with explicit dimensions
223
+ z = x .dot (y , dim = ["b" , "c" ])
224
+ fn = xr_function ([x , y ], z )
225
+ z_test = fn (x_test , y_test )
226
+ expected = x_test .dot (y_test , dim = ["b" , "c" ])
227
+ xr_assert_allclose (z_test , expected )
228
+
229
+ # Same but with ellipses
230
+ z = x .dot (y , dim = ...)
231
+ fn = xr_function ([x , y ], z )
232
+ z_test = fn (x_test , y_test )
233
+ expected = x_test .dot (y_test , dim = ...)
234
+ xr_assert_allclose (z_test , expected )
235
+
236
+ # Dot product with sum
237
+ x_test = DataArray (np .arange (24.0 ).reshape (2 , 3 , 4 ), dims = ("a" , "b" , "c" ))
238
+ y_test = DataArray (np .arange (60.0 ).reshape (3 , 4 , 5 ), dims = ("b" , "c" , "d" ))
239
+ expected = x_test .dot (y_test , dim = ("a" , "b" , "c" ))
240
+
241
+ x = xtensor ("x" , dims = ("a" , "b" , "c" ), shape = (2 , 3 , 4 ))
242
+ y = xtensor ("y" , dims = ("b" , "c" , "d" ), shape = (3 , 4 , 5 ))
243
+ z = x .dot (y , dim = ("a" , "b" , "c" ))
244
+ fn = xr_function ([x , y ], z )
245
+ z_test = fn (x_test , y_test )
246
+ xr_assert_allclose (z_test , expected )
247
+
248
+ # Dot product with sum in the middle
249
+ x_test = DataArray (np .arange (120.0 ).reshape (2 , 3 , 4 , 5 ), dims = ("a" , "b" , "c" , "d" ))
250
+ y_test = DataArray (np .arange (360.0 ).reshape (3 , 4 , 5 , 6 ), dims = ("b" , "c" , "d" , "e" ))
251
+ expected = x_test .dot (y_test , dim = ("b" , "d" ))
252
+ x = xtensor ("x" , dims = ("a" , "b" , "c" , "d" ), shape = (2 , 3 , 4 , 5 ))
253
+ y = xtensor ("y" , dims = ("b" , "c" , "d" , "e" ), shape = (3 , 4 , 5 , 6 ))
254
+ z = x .dot (y , dim = ("b" , "d" ))
255
+ fn = xr_function ([x , y ], z )
256
+ z_test = fn (x_test , y_test )
257
+ xr_assert_allclose (z_test , expected )
258
+
259
+ # Same but with first two dims
260
+ expected = x_test .dot (y_test , dim = ["a" , "b" ])
261
+ z = x .dot (y , dim = ["a" , "b" ])
262
+ fn = xr_function ([x , y ], z )
263
+ z_test = fn (x_test , y_test )
264
+ xr_assert_allclose (z_test , expected )
265
+
266
+ # Same but with last two
267
+ expected = x_test .dot (y_test , dim = ["d" , "e" ])
268
+ z = x .dot (y , dim = ["d" , "e" ])
269
+ fn = xr_function ([x , y ], z )
270
+ z_test = fn (x_test , y_test )
271
+ xr_assert_allclose (z_test , expected )
272
+
273
+ # Same but with every other dim
274
+ expected = x_test .dot (y_test , dim = ["a" , "c" , "e" ])
275
+ z = x .dot (y , dim = ["a" , "c" , "e" ])
276
+ fn = xr_function ([x , y ], z )
277
+ z_test = fn (x_test , y_test )
278
+ xr_assert_allclose (z_test , expected )
279
+
280
+ # Test symbolic shapes
281
+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (None , 3 )) # First dimension is symbolic
282
+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (3 , None )) # Second dimension is symbolic
283
+ z = x .dot (y )
284
+ fn = xr_function ([x , y ], z )
285
+ x_test = DataArray (np .ones ((2 , 3 )), dims = ("a" , "b" ))
286
+ y_test = DataArray (np .ones ((3 , 4 )), dims = ("b" , "c" ))
287
+ z_test = fn (x_test , y_test )
288
+ expected = x_test .dot (y_test )
289
+ xr_assert_allclose (z_test , expected )
290
+
291
+
292
+ def test_dot_errors ():
293
+ # No matching dimensions
294
+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
295
+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (3 , 4 ))
296
+ with pytest .raises (ValueError , match = "Dimension e not found in either input" ):
297
+ x .dot (y , dim = "e" )
298
+
299
+ # Concrete dimension size mismatches
300
+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
301
+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (4 , 5 ))
302
+ with pytest .raises (
303
+ ValueError ,
304
+ match = "Size of dim 'b' does not match" ,
305
+ ):
306
+ x .dot (y )
307
+
308
+ # Symbolic dimension size mismatches
309
+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , None ))
310
+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (None , 5 ))
311
+ z = x .dot (y )
312
+ fn = xr_function ([x , y ], z )
313
+ x_test = DataArray (np .ones ((2 , 3 )), dims = ("a" , "b" ))
314
+ y_test = DataArray (np .ones ((4 , 5 )), dims = ("b" , "c" ))
315
+ # Doesn't fail until the rewrite
316
+ with pytest .raises (ValueError , match = "not aligned" ):
317
+ fn (x_test , y_test )
0 commit comments