@@ -428,12 +428,7 @@ def _colorize(agg, color_key, how, alpha, span, min_alpha, name, color_baseline,
428428
429429 cached_path = _EINSUM_PATH_CACHE [cache_key ]
430430 rgb_sum = np .einsum ('hwc,cr->hwr' , color_data , RGB , optimize = cached_path )
431- rgb_avg_present = np .einsum (
432- 'hwc,cr->hwr' ,
433- color_mask .astype (color_data .dtype , copy = False ),
434- RGB ,
435- optimize = cached_path ,
436- )
431+ rgb_avg_present = np .einsum ('hwc,cr->hwr' , color_mask , RGB , optimize = cached_path )
437432
438433 # Divide by totals (broadcast) once, then cast once
439434 with np .errstate (divide = 'ignore' , invalid = 'ignore' ):
@@ -447,10 +442,9 @@ def _colorize(agg, color_key, how, alpha, span, min_alpha, name, color_baseline,
447442 rgb2 = (rgb_avg_present / cmask_sum [..., None ]).astype (np .uint8 )
448443
449444 # --- Fill pixels with no color mass using the avg-present fallback ---
450- # Reuse color_total instead of re-summing
451445 missing_colors = (color_total == 0 )
452- # Select per-channel in one shot to reduce passes
453- rgb_array [ missing_colors ] = rgb2 [ missing_colors ]
446+ if np . any ( missing_colors ):
447+ rgb_array = np . where ( missing_colors [..., None ], rgb2 , rgb_array )
454448
455449 total = nansum_missing (data , axis = 2 )
456450 mask = np .isnan (total )
0 commit comments