Skip to content

Commit f7ff7d4

Browse files
committed
More changes to matmul blog
* Add header for dispatch * Move to 16,16 workgroups everywhere 2D * Make it clear dispatch count it running on CPU not dispatches are for CPU
1 parent 9fd6eeb commit f7ff7d4

File tree

4 files changed

+9
-6
lines changed

4 files changed

+9
-6
lines changed

blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/variants.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ impl Gpu for Workgroup2d {
8181

8282
impl GridComputation for Workgroup2d {
8383
fn workgroup(&self) -> UVec3 {
84-
UVec3::new(8, 8, 1)
84+
UVec3::new(16, 16, 1)
8585
}
8686

8787
fn dispatch_count(&self, m: u32, n: u32) -> UVec3 {

blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/workgroup_2d/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use settings::Dimensions;
44
use spirv_std::glam::UVec3;
55
use spirv_std::spirv;
66

7-
#[spirv(compute(threads(8, 8)))]
7+
#[spirv(compute(threads(16, 16)))]
88
pub fn matmul(
99
#[spirv(global_invocation_id)] global_id: UVec3,
1010
#[spirv(uniform, descriptor_set = 0, binding = 0)] dimensions: &Dimensions,

blog/2024-11-21-optimizing-matrix-mul/index.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,10 @@ examples.
175175

176176
:::
177177

178-
Each workgroup, since it’s only one thread, processes one `result[i, j]`.
178+
#### Dispatching workgroups
179+
180+
Each workgroup, since it’s only one thread (`#[spirv(compute(threads(1)))]`), processes
181+
one `result[i, j]`.
179182

180183
To calculate the full matrix, we need to launch as many entries as there are in the
181184
matrix. Here we specify that (`Uvec3::new(m * n, 1, 1`) on the CPU:
@@ -241,7 +244,7 @@ Although we don't change much about our code, if we distribute our work in 2 dim
241244
we're able to bypass these limits and launch more workgroups that are larger. This
242245
allows us to calculate a 4096x4096 matmul.
243246

244-
We update our `compute(threads(256)))` to `compute(threads((8, 8)))`, and make the small
247+
We update our `compute(threads(256)))` to `compute(threads((16, 16)))`, and make the small
245248
change to `row` and `col` from Zach's post to increase speed:
246249

247250
import { RustWorkgroup2d } from './snippets/workgroup_2d.tsx';

blog/2024-11-21-optimizing-matrix-mul/snippets/naive.tsx

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ export const RustNaiveWorkgroupCount: React.FC = () => (
5959
language="rust"
6060
className="text-xs"
6161
lines="26-34"
62-
title="Calculating how many workgroup dispatches are needed on the CPU"
62+
title="Calculating on the CPU how many workgroup dispatches are needed"
6363
>
6464
{RustWorkgroupCount}
6565
</Snippet>
@@ -71,7 +71,7 @@ export const RustNaiveDispatch: React.FC = () => (
7171
className="text-xs"
7272
lines="145,147"
7373
strip_leading_spaces
74-
title="Using wgpu on the CPU to dispatch to the GPU"
74+
title="Using wgpu on the CPU to dispatch workgroups to the GPU"
7575
>
7676
{RustWgpuBackend}
7777
</Snippet>

0 commit comments

Comments
 (0)