Skip to content

Commit ac708e8

Browse files
authored
non-square VAE tiling (#3)
* refactor tile number calculation * support non-square tiles * add env var to change tile overlap * add safeguards and better error messages for SD_TILE_OVERLAP * add safeguards and include overlapping factor for SD_TILE_SIZE * avoid rounding issues when specifying SD_TILE_SIZE as a factor * lower SD_TILE_OVERLAP limit
1 parent e201588 commit ac708e8

File tree

2 files changed

+129
-55
lines changed

2 files changed

+129
-55
lines changed

ggml_extend.hpp

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,38 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
607607

608608
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
609609

610+
__STATIC_INLINE__ void
611+
sd_tiling_calc_tiles(int &num_tiles_dim, float& tile_overlap_factor_dim, int small_dim, int tile_size, const float tile_overlap_factor) {
612+
613+
int tile_overlap = (tile_size * tile_overlap_factor);
614+
int non_tile_overlap = tile_size - tile_overlap;
615+
616+
num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
617+
int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim;
618+
619+
if ((overshoot_dim != non_tile_overlap) && (overshoot_dim <= num_tiles_dim * (tile_size / 2 - tile_overlap))) {
620+
// if tiles don't fit perfectly using the desired overlap
621+
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
622+
num_tiles_dim++;
623+
}
624+
625+
tile_overlap_factor_dim = (float)(tile_size * num_tiles_dim - small_dim) / (float)(tile_size * (num_tiles_dim - 1));
626+
if (num_tiles_dim <= 2) {
627+
if (small_dim <= tile_size) {
628+
num_tiles_dim = 1;
629+
tile_overlap_factor_dim = 0;
630+
} else {
631+
num_tiles_dim = 2;
632+
tile_overlap_factor_dim = (2 * tile_size - small_dim) / (float)tile_size;
633+
}
634+
}
635+
}
636+
610637
// Tiling
611-
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
638+
__STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, ggml_tensor* output, const int scale,
639+
const int p_tile_size_x, const int p_tile_size_y,
640+
const float tile_overlap_factor, on_tile_process on_processing) {
641+
612642
output = ggml_set_f32(output, 0);
613643

614644
int input_width = (int)input->ne[0];
@@ -629,62 +659,27 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
629659
small_height = input_height;
630660
}
631661

632-
int tile_overlap = (tile_size * tile_overlap_factor);
633-
int non_tile_overlap = tile_size - tile_overlap;
634-
635-
int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
636-
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % small_width;
637-
638-
if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
639-
// if tiles don't fit perfectly using the desired overlap
640-
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
641-
num_tiles_x++;
642-
}
643-
644-
float tile_overlap_factor_x = (float)(tile_size * num_tiles_x - small_width) / (float)(tile_size * (num_tiles_x - 1));
645-
if (num_tiles_x <= 2) {
646-
if (small_width <= tile_size) {
647-
num_tiles_x = 1;
648-
tile_overlap_factor_x = 0;
649-
} else {
650-
num_tiles_x = 2;
651-
tile_overlap_factor_x = (2 * tile_size - small_width) / (float)tile_size;
652-
}
653-
}
654-
655-
int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
656-
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % small_height;
662+
int num_tiles_x;
663+
float tile_overlap_factor_x;
664+
sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor);
657665

658-
if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
659-
// if tiles don't fit perfectly using the desired overlap
660-
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
661-
num_tiles_y++;
662-
}
663-
664-
float tile_overlap_factor_y = (float)(tile_size * num_tiles_y - small_height) / (float)(tile_size * (num_tiles_y - 1));
665-
if (num_tiles_y <= 2) {
666-
if (small_height <= tile_size) {
667-
num_tiles_y = 1;
668-
tile_overlap_factor_y = 0;
669-
} else {
670-
num_tiles_y = 2;
671-
tile_overlap_factor_y = (2 * tile_size - small_height) / (float)tile_size;
672-
}
673-
}
666+
int num_tiles_y;
667+
float tile_overlap_factor_y;
668+
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor);
674669

675670
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
676671
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
677672

678673
GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2
679674

680-
int tile_overlap_x = (int32_t)(tile_size * tile_overlap_factor_x);
681-
int non_tile_overlap_x = tile_size - tile_overlap_x;
675+
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
676+
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
682677

683-
int tile_overlap_y = (int32_t)(tile_size * tile_overlap_factor_y);
684-
int non_tile_overlap_y = tile_size - tile_overlap_y;
678+
int tile_overlap_y = (int32_t)(p_tile_size_y * tile_overlap_factor_y);
679+
int non_tile_overlap_y = p_tile_size_y - tile_overlap_y;
685680

686-
int tile_size_x = tile_size < small_width ? tile_size : small_width;
687-
int tile_size_y = tile_size < small_height ? tile_size : small_height;
681+
int tile_size_x = p_tile_size_x < small_width ? p_tile_size_x : small_width;
682+
int tile_size_y = p_tile_size_y < small_height ? p_tile_size_y : small_height;
688683

689684
int input_tile_size_x = tile_size_x;
690685
int input_tile_size_y = tile_size_y;
@@ -773,6 +768,11 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
773768
ggml_free(tiles_ctx);
774769
}
775770

