Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 54 additions & 33 deletions datashader/transfer_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import xarray as xr

from datashader.colors import rgb, Sets1to3
from datashader.utils import nansum_missing, ngjit, uint32_to_uint8
from datashader.utils import ngjit, uint32_to_uint8

try:
import dask.array as da
Expand Down Expand Up @@ -356,13 +356,11 @@ def _interpolate(agg, cmap, how, alpha, span, min_alpha, name, rescale_discrete_

return Image(img, coords=agg.coords, dims=agg.dims, name=name)

_EINSUM_PATH_CACHE = {}

def _colorize(agg, color_key, how, alpha, span, min_alpha, name, color_baseline,
rescale_discrete_levels):
if cupy and isinstance(agg.data, cupy.ndarray):
array = cupy.array
else:
array = np.array
xp = cupy if cupy and isinstance(agg.data, cupy.ndarray) else np

if not agg.ndim == 3:
raise ValueError("agg must be 3D")
Expand All @@ -384,57 +382,81 @@ def _colorize(agg, color_key, how, alpha, span, min_alpha, name, color_baseline,
f"fields available ({len(cats)})")

colors = [rgb(color_key[c]) for c in cats]
rs, gs, bs = map(array, zip(*colors))
rs, gs, bs = map(xp.array, zip(*colors))

# Reorient array (transposing the category dimension first)
agg_t = agg.transpose(*((agg.dims[-1],)+agg.dims[:2]))
agg_t = agg.transpose(*(agg.dims[-1], *agg.dims[:2]))
data = agg_t.data.transpose([1, 2, 0])
if da and isinstance(data, da.Array):
data = data.compute()
color_data = data.copy()

color_data = np.ascontiguousarray(data)
if data is color_data:
color_data = color_data.copy()

nan_mask = np.isnan(data)
color_mask = ~nan_mask

# subtract color_baseline if needed
with warnings.catch_warnings():
warnings.filterwarnings('ignore', r'All-NaN slice encountered')
baseline = np.nanmin(color_data) if color_baseline is None else color_baseline
with np.errstate(invalid='ignore'):
# in-place add/sub to minimize temporaries
if baseline > 0:
color_data -= baseline
np.subtract(color_data, baseline, out=color_data, where=color_mask)
Comment on lines -402 to +407
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't this already in place? Is it to avoid calculation along the color_mask axis?

elif baseline < 0:
color_data += -baseline
if color_data.dtype.kind != 'u' and color_baseline is not None:
color_data[color_data<0]=0
np.add(color_data, -baseline, out=color_data, where=color_mask)

# If an explicit baseline was given and dtype is signed, clip negatives to 0 (in-place)
if (color_baseline is not None) and (color_data.dtype.kind != 'u'):
np.maximum(color_data, 0, out=color_data)

color_total = nansum_missing(color_data, axis=2)
# dot does not handle nans, so replace with zeros
color_data[np.isnan(data)] = 0
# Replace NaNs with 0s for dot/matmul in one pass (in-place)
np.nan_to_num(color_data, copy=False) # NaN -> 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use an existing mask for this? I would assume nan_to_mask would do this calculation behind the scenes.

Also note that this will convert inf values to floats.

color_total = np.sum(color_data, axis=2)

# zero-count pixels will be 0/0, but it's safe to ignore that when dividing
# --- Optimized matrix multiplication using einsum with path caching ---
RGB = np.stack([rs, gs, bs], axis=1) # (C,3)

# Calculate color_data
cache_key = (color_data.shape, RGB.shape)
if cache_key not in _EINSUM_PATH_CACHE:
_EINSUM_PATH_CACHE[cache_key] = np.einsum_path(
'hwc,cr->hwr', color_data, RGB, optimize=True
)[0]

cached_path = _EINSUM_PATH_CACHE[cache_key]
rgb_sum = np.einsum('hwc,cr->hwr', color_data, RGB, optimize=cached_path)
rgb_avg_present = np.einsum('hwc,cr->hwr', color_mask, RGB, optimize=cached_path)

# Divide by totals (broadcast) once, then cast once
with np.errstate(divide='ignore', invalid='ignore'):
r = (color_data.dot(rs)/color_total).astype(np.uint8)
g = (color_data.dot(gs)/color_total).astype(np.uint8)
b = (color_data.dot(bs)/color_total).astype(np.uint8)
rgb_array = (rgb_sum / color_total[..., None]).astype(np.uint8)

# special case -- to give an appropriate color when min_alpha != 0 and data=0,
# take avg color of all non-nan categories
color_mask = ~np.isnan(data)
# --- "Average color of non-NaN categories" path ---
# Sum of True values per pixel
cmask_sum = np.sum(color_mask, axis=2)

with np.errstate(divide='ignore', invalid='ignore'):
r2 = (color_mask.dot(rs)/cmask_sum).astype(np.uint8)
g2 = (color_mask.dot(gs)/cmask_sum).astype(np.uint8)
b2 = (color_mask.dot(bs)/cmask_sum).astype(np.uint8)
rgb2 = (rgb_avg_present / cmask_sum[..., None]).astype(np.uint8)

missing_colors = np.sum(color_data, axis=2) == 0
r = np.where(missing_colors, r2, r)
g = np.where(missing_colors, g2, g)
b = np.where(missing_colors, b2, b)
# --- Fill pixels with no color mass using the avg-present fallback ---
missing_colors = (color_total == 0)
if np.any(missing_colors):
rgb_array = np.where(missing_colors[..., None], rgb2, rgb_array)

total = nansum_missing(data, axis=2)
mask = np.isnan(total)
# total = nansum_missing(data, axis=2)
# mask = np.isnan(total)
total = np.sum(data, axis=2)
mask = np.any(nan_mask, axis=2)
a = _interpolate_alpha(data, total, mask, how, alpha, span, min_alpha, rescale_discrete_levels)

values = np.dstack([r, g, b, a]).view(np.uint32).reshape(a.shape)
rgba_array = np.dstack([rgb_array, a])
# Ensure array is contiguous for view operation
if not rgba_array.flags.c_contiguous:
rgba_array = np.ascontiguousarray(rgba_array)
values = rgba_array.view(np.uint32).reshape(a.shape)
if cupy and isinstance(values, cupy.ndarray):
# Convert cupy array to numpy for final image
values = cupy.asnumpy(values)
Expand All @@ -449,7 +471,6 @@ def _colorize(agg, color_key, how, alpha, span, min_alpha, name, color_baseline,


def _interpolate_alpha(data, total, mask, how, alpha, span, min_alpha, rescale_discrete_levels):

if cupy and isinstance(data, cupy.ndarray):
from ._cuda_utils import interp, masked_clip_2d
array_module = cupy
Expand Down
Loading