Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ static void sample_k_diffusion(sample_method_t method,
ggml_context* work_ctx,
ggml_tensor* x,
std::vector<float> sigmas,
int initial_step,
std::shared_ptr<RNG> rng,
float eta) {
size_t steps = sigmas.size() - 1;
Expand Down Expand Up @@ -1248,12 +1249,13 @@ static void sample_k_diffusion(sample_method_t method,
// - pred_sample_direction -> "direction pointing to
// x_t"
// - pred_prev_sample -> "x_t-1"
int timestep =
roundf(TIMESTEPS -
i * ((float)TIMESTEPS / steps)) -
1;
int timestep = TIMESTEPS - 1 -
(int)roundf((initial_step + i) *
(TIMESTEPS / float(initial_step + steps)));
// 1. get previous step value (=t-1)
int prev_timestep = timestep - TIMESTEPS / steps;
int prev_timestep = TIMESTEPS - 1 -
(int)roundf((initial_step + i + 1) *
(TIMESTEPS / float(initial_step + steps)));
// The sigma here is chosen to cause the
// CompVisDenoiser to produce t = timestep
float sigma = compvis_sigmas[timestep];
Expand Down Expand Up @@ -1425,9 +1427,14 @@ static void sample_k_diffusion(sample_method_t method,
// Analytic form for TCD timesteps
int timestep = TIMESTEPS - 1 -
(TIMESTEPS / original_steps) *
(int)floor(i * ((float)original_steps / steps));
(int)floor((initial_step + i) *
((float)original_steps / (initial_step + steps)));
// 1. get previous step value
int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor((i + 1) * ((float)original_steps / steps));
int prev_timestep = i >= steps - 1 ? 0 :
TIMESTEPS - 1 -
(TIMESTEPS / original_steps) *
(int)floor((initial_step + i + 1) *
((float)original_steps / (initial_step + steps)));
// Here timestep_s is tau_n' in Algorithm 4. The _s
// notation appears to be that from C. Lu,
// "DPM-Solver: A Fast ODE Solver for Diffusion
Expand Down
12 changes: 10 additions & 2 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,7 @@ class StableDiffusionGGML {
int shifted_timestep,
sample_method_t method,
const std::vector<float>& sigmas,
int initial_step,
int start_merge_step,
SDCondition id_cond,
std::vector<ggml_tensor*> ref_latents = {},
Expand Down Expand Up @@ -1837,7 +1838,7 @@ class StableDiffusionGGML {
return denoised;
};

sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta);
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, initial_step, sampler_rng, eta);

if (easycache_enabled) {
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
Expand Down Expand Up @@ -2762,6 +2763,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
int height,
enum sample_method_t sample_method,
const std::vector<float>& sigmas,
int initial_step,
int64_t seed,
int batch_count,
sd_image_t control_image,
Expand Down Expand Up @@ -3056,6 +3058,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
shifted_timestep,
sample_method,
sigmas,
initial_step,
start_merge_step,
id_cond,
ref_latents,
Expand Down Expand Up @@ -3173,6 +3176,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_ctx->sd->get_image_seq_len(height, width),
sd_img_gen_params->sample_params.scheduler,
sd_ctx->sd->version);
int initial_step = 0;

ggml_tensor* init_latent = nullptr;
ggml_tensor* concat_latent = nullptr;
Expand All @@ -3185,7 +3189,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
t_enc--;
LOG_INFO("target t_enc is %zu steps", t_enc);
std::vector<float> sigma_sched;
sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end());
initial_step = sample_steps - t_enc - 1;
sigma_sched.assign(sigmas.begin() + initial_step, sigmas.end());
sigmas = sigma_sched;

ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
Expand Down Expand Up @@ -3373,6 +3378,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
height,
sample_method,
sigmas,
initial_step,
seed,
sd_img_gen_params->batch_count,
sd_img_gen_params->control_image,
Expand Down Expand Up @@ -3709,6 +3715,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_vid_gen_params->high_noise_sample_params.shifted_timestep,
high_noise_sample_method,
high_noise_sigmas,
0,
-1,
{},
{},
Expand Down Expand Up @@ -3746,6 +3753,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_vid_gen_params->sample_params.shifted_timestep,
sample_method,
sigmas,
0,
-1,
{},
{},
Expand Down
Loading