@@ -491,19 +491,29 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
491491typedef std::function<void (ggml_tensor*, ggml_tensor*, bool )> on_tile_process;
492492
493493// Tiling
494- __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
494+ __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true ) {
495495 int input_width = (int )input->ne [0 ];
496496 int input_height = (int )input->ne [1 ];
497497 int output_width = (int )output->ne [0 ];
498498 int output_height = (int )output->ne [1 ];
499+
500+ int input_tile_size, output_tile_size;
501+ if (scaled_out) {
502+ input_tile_size = tile_size;
503+ output_tile_size = tile_size * scale;
504+ } else {
505+ input_tile_size = tile_size * scale;
506+ output_tile_size = tile_size;
507+ }
508+
499509 GGML_ASSERT (input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0 ); // should be multiple of 2
500510
501- int tile_overlap = (int32_t )(tile_size * tile_overlap_factor);
502- int non_tile_overlap = tile_size - tile_overlap;
511+ int tile_overlap = (int32_t )(input_tile_size * tile_overlap_factor);
512+ int non_tile_overlap = input_tile_size - tile_overlap;
503513
504514 struct ggml_init_params params = {};
505- params.mem_size += tile_size * tile_size * input->ne [2 ] * sizeof (float ); // input chunk
506- params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne [2 ] * sizeof (float ); // output chunk
515+ params.mem_size += input_tile_size * input_tile_size * input->ne [2 ] * sizeof (float ); // input chunk
516+ params.mem_size += output_tile_size * output_tile_size * output->ne [2 ] * sizeof (float ); // output chunk
507517 params.mem_size += 3 * ggml_tensor_overhead ();
508518 params.mem_buffer = NULL ;
509519 params.no_alloc = false ;
@@ -518,8 +528,8 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
518528 }
519529
520530 // tiling
521- ggml_tensor* input_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, tile_size, tile_size , input->ne [2 ], 1 );
522- ggml_tensor* output_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale , output->ne [2 ], 1 );
531+ ggml_tensor* input_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size , input->ne [2 ], 1 );
532+ ggml_tensor* output_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size , output->ne [2 ], 1 );
523533 on_processing (input_tile, NULL , true );
524534 int num_tiles = ceil ((float )input_width / non_tile_overlap) * ceil ((float )input_height / non_tile_overlap);
525535 LOG_INFO (" processing %i tiles" , num_tiles);
@@ -528,19 +538,23 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
528538 bool last_y = false , last_x = false ;
529539 float last_time = 0 .0f ;
530540 for (int y = 0 ; y < input_height && !last_y; y += non_tile_overlap) {
531- if (y + tile_size >= input_height) {
532- y = input_height - tile_size ;
541+ if (y + input_tile_size >= input_height) {
542+ y = input_height - input_tile_size ;
533543 last_y = true ;
534544 }
535545 for (int x = 0 ; x < input_width && !last_x; x += non_tile_overlap) {
536- if (x + tile_size >= input_width) {
537- x = input_width - tile_size ;
546+ if (x + input_tile_size >= input_width) {
547+ x = input_width - input_tile_size ;
538548 last_x = true ;
539549 }
540550 int64_t t1 = ggml_time_ms ();
541551 ggml_split_tensor_2d (input, input_tile, x, y);
542552 on_processing (input_tile, output_tile, false );
543- ggml_merge_tensor_2d (output_tile, output, x * scale, y * scale, tile_overlap * scale);
553+ if (scaled_out) {
554+ ggml_merge_tensor_2d (output_tile, output, x * scale, y * scale, tile_overlap * scale);
555+ } else {
556+ ggml_merge_tensor_2d (output_tile, output, x / scale, y / scale, tile_overlap / scale);
557+ }
544558 int64_t t2 = ggml_time_ms ();
545559 last_time = (t2 - t1) / 1000 .0f ;
546560 pretty_progress (tile_count, num_tiles, last_time);
@@ -673,13 +687,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
673687#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
674688 struct ggml_tensor * kqv = ggml_flash_attn (ctx, q, k, v, false ); // [N * n_head, n_token, d_head]
675689#else
676- float d_head = (float )q->ne [0 ];
690+ float d_head = (float )q->ne [0 ];
677691 struct ggml_tensor * kq = ggml_mul_mat (ctx, k, q); // [N * n_head, n_token, n_k]
678692 kq = ggml_scale_inplace (ctx, kq, 1 .0f / sqrt (d_head));
679693 if (mask) {
680694 kq = ggml_diag_mask_inf_inplace (ctx, kq, 0 );
681695 }
682- kq = ggml_soft_max_inplace (ctx, kq);
696+ kq = ggml_soft_max_inplace (ctx, kq);
683697 struct ggml_tensor * kqv = ggml_mul_mat (ctx, v, kq); // [N * n_head, n_token, d_head]
684698#endif
685699 return kqv;
0 commit comments