Skip to content

Commit d20f77f

Browse files
committed
Reshape before approx
1 parent 93ed721 commit d20f77f

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

flux.hpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ namespace Flux {
603603
bool qkv_bias = true;
604604
bool guidance_embed = true;
605605
bool flash_attn = true;
606-
bool is_chroma = false;
606+
bool is_chroma = false;
607607
};
608608

609609
struct Flux : public GGMLBlock {
@@ -850,20 +850,19 @@ namespace Flux {
850850

851851
// auto arrange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // Not working on a lot of backends
852852
auto arrange = y;
853-
auto modulation_index = ggml_nn_timestep_embedding(ctx, arrange, 32, 10000, 1000.f);// [1, 344, 32]
854-
853+
auto modulation_index = ggml_nn_timestep_embedding(ctx, arrange, 32, 10000, 1000.f); // [1, 344, 32]
854+
855855
// Batch broadcast (will it ever be useful)
856-
modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2], modulation_index->ne[3]));// [N, 344, 32]
856+
modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32]
857857

858+
auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32]
859+
timestep_guidance = ggml_repeat(ctx, timestep_guidance, modulation_index); // [N, 344, 32]
858860

859-
auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32]
860-
timestep_guidance = ggml_repeat(ctx, timestep_guidance, modulation_index); // [N, 344, 32]
861-
862861
vec = ggml_concat(ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64]
863-
vec = approx->forward(ctx, vec); // [N, 344, hidden_size]
864-
865862
// Permute for consistency with non-distilled modulation implementation
866-
vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, hidden_size]
863+
vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, 64]
864+
vec = approx->forward(ctx, vec); // [344, N, hidden_size]
865+
867866
} else {
868867
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
869868
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);

0 commit comments

Comments
 (0)