Skip to content

Commit 26f87e3

Browse files
committed
Fix broadcasting for empty sparse arrays related to recent changes (bug #67714)
* Fix broadcasting for empty sparse arrays in min/max binary oprations. * Fix broadcasting for plus, minus, product, and quotient operators in sparse arrays. The following files are changed: * CSparse.cc (min, max): Fix broadcasting for emptry matrices. Add tests. * dSparse.cc (min, max): Fix broadcasting for emptry matrices. Add tests. * MSparse.cc (plus_or_minus, product, quotient*) : Fix broadcasting for emptry matrices. * Sparse-op-defs.h (SPARSE_SMSM_BIN_OP_*): Fix broadcasting for emptry matrices.
1 parent 3a4d567 commit 26f87e3

File tree

4 files changed

+348
-44
lines changed

4 files changed

+348
-44
lines changed

liboctave/array/CSparse.cc

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8097,13 +8097,33 @@ min (const SparseComplexMatrix& a, const SparseComplexMatrix& b,
80978097
else
80988098
{
80998099
if (a_nr == 0 && (b_nr == 0 || b_nr == 1))
8100-
r.resize (a_nr, std::max (a_nc, b_nc));
8100+
{
8101+
if (a_nc == 1 || b_nc == 1 || a_nc == b_nc)
8102+
r.resize (a_nr, std::max (a_nc, b_nc));
8103+
else
8104+
octave::err_nonconformant ("min", a_nr, a_nc, b_nr, b_nc);
8105+
}
81018106
else if (a_nc == 0 && (b_nc == 0 || b_nc == 1))
8102-
r.resize (std::max (a_nr, b_nr), a_nc);
8107+
{
8108+
if (a_nr == 1 || b_nr == 1 || a_nr == b_nr)
8109+
r.resize (std::max (a_nr, b_nr), a_nc);
8110+
else
8111+
octave::err_nonconformant ("min", a_nr, a_nc, b_nr, b_nc);
8112+
}
81038113
else if (b_nr == 0 && (a_nr == 0 || a_nr == 1))
8104-
r.resize (b_nr, std::max (a_nc, b_nc));
8114+
{
8115+
if (b_nc == 1 || a_nc == 1 || b_nc == a_nc)
8116+
r.resize (b_nr, std::max (a_nc, b_nc));
8117+
else
8118+
octave::err_nonconformant ("min", a_nr, a_nc, b_nr, b_nc);
8119+
}
81058120
else if (b_nc == 0 && (a_nc == 0 || a_nc == 1))
8106-
r.resize (std::max (a_nr, b_nr), b_nc);
8121+
{
8122+
if (b_nr == 1 || a_nr == 1 || b_nr == a_nr)
8123+
r.resize (std::max (a_nr, b_nr), b_nc);
8124+
else
8125+
octave::err_nonconformant ("min", a_nr, a_nc, b_nr, b_nc);
8126+
}
81078127
else
81088128
octave::err_nonconformant ("min", a_nr, a_nc, b_nr, b_nc);
81098129
}
@@ -8271,13 +8291,33 @@ max (const SparseComplexMatrix& a, const SparseComplexMatrix& b,
82718291
else
82728292
{
82738293
if (a_nr == 0 && (b_nr == 0 || b_nr == 1))
8274-
r.resize (a_nr, std::max (a_nc, b_nc));
8294+
{
8295+
if (a_nc == 1 || b_nc == 1 || a_nc == b_nc)
8296+
r.resize (a_nr, std::max (a_nc, b_nc));
8297+
else
8298+
octave::err_nonconformant ("max", a_nr, a_nc, b_nr, b_nc);
8299+
}
82758300
else if (a_nc == 0 && (b_nc == 0 || b_nc == 1))
8276-
r.resize (std::max (a_nr, b_nr), a_nc);
8301+
{
8302+
if (a_nr == 1 || b_nr == 1 || a_nr == b_nr)
8303+
r.resize (std::max (a_nr, b_nr), a_nc);
8304+
else
8305+
octave::err_nonconformant ("max", a_nr, a_nc, b_nr, b_nc);
8306+
}
82778307
else if (b_nr == 0 && (a_nr == 0 || a_nr == 1))
8278-
r.resize (b_nr, std::max (a_nc, b_nc));
8308+
{
8309+
if (b_nc == 1 || a_nc == 1 || b_nc == a_nc)
8310+
r.resize (b_nr, std::max (a_nc, b_nc));
8311+
else
8312+
octave::err_nonconformant ("max", a_nr, a_nc, b_nr, b_nc);
8313+
}
82798314
else if (b_nc == 0 && (a_nc == 0 || a_nc == 1))
8280-
r.resize (std::max (a_nr, b_nr), b_nc);
8315+
{
8316+
if (b_nr == 1 || a_nr == 1 || b_nr == a_nr)
8317+
r.resize (std::max (a_nr, b_nr), b_nc);
8318+
else
8319+
octave::err_nonconformant ("max", a_nr, a_nc, b_nr, b_nc);
8320+
}
82818321
else
82828322
octave::err_nonconformant ("max", a_nr, a_nc, b_nr, b_nc);
82838323
}
@@ -8287,17 +8327,59 @@ max (const SparseComplexMatrix& a, const SparseComplexMatrix& b,
82878327

82888328
/*
82898329
8330+
## Testing broadcasting of empty matrices with min/max functions
8331+
%!assert (min (sparse (zeros (0,1)), sparse ([1, 2, 3, 4i])), sparse (zeros (0,4)))
8332+
%!error min (sparse (zeros (0,2)), sparse ([1, 2, 3, 4i]))
8333+
%!assert (max (sparse (zeros (0,1)), sparse ([1, 2, 3, 4i])), sparse (zeros (0,4)))
8334+
%!error max (sparse (zeros (0,2)), sparse ([1, 2, 3, 4i]))
8335+
%!assert (min (sparse (zeros (1,0)), sparse ([1; 2; 3; 4i])), sparse (zeros (4,0)))
8336+
%!error min (sparse (zeros (2,0)), sparse ([1; 2; 3; 4i]))
8337+
%!assert (max (sparse (zeros (1,0)), sparse ([1; 2; 3; 4i])), sparse (zeros (4,0)))
8338+
%!error max (sparse (zeros (2,0)), sparse ([1; 2; 3; 4i]))
8339+
82908340
## Testing broadcasting of empty matrices with math operators.
82918341
## This has been fixed in MSparse.cc and Sparse-op-defs.h
82928342
%!assert (sparse (zeros (0,1)) + sparse ([1, 2, 3, 4i]), sparse (zeros (0,4)))
8343+
%!error <operator \+: nonconformant arguments \(op1 is 0x2, op2 is 1x4\)> ...
8344+
%! sparse (zeros (0,2)) + sparse ([1, 2, 3, 4i])
82938345
%!assert (sparse (zeros (0,1)) - sparse ([1, 2, 3, 4i]), sparse (zeros (0,4)))
8346+
%!error <operator -: nonconformant arguments \(op1 is 0x2, op2 is 1x4\)> ...
8347+
%! sparse (zeros (0,2)) - sparse ([1, 2, 3, 4i])
82948348
%!assert (sparse (zeros (0,1)) * sparse ([1, 2, 3, 4i]), sparse (zeros (0,4)))
8349+
%!error <operator \*: nonconformant arguments \(op1 is 0x2, op2 is 1x4\)> ...
8350+
%! sparse (zeros (0,2)) * sparse ([1, 2, 3, 4i])
8351+
%!assert (sparse (zeros (0,1)) .* sparse ([1, 2, 3, 4i]), sparse (zeros (0,4)))
8352+
%!error <product: nonconformant arguments \(op1 is 0x2, op2 is 1x4\)> ...
8353+
%! sparse (zeros (0,2)) .* sparse ([1, 2, 3, 4i])
82958354
%!assert (sparse (zeros (0,1)) ./ sparse ([1, 2, 3, 4i]), sparse (zeros (0,4)))
8355+
%!error <quotient: nonconformant arguments \(op1 is 0x2, op2 is 1x4\)> ...
8356+
%! sparse (zeros (0,2)) ./ sparse ([1, 2, 3, 4i])
82968357
%!test
82978358
%! a = sparse (zeros (0,1));
82988359
%! assert (a += sparse ([1, 2, 3, 4i]), sparse (zeros (0,4)))
82998360
%! assert (a -= sparse ([1, 2, 3, 4i]), sparse (zeros (0,4)))
83008361
8362+
%!assert (sparse (zeros (1,0)) + sparse ([1; 2; 3; 4i]), sparse (zeros (4,0)))
8363+
%!error <operator \+: nonconformant arguments \(op1 is 2x0, op2 is 4x1\)> ...
8364+
%! sparse (zeros (2,0)) + sparse ([1; 2; 3; 4i])
8365+
%!assert (sparse (zeros (1,0)) - sparse ([1; 2; 3; 4i]), sparse (zeros (4,0)))
8366+
%!error <operator -: nonconformant arguments \(op1 is 2x0, op2 is 4x1\)> ...
8367+
%! sparse (zeros (2,0)) - sparse ([1; 2; 3; 4i])
8368+
%!error <operator \*: nonconformant arguments \(op1 is 1x0, op2 is 4x1\)>
8369+
%! sparse (zeros (1,0)) * sparse ([1; 2; 3; 4i])
8370+
%!error <operator \*: nonconformant arguments \(op1 is 2x0, op2 is 4x1\)> ...
8371+
%! sparse (zeros (2,0)) * sparse ([1; 2; 3; 4i])
8372+
%!assert (sparse (zeros (1,0)) .* sparse ([1; 2; 3; 4i]), sparse (zeros (4,0)))
8373+
%!error <product: nonconformant arguments \(op1 is 2x0, op2 is 4x1\)> ...
8374+
%! sparse (zeros (2,0)) .* sparse ([1; 2; 3; 4i])
8375+
%!assert (sparse (zeros (1,0)) ./ sparse ([1; 2; 3; 4i]), sparse (zeros (4,0)));
8376+
%!error <quotient: nonconformant arguments \(op1 is 2x0, op2 is 4x1\)> ...
8377+
%! sparse (zeros (2,0)) ./ sparse ([1; 2; 3; 4i])
8378+
%!test
8379+
%! a = sparse (zeros (1,0));
8380+
%! assert (a += sparse ([1; 2; 3; 4i]), sparse (zeros (4,0)))
8381+
%! assert (a -= sparse ([1; 2; 3; 4i]), sparse (zeros (4,0)))
8382+
83018383
*/
83028384

83038385
SPARSE_SMS_CMP_OPS (SparseComplexMatrix, Complex)

liboctave/array/MSparse.cc

Lines changed: 96 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,33 @@ plus_or_minus (MSparse<T>& a, const MSparse<T>& b, OP op, const char *op_name)
4040
octave_idx_type b_nc = b.cols ();
4141

4242
if (a_nr == 0 && (b_nr == 0 || b_nr == 1))
43-
r.resize (a_nr, std::max (a_nc, b_nc));
43+
{
44+
if (a_nc == 1 || b_nc == 1 || a_nc == b_nc)
45+
r.resize (a_nr, std::max (a_nc, b_nc));
46+
else
47+
octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
48+
}
4449
else if (a_nc == 0 && (b_nc == 0 || b_nc == 1))
45-
r.resize (std::max (a_nr, b_nr), a_nc);
50+
{
51+
if (a_nr == 1 || b_nr == 1 || a_nr == b_nr)
52+
r.resize (std::max (a_nr, b_nr), a_nc);
53+
else
54+
octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
55+
}
4656
else if (b_nr == 0 && (a_nr == 0 || a_nr == 1))
47-
r.resize (b_nr, std::max (a_nc, b_nc));
57+
{
58+
if (b_nc == 1 || a_nc == 1 || b_nc == a_nc)
59+
r.resize (b_nr, std::max (a_nc, b_nc));
60+
else
61+
octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
62+
}
4863
else if (b_nc == 0 && (a_nc == 0 || a_nc == 1))
49-
r.resize (std::max (a_nr, b_nr), b_nc);
64+
{
65+
if (b_nr == 1 || a_nr == 1 || b_nr == a_nr)
66+
r.resize (std::max (a_nr, b_nr), b_nc);
67+
else
68+
octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
69+
}
5070
else if (a_nr != b_nr || a_nc != b_nc)
5171
octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
5272
else
@@ -312,13 +332,33 @@ plus_or_minus (const MSparse<T>& a, const MSparse<T>& b, OP op,
312332
}
313333
}
314334
if (a_nr == 0 && (b_nr == 0 || b_nr == 1))
315-
r.resize (a_nr, std::max (a_nc, b_nc));
335+
{
336+
if (a_nc == 1 || b_nc == 1 || a_nc == b_nc)
337+
r.resize (a_nr, std::max (a_nc, b_nc));
338+
else
339+
octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
340+
}
316341
else if (a_nc == 0 && (b_nc == 0 || b_nc == 1))
317-
r.resize (std::max (a_nr, b_nr), a_nc);
342+
{
343+
if (a_nr == 1 || b_nr == 1 || a_nr == b_nr)
344+
r.resize (std::max (a_nr, b_nr), a_nc);
345+
else
346+
octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
347+
}
318348
else if (b_nr == 0 && (a_nr == 0 || a_nr == 1))
319-
r.resize (b_nr, std::max (a_nc, b_nc));
349+
{
350+
if (b_nc == 1 || a_nc == 1 || b_nc == a_nc)
351+
r.resize (b_nr, std::max (a_nc, b_nc));
352+
else
353+
octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
354+
}
320355
else if (b_nc == 0 && (a_nc == 0 || a_nc == 1))
321-
r.resize (std::max (a_nr, b_nr), b_nc);
356+
{
357+
if (b_nr == 1 || a_nr == 1 || b_nr == a_nr)
358+
r.resize (std::max (a_nr, b_nr), b_nc);
359+
else
360+
octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
361+
}
322362
else if (a_nr != b_nr || a_nc != b_nc)
323363
octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
324364
else
@@ -441,13 +481,33 @@ product (const MSparse<T>& a, const MSparse<T>& b)
441481
}
442482
}
443483
if (a_nr == 0 && (b_nr == 0 || b_nr == 1))
444-
r.resize (a_nr, std::max (a_nc, b_nc));
484+
{
485+
if (a_nc == 1 || b_nc == 1 || a_nc == b_nc)
486+
r.resize (a_nr, std::max (a_nc, b_nc));
487+
else
488+
octave::err_nonconformant ("product", a_nr, a_nc, b_nr, b_nc);
489+
}
445490
else if (a_nc == 0 && (b_nc == 0 || b_nc == 1))
446-
r.resize (std::max (a_nr, b_nr), a_nc);
491+
{
492+
if (a_nr == 1 || b_nr == 1 || a_nr == b_nr)
493+
r.resize (std::max (a_nr, b_nr), a_nc);
494+
else
495+
octave::err_nonconformant ("product", a_nr, a_nc, b_nr, b_nc);
496+
}
447497
else if (b_nr == 0 && (a_nr == 0 || a_nr == 1))
448-
r.resize (b_nr, std::max (a_nc, b_nc));
498+
{
499+
if (b_nc == 1 || a_nc == 1 || b_nc == a_nc)
500+
r.resize (b_nr, std::max (a_nc, b_nc));
501+
else
502+
octave::err_nonconformant ("product", a_nr, a_nc, b_nr, b_nc);
503+
}
449504
else if (b_nc == 0 && (a_nc == 0 || a_nc == 1))
450-
r.resize (std::max (a_nr, b_nr), b_nc);
505+
{
506+
if (b_nr == 1 || a_nr == 1 || b_nr == a_nr)
507+
r.resize (std::max (a_nr, b_nr), b_nc);
508+
else
509+
octave::err_nonconformant ("product", a_nr, a_nc, b_nr, b_nc);
510+
}
451511
else if (a_nr != b_nr || a_nc != b_nc)
452512
octave::err_nonconformant ("product", a_nr, a_nc, b_nr, b_nc);
453513
else
@@ -569,13 +629,33 @@ quotient (const MSparse<T>& a, const MSparse<T>& b)
569629
}
570630
}
571631
if (a_nr == 0 && (b_nr == 0 || b_nr == 1))
572-
r.resize (a_nr, std::max (a_nc, b_nc));
632+
{
633+
if (a_nc == 1 || b_nc == 1 || a_nc == b_nc)
634+
r.resize (a_nr, std::max (a_nc, b_nc));
635+
else
636+
octave::err_nonconformant ("quotient", a_nr, a_nc, b_nr, b_nc);
637+
}
573638
else if (a_nc == 0 && (b_nc == 0 || b_nc == 1))
574-
r.resize (std::max (a_nr, b_nr), a_nc);
639+
{
640+
if (a_nr == 1 || b_nr == 1 || a_nr == b_nr)
641+
r.resize (std::max (a_nr, b_nr), a_nc);
642+
else
643+
octave::err_nonconformant ("quotient", a_nr, a_nc, b_nr, b_nc);
644+
}
575645
else if (b_nr == 0 && (a_nr == 0 || a_nr == 1))
576-
r.resize (b_nr, std::max (a_nc, b_nc));
646+
{
647+
if (b_nc == 1 || a_nc == 1 || b_nc == a_nc)
648+
r.resize (b_nr, std::max (a_nc, b_nc));
649+
else
650+
octave::err_nonconformant ("quotient", a_nr, a_nc, b_nr, b_nc);
651+
}
577652
else if (b_nc == 0 && (a_nc == 0 || a_nc == 1))
578-
r.resize (std::max (a_nr, b_nr), b_nc);
653+
{
654+
if (b_nr == 1 || a_nr == 1 || b_nr == a_nr)
655+
r.resize (std::max (a_nr, b_nr), b_nc);
656+
else
657+
octave::err_nonconformant ("quotient", a_nr, a_nc, b_nr, b_nc);
658+
}
579659
else if (a_nr != b_nr || a_nc != b_nc)
580660
octave::err_nonconformant ("quotient", a_nr, a_nc, b_nr, b_nc);
581661
else

0 commit comments

Comments
 (0)