@@ -34,13 +34,17 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3434
3535#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
3636
37+ // shared memory to hold calculated positions, this would reduce register usage thus improving performance.
38+ shared u16vec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z * TILE_SIZE * TILE_SIZE];
39+
3740/*
3841 * Computes a 2D pointwise convolution of an NxN output tile. Calculating an
3942 * output tile for pointwise convolution is more efficient because the kernel
4043 * size is only 1x1, making it easier to re-use loaded texels from t_kernel.
4144 */
4245void main() {
4346 const uint16_t out_limits_y_scaled = uint16_t((out_limits.y + TILE_SIZE - 1 ) / TILE_SIZE);
47+ const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;
4448
4549 const u16vec3 gpos = u16vec3(
4650 gl_GlobalInvocationID.x / (out_limits_y_scaled * out_limits.z),
@@ -58,6 +62,7 @@ void main() {
5862 for (int x = 0 ; x < TILE_SIZE; ++ x) {
5963 pos[i] = u16vec2(
6064 gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y);
65+ pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
6166 i++ ;
6267 }
6368 }
@@ -73,7 +78,7 @@ void main() {
7378 // the top-left element is in a region added by padding.
7479 u16vec2 ipos[TILE_SIZE * TILE_SIZE];
7580 for (int i = 0 ; i < TILE_SIZE * TILE_SIZE; ++ i) {
76- ipos[i] = pos[i].xy * u16vec2(stride) - u16vec2(padding);
81+ ipos[i] = pos[i] * u16vec2(stride) - u16vec2(padding);
7782 }
7883
7984 vec4 sum[TILE_SIZE * TILE_SIZE];
@@ -138,8 +143,9 @@ void main() {
138143 }
139144
140145 for (int i = 0 ; i < TILE_SIZE * TILE_SIZE; ++ i) {
141- if (all (lessThan (u16vec3(pos[i], gpos.z), out_limits))) {
142- imageStore(t_out, u16vec3(pos[i], gpos.z), op(sum[i], out_min, out_max));
146+ const u16vec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
147+ if (all (lessThan (u16vec3(pos, gpos.z), out_limits))) {
148+ imageStore(t_out, u16vec3(pos, gpos.z), op(sum[i], out_min, out_max));
143149 }
144150 }
145151}
0 commit comments