@@ -603,7 +603,7 @@ namespace Flux {
603603 bool qkv_bias = true ;
604604 bool guidance_embed = true ;
605605 bool flash_attn = true ;
606- bool is_chroma = false ;
606+ bool is_chroma = false ;
607607 };
608608
609609 struct Flux : public GGMLBlock {
@@ -850,20 +850,19 @@ namespace Flux {
850850
851851 // auto arrange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // Not working on a lot of backends
852852 auto arrange = y;
853- auto modulation_index = ggml_nn_timestep_embedding (ctx, arrange, 32 , 10000 , 1000 .f );// [1, 344, 32]
854-
853+ auto modulation_index = ggml_nn_timestep_embedding (ctx, arrange, 32 , 10000 , 1000 .f ); // [1, 344, 32]
854+
855855 // Batch broadcast (will it ever be useful)
856- modulation_index = ggml_repeat (ctx, modulation_index, ggml_new_tensor_4d (ctx, GGML_TYPE_F32, modulation_index->ne [0 ], modulation_index->ne [1 ], img->ne [2 ], modulation_index-> ne [ 3 ] ));// [N, 344, 32]
856+ modulation_index = ggml_repeat (ctx, modulation_index, ggml_new_tensor_3d (ctx, GGML_TYPE_F32, modulation_index->ne [0 ], modulation_index->ne [1 ], img->ne [2 ])); // [N, 344, 32]
857857
858+ auto timestep_guidance = ggml_concat (ctx, distill_timestep, distill_guidance, 0 ); // [N, 1, 32]
859+ timestep_guidance = ggml_repeat (ctx, timestep_guidance, modulation_index); // [N, 344, 32]
858860
859- auto timestep_guidance = ggml_concat (ctx, distill_timestep, distill_guidance, 0 ); // [N, 1, 32]
860- timestep_guidance = ggml_repeat (ctx, timestep_guidance, modulation_index); // [N, 344, 32]
861-
862861 vec = ggml_concat (ctx, timestep_guidance, modulation_index, 0 ); // [N, 344, 64]
863- vec = approx->forward (ctx, vec); // [N, 344, hidden_size]
864-
865862 // Permute for consistency with non-distilled modulation implementation
866- vec = ggml_cont (ctx, ggml_permute (ctx, vec, 0 , 2 , 1 , 3 )); // [344, N, hidden_size]
863+ vec = ggml_cont (ctx, ggml_permute (ctx, vec, 0 , 2 , 1 , 3 )); // [344, N, 64]
864+ vec = approx->forward (ctx, vec); // [344, N, hidden_size]
865+
867866 } else {
868867 auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks[" time_in" ]);
869868 auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks[" vector_in" ]);
0 commit comments