Skip to content

Conversation

@philippjfr
Copy link
Member

@philippjfr philippjfr commented Aug 26, 2025

As the title says, this attempts to optimize the colorize part of the shade operation by avoiding temporary copies and performing a single matmul operation rather than multiple dot operations. In my testing this is about a 10% speedup. My guess is that this could result in even better performance for systems with MKL support.

@codecov
Copy link

codecov bot commented Aug 26, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 88.34%. Comparing base (f44670c) to head (36a703c).
⚠️ Report is 2 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1437   +/-   ##
=======================================
  Coverage   88.33%   88.34%           
=======================================
  Files          96       96           
  Lines       18901    18908    +7     
=======================================
+ Hits        16696    16704    +8     
+ Misses       2205     2204    -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines -402 to +411
color_data -= baseline
np.subtract(color_data, baseline, out=color_data, where=color_mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't this already in place? Is it to avoid calculation along the color_mask axis?

Comment on lines 432 to 433
cd2 = color_data.reshape(-1, C)
rgb_sum = (cd2 @ RGB).reshape(H, W, 3) # weighted sums for r,g,b
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tried to play around with Einstein notation or np.tensordot for this? AI generated benchmark:

import numpy as np
import timeit

def benchmark(H, W, C):
    # Random test data
    color_data = np.random.rand(H, W, C)
    RGB = np.random.rand(C, 3)

    # Method 1: reshape + matmul
    def method_matmul():
        cd2 = color_data.reshape(-1, C)
        return (cd2 @ RGB).reshape(H, W, 3)

    # Method 2: einsum
    def method_einsum():
        return np.einsum('hwc,cj->hwj', color_data, RGB)

    # Method 3: einsum with optimize=True
    def method_einsum_opt():
        return np.einsum('hwc,cj->hwj', color_data, RGB, optimize=True)

    # Method 4: tensordot
    def method_tensordot():
        return np.tensordot(color_data, RGB, axes=([2],[0]))  # shape: (H, W, 3)

    # Verify correctness
    out1 = method_matmul()
    out2 = method_einsum()
    out3 = method_einsum_opt()
    out4 = method_tensordot()
    assert np.allclose(out1, out2)
    assert np.allclose(out1, out3)
    assert np.allclose(out1, out4)

    # Benchmark
    time_matmul = timeit.timeit(method_matmul, number=10)
    time_einsum = timeit.timeit(method_einsum, number=10)
    time_einsum_opt = timeit.timeit(method_einsum_opt, number=10)
    time_tensordot = timeit.timeit(method_tensordot, number=10)

    print(f"H={H}, W={W}, C={C}")
    print(f"reshape+matmul: {time_matmul:.4f} s")
    print(f"einsum: {time_einsum:.4f} s")
    print(f"einsum (optimize=True): {time_einsum_opt:.4f} s")
    print(f"tensordot: {time_tensordot:.4f} s")
    print("-" * 50)

# Test different shapes
benchmark(256, 256, 64)
benchmark(512, 512, 64)
benchmark(256, 256, 256)
benchmark(128, 128, 1024)
H=256, W=256, C=64
reshape+matmul: 0.0284 s
einsum: 0.1792 s
einsum (optimize=True): 0.0188 s
tensordot: 0.0280 s
--------------------------------------------------
H=512, W=512, C=64
reshape+matmul: 0.1330 s
einsum: 0.5891 s
einsum (optimize=True): 0.0341 s
tensordot: 0.1112 s
--------------------------------------------------
H=256, W=256, C=256
reshape+matmul: 0.1042 s
einsum: 0.6018 s
einsum (optimize=True): 0.0419 s
tensordot: 0.1038 s
--------------------------------------------------
H=128, W=128, C=1024
reshape+matmul: 0.0713 s
einsum: 0.6043 s
einsum (optimize=True): 0.0497 s
tensordot: 0.0712 s
--------------------------------------------------

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, not 100% sure what to take from that though I will state that in most cases C << 100.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why I said it was AI-generated. I assume you had a more fully fledged way to actually measure your performance. This was more to show that there was a tiny bit of performance that could be gained here.


# Replace NaNs with 0s for dot/matmul in one pass (in-place)
# If you don't want to mutate color_data contents further, copy first.
np.nan_to_num(color_data, copy=False) # NaN -> 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use an existing mask for this? I would assume nan_to_mask would do this calculation behind the scenes.

Also note that this will convert inf values to floats.

@codspeed-hq
Copy link

codspeed-hq bot commented Sep 8, 2025

CodSpeed Instrumentation Performance Report

Merging #1437 will degrade performances by 17.88%

Comparing optimize_colorize (36a703c) with main (a4d57be)

Summary

❌ 1 regression
✅ 42 untouched

⚠️ Please fix the performance issues or acknowledge them on CodSpeed.

Benchmarks breakdown

Benchmark BASE HEAD Change
test_layout[forceatlas2_layout] 68.7 ms 83.6 ms -17.88%

@philippjfr
Copy link
Member Author

My profiling code

import time

import numpy as np
import pandas as pd
import datashader as ds

N = int(10e6)
C = 20

def gen_data(N=int(10e6), C=20):
    xy = np.random.randn(int(N), 2)
    c = np.random.choice([chr(65+i) for i in range(C)], size=N)
    df = pd.DataFrame(xy, columns=['x', 'y'])
    df['c'] = pd.Series(c).astype('category')
    return df

def profile(df, size=1000):
    W = H = size
    cvs = ds.Canvas(plot_width=W, plot_height=H)
    agg = cvs.points(df, x='x', y='y', agg=ds.count_cat('c'))

    pre = time.monotonic()
    ds.transfer_functions.shade(agg)
    return time.monotonic()-pre

# Warmup
df = gen_data(C=1)
profile(df, size=10)

results = []
for c in (1, 5, 10, 20):
    df = gen_data(C=c)
    for s in range(1000, 6000, 1000):
        timing = profile(df, size=s)
        results.append((c, s, timing))
Screenshot 2025-09-10 at 17 50 04

@hoxbro
Copy link
Member

hoxbro commented Sep 12, 2025

Current status:

current = 6b0982b, before = fb2c16e, main = 184ef3c

Benchmark code

import time
import numpy as np
import pandas as pd
import datashader as ds
import datashader.transfer_functions as tf

N = int(10e6)
C = 20


class Profile:
    def __init__(self, output_file):
        import cProfile

        self.profiler = cProfile.Profile()
        self.output_file = output_file

    def __enter__(self):
        self.profiler.enable()
        return self.profiler

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.profiler.disable()
        self.profiler.dump_stats(self.output_file)


class LineProfileContext:
    def __init__(self, output_file):
        from line_profiler import LineProfiler

        self.profiler = LineProfiler()
        self.output_file = output_file
        self.functions_to_profile = []

    def add_function(self, func):
        """Add a function to be profiled line-by-line"""
        self.profiler.add_function(func)
        self.functions_to_profile.append(func)
        return func

    def __enter__(self):
        self.profiler.enable_by_count()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.profiler.disable_by_count()

        self.profiler.dump_stats(self.output_file)
        self.profiler.print_stats()


def gen_data(N=N, C=C):
    np.random.seed(1)
    xy = np.random.randn(int(N), 2)
    c = np.random.choice([chr(65 + i) for i in range(C)], size=N)
    df = pd.DataFrame(xy, columns=["x", "y"])
    df["c"] = pd.Series(c).astype("category")
    return df


def profile(df, size=1000):
    W = H = size
    cvs = ds.Canvas(plot_width=W, plot_height=H)
    agg = cvs.points(df, x="x", y="y", agg=ds.count_cat("c"))

    tf.shade(agg)  # warm up
    pre = time.monotonic()

    # with LineProfileContext("line_profile.lprof") as line_profiler:
    #     line_profiler.add_function(tf._colorize)
    #     tf.shade(agg)

    # with Profile(output_file="optional.perf"):
    #     ds.transfer_functions.shade(agg)
    return time.monotonic() - pre


# Warmup
df = gen_data(C=20)
profile(df, size=5000)

results = []
for c in (1, 5, 10, 20):
    df = gen_data(C=c)
    for s in range(1000, 6000, 1000):
        timing = profile(df, size=s)
        results.append((c, s, timing))
        print(f"{c=}, {s=}, {timing=}")

Plotting

current = [  # 6b0982b
    dict(c=1, s=1000, timing=0.07571580299918423),
    dict(c=1, s=2000, timing=0.295644159999938),
    dict(c=1, s=3000, timing=0.6464670440000191),
    dict(c=1, s=4000, timing=1.1230143669999961),
    dict(c=1, s=5000, timing=1.7188509200004773),
    dict(c=5, s=1000, timing=0.11476161499922455),
    dict(c=5, s=2000, timing=0.44561883799906354),
    dict(c=5, s=3000, timing=0.9790756620004686),
    dict(c=5, s=4000, timing=1.7118233849996614),
    dict(c=5, s=5000, timing=2.6856131889999233),
    dict(c=10, s=1000, timing=0.14612587800002075),
    dict(c=10, s=2000, timing=0.5675542549997772),
    dict(c=10, s=3000, timing=1.2379318599996623),
    dict(c=10, s=4000, timing=2.2251677369986282),
    dict(c=10, s=5000, timing=3.397321854999973),
    dict(c=20, s=1000, timing=0.1993868179997662),
    dict(c=20, s=2000, timing=0.8214870430001611),
    dict(c=20, s=3000, timing=1.7614306820014463),
    dict(c=20, s=4000, timing=3.0943053329992836),
    dict(c=20, s=5000, timing=4.7508491489988955),
]

before = [  # fb2c16e
    dict(c=1, s=1000, timing=0.07645769699956873),
    dict(c=1, s=2000, timing=0.3170905290007795),
    dict(c=1, s=3000, timing=0.7142776969994884),
    dict(c=1, s=4000, timing=1.2551025209995714),
    dict(c=1, s=5000, timing=1.9599227520011482),
    dict(c=5, s=1000, timing=0.16230177999932494),
    dict(c=5, s=2000, timing=0.5520949959991412),
    dict(c=5, s=3000, timing=1.2177506650004943),
    dict(c=5, s=4000, timing=2.171157504999428),
    dict(c=5, s=5000, timing=3.3560801679996075),
    dict(c=10, s=1000, timing=0.2009295749994635),
    dict(c=10, s=2000, timing=0.7160231019988714),
    dict(c=10, s=3000, timing=1.6094946339999296),
    dict(c=10, s=4000, timing=2.7828460880009516),
    dict(c=10, s=5000, timing=4.274540911001168),
    dict(c=20, s=1000, timing=0.2542700350004452),
    dict(c=20, s=2000, timing=0.9284682460001932),
    dict(c=20, s=3000, timing=2.0608999519990903),
    dict(c=20, s=4000, timing=3.6744658019997587),
    dict(c=20, s=5000, timing=5.747611536000477),
]
main = [  # 184ef3c
    dict(c=1, s=1000, timing=0.0718935530003364),
    dict(c=1, s=2000, timing=0.31208833799973945),
    dict(c=1, s=3000, timing=0.7055044320004527),
    dict(c=1, s=4000, timing=1.2214937410008133),
    dict(c=1, s=5000, timing=1.899291293000715),
    dict(c=5, s=1000, timing=0.1668667740013916),
    dict(c=5, s=2000, timing=0.6655240790005337),
    dict(c=5, s=3000, timing=1.5014597809986299),
    dict(c=5, s=4000, timing=2.5980365989998973),
    dict(c=5, s=5000, timing=4.086923677999948),
    dict(c=10, s=1000, timing=0.2160664650000399),
    dict(c=10, s=2000, timing=0.8515692499986471),
    dict(c=10, s=3000, timing=1.879354708999017),
    dict(c=10, s=4000, timing=3.3537094929997693),
    dict(c=10, s=5000, timing=5.179373872999349),
    dict(c=20, s=1000, timing=0.3006982959996094),
    dict(c=20, s=2000, timing=1.1935125410000182),
    dict(c=20, s=3000, timing=2.72798394000165),
    dict(c=20, s=4000, timing=4.852396742000565),
    dict(c=20, s=5000, timing=7.331704080999771),
]

import hvplot.pandas
import pandas as pd

fn = (
    lambda x: pd.DataFrame(eval(x))
    .hvplot.bar(x="s", y="timing", by="c", title=x)
    .opts(show_grid=True)
)

(fn("current") + fn("before") + fn("main")).cols(1)

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants