7
7
8
8
layout(local_size_x = 256) in;
9
9
10
- layout(set = 0, binding = 0, r32f) uniform coherent image2D uImages[13];
10
+ layout(set = 0, binding = 0, r32f) uniform writeonly image2D uImageTop;
11
+ layout(set = 0, binding = 1, r32f) coherent uniform image2D uImages[12];
11
12
layout(set = 1, binding = 0) uniform sampler2D uTexture;
12
13
layout(set = 1, binding = 1) buffer Counter
13
14
{
@@ -62,23 +63,38 @@ vec4 transform_z(vec4 zs)
62
63
void write_image(ivec2 coord, int mip, float v)
63
64
{
64
65
// Rely on image robustness to clean up the OOB writes here.
65
- imageStore(uImages[mip], coord, vec4(v));
66
+ imageStore(uImages[mip - 1 ], coord, vec4(v));
66
67
}
67
68
68
- void write_image4 (ivec2 coord, int mip, vec4 v)
69
+ void write_image4_top (ivec2 coord, int mip, vec4 v)
69
70
{
70
- imageStore(uImages[mip] , coord + ivec2(0, 0), v.xxxx);
71
- imageStore(uImages[mip] , coord + ivec2(1, 0), v.yyyy);
72
- imageStore(uImages[mip] , coord + ivec2(0, 1), v.zzzz);
73
- imageStore(uImages[mip] , coord + ivec2(1, 1), v.wwww);
71
+ imageStore(uImageTop , coord + ivec2(0, 0), v.xxxx);
72
+ imageStore(uImageTop , coord + ivec2(1, 0), v.yyyy);
73
+ imageStore(uImageTop , coord + ivec2(0, 1), v.zzzz);
74
+ imageStore(uImageTop , coord + ivec2(1, 1), v.wwww);
74
75
}
75
76
76
77
const int SHARED_WIDTH = 32;
77
78
const int SHARED_HEIGHT = 32;
78
- const int BANK_STRIDE = SHARED_WIDTH * SHARED_HEIGHT;
79
- shared float shared_buffer[2 * BANK_STRIDE];
79
+ shared float shared_buffer[SHARED_HEIGHT][SHARED_WIDTH];
80
80
shared bool shared_is_last_workgroup;
81
81
82
+ void store_shared(ivec2 coord, float d)
83
+ {
84
+ shared_buffer[coord.y][coord.x] = d;
85
+ }
86
+
87
+ float load_shared(ivec2 coord)
88
+ {
89
+ return shared_buffer[coord.y][coord.x];
90
+ }
91
+
92
+ vec4 fetch_2x2_texture(ivec2 base_coord)
93
+ {
94
+ vec2 fcoord = vec2(base_coord) * registers.inv_resolution;
95
+ return textureGatherOffset(uTexture, fcoord, ivec2(1, 1)).wzxy;
96
+ }
97
+
82
98
mat4 fetch_4x4_texture(ivec2 base_coord)
83
99
{
84
100
vec2 fcoord = vec2(base_coord) * registers.inv_resolution;
@@ -92,59 +108,38 @@ mat4 fetch_4x4_texture(ivec2 base_coord)
92
108
vec4 fetch_2x2_image_mip6(ivec2 base_coord)
93
109
{
94
110
ivec2 max_coord = mip_resolution(6) - 1;
95
- float d0 = imageLoad(uImages[6 ], min(base_coord + ivec2(0, 0), max_coord)).x;
96
- float d1 = imageLoad(uImages[6 ], min(base_coord + ivec2(1, 0), max_coord)).x;
97
- float d2 = imageLoad(uImages[6 ], min(base_coord + ivec2(0, 1), max_coord)).x;
98
- float d3 = imageLoad(uImages[6 ], min(base_coord + ivec2(1, 1), max_coord)).x;
111
+ float d0 = imageLoad(uImages[5 ], min(base_coord + ivec2(0, 0), max_coord)).x;
112
+ float d1 = imageLoad(uImages[5 ], min(base_coord + ivec2(1, 0), max_coord)).x;
113
+ float d2 = imageLoad(uImages[5 ], min(base_coord + ivec2(0, 1), max_coord)).x;
114
+ float d3 = imageLoad(uImages[5 ], min(base_coord + ivec2(1, 1), max_coord)).x;
99
115
return vec4(d0, d1, d2, d3);
100
116
}
101
117
102
118
float fetch_image_mip6(ivec2 coord)
103
119
{
104
- return imageLoad(uImages[6 ], coord).x;
120
+ return imageLoad(uImages[5 ], coord).x;
105
121
}
106
122
107
- mat4 write_mip0_transformed(mat4 M , ivec2 base_coord)
123
+ vec4 write_mip0_transformed(vec4 v , ivec2 base_coord)
108
124
{
109
- vec4 q00 = transform_z(M[0]);
110
- vec4 q10 = transform_z(M[1]);
111
- vec4 q01 = transform_z(M[2]);
112
- vec4 q11 = transform_z(M[3]);
113
-
125
+ v = transform_z(v);
114
126
// Write out transformed LOD 0
115
- write_image4(base_coord + ivec2(0, 0), 0, q00);
116
- write_image4(base_coord + ivec2(2, 0), 0, q10);
117
- write_image4(base_coord + ivec2(0, 2), 0, q01);
118
- write_image4(base_coord + ivec2(2, 2), 0, q11);
119
-
120
- return mat4(q00, q10, q01, q11);
127
+ write_image4_top(base_coord, 0, v);
128
+ return v;
121
129
}
122
130
123
131
// For LOD 0 to 6, it is expected that the division is exact,
124
132
// i.e., the lower resolution mip is exactly half resolution.
125
133
// This way we avoid needing to fold in neighbors.
126
134
127
- float reduce_mip_registers(mat4 M , ivec2 base_coord, int mip)
135
+ float reduce_mip_simple(vec4 v , ivec2 base_coord, int mip)
128
136
{
129
- vec4 q00 = M[0];
130
- vec4 q10 = M[1];
131
- vec4 q01 = M[2];
132
- vec4 q11 = M[3];
133
-
134
- ivec2 mip_res = mip_resolution(mip);
135
-
136
- float d00 = reduce(q00);
137
- float d10 = reduce(q10);
138
- float d01 = reduce(q01);
139
- float d11 = reduce(q11);
140
-
141
- q00 = vec4(d00, d10, d01, d11);
142
- write_image4(base_coord, mip, q00);
143
-
144
- return reduce(q00);
137
+ float reduced = reduce(v);
138
+ write_image(base_coord, mip, reduced);
139
+ return reduced;
145
140
}
146
141
147
- void reduce_mip_shared(ivec2 base_coord, int mip)
142
+ float reduce_mip_shared(ivec2 base_coord, int mip)
148
143
{
149
144
ivec2 mip_res_higher = mip_resolution(mip - 1);
150
145
ivec2 mip_res_target = mip_resolution(mip);
@@ -153,37 +148,31 @@ void reduce_mip_shared(ivec2 base_coord, int mip)
153
148
bool vert_fold = base_coord.y + 1 == mip_res_target.y && (mip_res_higher.y & 1) != 0;
154
149
bool diag_fold = horiz_fold && vert_fold;
155
150
156
- const int DOUBLE_SHARED_WIDTH = SHARED_WIDTH * 2;
157
-
158
151
// Ping-pong the shared buffer to avoid double barrier.
159
- int out_offset = (mip & 1) * BANK_STRIDE;
160
- int in_offset = BANK_STRIDE - out_offset;
161
- int base_in_coord = in_offset + base_coord.y * DOUBLE_SHARED_WIDTH + base_coord.x * 2;
162
-
163
- float d00 = shared_buffer[base_in_coord];
164
- float d10 = shared_buffer[base_in_coord + 1];
165
- float d01 = shared_buffer[base_in_coord + SHARED_WIDTH];
166
- float d11 = shared_buffer[base_in_coord + SHARED_WIDTH + 1];
152
+ float d00 = load_shared(2 * base_coord + ivec2(0, 0));
153
+ float d10 = load_shared(2 * base_coord + ivec2(1, 0));
154
+ float d01 = load_shared(2 * base_coord + ivec2(0, 1));
155
+ float d11 = load_shared(2 * base_coord + ivec2(1, 1));
167
156
168
157
float reduced = reduce(vec4(d00, d10, d01, d11));
169
158
170
159
if (horiz_fold)
171
160
{
172
- reduced = REDUCE_OPERATOR(reduced, shared_buffer[base_in_coord + 2] );
173
- reduced = REDUCE_OPERATOR(reduced, shared_buffer[base_in_coord + 2 + SHARED_WIDTH] );
161
+ reduced = REDUCE_OPERATOR(reduced, load_shared(2 * base_coord + ivec2(2, 0)) );
162
+ reduced = REDUCE_OPERATOR(reduced, load_shared(2 * base_coord + ivec2(2, 1)) );
174
163
}
175
164
176
165
if (vert_fold)
177
166
{
178
- reduced = REDUCE_OPERATOR(reduced, shared_buffer[base_in_coord + DOUBLE_SHARED_WIDTH] );
179
- reduced = REDUCE_OPERATOR(reduced, shared_buffer[base_in_coord + DOUBLE_SHARED_WIDTH + 1] );
167
+ reduced = REDUCE_OPERATOR(reduced, load_shared(2 * base_coord + ivec2(0, 2)) );
168
+ reduced = REDUCE_OPERATOR(reduced, load_shared(2 * base_coord + ivec2(1, 2)) );
180
169
}
181
170
182
171
if (diag_fold)
183
- reduced = REDUCE_OPERATOR(reduced, shared_buffer[base_in_coord + DOUBLE_SHARED_WIDTH + 2] );
172
+ reduced = REDUCE_OPERATOR(reduced, load_shared(2 * base_coord + ivec2(2, 2)) );
184
173
185
- shared_buffer[out_offset + base_coord.y * SHARED_WIDTH + base_coord.x] = reduced;
186
174
write_image(base_coord, mip, reduced);
175
+ return reduced;
187
176
}
188
177
189
178
void reduce_mip_lod7(ivec2 base_coord)
@@ -217,35 +206,18 @@ void reduce_mip_lod7(ivec2 base_coord)
217
206
reduced = REDUCE_OPERATOR(reduced, fetch_image_mip6(2 * base_coord + ivec2(2, 2)));
218
207
219
208
write_image(base_coord, 7, reduced);
220
- shared_buffer[BANK_STRIDE + base_coord.y * SHARED_WIDTH + base_coord.x] = reduced;
209
+ store_shared( base_coord, reduced) ;
221
210
}
222
211
223
- float reduce_mips_simd16(ivec2 base_coord, uint local_index , int mip, float d )
212
+ float reduce_mip_simd4(float d, ivec2 base_coord , int mip)
224
213
{
225
- ivec2 mip_res = mip_resolution(mip);
226
- float d_horiz, d_vert, d_diag;
227
- bool swap_horiz, swap_vert;
228
-
229
- d_horiz = subgroupQuadSwapHorizontal(d);
230
- d_vert = subgroupQuadSwapVertical(d);
231
- d_diag = subgroupQuadSwapDiagonal(d);
232
- write_image(base_coord, mip, d);
233
-
234
- if (registers.mips > mip + 1)
235
- {
236
- base_coord >>= 1;
237
- mip_res = mip_resolution(mip + 1);
238
- d = reduce(vec4(d, d_horiz, d_vert, d_diag));
239
-
240
- // This requires only SIMD16, which everyone can do.
241
- d_horiz = subgroupShuffleXor(d, SHUFFLE_X1);
242
- d_vert = subgroupShuffleXor(d, SHUFFLE_Y1);
243
- d_diag = subgroupShuffleXor(d, SHUFFLE_X1 | SHUFFLE_Y1);
244
- if ((local_index & 3) == 0)
245
- write_image(base_coord, mip + 1, d);
246
- }
247
-
248
- return reduce(vec4(d, d_horiz, d_vert, d_diag));
214
+ float d_horiz = subgroupQuadSwapHorizontal(d);
215
+ float d_vert = subgroupQuadSwapVertical(d);
216
+ float d_diag = subgroupQuadSwapDiagonal(d);
217
+ d = reduce(vec4(d, d_horiz, d_vert, d_diag));
218
+ if ((gl_SubgroupInvocationID & 3) == 0)
219
+ write_image(base_coord, mip, d);
220
+ return d;
249
221
}
250
222
251
223
// Each workgroup reduces 64x64 on its own.
@@ -256,37 +228,99 @@ void main()
256
228
uint local_index = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID;
257
229
uvec2 local_coord = unswizzle16x16(local_index);
258
230
259
- // LOD 0 feedback
260
- ivec2 base_coord = ivec2(local_coord) * 4 + ivec2(gl_WorkGroupID.xy * 64u);
261
- mat4 M = fetch_4x4_texture(base_coord);
262
- M = write_mip0_transformed(M, base_coord);
231
+ bool is_8x8 = local_index < 64u;
232
+ bool is_2x2 = local_index < 4u;
233
+
234
+ ivec2 base_coord = ivec2(local_coord) * 2 + ivec2(gl_WorkGroupID.xy * 64u);
235
+ ivec2 base_coord_00 = base_coord + ivec2( 0, 0);
236
+ ivec2 base_coord_10 = base_coord + ivec2(32, 0);
237
+ ivec2 base_coord_01 = base_coord + ivec2( 0, 32);
238
+ ivec2 base_coord_11 = base_coord + ivec2(32, 32);
263
239
264
- // Write LOD 1, Compute LOD 2
240
+ // Follow FFX SPD's access pattern here.
241
+ // It seems like we need to be super careful about memory access patterns to get optimal bandwidth.
242
+
243
+ // LOD 0 feedback with transform.
244
+ vec4 tile00 = write_mip0_transformed(fetch_2x2_texture(base_coord_00), base_coord_00);
245
+ vec4 tile10 = write_mip0_transformed(fetch_2x2_texture(base_coord_10), base_coord_10);
246
+ vec4 tile01 = write_mip0_transformed(fetch_2x2_texture(base_coord_01), base_coord_01);
247
+ vec4 tile11 = write_mip0_transformed(fetch_2x2_texture(base_coord_11), base_coord_11);
265
248
if (registers.mips <= 1)
266
249
return;
267
- float d = reduce_mip_registers(M, base_coord >> 1, 1);
250
+
251
+ // Write LOD 1
252
+ ivec2 base_coord_lod1 = base_coord >> 1;
253
+ float reduced00 = reduce_mip_simple(tile00, base_coord_lod1 + ivec2( 0, 0), 1);
254
+ float reduced10 = reduce_mip_simple(tile10, base_coord_lod1 + ivec2(16, 0), 1);
255
+ float reduced01 = reduce_mip_simple(tile01, base_coord_lod1 + ivec2( 0, 16), 1);
256
+ float reduced11 = reduce_mip_simple(tile11, base_coord_lod1 + ivec2(16, 16), 1);
268
257
if (registers.mips <= 2)
269
258
return;
270
259
271
- // Write LOD 2, Compute LOD 3-4
272
- d = reduce_mips_simd16(base_coord >> 2, local_index, 2, d);
273
- if (registers.mips <= 4)
260
+ // Write LOD 2
261
+ ivec2 base_coord_lod2 = base_coord >> 2;
262
+ reduced00 = reduce_mip_simd4(reduced00, base_coord_lod2 + ivec2(0, 0), 2);
263
+ reduced10 = reduce_mip_simd4(reduced10, base_coord_lod2 + ivec2(8, 0), 2);
264
+ reduced01 = reduce_mip_simd4(reduced01, base_coord_lod2 + ivec2(0, 8), 2);
265
+ reduced11 = reduce_mip_simd4(reduced11, base_coord_lod2 + ivec2(8, 8), 2);
266
+
267
+ if (registers.mips <= 3)
274
268
return;
275
269
276
- // Write LOD 4 to shared
277
- if ((local_index & 15) == 0)
278
- shared_buffer[local_index >> 4] = d;
270
+ if ((gl_SubgroupInvocationID & 3) == 0)
271
+ {
272
+ ivec2 local_coord_shared = ivec2(local_coord) >> 1;
273
+ store_shared(local_coord_shared + ivec2(0, 0), reduced00);
274
+ store_shared(local_coord_shared + ivec2(8, 0), reduced10);
275
+ store_shared(local_coord_shared + ivec2(0, 8), reduced01);
276
+ store_shared(local_coord_shared + ivec2(8, 8), reduced11);
277
+ }
279
278
barrier();
280
279
281
- // Write LOD 4, Compute LOD 5-6.
282
- if (local_index < 16)
283
- d = reduce_mips_simd16(ivec2(gl_WorkGroupID.xy * 4u + local_coord), local_index, 4, shared_buffer[local_index]);
280
+ // Write LOD 3
281
+ float reduced = 0.0;
282
+ if (is_8x8)
283
+ {
284
+ ivec2 base_coord_lod3 = ivec2(gl_WorkGroupID.xy * 8u) + ivec2(local_coord);
285
+ ivec2 shared_coord = ivec2(local_coord) * 2;
286
+ float d00 = load_shared(shared_coord + ivec2(0, 0));
287
+ float d10 = load_shared(shared_coord + ivec2(1, 0));
288
+ float d01 = load_shared(shared_coord + ivec2(0, 1));
289
+ float d11 = load_shared(shared_coord + ivec2(1, 1));
290
+ reduced = reduce_mip_simple(vec4(d00, d10, d01, d11), base_coord_lod3, 3);
291
+
292
+ // Write LOD 4
293
+ if (registers.mips > 4)
294
+ reduced = reduce_mip_simd4(reduced, base_coord_lod3 >> 1, 4);
295
+ }
284
296
285
- // Write LOD 6.
286
- if (registers.mips <= 6)
297
+ if (registers.mips <= 5)
287
298
return;
288
- if (local_index == 0)
289
- write_image(ivec2(gl_WorkGroupID.xy), 6, d);
299
+
300
+ // Need this to ensure there is no write-after-read hazard on the shared buffer.
301
+ barrier();
302
+
303
+ if (is_8x8 && (gl_SubgroupInvocationID & 3) == 0)
304
+ store_shared(ivec2(local_coord) >> 1, reduced);
305
+
306
+ barrier();
307
+
308
+ // Write LOD 5.
309
+ if (is_2x2)
310
+ {
311
+ ivec2 base_coord_lod5 = ivec2(gl_WorkGroupID.xy * 2u) + ivec2(local_coord);
312
+ ivec2 shared_coord = ivec2(local_coord) * 2;
313
+ float d00 = load_shared(shared_coord + ivec2(0, 0));
314
+ float d10 = load_shared(shared_coord + ivec2(1, 0));
315
+ float d01 = load_shared(shared_coord + ivec2(0, 1));
316
+ float d11 = load_shared(shared_coord + ivec2(1, 1));
317
+ reduced = reduce_mip_simple(vec4(d00, d10, d01, d11), base_coord_lod5, 5);
318
+
319
+ // Write LOD 6
320
+ if (registers.mips > 6)
321
+ reduce_mip_simd4(reduced, base_coord_lod5 >> 1, 6);
322
+ }
323
+
290
324
if (registers.mips <= 7)
291
325
return;
292
326
@@ -302,43 +336,22 @@ void main()
302
336
if (local_index == 0)
303
337
atomic_counter = 0u;
304
338
305
- // At this point, the mip resolutions may be non-POT and things get spicy.
306
- // Not using subgroup ops anymore, so use straight linear coordinates.
307
- local_coord.x = bitfieldExtract(local_index, 0, 4);
308
- local_coord.y = bitfieldExtract(local_index, 4, 4);
309
-
310
339
// Write LOD 7-8, Compute LOD 8
311
340
ivec2 mip_res7 = mip_resolution(7);
312
341
for (int y = 0; y < mip_res7.y; y += 16)
313
342
for (int x = 0; x < mip_res7.x; x += 16)
314
343
reduce_mip_lod7(ivec2(local_coord) + ivec2(x, y));
315
344
316
- if (registers.mips <= 8)
317
- return;
318
- barrier();
319
- reduce_mip_shared(ivec2(local_coord), 8);
320
-
321
- if (registers.mips <= 9)
322
- return;
323
- barrier();
324
- if (local_index < 64)
325
- reduce_mip_shared(ivec2(local_coord), 9);
326
-
327
- if (registers.mips <= 10)
328
- return;
329
- barrier();
330
- if (local_index < 16)
331
- reduce_mip_shared(ivec2(local_coord), 10);
332
-
333
- if (registers.mips <= 11)
334
- return;
335
- barrier();
336
- if (local_index < 4)
337
- reduce_mip_shared(ivec2(local_coord), 11);
338
-
339
- if (registers.mips <= 12)
340
- return;
341
- barrier();
342
- if (local_index == 0)
343
- reduce_mip_shared(ivec2(0), 12);
345
+ for (int mip = 8, invocations = 256; mip <= 12; mip++, invocations /= 4)
346
+ {
347
+ if (registers.mips <= mip)
348
+ break;
349
+ barrier();
350
+ float d;
351
+ if (local_index < invocations)
352
+ d = reduce_mip_shared(ivec2(local_coord), mip);
353
+ barrier();
354
+ if (local_index < invocations)
355
+ store_shared(ivec2(local_coord), d);
356
+ }
344
357
}
0 commit comments