@@ -150,3 +150,167 @@ def test_cast():
150150 yc64 = x .astype ("complex64" )
151151 with pytest .raises (TypeError , match = "Casting from complex to real is ambiguous" ):
152152 yc64 .astype ("float64" )
153+
154+
155+ def test_dot ():
156+ """Test basic dot product operations."""
157+ # Test matrix-vector dot product (with multiple-letter dim names)
158+ x = xtensor ("x" , dims = ("aa" , "bb" ), shape = (2 , 3 ))
159+ y = xtensor ("y" , dims = ("bb" ,), shape = (3 ,))
160+ z = x .dot (y )
161+ fn = xr_function ([x , y ], z )
162+
163+ x_test = DataArray (np .ones ((2 , 3 )), dims = ("aa" , "bb" ))
164+ y_test = DataArray (np .ones (3 ), dims = ("bb" ,))
165+ z_test = fn (x_test , y_test )
166+ expected = x_test .dot (y_test )
167+ xr_assert_allclose (z_test , expected )
168+
169+ # Test matrix-vector dot product with ellipsis
170+ z = x .dot (y , dim = ...)
171+ fn = xr_function ([x , y ], z )
172+ z_test = fn (x_test , y_test )
173+ expected = x_test .dot (y_test , dim = ...)
174+ xr_assert_allclose (z_test , expected )
175+
176+ # Test matrix-matrix dot product
177+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
178+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (3 , 4 ))
179+ z = x .dot (y )
180+ fn = xr_function ([x , y ], z )
181+
182+ x_test = DataArray (np .add .outer (np .arange (2.0 ), np .arange (3.0 )), dims = ("a" , "b" ))
183+ y_test = DataArray (np .add .outer (np .arange (3.0 ), np .arange (4.0 )), dims = ("b" , "c" ))
184+ z_test = fn (x_test , y_test )
185+ expected = x_test .dot (y_test )
186+ xr_assert_allclose (z_test , expected )
187+
188+ # Test matrix-matrix dot product with string dim
189+ z = x .dot (y , dim = "b" )
190+ fn = xr_function ([x , y ], z )
191+ z_test = fn (x_test , y_test )
192+ expected = x_test .dot (y_test , dim = "b" )
193+ xr_assert_allclose (z_test , expected )
194+
195+ # Test matrix-matrix dot product with list of dims
196+ z = x .dot (y , dim = ["b" ])
197+ fn = xr_function ([x , y ], z )
198+ z_test = fn (x_test , y_test )
199+ expected = x_test .dot (y_test , dim = ["b" ])
200+ xr_assert_allclose (z_test , expected )
201+
202+ # Test matrix-matrix dot product with ellipsis
203+ z = x .dot (y , dim = ...)
204+ fn = xr_function ([x , y ], z )
205+ z_test = fn (x_test , y_test )
206+ expected = x_test .dot (y_test , dim = ...)
207+ xr_assert_allclose (z_test , expected )
208+
209+ # Test a case where there are two dimensions to sum over
210+ x = xtensor ("x" , dims = ("a" , "b" , "c" ), shape = (2 , 3 , 4 ))
211+ y = xtensor ("y" , dims = ("b" , "c" , "d" ), shape = (3 , 4 , 5 ))
212+ z = x .dot (y )
213+ fn = xr_function ([x , y ], z )
214+
215+ x_test = DataArray (np .arange (24.0 ).reshape (2 , 3 , 4 ), dims = ("a" , "b" , "c" ))
216+ y_test = DataArray (np .arange (60.0 ).reshape (3 , 4 , 5 ), dims = ("b" , "c" , "d" ))
217+ z_test = fn (x_test , y_test )
218+ expected = x_test .dot (y_test )
219+ xr_assert_allclose (z_test , expected )
220+
221+ # Same but with explicit dimensions
222+ z = x .dot (y , dim = ["b" , "c" ])
223+ fn = xr_function ([x , y ], z )
224+ z_test = fn (x_test , y_test )
225+ expected = x_test .dot (y_test , dim = ["b" , "c" ])
226+ xr_assert_allclose (z_test , expected )
227+
228+ # Same but with ellipses
229+ z = x .dot (y , dim = ...)
230+ fn = xr_function ([x , y ], z )
231+ z_test = fn (x_test , y_test )
232+ expected = x_test .dot (y_test , dim = ...)
233+ xr_assert_allclose (z_test , expected )
234+
235+ # Dot product with sum
236+ x_test = DataArray (np .arange (24.0 ).reshape (2 , 3 , 4 ), dims = ("a" , "b" , "c" ))
237+ y_test = DataArray (np .arange (60.0 ).reshape (3 , 4 , 5 ), dims = ("b" , "c" , "d" ))
238+ expected = x_test .dot (y_test , dim = ("a" , "b" , "c" ))
239+
240+ x = xtensor ("x" , dims = ("a" , "b" , "c" ), shape = (2 , 3 , 4 ))
241+ y = xtensor ("y" , dims = ("b" , "c" , "d" ), shape = (3 , 4 , 5 ))
242+ z = x .dot (y , dim = ("a" , "b" , "c" ))
243+ fn = xr_function ([x , y ], z )
244+ z_test = fn (x_test , y_test )
245+ xr_assert_allclose (z_test , expected )
246+
247+ # Dot product with sum in the middle
248+ x_test = DataArray (np .arange (120.0 ).reshape (2 , 3 , 4 , 5 ), dims = ("a" , "b" , "c" , "d" ))
249+ y_test = DataArray (np .arange (360.0 ).reshape (3 , 4 , 5 , 6 ), dims = ("b" , "c" , "d" , "e" ))
250+ expected = x_test .dot (y_test , dim = ("b" , "d" ))
251+ x = xtensor ("x" , dims = ("a" , "b" , "c" , "d" ), shape = (2 , 3 , 4 , 5 ))
252+ y = xtensor ("y" , dims = ("b" , "c" , "d" , "e" ), shape = (3 , 4 , 5 , 6 ))
253+ z = x .dot (y , dim = ("b" , "d" ))
254+ fn = xr_function ([x , y ], z )
255+ z_test = fn (x_test , y_test )
256+ xr_assert_allclose (z_test , expected )
257+
258+ # Same but with first two dims
259+ expected = x_test .dot (y_test , dim = ["a" , "b" ])
260+ z = x .dot (y , dim = ["a" , "b" ])
261+ fn = xr_function ([x , y ], z )
262+ z_test = fn (x_test , y_test )
263+ xr_assert_allclose (z_test , expected )
264+
265+ # Same but with last two
266+ expected = x_test .dot (y_test , dim = ["d" , "e" ])
267+ z = x .dot (y , dim = ["d" , "e" ])
268+ fn = xr_function ([x , y ], z )
269+ z_test = fn (x_test , y_test )
270+ xr_assert_allclose (z_test , expected )
271+
272+ # Same but with every other dim
273+ expected = x_test .dot (y_test , dim = ["a" , "c" , "e" ])
274+ z = x .dot (y , dim = ["a" , "c" , "e" ])
275+ fn = xr_function ([x , y ], z )
276+ z_test = fn (x_test , y_test )
277+ xr_assert_allclose (z_test , expected )
278+
279+ # Test symbolic shapes
280+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (None , 3 )) # First dimension is symbolic
281+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (3 , None )) # Second dimension is symbolic
282+ z = x .dot (y )
283+ fn = xr_function ([x , y ], z )
284+ x_test = DataArray (np .ones ((2 , 3 )), dims = ("a" , "b" ))
285+ y_test = DataArray (np .ones ((3 , 4 )), dims = ("b" , "c" ))
286+ z_test = fn (x_test , y_test )
287+ expected = x_test .dot (y_test )
288+ xr_assert_allclose (z_test , expected )
289+
290+
291+ def test_dot_errors ():
292+ # No matching dimensions
293+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
294+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (3 , 4 ))
295+ with pytest .raises (ValueError , match = "Dimension e not found in either input" ):
296+ x .dot (y , dim = "e" )
297+
298+ # Concrete dimension size mismatches
299+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
300+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (4 , 5 ))
301+ with pytest .raises (
302+ ValueError ,
303+ match = "Size of dim 'b' does not match" ,
304+ ):
305+ x .dot (y )
306+
307+ # Symbolic dimension size mismatches
308+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , None ))
309+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (None , 5 ))
310+ z = x .dot (y )
311+ fn = xr_function ([x , y ], z )
312+ x_test = DataArray (np .ones ((2 , 3 )), dims = ("a" , "b" ))
313+ y_test = DataArray (np .ones ((4 , 5 )), dims = ("b" , "c" ))
314+ # Doesn't fail until the rewrite
315+ with pytest .raises (ValueError , match = "not aligned" ):
316+ fn (x_test , y_test )
0 commit comments