Skip to content

Commit

Permalink
Use sertit.types.make_iterable
Browse files Browse the repository at this point in the history
  • Loading branch information
remi-braun committed Jul 29, 2024
1 parent 0e6e16c commit fd8c192
Show file tree
Hide file tree
Showing 13 changed files with 20 additions and 34 deletions.
5 changes: 2 additions & 3 deletions CI/SCRIPTS_WEEKLY/test_all_sat_end_to_end_on_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import xarray as xr
from lxml import etree
from sertit import AnyPath, ci, path
from sertit import AnyPath, ci, path, types

from CI.scripts_utils import (
CI_EOREADER_S3,
Expand Down Expand Up @@ -166,8 +166,7 @@ def _test_core(

with xr.set_options(warn_for_unclosed_files=debug):
# DATA paths
if not isinstance(prod_dirs, list):
prod_dirs = [prod_dirs]
prod_dirs = types.make_interable(prod_dirs)

pattern_paths = []
for prod_dir in prod_dirs:
Expand Down
7 changes: 3 additions & 4 deletions eoreader/bands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,8 @@ def convert_to_band(tc) -> BandNames:

if as_list:
band_list = []
if not types.is_iterable(to_convert):
to_convert = [to_convert]
to_convert = types.make_interable(to_convert)

for tc in to_convert:
tc_band = convert_to_band(tc=tc)
band_list.append(tc_band)
Expand Down Expand Up @@ -492,8 +492,7 @@ def to_str(
list: str bands
"""
if as_list:
if not types.is_iterable(to_convert):
to_convert = [to_convert]
to_convert = types.make_interable(to_convert)

bands_str = []
for tc in to_convert:
Expand Down
4 changes: 2 additions & 2 deletions eoreader/bands/band_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def from_list(cls, name_list: Union[list, str]) -> list:
Returns:
list: List of enums
"""
if not types.is_iterable(name_list):
name_list = [name_list]
name_list = types.make_interable(name_list)

try:
band_names = [cls(name) for name in name_list]
except ValueError as ex:
Expand Down
3 changes: 1 addition & 2 deletions eoreader/products/custom_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,7 @@ def _load_bands(
return {}

# Get band paths
if not types.is_iterable(bands):
bands = [bands]
bands = types.make_interable(bands)

if pixel_size is None and size is not None:
pixel_size = self._pixel_size_from_img_size(size)
Expand Down
3 changes: 1 addition & 2 deletions eoreader/products/optical/hls_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,8 +873,7 @@ def _load_bands(
return {}

# Get band paths
if not types.is_iterable(bands):
bands = [bands]
bands = types.make_interable(bands)

if pixel_size is None and size is not None:
pixel_size = self._pixel_size_from_img_size(size)
Expand Down
3 changes: 1 addition & 2 deletions eoreader/products/optical/landsat_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,8 +1430,7 @@ def _load_bands(
return {}

# Get band paths
if not types.is_iterable(bands):
bands = [bands]
bands = types.make_interable(bands)

if pixel_size is None and size is not None:
pixel_size = self._pixel_size_from_img_size(size)
Expand Down
4 changes: 2 additions & 2 deletions eoreader/products/optical/planet_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,8 +1004,8 @@ def _get_path(
Union[list, str]: Paths(s)
"""
if invalid_lookahead is not None and not types.is_iterable(invalid_lookahead):
invalid_lookahead = [invalid_lookahead]
if invalid_lookahead is not None:
invalid_lookahead = types.make_interable(invalid_lookahead)

ok_paths = []
try:
Expand Down
3 changes: 1 addition & 2 deletions eoreader/products/optical/s2_e84_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,7 @@ def _load_bands(
return {}

# Get band paths
if not types.is_iterable(bands):
bands = [bands]
bands = types.make_interable(bands)

if pixel_size is None and size is not None:
pixel_size = self._pixel_size_from_img_size(size)
Expand Down
3 changes: 1 addition & 2 deletions eoreader/products/optical/s2_theia_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,8 +820,7 @@ def _create_mask(
xr.DataArray: Mask masked array
"""
if not types.is_iterable(bit_ids):
bit_ids = [bit_ids]
bit_ids = types.make_interable(bit_ids)
conds = rasters.read_bit_array(bit_array.astype(np.uint8), bit_ids)
cond = reduce(lambda x, y: x | y, conds) # Use every condition (bitwise or)

Expand Down
3 changes: 1 addition & 2 deletions eoreader/products/optical/s3_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,7 @@ def _load_bands(
return {}

# Get band paths
if not types.is_iterable(bands):
bands = [bands]
bands = types.make_interable(bands)

if pixel_size is None and size is not None:
pixel_size = self._pixel_size_from_img_size(size)
Expand Down
3 changes: 1 addition & 2 deletions eoreader/products/optical/s3_slstr_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,8 +1098,7 @@ def _create_mask(
xr.DataArray: Mask masked array
"""
if not types.is_iterable(bit_ids):
bit_ids = [bit_ids]
bit_ids = types.make_interable(bit_ids)
conds = rasters.read_bit_array(bit_array, bit_ids)
cond = reduce(lambda x, y: x | y, conds) # Use every condition (bitwise or)

Expand Down
10 changes: 3 additions & 7 deletions eoreader/products/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,9 +1239,7 @@ def has_bands(self, bands: Union[list, BandNames, str]) -> bool:
Returns:
bool: True if the products has the specified band
"""
if not types.is_iterable(bands):
bands = [bands]

bands = types.make_interable(bands)
return all([self.has_band(band) for band in set(bands)])

@abstractmethod
Expand Down Expand Up @@ -1715,8 +1713,7 @@ def _update_attrs(
# Are we sure of that ?
xarr.attrs = {}

if not types.is_iterable(bands):
bands = [bands]
bands = types.make_interable(bands)
long_name = to_str(bands)
xr_name = "_".join(long_name)
attr_name = " ".join(long_name)
Expand Down Expand Up @@ -2032,8 +2029,7 @@ def to_band(self, raw_bands: Union[list, BandNames, str, int]) -> list:
Returns:
list: Mapped bands
"""
if not types.is_iterable(raw_bands):
raw_bands = [raw_bands]
raw_bands = types.make_interable(raw_bands)

bands = []
for raw_band in raw_bands:
Expand Down
3 changes: 1 addition & 2 deletions eoreader/products/sar/sar_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,8 +633,7 @@ def _load_bands(
return {}

# Get band paths
if not types.is_iterable(bands):
bands = [bands]
bands = types.make_interable(bands)

if pixel_size is None and size is not None:
pixel_size = self._pixel_size_from_img_size(size)
Expand Down

0 comments on commit fd8c192

Please sign in to comment.