@@ -356,6 +356,7 @@ def _interpolate(agg, cmap, how, alpha, span, min_alpha, name, rescale_discrete_
356356
357357 return Image (img , coords = agg .coords , dims = agg .dims , name = name )
358358
359+ _EINSUM_PATH_CACHE = {}
359360
360361def _colorize (agg , color_key , how , alpha , span , min_alpha , name , color_baseline ,
361362 rescale_discrete_levels ):
@@ -389,7 +390,6 @@ def _colorize(agg, color_key, how, alpha, span, min_alpha, name, color_baseline,
389390 if da and isinstance (data , da .Array ):
390391 data = data .compute ()
391392
392- H , W , C = data .shape
393393 color_data = np .ascontiguousarray (data )
394394 if data is color_data :
395395 color_data = color_data .copy ()
@@ -416,26 +416,33 @@ def _colorize(agg, color_key, how, alpha, span, min_alpha, name, color_baseline,
416416 np .nan_to_num (color_data , copy = False ) # NaN -> 0
417417 color_total = np .sum (color_data , axis = 2 )
418418
419- # --- Compute all 3 weighted sums in one BLAS call ---
420- # Stack weights -> (C, 3)
419+ # --- Optimized matrix multiplication using einsum with path caching ---
421420 RGB = np .stack ([rs , gs , bs ], axis = 1 ) # (C,3)
422421
423- # Reshape to 2D so @ uses fast GEMM: (-1, C) @ (C, 3) -> (-1, 3)
424- cd2 = color_data .reshape (- 1 , C )
425- rgb_sum = (cd2 @ RGB ).reshape (H , W , 3 ) # weighted sums for r,g,b
422+ # Calculate color_data
423+ cache_key = (color_data .shape , RGB .shape )
424+ if cache_key not in _EINSUM_PATH_CACHE :
425+ _EINSUM_PATH_CACHE [cache_key ] = np .einsum_path (
426+ 'hwc,cr->hwr' , color_data , RGB , optimize = True
427+ )[0 ]
428+
429+ cached_path = _EINSUM_PATH_CACHE [cache_key ]
430+ rgb_sum = np .einsum ('hwc,cr->hwr' , color_data , RGB , optimize = cached_path )
431+ rgb_avg_present = np .einsum (
432+ 'hwc,cr->hwr' ,
433+ color_mask .astype (color_data .dtype , copy = False ),
434+ RGB ,
435+ optimize = cached_path ,
436+ )
426437
427438 # Divide by totals (broadcast) once, then cast once
428439 with np .errstate (divide = 'ignore' , invalid = 'ignore' ):
429440 rgb_array = (rgb_sum / color_total [..., None ]).astype (np .uint8 )
430441
431- # --- “ Average color of non-NaN categories” path (also as one matmul) ---
442+ # --- " Average color of non-NaN categories" path ---
432443 # Sum of True values per pixel
433444 cmask_sum = np .sum (color_mask , axis = 2 )
434445
435- # Cast mask to float once and reuse the same GEMM path
436- cm2 = color_mask .reshape (- 1 , C ).astype (np .float32 , copy = False )
437- rgb_avg_present = (cm2 @ RGB ).reshape (H , W , 3 ) # sums of rs/gs/bs over present cats
438-
439446 with np .errstate (divide = 'ignore' , invalid = 'ignore' ):
440447 rgb2 = (rgb_avg_present / cmask_sum [..., None ]).astype (np .uint8 )
441448
@@ -449,7 +456,11 @@ def _colorize(agg, color_key, how, alpha, span, min_alpha, name, color_baseline,
449456 mask = np .isnan (total )
450457 a = _interpolate_alpha (data , total , mask , how , alpha , span , min_alpha , rescale_discrete_levels )
451458
452- values = np .dstack ([rgb_array , a ]).view (np .uint32 ).reshape (a .shape )
459+ rgba_array = np .dstack ([rgb_array , a ])
460+ # Ensure array is contiguous for view operation
461+ if not rgba_array .flags .c_contiguous :
462+ rgba_array = np .ascontiguousarray (rgba_array )
463+ values = rgba_array .view (np .uint32 ).reshape (a .shape )
453464 if cupy and isinstance (values , cupy .ndarray ):
454465 # Convert cupy array to numpy for final image
455466 values = cupy .asnumpy (values )
0 commit comments