@@ -607,8 +607,38 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
607607
608608typedef std::function<void (ggml_tensor*, ggml_tensor*, bool )> on_tile_process;
609609
610+ __STATIC_INLINE__ void
611+ sd_tiling_calc_tiles (int &num_tiles_dim, float & tile_overlap_factor_dim, int small_dim, int tile_size, const float tile_overlap_factor) {
612+
613+ int tile_overlap = (tile_size * tile_overlap_factor);
614+ int non_tile_overlap = tile_size - tile_overlap;
615+
616+ num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
617+ int overshoot_dim = ((num_tiles_dim + 1 ) * non_tile_overlap + tile_overlap) % small_dim;
618+
619+ if ((overshoot_dim != non_tile_overlap) && (overshoot_dim <= num_tiles_dim * (tile_size / 2 - tile_overlap))) {
620+ // if tiles don't fit perfectly using the desired overlap
621+ // and there is enough room to squeeze an extra tile without overlap becoming >0.5
622+ num_tiles_dim++;
623+ }
624+
625+ tile_overlap_factor_dim = (float )(tile_size * num_tiles_dim - small_dim) / (float )(tile_size * (num_tiles_dim - 1 ));
626+ if (num_tiles_dim <= 2 ) {
627+ if (small_dim <= tile_size) {
628+ num_tiles_dim = 1 ;
629+ tile_overlap_factor_dim = 0 ;
630+ } else {
631+ num_tiles_dim = 2 ;
632+ tile_overlap_factor_dim = (2 * tile_size - small_dim) / (float )tile_size;
633+ }
634+ }
635+ }
636+
610637// Tiling
611- __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
638+ __STATIC_INLINE__ void sd_tiling_non_square (ggml_tensor* input, ggml_tensor* output, const int scale,
639+ const int p_tile_size_x, const int p_tile_size_y,
640+ const float tile_overlap_factor, on_tile_process on_processing) {
641+
612642 output = ggml_set_f32 (output, 0 );
613643
614644 int input_width = (int )input->ne [0 ];
@@ -629,62 +659,27 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
629659 small_height = input_height;
630660 }
631661
632- int tile_overlap = (tile_size * tile_overlap_factor);
633- int non_tile_overlap = tile_size - tile_overlap;
634-
635- int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
636- int overshoot_x = ((num_tiles_x + 1 ) * non_tile_overlap + tile_overlap) % small_width;
637-
638- if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
639- // if tiles don't fit perfectly using the desired overlap
640- // and there is enough room to squeeze an extra tile without overlap becoming >0.5
641- num_tiles_x++;
642- }
643-
644- float tile_overlap_factor_x = (float )(tile_size * num_tiles_x - small_width) / (float )(tile_size * (num_tiles_x - 1 ));
645- if (num_tiles_x <= 2 ) {
646- if (small_width <= tile_size) {
647- num_tiles_x = 1 ;
648- tile_overlap_factor_x = 0 ;
649- } else {
650- num_tiles_x = 2 ;
651- tile_overlap_factor_x = (2 * tile_size - small_width) / (float )tile_size;
652- }
653- }
654-
655- int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
656- int overshoot_y = ((num_tiles_y + 1 ) * non_tile_overlap + tile_overlap) % small_height;
662+ int num_tiles_x;
663+ float tile_overlap_factor_x;
664+ sd_tiling_calc_tiles (num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor);
657665
658- if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
659- // if tiles don't fit perfectly using the desired overlap
660- // and there is enough room to squeeze an extra tile without overlap becoming >0.5
661- num_tiles_y++;
662- }
663-
664- float tile_overlap_factor_y = (float )(tile_size * num_tiles_y - small_height) / (float )(tile_size * (num_tiles_y - 1 ));
665- if (num_tiles_y <= 2 ) {
666- if (small_height <= tile_size) {
667- num_tiles_y = 1 ;
668- tile_overlap_factor_y = 0 ;
669- } else {
670- num_tiles_y = 2 ;
671- tile_overlap_factor_y = (2 * tile_size - small_height) / (float )tile_size;
672- }
673- }
666+ int num_tiles_y;
667+ float tile_overlap_factor_y;
668+ sd_tiling_calc_tiles (num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor);
674669
675670 LOG_DEBUG (" num tiles : %d, %d " , num_tiles_x, num_tiles_y);
676671 LOG_DEBUG (" optimal overlap : %f, %f (targeting %f)" , tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
677672
678673 GGML_ASSERT (input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0 ); // should be multiple of 2
679674
680- int tile_overlap_x = (int32_t )(tile_size * tile_overlap_factor_x);
681- int non_tile_overlap_x = tile_size - tile_overlap_x;
675+ int tile_overlap_x = (int32_t )(p_tile_size_x * tile_overlap_factor_x);
676+ int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
682677
683- int tile_overlap_y = (int32_t )(tile_size * tile_overlap_factor_y);
684- int non_tile_overlap_y = tile_size - tile_overlap_y;
678+ int tile_overlap_y = (int32_t )(p_tile_size_y * tile_overlap_factor_y);
679+ int non_tile_overlap_y = p_tile_size_y - tile_overlap_y;
685680
686- int tile_size_x = tile_size < small_width ? tile_size : small_width;
687- int tile_size_y = tile_size < small_height ? tile_size : small_height;
681+ int tile_size_x = p_tile_size_x < small_width ? p_tile_size_x : small_width;
682+ int tile_size_y = p_tile_size_y < small_height ? p_tile_size_y : small_height;
688683
689684 int input_tile_size_x = tile_size_x;
690685 int input_tile_size_y = tile_size_y;
@@ -773,6 +768,11 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
773768 ggml_free (tiles_ctx);
774769}
775770
771+ __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale,
772+ const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
773+ sd_tiling_non_square (input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing);
774+ }
775+
776776__STATIC_INLINE__ struct ggml_tensor * ggml_group_norm_32 (struct ggml_context * ctx,
777777 struct ggml_tensor * a) {
778778 const float eps = 1e-6f ; // default eps parameter
0 commit comments