Skip to content

Commit 2a95dcb

Browse files
committed
Another rewrite of the HiZ shader.
Memory access pattern is critical for perf.
1 parent 8ef38e0 commit 2a95dcb

File tree

2 files changed

+153
-140
lines changed

2 files changed

+153
-140
lines changed

assets/shaders/post/hiz.comp

Lines changed: 151 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
layout(local_size_x = 256) in;
99

10-
layout(set = 0, binding = 0, r32f) uniform coherent image2D uImages[13];
10+
layout(set = 0, binding = 0, r32f) uniform writeonly image2D uImageTop;
11+
layout(set = 0, binding = 1, r32f) coherent uniform image2D uImages[12];
1112
layout(set = 1, binding = 0) uniform sampler2D uTexture;
1213
layout(set = 1, binding = 1) buffer Counter
1314
{
@@ -62,23 +63,38 @@ vec4 transform_z(vec4 zs)
6263
void write_image(ivec2 coord, int mip, float v)
6364
{
6465
// Rely on image robustness to clean up the OOB writes here.
65-
imageStore(uImages[mip], coord, vec4(v));
66+
imageStore(uImages[mip - 1], coord, vec4(v));
6667
}
6768

68-
void write_image4(ivec2 coord, int mip, vec4 v)
69+
void write_image4_top(ivec2 coord, int mip, vec4 v)
6970
{
70-
imageStore(uImages[mip], coord + ivec2(0, 0), v.xxxx);
71-
imageStore(uImages[mip], coord + ivec2(1, 0), v.yyyy);
72-
imageStore(uImages[mip], coord + ivec2(0, 1), v.zzzz);
73-
imageStore(uImages[mip], coord + ivec2(1, 1), v.wwww);
71+
imageStore(uImageTop, coord + ivec2(0, 0), v.xxxx);
72+
imageStore(uImageTop, coord + ivec2(1, 0), v.yyyy);
73+
imageStore(uImageTop, coord + ivec2(0, 1), v.zzzz);
74+
imageStore(uImageTop, coord + ivec2(1, 1), v.wwww);
7475
}
7576

7677
const int SHARED_WIDTH = 32;
7778
const int SHARED_HEIGHT = 32;
78-
const int BANK_STRIDE = SHARED_WIDTH * SHARED_HEIGHT;
79-
shared float shared_buffer[2 * BANK_STRIDE];
79+
shared float shared_buffer[SHARED_HEIGHT][SHARED_WIDTH];
8080
shared bool shared_is_last_workgroup;
8181

82+
void store_shared(ivec2 coord, float d)
83+
{
84+
shared_buffer[coord.y][coord.x] = d;
85+
}
86+
87+
float load_shared(ivec2 coord)
88+
{
89+
return shared_buffer[coord.y][coord.x];
90+
}
91+
92+
vec4 fetch_2x2_texture(ivec2 base_coord)
93+
{
94+
vec2 fcoord = vec2(base_coord) * registers.inv_resolution;
95+
return textureGatherOffset(uTexture, fcoord, ivec2(1, 1)).wzxy;
96+
}
97+
8298
mat4 fetch_4x4_texture(ivec2 base_coord)
8399
{
84100
vec2 fcoord = vec2(base_coord) * registers.inv_resolution;
@@ -92,59 +108,38 @@ mat4 fetch_4x4_texture(ivec2 base_coord)
92108
vec4 fetch_2x2_image_mip6(ivec2 base_coord)
93109
{
94110
ivec2 max_coord = mip_resolution(6) - 1;
95-
float d0 = imageLoad(uImages[6], min(base_coord + ivec2(0, 0), max_coord)).x;
96-
float d1 = imageLoad(uImages[6], min(base_coord + ivec2(1, 0), max_coord)).x;
97-
float d2 = imageLoad(uImages[6], min(base_coord + ivec2(0, 1), max_coord)).x;
98-
float d3 = imageLoad(uImages[6], min(base_coord + ivec2(1, 1), max_coord)).x;
111+
float d0 = imageLoad(uImages[5], min(base_coord + ivec2(0, 0), max_coord)).x;
112+
float d1 = imageLoad(uImages[5], min(base_coord + ivec2(1, 0), max_coord)).x;
113+
float d2 = imageLoad(uImages[5], min(base_coord + ivec2(0, 1), max_coord)).x;
114+
float d3 = imageLoad(uImages[5], min(base_coord + ivec2(1, 1), max_coord)).x;
99115
return vec4(d0, d1, d2, d3);
100116
}
101117

102118
float fetch_image_mip6(ivec2 coord)
103119
{
104-
return imageLoad(uImages[6], coord).x;
120+
return imageLoad(uImages[5], coord).x;
105121
}
106122

107-
mat4 write_mip0_transformed(mat4 M, ivec2 base_coord)
123+
vec4 write_mip0_transformed(vec4 v, ivec2 base_coord)
108124
{
109-
vec4 q00 = transform_z(M[0]);
110-
vec4 q10 = transform_z(M[1]);
111-
vec4 q01 = transform_z(M[2]);
112-
vec4 q11 = transform_z(M[3]);
113-
125+
v = transform_z(v);
114126
// Write out transformed LOD 0
115-
write_image4(base_coord + ivec2(0, 0), 0, q00);
116-
write_image4(base_coord + ivec2(2, 0), 0, q10);
117-
write_image4(base_coord + ivec2(0, 2), 0, q01);
118-
write_image4(base_coord + ivec2(2, 2), 0, q11);
119-
120-
return mat4(q00, q10, q01, q11);
127+
write_image4_top(base_coord, 0, v);
128+
return v;
121129
}
122130

123131
// For LOD 0 to 6, it is expected that the division is exact,
124132
// i.e., the lower resolution mip is exactly half resolution.
125133
// This way we avoid needing to fold in neighbors.
126134

127-
float reduce_mip_registers(mat4 M, ivec2 base_coord, int mip)
135+
float reduce_mip_simple(vec4 v, ivec2 base_coord, int mip)
128136
{
129-
vec4 q00 = M[0];
130-
vec4 q10 = M[1];
131-
vec4 q01 = M[2];
132-
vec4 q11 = M[3];
133-
134-
ivec2 mip_res = mip_resolution(mip);
135-
136-
float d00 = reduce(q00);
137-
float d10 = reduce(q10);
138-
float d01 = reduce(q01);
139-
float d11 = reduce(q11);
140-
141-
q00 = vec4(d00, d10, d01, d11);
142-
write_image4(base_coord, mip, q00);
143-
144-
return reduce(q00);
137+
float reduced = reduce(v);
138+
write_image(base_coord, mip, reduced);
139+
return reduced;
145140
}
146141

147-
void reduce_mip_shared(ivec2 base_coord, int mip)
142+
float reduce_mip_shared(ivec2 base_coord, int mip)
148143
{
149144
ivec2 mip_res_higher = mip_resolution(mip - 1);
150145
ivec2 mip_res_target = mip_resolution(mip);
@@ -153,37 +148,31 @@ void reduce_mip_shared(ivec2 base_coord, int mip)
153148
bool vert_fold = base_coord.y + 1 == mip_res_target.y && (mip_res_higher.y & 1) != 0;
154149
bool diag_fold = horiz_fold && vert_fold;
155150

156-
const int DOUBLE_SHARED_WIDTH = SHARED_WIDTH * 2;
157-
158151
// Ping-pong the shared buffer to avoid double barrier.
159-
int out_offset = (mip & 1) * BANK_STRIDE;
160-
int in_offset = BANK_STRIDE - out_offset;
161-
int base_in_coord = in_offset + base_coord.y * DOUBLE_SHARED_WIDTH + base_coord.x * 2;
162-
163-
float d00 = shared_buffer[base_in_coord];
164-
float d10 = shared_buffer[base_in_coord + 1];
165-
float d01 = shared_buffer[base_in_coord + SHARED_WIDTH];
166-
float d11 = shared_buffer[base_in_coord + SHARED_WIDTH + 1];
152+
float d00 = load_shared(2 * base_coord + ivec2(0, 0));
153+
float d10 = load_shared(2 * base_coord + ivec2(1, 0));
154+
float d01 = load_shared(2 * base_coord + ivec2(0, 1));
155+
float d11 = load_shared(2 * base_coord + ivec2(1, 1));
167156

168157
float reduced = reduce(vec4(d00, d10, d01, d11));
169158

170159
if (horiz_fold)
171160
{
172-
reduced = REDUCE_OPERATOR(reduced, shared_buffer[base_in_coord + 2]);
173-
reduced = REDUCE_OPERATOR(reduced, shared_buffer[base_in_coord + 2 + SHARED_WIDTH]);
161+
reduced = REDUCE_OPERATOR(reduced, load_shared(2 * base_coord + ivec2(2, 0)));
162+
reduced = REDUCE_OPERATOR(reduced, load_shared(2 * base_coord + ivec2(2, 1)));
174163
}
175164

176165
if (vert_fold)
177166
{
178-
reduced = REDUCE_OPERATOR(reduced, shared_buffer[base_in_coord + DOUBLE_SHARED_WIDTH]);
179-
reduced = REDUCE_OPERATOR(reduced, shared_buffer[base_in_coord + DOUBLE_SHARED_WIDTH + 1]);
167+
reduced = REDUCE_OPERATOR(reduced, load_shared(2 * base_coord + ivec2(0, 2)));
168+
reduced = REDUCE_OPERATOR(reduced, load_shared(2 * base_coord + ivec2(1, 2)));
180169
}
181170

182171
if (diag_fold)
183-
reduced = REDUCE_OPERATOR(reduced, shared_buffer[base_in_coord + DOUBLE_SHARED_WIDTH + 2]);
172+
reduced = REDUCE_OPERATOR(reduced, load_shared(2 * base_coord + ivec2(2, 2)));
184173

185-
shared_buffer[out_offset + base_coord.y * SHARED_WIDTH + base_coord.x] = reduced;
186174
write_image(base_coord, mip, reduced);
175+
return reduced;
187176
}
188177

189178
void reduce_mip_lod7(ivec2 base_coord)
@@ -217,35 +206,18 @@ void reduce_mip_lod7(ivec2 base_coord)
217206
reduced = REDUCE_OPERATOR(reduced, fetch_image_mip6(2 * base_coord + ivec2(2, 2)));
218207

219208
write_image(base_coord, 7, reduced);
220-
shared_buffer[BANK_STRIDE + base_coord.y * SHARED_WIDTH + base_coord.x] = reduced;
209+
store_shared(base_coord, reduced);
221210
}
222211

223-
float reduce_mips_simd16(ivec2 base_coord, uint local_index, int mip, float d)
212+
float reduce_mip_simd4(float d, ivec2 base_coord, int mip)
224213
{
225-
ivec2 mip_res = mip_resolution(mip);
226-
float d_horiz, d_vert, d_diag;
227-
bool swap_horiz, swap_vert;
228-
229-
d_horiz = subgroupQuadSwapHorizontal(d);
230-
d_vert = subgroupQuadSwapVertical(d);
231-
d_diag = subgroupQuadSwapDiagonal(d);
232-
write_image(base_coord, mip, d);
233-
234-
if (registers.mips > mip + 1)
235-
{
236-
base_coord >>= 1;
237-
mip_res = mip_resolution(mip + 1);
238-
d = reduce(vec4(d, d_horiz, d_vert, d_diag));
239-
240-
// This requires only SIMD16, which everyone can do.
241-
d_horiz = subgroupShuffleXor(d, SHUFFLE_X1);
242-
d_vert = subgroupShuffleXor(d, SHUFFLE_Y1);
243-
d_diag = subgroupShuffleXor(d, SHUFFLE_X1 | SHUFFLE_Y1);
244-
if ((local_index & 3) == 0)
245-
write_image(base_coord, mip + 1, d);
246-
}
247-
248-
return reduce(vec4(d, d_horiz, d_vert, d_diag));
214+
float d_horiz = subgroupQuadSwapHorizontal(d);
215+
float d_vert = subgroupQuadSwapVertical(d);
216+
float d_diag = subgroupQuadSwapDiagonal(d);
217+
d = reduce(vec4(d, d_horiz, d_vert, d_diag));
218+
if ((gl_SubgroupInvocationID & 3) == 0)
219+
write_image(base_coord, mip, d);
220+
return d;
249221
}
250222

251223
// Each workgroup reduces 64x64 on its own.
@@ -256,37 +228,99 @@ void main()
256228
uint local_index = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID;
257229
uvec2 local_coord = unswizzle16x16(local_index);
258230

259-
// LOD 0 feedback
260-
ivec2 base_coord = ivec2(local_coord) * 4 + ivec2(gl_WorkGroupID.xy * 64u);
261-
mat4 M = fetch_4x4_texture(base_coord);
262-
M = write_mip0_transformed(M, base_coord);
231+
bool is_8x8 = local_index < 64u;
232+
bool is_2x2 = local_index < 4u;
233+
234+
ivec2 base_coord = ivec2(local_coord) * 2 + ivec2(gl_WorkGroupID.xy * 64u);
235+
ivec2 base_coord_00 = base_coord + ivec2( 0, 0);
236+
ivec2 base_coord_10 = base_coord + ivec2(32, 0);
237+
ivec2 base_coord_01 = base_coord + ivec2( 0, 32);
238+
ivec2 base_coord_11 = base_coord + ivec2(32, 32);
263239

264-
// Write LOD 1, Compute LOD 2
240+
// Follow FFX SPD's access pattern here.
241+
// It seems like we need to be super careful about memory access patterns to get optimal bandwidth.
242+
243+
// LOD 0 feedback with transform.
244+
vec4 tile00 = write_mip0_transformed(fetch_2x2_texture(base_coord_00), base_coord_00);
245+
vec4 tile10 = write_mip0_transformed(fetch_2x2_texture(base_coord_10), base_coord_10);
246+
vec4 tile01 = write_mip0_transformed(fetch_2x2_texture(base_coord_01), base_coord_01);
247+
vec4 tile11 = write_mip0_transformed(fetch_2x2_texture(base_coord_11), base_coord_11);
265248
if (registers.mips <= 1)
266249
return;
267-
float d = reduce_mip_registers(M, base_coord >> 1, 1);
250+
251+
// Write LOD 1
252+
ivec2 base_coord_lod1 = base_coord >> 1;
253+
float reduced00 = reduce_mip_simple(tile00, base_coord_lod1 + ivec2( 0, 0), 1);
254+
float reduced10 = reduce_mip_simple(tile10, base_coord_lod1 + ivec2(16, 0), 1);
255+
float reduced01 = reduce_mip_simple(tile01, base_coord_lod1 + ivec2( 0, 16), 1);
256+
float reduced11 = reduce_mip_simple(tile11, base_coord_lod1 + ivec2(16, 16), 1);
268257
if (registers.mips <= 2)
269258
return;
270259

271-
// Write LOD 2, Compute LOD 3-4
272-
d = reduce_mips_simd16(base_coord >> 2, local_index, 2, d);
273-
if (registers.mips <= 4)
260+
// Write LOD 2
261+
ivec2 base_coord_lod2 = base_coord >> 2;
262+
reduced00 = reduce_mip_simd4(reduced00, base_coord_lod2 + ivec2(0, 0), 2);
263+
reduced10 = reduce_mip_simd4(reduced10, base_coord_lod2 + ivec2(8, 0), 2);
264+
reduced01 = reduce_mip_simd4(reduced01, base_coord_lod2 + ivec2(0, 8), 2);
265+
reduced11 = reduce_mip_simd4(reduced11, base_coord_lod2 + ivec2(8, 8), 2);
266+
267+
if (registers.mips <= 3)
274268
return;
275269

276-
// Write LOD 4 to shared
277-
if ((local_index & 15) == 0)
278-
shared_buffer[local_index >> 4] = d;
270+
if ((gl_SubgroupInvocationID & 3) == 0)
271+
{
272+
ivec2 local_coord_shared = ivec2(local_coord) >> 1;
273+
store_shared(local_coord_shared + ivec2(0, 0), reduced00);
274+
store_shared(local_coord_shared + ivec2(8, 0), reduced10);
275+
store_shared(local_coord_shared + ivec2(0, 8), reduced01);
276+
store_shared(local_coord_shared + ivec2(8, 8), reduced11);
277+
}
279278
barrier();
280279

281-
// Write LOD 4, Compute LOD 5-6.
282-
if (local_index < 16)
283-
d = reduce_mips_simd16(ivec2(gl_WorkGroupID.xy * 4u + local_coord), local_index, 4, shared_buffer[local_index]);
280+
// Write LOD 3
281+
float reduced = 0.0;
282+
if (is_8x8)
283+
{
284+
ivec2 base_coord_lod3 = ivec2(gl_WorkGroupID.xy * 8u) + ivec2(local_coord);
285+
ivec2 shared_coord = ivec2(local_coord) * 2;
286+
float d00 = load_shared(shared_coord + ivec2(0, 0));
287+
float d10 = load_shared(shared_coord + ivec2(1, 0));
288+
float d01 = load_shared(shared_coord + ivec2(0, 1));
289+
float d11 = load_shared(shared_coord + ivec2(1, 1));
290+
reduced = reduce_mip_simple(vec4(d00, d10, d01, d11), base_coord_lod3, 3);
291+
292+
// Write LOD 4
293+
if (registers.mips > 4)
294+
reduced = reduce_mip_simd4(reduced, base_coord_lod3 >> 1, 4);
295+
}
284296

285-
// Write LOD 6.
286-
if (registers.mips <= 6)
297+
if (registers.mips <= 5)
287298
return;
288-
if (local_index == 0)
289-
write_image(ivec2(gl_WorkGroupID.xy), 6, d);
299+
300+
// Need this to ensure there is no write-after-read hazard on the shared buffer.
301+
barrier();
302+
303+
if (is_8x8 && (gl_SubgroupInvocationID & 3) == 0)
304+
store_shared(ivec2(local_coord) >> 1, reduced);
305+
306+
barrier();
307+
308+
// Write LOD 5.
309+
if (is_2x2)
310+
{
311+
ivec2 base_coord_lod5 = ivec2(gl_WorkGroupID.xy * 2u) + ivec2(local_coord);
312+
ivec2 shared_coord = ivec2(local_coord) * 2;
313+
float d00 = load_shared(shared_coord + ivec2(0, 0));
314+
float d10 = load_shared(shared_coord + ivec2(1, 0));
315+
float d01 = load_shared(shared_coord + ivec2(0, 1));
316+
float d11 = load_shared(shared_coord + ivec2(1, 1));
317+
reduced = reduce_mip_simple(vec4(d00, d10, d01, d11), base_coord_lod5, 5);
318+
319+
// Write LOD 6
320+
if (registers.mips > 6)
321+
reduce_mip_simd4(reduced, base_coord_lod5 >> 1, 6);
322+
}
323+
290324
if (registers.mips <= 7)
291325
return;
292326

@@ -302,43 +336,22 @@ void main()
302336
if (local_index == 0)
303337
atomic_counter = 0u;
304338

305-
// At this point, the mip resolutions may be non-POT and things get spicy.
306-
// Not using subgroup ops anymore, so use straight linear coordinates.
307-
local_coord.x = bitfieldExtract(local_index, 0, 4);
308-
local_coord.y = bitfieldExtract(local_index, 4, 4);
309-
310339
// Write LOD 7-8, Compute LOD 8
311340
ivec2 mip_res7 = mip_resolution(7);
312341
for (int y = 0; y < mip_res7.y; y += 16)
313342
for (int x = 0; x < mip_res7.x; x += 16)
314343
reduce_mip_lod7(ivec2(local_coord) + ivec2(x, y));
315344

316-
if (registers.mips <= 8)
317-
return;
318-
barrier();
319-
reduce_mip_shared(ivec2(local_coord), 8);
320-
321-
if (registers.mips <= 9)
322-
return;
323-
barrier();
324-
if (local_index < 64)
325-
reduce_mip_shared(ivec2(local_coord), 9);
326-
327-
if (registers.mips <= 10)
328-
return;
329-
barrier();
330-
if (local_index < 16)
331-
reduce_mip_shared(ivec2(local_coord), 10);
332-
333-
if (registers.mips <= 11)
334-
return;
335-
barrier();
336-
if (local_index < 4)
337-
reduce_mip_shared(ivec2(local_coord), 11);
338-
339-
if (registers.mips <= 12)
340-
return;
341-
barrier();
342-
if (local_index == 0)
343-
reduce_mip_shared(ivec2(0), 12);
345+
for (int mip = 8, invocations = 256; mip <= 12; mip++, invocations /= 4)
346+
{
347+
if (registers.mips <= mip)
348+
break;
349+
barrier();
350+
float d;
351+
if (local_index < invocations)
352+
d = reduce_mip_shared(ivec2(local_coord), mip);
353+
barrier();
354+
if (local_index < invocations)
355+
store_shared(ivec2(local_coord), d);
356+
}
344357
}

0 commit comments

Comments
 (0)