Skip to content

Commit fe6fcec

Browse files
committed
use einsum
1 parent 94ddff7 commit fe6fcec

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

datashader/transfer_functions/__init__.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ def _interpolate(agg, cmap, how, alpha, span, min_alpha, name, rescale_discrete_
356356

357357
return Image(img, coords=agg.coords, dims=agg.dims, name=name)
358358

359+
_EINSUM_PATH_CACHE = {}
359360

360361
def _colorize(agg, color_key, how, alpha, span, min_alpha, name, color_baseline,
361362
rescale_discrete_levels):
@@ -389,7 +390,6 @@ def _colorize(agg, color_key, how, alpha, span, min_alpha, name, color_baseline,
389390
if da and isinstance(data, da.Array):
390391
data = data.compute()
391392

392-
H, W, C = data.shape
393393
color_data = np.ascontiguousarray(data)
394394
if data is color_data:
395395
color_data = color_data.copy()
@@ -416,26 +416,33 @@ def _colorize(agg, color_key, how, alpha, span, min_alpha, name, color_baseline,
416416
np.nan_to_num(color_data, copy=False) # NaN -> 0
417417
color_total = np.sum(color_data, axis=2)
418418

419-
# --- Compute all 3 weighted sums in one BLAS call ---
420-
# Stack weights -> (C, 3)
419+
# --- Optimized matrix multiplication using einsum with path caching ---
421420
RGB = np.stack([rs, gs, bs], axis=1) # (C,3)
422421

423-
# Reshape to 2D so @ uses fast GEMM: (-1, C) @ (C, 3) -> (-1, 3)
424-
cd2 = color_data.reshape(-1, C)
425-
rgb_sum = (cd2 @ RGB).reshape(H, W, 3) # weighted sums for r,g,b
422+
# Calculate color_data
423+
cache_key = (color_data.shape, RGB.shape)
424+
if cache_key not in _EINSUM_PATH_CACHE:
425+
_EINSUM_PATH_CACHE[cache_key] = np.einsum_path(
426+
'hwc,cr->hwr', color_data, RGB, optimize=True
427+
)[0]
428+
429+
cached_path = _EINSUM_PATH_CACHE[cache_key]
430+
rgb_sum = np.einsum('hwc,cr->hwr', color_data, RGB, optimize=cached_path)
431+
rgb_avg_present = np.einsum(
432+
'hwc,cr->hwr',
433+
color_mask.astype(color_data.dtype, copy=False),
434+
RGB,
435+
optimize=cached_path,
436+
)
426437

427438
# Divide by totals (broadcast) once, then cast once
428439
with np.errstate(divide='ignore', invalid='ignore'):
429440
rgb_array = (rgb_sum / color_total[..., None]).astype(np.uint8)
430441

431-
# --- Average color of non-NaN categories path (also as one matmul) ---
442+
# --- "Average color of non-NaN categories" path ---
432443
# Sum of True values per pixel
433444
cmask_sum = np.sum(color_mask, axis=2)
434445

435-
# Cast mask to float once and reuse the same GEMM path
436-
cm2 = color_mask.reshape(-1, C).astype(np.float32, copy=False)
437-
rgb_avg_present = (cm2 @ RGB).reshape(H, W, 3) # sums of rs/gs/bs over present cats
438-
439446
with np.errstate(divide='ignore', invalid='ignore'):
440447
rgb2 = (rgb_avg_present / cmask_sum[..., None]).astype(np.uint8)
441448

@@ -449,7 +456,11 @@ def _colorize(agg, color_key, how, alpha, span, min_alpha, name, color_baseline,
449456
mask = np.isnan(total)
450457
a = _interpolate_alpha(data, total, mask, how, alpha, span, min_alpha, rescale_discrete_levels)
451458

452-
values = np.dstack([rgb_array, a]).view(np.uint32).reshape(a.shape)
459+
rgba_array = np.dstack([rgb_array, a])
460+
# Ensure array is contiguous for view operation
461+
if not rgba_array.flags.c_contiguous:
462+
rgba_array = np.ascontiguousarray(rgba_array)
463+
values = rgba_array.view(np.uint32).reshape(a.shape)
453464
if cupy and isinstance(values, cupy.ndarray):
454465
# Convert cupy array to numpy for final image
455466
values = cupy.asnumpy(values)

0 commit comments

Comments
 (0)