771+
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale,
772+
const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
773+
sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing);
774+
}
775+
776776
__STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
777777
struct ggml_tensor* a) {
778778
const float eps = 1e-6f; // default eps parameter

stable-diffusion.cpp

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,23 +1427,91 @@ class StableDiffusionGGML {
14271427
x->ne[3]); // channels
14281428
int64_t t0 = ggml_time_ms();
14291429

1430-
int tile_size = 32;
1431-
// TODO: arg instead of env?
1430+
// TODO: args instead of env for tile size / overlap?
1431+
1432+
float tile_overlap = 0.5f;
1433+
const char* SD_TILE_OVERLAP = getenv("SD_TILE_OVERLAP");
1434+
if (SD_TILE_OVERLAP != nullptr) {
1435+
std::string sd_tile_overlap_str = SD_TILE_OVERLAP;
1436+
try {
1437+
tile_overlap = std::stof(sd_tile_overlap_str);
1438+
if (tile_overlap < 0.0) {
1439+
LOG_WARN("SD_TILE_OVERLAP too low, setting it to 0.0");
1440+
tile_overlap = 0.0;
1441+
}
1442+
else if (tile_overlap > 0.5) {
1443+
LOG_WARN("SD_TILE_OVERLAP too high, setting it to 0.5");
1444+
tile_overlap = 0.5;
1445+
}
1446+
} catch (const std::invalid_argument&) {
1447+
LOG_WARN("SD_TILE_OVERLAP is invalid, keeping the default");
1448+
} catch (const std::out_of_range&) {
1449+
LOG_WARN("SD_TILE_OVERLAP is out of range, keeping the default");
1450+
}
1451+
}
1452+
1453+
int tile_size_x = 32;
1454+
int tile_size_y = 32;
14321455
const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE");
14331456
if (SD_TILE_SIZE != nullptr) {
1457+
// format is AxB, or just A (equivalent to AxA)
1458+
// A and B can be integers (tile size) or floating point
1459+
// floating point <= 1 means simple fraction of the latent dimension
1460+
// floating point > 1 means number of tiles across that dimension
1461+
// a single number gets applied to both
1462+
auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
1463+
float factor = std::stof(factor_str);
1464+
if (factor > 1.0)
1465+
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1466+
return factor;
1467+
};
1468+
const int latent_x = W / (decode ? 1 : 8);
1469+
const int latent_y = H / (decode ? 1 : 8);
1470+
const int min_tile_dimension = 4;
14341471
std::string sd_tile_size_str = SD_TILE_SIZE;
1472+
size_t x_pos = sd_tile_size_str.find('x');
14351473
try {
1436-
tile_size = std::stoi(sd_tile_size_str);
1474+
int tmp_x = tile_size_x, tmp_y = tile_size_y;
1475+
if (x_pos != std::string::npos) {
1476+
std::string tile_x_str = sd_tile_size_str.substr(0, x_pos);
1477+
std::string tile_y_str = sd_tile_size_str.substr(x_pos + 1);
1478+
if (tile_x_str.find('.') != std::string::npos) {
1479+
tmp_x = std::round(latent_x * get_tile_factor(tile_x_str));
1480+
}
1481+
else {
1482+
tmp_x = std::stoi(tile_x_str);
1483+
}
1484+
if (tile_y_str.find('.') != std::string::npos) {
1485+
tmp_y = std::round(latent_y * get_tile_factor(tile_y_str));
1486+
}
1487+
else {
1488+
tmp_y = std::stoi(tile_y_str);
1489+
}
1490+
}
1491+
else {
1492+
if (sd_tile_size_str.find('.') != std::string::npos) {
1493+
float tile_factor = get_tile_factor(sd_tile_size_str);
1494+
tmp_x = std::round(latent_x * tile_factor);
1495+
tmp_y = std::round(latent_y * tile_factor);
1496+
}
1497+
else {
1498+
tmp_x = tmp_y = std::stoi(sd_tile_size_str);
1499+
}
1500+
}
1501+
tile_size_x = std::max(std::min(tmp_x, latent_x), min_tile_dimension);
1502+
tile_size_y = std::max(std::min(tmp_y, latent_y), min_tile_dimension);
14371503
} catch (const std::invalid_argument&) {
1438-
LOG_WARN("Invalid");
1504+
LOG_WARN("SD_TILE_SIZE is invalid, keeping the default");
14391505
} catch (const std::out_of_range&) {
1440-
LOG_WARN("OOR");
1506+
LOG_WARN("SD_TILE_SIZE is out of range, keeping the default");
14411507
}
14421508
}
1509+
14431510
if(!decode){
14441511
// TODO: also use and arg for this one?
14451512
// to keep the compute buffer size consistent
1446-
tile_size*=1.30539;
1513+
tile_size_x*=1.30539;
1514+
tile_size_y*=1.30539;
14471515
}
14481516
if (!use_tiny_autoencoder) {
14491517
if (decode) {
@@ -1452,11 +1520,17 @@ class StableDiffusionGGML {
14521520
ggml_tensor_scale_input(x);
14531521
}
14541522
if (vae_tiling) {
1523+
if (SD_TILE_SIZE != nullptr) {
1524+
LOG_INFO("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
1525+
}
1526+
if (SD_TILE_OVERLAP != nullptr) {
1527+
LOG_INFO("VAE Tile overlap: %.2f", tile_overlap);
1528+
}
14551529
// split latent in 32x32 tiles and compute in several steps
14561530
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14571531
first_stage_model->compute(n_threads, in, decode, &out);
14581532
};
1459-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
1533+
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
14601534
} else {
14611535
first_stage_model->compute(n_threads, x, decode, &result);
14621536
}

0 commit comments

Comments
 (0)