Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion examples/common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ ArgOptions SDGenerationParams::get_options() {
&extra_sample_args},
{"",
"--extra-tiling-args",
"extra VAE tiling args, key=value list. LTX video VAE supports temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)",
"extra VAE tiling args, key=value list. max_buffer_size (bytes) forces the auto fallback to tile when an untiled VAE compute buffer would exceed it. LTX video VAE supports temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)",
&extra_tiling_args},
};

Expand Down Expand Up @@ -1097,6 +1097,12 @@ ArgOptions SDGenerationParams::get_options() {
"process vae in tiles to reduce memory usage",
true,
&vae_tiling_params.enabled},
{"",
"--no-vae-tiling-fallback",
"disable the automatic fallback to VAE tiling when an untiled decode would exceed the "
"backend's max buffer size (fail instead of tiling)",
false,
&vae_tiling_params.auto_tile},
{"",
"--temporal-tiling",
"enable temporal tiling for LTX video VAE decode",
Expand Down Expand Up @@ -1841,6 +1847,9 @@ bool SDGenerationParams::from_json_str(
if (tiling_json.contains("enabled") && tiling_json["enabled"].is_boolean()) {
vae_tiling_params.enabled = tiling_json["enabled"];
}
if (tiling_json.contains("auto_tile") && tiling_json["auto_tile"].is_boolean()) {
vae_tiling_params.auto_tile = tiling_json["auto_tile"];
}
if (tiling_json.contains("temporal_tiling") && tiling_json["temporal_tiling"].is_boolean()) {
vae_tiling_params.temporal_tiling = tiling_json["temporal_tiling"];
}
Expand Down Expand Up @@ -2660,10 +2669,12 @@ std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params,
}

if (gen_params.vae_tiling_params.enabled ||
!gen_params.vae_tiling_params.auto_tile ||
gen_params.vae_tiling_params.temporal_tiling ||
!gen_params.extra_tiling_args.empty()) {
root["vae_tiling"] = {
{"enabled", gen_params.vae_tiling_params.enabled},
{"auto_tile", gen_params.vae_tiling_params.auto_tile},
{"temporal_tiling", gen_params.vae_tiling_params.temporal_tiling},
{"tile_size_x", gen_params.vae_tiling_params.tile_size_x},
{"tile_size_y", gen_params.vae_tiling_params.tile_size_y},
Expand Down
2 changes: 1 addition & 1 deletion examples/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ struct SDGenerationParams {
int video_frames = 1;
int fps = 16;
float vace_strength = 1.f;
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr, true};
std::string extra_tiling_args;

std::string pm_id_images_dir;
Expand Down
2 changes: 1 addition & 1 deletion examples/server/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ Shared default fields used by both `img_gen` and `vid_gen`:
| `output_format` | `string` |
| `output_compression` | `integer` |

`vae_tiling_params.extra_tiling_args` accepts a key=value list. For LTX video VAE temporal tiling, `temporal_tile_frames` defaults to `4` and `temporal_tile_overlap` defaults to `1`.
`vae_tiling_params.extra_tiling_args` accepts a key=value list. `max_buffer_size` (bytes) forces the automatic tiling fallback when an untiled VAE compute buffer would exceed it. For LTX video VAE temporal tiling, `temporal_tile_frames` defaults to `4` and `temporal_tile_overlap` defaults to `1`.

`img_gen`-specific default fields:

Expand Down
3 changes: 2 additions & 1 deletion include/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,15 @@ enum lora_apply_mode_t {
};

typedef struct {
bool enabled;
bool enabled; // true => always tile (ON)
bool temporal_tiling;
int tile_size_x;
int tile_size_y;
float target_overlap;
float rel_size_x;
float rel_size_y;
const char* extra_tiling_args;
bool auto_tile; // AUTO (default): tile only when an untiled VAE decode would exceed the backend's max buffer size
} sd_tiling_params_t;

typedef struct {
Expand Down
82 changes: 80 additions & 2 deletions src/core/ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1705,11 +1705,18 @@ struct GGMLRunner {

ggml_context* compute_ctx = nullptr;
ggml_gallocr* compute_allocr = nullptr;
// Set when alloc_compute_buffer deferred to tiling on purpose (not a failure).
bool compute_buffer_deferred_to_tiling = false;

size_t max_graph_vram_bytes = 0;
bool stream_layers_enabled = false;
size_t observed_max_effective_budget_ = 0;

// When set, alloc_compute_buffer declines a too-large untiled decode so VAE AUTO can tile.
bool probe_compute_buffer_fits_ = false;
// Optional user cap (bytes) to force tiling; 0 = no cap.
size_t probe_max_bytes_ = 0;

std::shared_ptr<WeightAdapter> weight_adapter = nullptr;
std::weak_ptr<RunnerWeightManager> weight_manager;
std::unordered_set<const ggml_tensor*> kept_compute_param_tensor_set;
Expand Down Expand Up @@ -1978,10 +1985,74 @@ struct GGMLRunner {
}

bool alloc_compute_buffer(ggml_cgraph* gf) {
compute_buffer_deferred_to_tiling = false;
if (compute_allocr != nullptr) {
return true;
}
compute_allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(runtime_backend));
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(runtime_backend);

if (probe_compute_buffer_fits_) {
// Defer a too-large untiled decode to tiling before the reserve hits a raw backend error.
if (probe_max_bytes_ > 0) {
ggml_gallocr* probe = ggml_gallocr_new(buft);
size_t sizes[1] = {0};
ggml_gallocr_reserve_n_size(probe, gf, nullptr, nullptr, sizes);
ggml_gallocr_free(probe);
if (sizes[0] > probe_max_bytes_) {
LOG_DEBUG("%s: untiled compute buffer %.2f MB exceeds requested max_buffer_size %.2f MB; deferring to tiling",
get_desc().c_str(),
sizes[0] / 1024.0 / 1024.0,
probe_max_bytes_ / 1024.0 / 1024.0);
compute_buffer_deferred_to_tiling = true;
return false;
}
}
if (sd_backend_is(runtime_backend, "Vulkan")) {
// buft_get_max_size only reports Vulkan's suballocation block; supports_op has the real per-buffer limit.
for (int i = 0; i < ggml_graph_n_nodes(gf); ++i) {
ggml_tensor* op = ggml_graph_node(gf, i);
if (!ggml_backend_supports_op(runtime_backend, op)) {
LOG_DEBUG("%s: untiled compute op %.2f MB exceeds backend support; deferring to tiling",
get_desc().c_str(),
ggml_nbytes(op) / 1024.0 / 1024.0);
compute_buffer_deferred_to_tiling = true;
return false;
}
}
} else {
ggml_gallocr* probe = ggml_gallocr_new(buft);
size_t sizes[1] = {0};
ggml_gallocr_reserve_n_size(probe, gf, nullptr, nullptr, sizes);
ggml_gallocr_free(probe);
size_t planned = sizes[0];

size_t max_size = ggml_backend_buft_get_max_size(buft);
bool over_buffer_cap = max_size > 0 && planned > max_size;

// CUDA/ROCm have no per-buffer cap, so gate on free VRAM plus a margin for the scratch pool the reserve omits.
bool over_free_vram = false;
ggml_backend_dev_t dev = ggml_backend_get_device(runtime_backend);
if (dev != nullptr && ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
size_t free_vram = 0, total_vram = 0;
ggml_backend_dev_memory(dev, &free_vram, &total_vram);
size_t margin = planned / 3;
if (margin < 512ull * 1024 * 1024) {
margin = 512ull * 1024 * 1024;
}
over_free_vram = free_vram > 0 && free_vram < planned + margin;
}

if (over_buffer_cap || over_free_vram) {
LOG_DEBUG("%s: untiled compute buffer %.2f MB won't fit free VRAM; deferring to tiling",
get_desc().c_str(),
planned / 1024.0 / 1024.0);
compute_buffer_deferred_to_tiling = true;
return false;
}
}
}

compute_allocr = ggml_gallocr_new(buft);

if (!ggml_gallocr_reserve(compute_allocr, gf)) {
// failed to allocate the compute buffer
Expand Down Expand Up @@ -2432,7 +2503,9 @@ struct GGMLRunner {
GraphWeightDoneGuard graph_weight_done_guard(this, &params_to_prepare);

if (!alloc_compute_buffer(gf)) {
LOG_ERROR("%s alloc compute buffer failed", get_desc().c_str());
if (!compute_buffer_deferred_to_tiling) {
LOG_ERROR("%s alloc compute buffer failed", get_desc().c_str());
}
return std::nullopt;
}
struct ComputeBufferGuard {
Expand Down Expand Up @@ -2822,6 +2895,11 @@ struct GGMLRunner {
void set_stream_layers_enabled(bool enabled) {
stream_layers_enabled = enabled;
}

void set_probe_compute_buffer_fits(bool enabled, size_t max_bytes = 0) {
probe_compute_buffer_fits_ = enabled;
probe_max_bytes_ = enabled ? max_bytes : 0;
}
};

class GGMLBlock {
Expand Down
45 changes: 45 additions & 0 deletions src/model/vae/vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,52 @@ struct VAE : public GGMLRunner {
"vae decode compute failed while processing a tile",
silent);
} else {
// AUTO: probe first so a too-large decode tiles instead of erroring; output.empty() backstops a real OOM.
const bool auto_probe = !tiling_params.enabled && tiling_params.auto_tile;
if (auto_probe) {
size_t max_bytes = 0;
if (tiling_params.extra_tiling_args != nullptr) {
for (const auto& [key, value] : parse_key_value_args(tiling_params.extra_tiling_args, "VAE extra tiling arg")) {
if (key == "max_buffer_size") {
max_bytes = strtoull(value.c_str(), nullptr, 10);
}
}
}
set_probe_compute_buffer_fits(true, max_bytes);
}
output = _compute(n_threads, input, true);
if (auto_probe) {
set_probe_compute_buffer_fits(false);
}
if (output.empty() && !tiling_params.enabled && tiling_params.auto_tile) {
free_compute_buffer();
if (!silent) {
LOG_WARN("vae: untiled decode buffer exceeded the backend limit; retrying with tiling");
}
sd_tiling_params_t auto_tiling = tiling_params;
auto_tiling.enabled = true;
set_tiling_params(auto_tiling);
const int scale_factor = get_scale_factor();
int64_t W = input.shape()[0] * scale_factor;
int64_t H = input.shape()[1] * scale_factor;
float tile_overlap;
int tile_size_x, tile_size_y;
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, auto_tiling, input.shape()[0], input.shape()[1]);
output = tiled_compute(
input,
n_threads,
static_cast<int>(W),
static_cast<int>(H),
scale_factor,
tile_size_x,
tile_size_y,
tile_overlap,
circular_x,
circular_y,
true,
"vae decode compute failed while processing a tile",
silent);
}
}

free_compute_buffer();
Expand Down
6 changes: 3 additions & 3 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class StableDiffusionGGML {
bool apply_lora_immediately = false;

std::string taesd_path;
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0, nullptr};
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0, nullptr, true};
bool enable_mmap = false;
sd::ggml_graph_cut::MaxVramAssignment max_vram_assignment;
bool stream_layers = false;
Expand Down Expand Up @@ -2843,7 +2843,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->control_strength = 0.9f;
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
sd_img_gen_params->pulid_params = {nullptr, 1.0f};
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr, true};
sd_cache_params_init(&sd_img_gen_params->cache);
sd_hires_params_init(&sd_img_gen_params->hires);
}
Expand Down Expand Up @@ -2930,7 +2930,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
sd_vid_gen_params->fps = 16;
sd_vid_gen_params->moe_boundary = 0.875f;
sd_vid_gen_params->vace_strength = 1.f;
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr, true};
sd_vid_gen_params->hires.enabled = false;
sd_vid_gen_params->hires.upscaler = SD_HIRES_UPSCALER_LATENT;
sd_vid_gen_params->hires.scale = 2.f;
Expand Down
Loading