Skip to content

Commit 452034b

Browse files
committed
Clean up header
1 parent add5a3e commit 452034b

File tree

2 files changed

+11
-19
lines changed

2 files changed

+11
-19
lines changed

coconut/compiler/templates/header.py_template

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -771,10 +771,7 @@ Additionally supports Cartesian products of numpy arrays."""
771771
if iterables:
772772
it_modules = [_coconut_get_base_module(it) for it in iterables]
773773
if _coconut.all(mod in _coconut.numpy_modules for mod in it_modules):
774-
if _coconut.any(mod in _coconut.xarray_modules for mod in it_modules):
775-
iterables = tuple((_coconut_xarray_to_numpy(it) if mod in _coconut.xarray_modules else it) for it, mod in _coconut.zip(iterables, it_modules))
776-
if _coconut.any(mod in _coconut.pandas_modules for mod in it_modules):
777-
iterables = tuple((it.to_numpy() if mod in _coconut.pandas_modules else it) for it, mod in _coconut.zip(iterables, it_modules))
774+
iterables = tuple((it.to_numpy() if mod in _coconut.pandas_modules else _coconut_xarray_to_numpy(it) if mod in _coconut.xarray_modules else it) for it, mod in _coconut.zip(iterables, it_modules))
778775
if _coconut.any(mod in _coconut.jax_numpy_modules for mod in it_modules):
779776
from jax import numpy
780777
else:
@@ -1104,12 +1101,7 @@ class multi_enumerate(_coconut_has_iter):
11041101
through inner iterables and produces a tuple index representing the index
11051102
in each inner iterable. Supports indexing.
11061103

1107-
For numpy arrays, effectively equivalent to:
1108-
it = np.nditer(iterable, flags=["multi_index", "refs_ok"])
1109-
for x in it:
1110-
yield it.multi_index, x
1111-
1112-
Also supports len for numpy arrays.
1104+
For numpy arrays, uses np.nditer under the hood and supports len.
11131105
"""
11141106
__slots__ = ()
11151107
def __repr__(self):
@@ -1960,10 +1952,10 @@ def all_equal(iterable, to=_coconut_sentinel):
19601952
"""
19611953
iterable_module = _coconut_get_base_module(iterable)
19621954
if iterable_module in _coconut.numpy_modules:
1963-
if iterable_module in _coconut.xarray_modules:
1964-
iterable = _coconut_xarray_to_numpy(iterable)
1965-
elif iterable_module in _coconut.pandas_modules:
1955+
if iterable_module in _coconut.pandas_modules:
19661956
iterable = iterable.to_numpy()
1957+
elif iterable_module in _coconut.xarray_modules:
1958+
iterable = _coconut_xarray_to_numpy(iterable)
19671959
return not _coconut.len(iterable) or (iterable == (iterable[0] if to is _coconut_sentinel else to)).all()
19681960
first_item = to
19691961
for item in iterable:

coconut/constants.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,22 +178,22 @@ def get_path_env_var(env_var, default):
178178
sys.setrecursionlimit(default_recursion_limit)
179179

180180
# modules that numpy-like arrays can live in
181-
xarray_modules = (
182-
"xarray",
181+
jax_numpy_modules = (
182+
"jaxlib",
183183
)
184184
pandas_modules = (
185185
"pandas",
186186
)
187-
jax_numpy_modules = (
188-
"jaxlib",
187+
xarray_modules = (
188+
"xarray",
189189
)
190190
numpy_modules = (
191191
"numpy",
192192
"torch",
193193
) + (
194-
xarray_modules
194+
jax_numpy_modules
195195
+ pandas_modules
196-
+ jax_numpy_modules
196+
+ xarray_modules
197197
)
198198

199199
legal_indent_chars = " \t" # the only Python-legal indent chars

0 commit comments

Comments
 (0)