Skip to content

Commit a7c44c0

Browse files
cbm755Alex Vong
authored andcommitted
Use a SymPy Array not a Matrix for non-Expr
It has mostly the same semantics as Matrix but can contain more general things. There might be some issues about different shapes of empties though... WIP on a fix for #1052.
1 parent 62ff091 commit a7c44c0

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

inst/@sym/private/elementwise_op.m

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
%% Copyright (C) 2014, 2016, 2018-2019, 2022 Colin B. Macdonald
1+
%% Copyright (C) 2014, 2016, 2018-2019, 2021-2022 Colin B. Macdonald
22
%% Copyright (C) 2016 Lagu
33
%%
44
%% This file is part of OctSymPy.
@@ -84,7 +84,7 @@
8484
% Make sure all matrices in the input are the same size, and set q to one of them
8585
'q = None'
8686
'for A in _ins:'
87-
' if isinstance(A, MatrixBase):'
87+
' if isinstance(A, (MatrixBase, NDimArray)):'
8888
' if q is None:'
8989
' q = A'
9090
' else:'
@@ -97,7 +97,10 @@
9797
'elements = []'
9898
'for i in range(0, len(q)):'
9999
' elements.append(_op(*[k[i] if isinstance(k, MatrixBase) else k for k in _ins]))'
100-
'return Matrix(*q.shape, elements)' ];
100+
'if all(isinstance(x, Expr) for x in elements):'
101+
' return Matrix(*q.shape, elements)'
102+
'dbout(f"elementwise_op: returning an Array not a Matrix")'
103+
'return Array(elements, shape=q.shape)' ];
101104

102105
z = pycall_sympy__ (cmd, varargin{:});
103106

inst/private/python_header.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def octoutput(x, et):
152152
f.text = str(OCTCODE_BOOL)
153153
f = ET.SubElement(a, "f")
154154
f.text = str(x)
155-
elif x is None or isinstance(x, (sp.Basic, sp.MatrixBase)):
155+
elif x is None or isinstance(x, (sp.Basic, sp.MatrixBase, sp.NDimArray)):
156156
# FIXME: is it weird to pretend None is a SymPy object?
157157
if isinstance(x, (sp.Matrix, sp.ImmutableMatrix)):
158158
_d = x.shape
@@ -161,6 +161,9 @@ def octoutput(x, et):
161161
_d = [float(r) if (isinstance(r, sp.Basic) and r.is_Integer)
162162
else float('nan') if isinstance(r, sp.Basic)
163163
else r for r in x.shape]
164+
elif isinstance(x, sp.NDimArray):
165+
_d = x.shape
166+
dbout(f"I am here with an array with shape {_d}")
164167
elif x is None:
165168
_d = (1,1)
166169
else:

0 commit comments

Comments
 (0)