Skip to content

Commit 3da8a6a

Browse files
authored
More changes to matmul blog (#28)
* Reword and move dispatch around * More changes to matmul blog * Add header for dispatch * Move to 16,16 workgroups everywhere 2D * Make it clear dispatch count it running on CPU not dispatches are for CPU
1 parent 318cd88 commit 3da8a6a

File tree

5 files changed

+90
-42
lines changed

5 files changed

+90
-42
lines changed

blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/variants.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ impl Gpu for Workgroup2d {
8181

8282
impl GridComputation for Workgroup2d {
8383
fn workgroup(&self) -> UVec3 {
84-
UVec3::new(8, 8, 1)
84+
UVec3::new(16, 16, 1)
8585
}
8686

8787
fn dispatch_count(&self, m: u32, n: u32) -> UVec3 {

blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/workgroup_2d/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use settings::Dimensions;
44
use spirv_std::glam::UVec3;
55
use spirv_std::spirv;
66

7-
#[spirv(compute(threads(8, 8)))]
7+
#[spirv(compute(threads(16, 16)))]
88
pub fn matmul(
99
#[spirv(global_invocation_id)] global_id: UVec3,
1010
#[spirv(uniform, descriptor_set = 0, binding = 0)] dimensions: &Dimensions,

blog/2024-11-21-optimizing-matrix-mul/index.md

+63-27
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,22 @@ platforms, including Windows, Linux, macOS, iOS[^1], Android, and the web[^2].
8888
By using Rust GPU and `wgpu`, we have a clean, portable setup with everything written in
8989
Rust.
9090

91+
## GPU program basics
92+
93+
The smallest unit of execution is a thread, which executes the GPU program.
94+
95+
Workgroups are groups of threads: they are grouped together and run in parallel (they’re
96+
called [thread blocks in
97+
CUDA](<https://en.wikipedia.org/wiki/Thread_block_(CUDA_programming)>)). They can access
98+
the same shared memory.
99+
100+
We can dispatch many of these workgroups at once. CUDA calls this a grid (which is made
101+
of thread blocks).
102+
103+
Workgroups and dispatching workgroups are defined in 3D. The size of a workgroup is
104+
defined by `compute(threads((x, y, z)))` where the number of threads per workgroup is
105+
x \* y \* z.
106+
91107
## Writing the kernel
92108

93109
### Kernel 1: Naive kernel
@@ -159,6 +175,35 @@ examples.
159175

160176
:::
161177

178+
#### Dispatching workgroups
179+
180+
Each workgroup, since it’s only one thread (`#[spirv(compute(threads(1)))]`), processes
181+
one `result[i, j]`.
182+
183+
To calculate the full matrix, we need to launch as many entries as there are in the
184+
matrix. Here we specify that (`Uvec3::new(m * n, 1, 1`) on the CPU:
185+
186+
import { RustNaiveWorkgroupCount } from './snippets/naive.tsx';
187+
188+
<RustNaiveWorkgroupCount/>
189+
190+
The `dispatch_count()` function runs on the CPU and is used by the CPU-to-GPU API (in
191+
our case `wgpu`) to configure and dispatch work to the GPU:
192+
193+
import { RustNaiveDispatch } from './snippets/naive.tsx';
194+
195+
<RustNaiveDispatch />
196+
197+
:::warning
198+
199+
This code appears more complicated than it needs to be. I abstracted the CPU-side code
200+
that talks to the GPU using generics and traits so I could easily slot in different
201+
kernels and their settings while writing the blog post.
202+
203+
You could just hardcode the value for simplicity.
204+
205+
:::
206+
162207
### Kernel 2: Moarrr threads!
163208

164209
With the first kernel, we're only able to compute small square matrices due to limits on
@@ -187,33 +232,19 @@ import { RustWorkgroup256WorkgroupCount } from './snippets/workgroup_256.tsx';
187232

188233
<RustWorkgroup256WorkgroupCount/>
189234

190-
The `dispatch_count()` function runs on the CPU and is used by the CPU-to-GPU API (in
191-
our case `wgpu`) to configure and dispatch to the GPU:
192-
193-
import { RustWorkgroup256WgpuDispatch } from './snippets/workgroup_256.tsx';
194-
195-
<RustWorkgroup256WgpuDispatch />
196-
197-
:::warning
198-
199-
This code appears more complicated than it needs to be. I abstracted the CPU-side code
200-
that talks to the GPU using generics and traits so I could easily slot in different
201-
kernels and their settings while writing the blog post.
202-
203-
You could just hardcode a value for simplicity.
204-
205-
:::
235+
With these two small changes we can handle larger matrices without hitting hardware
236+
workgroup limits.
206237

207238
### Kernel 3: Calculating with 2D workgroups
208239

209-
However doing all the computation in "1 dimension" limits the matrix size we can
240+
However, doing all the computation in "1 dimension" still limits the matrix size we can
210241
calculate.
211242

212243
Although we don't change much about our code, if we distribute our work in 2 dimensions
213244
we're able to bypass these limits and launch more workgroups that are larger. This
214245
allows us to calculate a 4096x4096 matmul.
215246

216-
We update our `compute(threads(256)))` to `compute(threads((8, 8)))`, and make the small
247+
We update our `compute(threads(256)))` to `compute(threads((16, 16)))`, and make the small
217248
change to `row` and `col` from Zach's post to increase speed:
218249

219250
import { RustWorkgroup2d } from './snippets/workgroup_2d.tsx';
@@ -257,24 +288,29 @@ import { RustTiling2dSimd } from './snippets/tiling_2d_simd.tsx';
257288
Each thread now calculates a 4x4 grid of the output matrix and we see a slight
258289
improvement over the last kernel.
259290

291+
To stay true to the spirit of Zach's original blog post, we'll wrap things up here and
292+
leave the "fancier" experiments for another time.
293+
260294
## Reflections on porting to Rust GPU
261295

262296
Porting to Rust GPU went quickly, as the kernels Zach used were fairly simple. Most of
263297
the time was spent with concerns that were not specifically about writing GPU code. For
264298
example, deciding how much to abstract vs how much to make the code easy to follow, if
265299
everything should be available at runtime or if each kernel should be a compilation
266-
target, etc. The code is not _great_ as it is still blog post code!
300+
target, etc. [The
301+
code](https://github.com/Rust-GPU/rust-gpu.github.io/tree/main/blog/2024-11-21-optimizing-matrix-mul/code)
302+
is not _great_ as it is still blog post code!
267303

268304
My background is not in GPU programming, but I do have Rust experience. I joined the
269305
Rust GPU project because I tried to use standard GPU languages and knew there must be a
270306
better way. Writing these GPU kernels felt like writing any other Rust code (other than
271-
debugging, more on that later) which is a huge win to me. Not only the language itself,
307+
debugging, more on that later) which is a huge win to me. Not just the language itself,
272308
but the entire development experience.
273309

274310
## Rust-specific party tricks
275311

276312
Rust lets us write code for both the CPU and GPU in ways that are often impossible—or at
277-
least less elegant—with other languages. I'm going to highlight some benefits of Rust I
313+
least less elegant—with other languages. I'm going to highlight some benefits I
278314
experienced while working on this blog post.
279315

280316
### Shared code across GPU and CPU
@@ -351,8 +387,9 @@ Testing the kernel in isolation is useful, but it does not reflect how the GPU e
351387
it with multiple invocations across workgroups and dispatches. To test the kernel
352388
end-to-end, I needed a test harness that simulated this behavior on the CPU.
353389

354-
Building the harness was straightforward. By enforcing the same invariants as the GPU I
355-
could validate the kernel under the same conditions the GPU would run it:
390+
Building the harness was straightforward due to the borrow checker. By enforcing the
391+
same invariants as the GPU I could validate the kernel under the same conditions the GPU
392+
would run it:
356393

357394
import { RustCpuBackendHarness } from './snippets/party.tsx';
358395

@@ -484,10 +521,9 @@ future.
484521
This kernel doesn't use conditional compilation, but it's a key feature of Rust that
485522
works with Rust GPU. With `#[cfg(...)]`, you can adapt kernels to different hardware or
486523
configurations without duplicating code. GPU languages like WGSL or GLSL offer
487-
preprocessor directives, but these tools lack standardization across ecosystems. Rust
488-
GPU leverages the existing Cargo ecosystem, so conditional compilation follows the same
489-
standards all Rust developers already know. This makes adapting kernels for different
490-
targets easier and more maintainable.
524+
preprocessor directives, but these tools lack standardization across projects. Rust GPU
525+
leverages the existing Cargo ecosystem, so conditional compilation follows the same
526+
standards all Rust developers already know.
491527

492528
## Come join us!
493529

blog/2024-11-21-optimizing-matrix-mul/snippets/naive.tsx

+25
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ import React from "react";
22
import CodeBlock from "@theme/CodeBlock";
33
import Snippet from "@site/src/components/Snippet";
44
import RustKernelSource from "!!raw-loader!../code/crates/gpu/naive/src/lib.rs";
5+
import RustWorkgroupCount from "!!raw-loader!../code/crates/cpu/matmul/src/variants.rs";
6+
import RustWgpuBackend from "!!raw-loader!../code/crates/cpu/matmul/src/backends/wgpu.rs";
57

68
export const WebGpuInputs: React.FC = () => (
79
<CodeBlock language="wgsl" title="WGSL" className="text-xs">
@@ -52,6 +54,29 @@ export const RustNaiveInputs: React.FC = () => (
5254
</Snippet>
5355
);
5456

57+
export const RustNaiveWorkgroupCount: React.FC = () => (
58+
<Snippet
59+
language="rust"
60+
className="text-xs"
61+
lines="26-34"
62+
title="Calculating on the CPU how many workgroup dispatches are needed"
63+
>
64+
{RustWorkgroupCount}
65+
</Snippet>
66+
);
67+
68+
export const RustNaiveDispatch: React.FC = () => (
69+
<Snippet
70+
language="rust"
71+
className="text-xs"
72+
lines="145,147"
73+
strip_leading_spaces
74+
title="Using wgpu on the CPU to dispatch workgroups to the GPU"
75+
>
76+
{RustWgpuBackend}
77+
</Snippet>
78+
);
79+
5580
export const RustNaiveWorkgroup: React.FC = () => (
5681
<Snippet language="rust" className="text-xs" lines="7">
5782
{RustKernelSource}

blog/2024-11-21-optimizing-matrix-mul/snippets/workgroup_256.tsx

-13
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import React from "react";
22
import Snippet from "@site/src/components/Snippet";
33
import RustKernelSource from "!!raw-loader!../code/crates/gpu/workgroup_256/src/lib.rs";
44
import VariantsSource from "!!raw-loader!../code/crates/cpu/matmul/src/variants.rs";
5-
import WgpuBackendSource from "!!raw-loader!../code/crates/cpu/matmul/src/backends/wgpu.rs";
65

76
export const RustWorkgroup256Workgroup: React.FC = () => (
87
<Snippet language="rust" className="text-xs" lines="7">
@@ -20,15 +19,3 @@ export const RustWorkgroup256WorkgroupCount: React.FC = () => (
2019
{VariantsSource}
2120
</Snippet>
2221
);
23-
24-
export const RustWorkgroup256WgpuDispatch: React.FC = () => (
25-
<Snippet
26-
language="rust"
27-
className="text-xs"
28-
lines="144,145,147"
29-
strip_leading_spaces
30-
title="Using wgpu on the CPU to dispatch to the GPU"
31-
>
32-
{WgpuBackendSource}
33-
</Snippet>
34-
);

0 commit comments

Comments
 (0)