@@ -26,63 +26,23 @@ layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
2626 BUF_T buffer_in[];
2727};
2828
29- // Corresponds to {1,4,9,24} in the example below.
3029layout (set = 0 , binding = 2 ) uniform PRECISION restrict Sizes {
3130 ivec4 sizes;
3231};
3332
34- // Corresponds to {3,3,7,10} in the example below.
3533layout (set = 0 , binding = 3 ) uniform PRECISION restrict OriginalSizes {
3634 ivec4 original_sizes;
3735};
3836
39- // Corresponds to {8,12} in the example below.
40- layout (set = 0 , binding = 4 ) uniform PRECISION restrict PaddedSizes {
41- ivec2 padded_sizes;
42- };
43-
4437layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
4538
4639layout (constant_id = 3 ) const int packed_dim = C_DIM;
4740
4841/*
4942 * Computes special prepacking for a 2D convolution. Each shader invocation
50- * calculates the input buffer location to read into the desired texel. This
51- * packing was originally developed on CPU and that approach is described in the
52- * rest of this comment. Refer to the code-level comments, for how we translate
53- * it to GPU by reversing the steps.
54- *
55- * Consider an example weight tensor of size {10,7,3,3}. The following
56- * transformations will be applied.
57- *
58- * 1. Pad the N and C dims so that both are a multiple of 4. In this case, 2
59- * batches and 1 channel of padding are added, producing a tensor of size
60- * {12,8,3,3}.
61- * at::pad(x, {0,0,0,0,0,1,0,2}, "constant", 0);
62- *
63- * 2. Split the tensor along the C dim so that each split has 4 channels.
64- * x.reshape({12,2,4,3,3});
65- *
66- * 3. For each split, "fold" the C dim into the W dim. Suppose the first rows
67- * at H=0 of the split have values
68- * 0,1,2 | 10,11,12 | 20,21,22 | 30,31,32
69- *
70- * where | denotes a channel boundary. Then, the goal is to combine those rows
71- * into one row with the values
72- * 0, 10, 20, 30, 1, 11, 21, 31, 2, 12, 22, 32
73- *
74- * x.permute({0,1,3,4,2}).reshape({12,2,3,12});
75- *
76- * 4. Stack the splits belonging to the same batch horizontally by swapping the
77- * C and H dims.
78- * x.permute({0,2,1,3}).reshape({12,3,24});
79- *
80- * 5. Repeat a similar process to "fold" the N dim into the C dim. Split along
81- * the N dim so that each split has 4 batches.
82- * x.reshape({3,4,3,24});
83- *
84- * 6. Stack the batches on each other vertically by swapping the N and C dims.
85- * x.permute({1,0,2,3}).reshape({4,9,24});
43+ * calculates the input buffer locations to read into the desired texel. This
44+ * packing was originally developed on CPU here:
45+ * https://github.com/pytorch/pytorch/blob/d63e7d0aa2e0a1b1fd7518f917224774afe97bae/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L120-L211
8646 */
8747void main() {
8848 const ivec3 pos = ivec3 (gl_GlobalInvocationID);
@@ -92,49 +52,44 @@ void main() {
9252 return ;
9353 }
9454
95- // As in usual staging shaders, map from GPU texel position to normal CPU
96- // buffer indices: (24,9) -> (4,9,24)
55+ // Map tensor_idx to normal buffer_i
9756 const ivec4 p0 = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim);
9857
99- // Re-map the normal CPU buffer indices to special indices, through a series
100- // of mappings: reshape is a no-op to the underlying indices, so we only map
101- // for pad and permute.
102- const int Np = padded_sizes.y;
103- const int Cp = padded_sizes.x;
58+ // Compute modified tensor_idx by inverting the CPU function
10459 const int N = original_sizes.w;
10560 const int C = original_sizes.z;
10661 const int H = original_sizes.y;
10762 const int W = original_sizes.x;
63+ const int J = sizes.x / (4 * W);
64+ const int K = sizes.y / H;
65+
66+ const ivec4 p1 = p0 / 4 ;
67+ const ivec4 p2 = p1 / W;
68+ const ivec4 p3 = p2 / J;
69+ const ivec4 p4 = p3 / H;
70+
71+ const ivec4 n = (p4 % K) * 4 + (p4 / K);
72+ const ivec4 c = (p2 % J) * 4 + (p0 % 4 );
73+ const ivec4 h = p3 % H;
74+ const ivec4 w = p1 % W;
10875
109- // Undo step 6 premute: (4,3,3,24) -> (3,4,3,24)
110- // Undo step 4 permute: (12,3,2,12) -> (12,2,3,12)
111- // Undo step 3 permute, part 1: (12,2,3h,3w,4) -> (12,2,3h,4,3w)
112- // Undo step 3 permute, part 2: (12,2,3h,4,3w) -> (12,2,4,3h,3w)
113- const ivec4 p1 = swap_adj_dims(p0, 4 , (Np / 4 ), (H * Cp * W));
114- const ivec4 p2 = swap_adj_dims(p1, H, (Cp / 4 ), (W * 4 ));
115- const ivec4 p3 = swap_adj_dims(p2, W, 4 , 1 );
116- const ivec4 p4 = swap_adj_dims(p3, H, 4 , W);
117-
118- // Undo step 1 pad: (12,8,3,3) -> (10,7,3,3)
119- // For values in the padded region, write zero instead of buffer data.
120- const ivec4 c = p4 % (Cp * H * W) / (H * W);
121- const ivec4 n = p4 / (Cp * H * W);
122- const ivec4 p5 = p4 - n * (Cp - C) * H * W;
123- const ivec4 mask = ivec4 (greaterThanEqual (c, ivec4 (C))) |
124- ivec4 (greaterThanEqual (n, ivec4 (N)));
76+ // Map modified tensor_idx to modified buffer_i
77+ // Zero out if modified tensor idx is out of bounds
78+ const ivec4 buf_i = n * C* H* W + c * H* W + h * W + w;
79+ const bvec4 mask = bvec4 (ivec4 (lessThan (n, ivec4 (N))) & ivec4 (lessThan (c, ivec4 (C))));
12580
12681 VEC4_T texel = VEC4_T(0 );
127- if (mask.x == 0 ) {
128- texel.x = SCALAR_T(buffer_in[p5 .x]);
82+ if (mask.x) {
83+ texel.x = SCALAR_T(buffer_in[buf_i .x]);
12984 }
130- if (mask.y == 0 ) {
131- texel.y = SCALAR_T(buffer_in[p5 .y]);
85+ if (mask.y) {
86+ texel.y = SCALAR_T(buffer_in[buf_i .y]);
13287 }
133- if (mask.z == 0 ) {
134- texel.z = SCALAR_T(buffer_in[p5 .z]);
88+ if (mask.z) {
89+ texel.z = SCALAR_T(buffer_in[buf_i .z]);
13590 }
136- if (mask.w == 0 ) {
137- texel.w = SCALAR_T(buffer_in[p5 .w]);
91+ if (mask.w) {
92+ texel.w = SCALAR_T(buffer_in[buf_i .w]);
13893 }
13994
14095 imageStore(image_out, pos.xy, texel);
0 commit comments