@@ -99,12 +99,10 @@ class AttnBlock : public UnaryBlock {
9999 k = ggml_cont (ctx, ggml_permute (ctx, k, 1 , 2 , 0 , 3 )); // [N, h, w, in_channels]
100100 k = ggml_reshape_3d (ctx, k, c, h * w, n); // [N, h * w, in_channels]
101101
102- auto v = v_proj->forward (ctx, h_); // [N, in_channels, h, w]
103- v = ggml_cont (ctx, ggml_permute (ctx, v, 1 , 2 , 0 , 3 )); // [N, h, w, in_channels]
104- v = ggml_reshape_3d (ctx, v, c, h * w, n); // [N, h * w, in_channels]
102+ auto v = v_proj->forward (ctx, h_); // [N, in_channels, h, w]
103+ v = ggml_reshape_3d (ctx, v, h * w, c, n); // [N, in_channels, h * w]
105104
106- // h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
107- h_ = ggml_nn_attention_ext (ctx, q, k, v, 1 , nullptr , false , true , false );
105+ h_ = ggml_nn_attention (ctx, q, k, v, false ); // [N, h * w, in_channels]
108106
109107 h_ = ggml_cont (ctx, ggml_permute (ctx, h_, 1 , 0 , 2 , 3 )); // [N, in_channels, h * w]
110108 h_ = ggml_reshape_4d (ctx, h_, w, h, c, n); // [N, in_channels, h, w]
0 commit comments