diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index bb5d6862c..4f60250c7 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -655,6 +656,21 @@ int main(int argc, const char* argv[]) { } } + SDAudioPtr input_audio; + if (gen_params.init_audio_path.size() > 0) { + input_audio.reset(static_cast(malloc(sizeof(sd_audio_t)))); + if (input_audio == nullptr) { + LOG_ERROR("malloc input audio failed"); + return 1; + } + *input_audio = load_pcm_wav_from_file(gen_params.init_audio_path); + if (input_audio->data == nullptr || input_audio->sample_count == 0) { + LOG_ERROR("load audio from '%s' failed", gen_params.init_audio_path.c_str()); + return 1; + } + gen_params.input_audio = input_audio.get(); + } + if (gen_params.ref_image_paths.size() > 0) { gen_params.ref_images.clear(); for (auto& path : gen_params.ref_image_paths) { diff --git a/examples/common/common.cpp b/examples/common/common.cpp index dd5d35055..759f2c589 100644 --- a/examples/common/common.cpp +++ b/examples/common/common.cpp @@ -870,6 +870,10 @@ ArgOptions SDGenerationParams::get_options() { "--end-img", "path to the end image, required by flf2v", &end_image_path}, + {"", + "--init-audio", + "path to the init audio WAV, for use with audio-to-video models", + &init_audio_path}, {"", "--mask", "path to the mask image", @@ -2223,6 +2227,11 @@ bool SDGenerationParams::validate(SDMode mode) { } } + if (mode != VID_GEN && init_audio_path.length() > 0) { + LOG_ERROR("error: init audio (--init-audio) is only supported in vid_gen mode\n"); + return false; + } + return true; } @@ -2362,6 +2371,7 @@ sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() { params.clip_skip = clip_skip; params.init_image = init_image.get(); params.end_image = end_image.get(); + params.input_audio = input_audio; params.control_frames = control_frame_views.empty() ? nullptr : control_frame_views.data(); params.control_frames_size = static_cast(control_frame_views.size()); params.width = get_resolved_width(); @@ -2431,6 +2441,7 @@ std::string SDGenerationParams::to_string() const { << " batch_count: " << batch_count << ",\n" << " init_image_path: \"" << init_image_path << "\",\n" << " end_image_path: \"" << end_image_path << "\",\n" + << " init_audio_path: \"" << init_audio_path << "\",\n" << " mask_image_path: \"" << mask_image_path << "\",\n" << " control_image_path: \"" << control_image_path << "\",\n" << " ref_image_paths: " << vec_str_to_string(ref_image_paths) << ",\n" diff --git a/examples/common/common.h b/examples/common/common.h index fcf9840db..b2e9a4e4b 100644 --- a/examples/common/common.h +++ b/examples/common/common.h @@ -204,6 +204,7 @@ struct SDGenerationParams { std::string init_image_path; std::string end_image_path; + std::string init_audio_path; std::string mask_image_path; std::string control_image_path; std::vector ref_image_paths; @@ -268,6 +269,7 @@ struct SDGenerationParams { SDImageOwner control_image; std::vector pm_id_images; std::vector control_frames; + const sd_audio_t* input_audio = nullptr; // Backing storage for sd_img_gen_params_t view fields. std::vector ref_image_views; diff --git a/examples/common/media_io.cpp b/examples/common/media_io.cpp index 506c67f4d..f0fdf374e 100644 --- a/examples/common/media_io.cpp +++ b/examples/common/media_io.cpp @@ -191,6 +191,20 @@ uint32_t read_u32_le_bytes(const uint8_t* data) { (static_cast(data[3]) << 24); } +uint16_t read_u16_le_bytes(const uint8_t* p) { + return static_cast(p[0]) | (static_cast(p[1]) << 8); +} + +int32_t read_s24_le_bytes(const uint8_t* p) { + int32_t value = static_cast(p[0]) | + (static_cast(p[1]) << 8) | + (static_cast(p[2]) << 16); + if (value & 0x00800000) { + value |= 0xff000000; + } + return value; +} + int stbi_ext_write_png_to_func(stbi_write_func* func, void* context, int x, @@ -1374,3 +1388,104 @@ bool write_wav_to_file(const std::string& path, file.write(reinterpret_cast(pcm.data()), static_cast(pcm.size() * sizeof(int16_t))); return file.good(); } + +sd_audio_t load_pcm_wav_from_file(const std::string& path) { + sd_audio_t audio = {0, 0, 0, nullptr}; + if (path.empty()) { + return audio; + } + + std::vector wav; + if (!read_binary_file_bytes(path.c_str(), wav)) { + LOG_ERROR("load WAV from '%s' failed", path.c_str()); + return audio; + } + if (wav.size() < 44 || std::memcmp(wav.data(), "RIFF", 4) != 0 || std::memcmp(wav.data() + 8, "WAVE", 4) != 0) { + LOG_ERROR("input audio file '%s' is not a RIFF/WAVE file", path.c_str()); + return audio; + } + + uint16_t format = 0; + uint16_t channels = 0; + uint32_t sample_rate = 0; + uint16_t bits_per_sample = 0; + const uint8_t* data = nullptr; + uint32_t data_size = 0; + + size_t pos = 12; + while (pos + 8 <= wav.size()) { + const uint8_t* chunk = wav.data() + pos; + uint32_t chunk_size = read_u32_le_bytes(chunk + 4); + size_t chunk_data = pos + 8; + if (chunk_data + chunk_size > wav.size()) { + break; + } + + if (std::memcmp(chunk, "fmt ", 4) == 0 && chunk_size >= 16) { + format = read_u16_le_bytes(wav.data() + chunk_data); + channels = read_u16_le_bytes(wav.data() + chunk_data + 2); + sample_rate = read_u32_le_bytes(wav.data() + chunk_data + 4); + bits_per_sample = read_u16_le_bytes(wav.data() + chunk_data + 14); + } else if (std::memcmp(chunk, "data", 4) == 0) { + data = wav.data() + chunk_data; + data_size = chunk_size; + } + pos = chunk_data + chunk_size + (chunk_size & 1); + } + + if (data == nullptr || data_size == 0 || channels == 0 || sample_rate == 0) { + LOG_ERROR("input WAV '%s' is missing fmt/data chunks", path.c_str()); + return audio; + } + if (format != 1 && format != 3) { + LOG_ERROR("unsupported WAV format %u in '%s', only PCM and float WAV are supported", + static_cast(format), + path.c_str()); + return audio; + } + + uint16_t bytes_per_sample = static_cast((bits_per_sample + 7) / 8); + uint32_t frame_bytes = static_cast(bytes_per_sample) * channels; + if (bytes_per_sample == 0 || frame_bytes == 0 || data_size < frame_bytes) { + LOG_ERROR("invalid WAV sample format in '%s'", path.c_str()); + return audio; + } + + uint64_t sample_count = data_size / frame_bytes; + size_t float_count = static_cast(sample_count) * channels; + float* samples = (float*)malloc(float_count * sizeof(float)); + if (samples == nullptr) { + return audio; + } + + for (uint64_t i = 0; i < sample_count; ++i) { + for (uint16_t ch = 0; ch < channels; ++ch) { + const uint8_t* src = data + i * frame_bytes + ch * bytes_per_sample; + float sample = 0.f; + if (format == 3 && bits_per_sample == 32) { + std::memcpy(&sample, src, sizeof(float)); + } else if (format == 1 && bits_per_sample == 8) { + sample = (static_cast(src[0]) - 128) / 128.f; + } else if (format == 1 && bits_per_sample == 16) { + sample = static_cast(read_u16_le_bytes(src)) / 32768.f; + } else if (format == 1 && bits_per_sample == 24) { + sample = read_s24_le_bytes(src) / 8388608.f; + } else if (format == 1 && bits_per_sample == 32) { + sample = static_cast(read_u32_le_bytes(src)) / 2147483648.f; + } else { + LOG_ERROR("unsupported WAV bit depth %u in '%s'", + static_cast(bits_per_sample), + path.c_str()); + free(samples); + return audio; + } + samples[i * channels + ch] = std::clamp(sample, -1.0f, 1.0f); + } + } + + audio.sample_rate = sample_rate; + audio.channels = channels; + audio.sample_count = sample_count; + audio.data = samples; + return audio; +} diff --git a/examples/common/media_io.h b/examples/common/media_io.h index 0f7679d7f..df2fd019b 100644 --- a/examples/common/media_io.h +++ b/examples/common/media_io.h @@ -110,4 +110,6 @@ bool write_wav_to_file(const std::string& path, uint32_t channels, uint32_t sample_rate); +sd_audio_t load_pcm_wav_from_file(const std::string& path); + #endif // __MEDIA_IO_H__ diff --git a/examples/common/resource_owners.hpp b/examples/common/resource_owners.hpp index d47134abe..5b5c3c506 100644 --- a/examples/common/resource_owners.hpp +++ b/examples/common/resource_owners.hpp @@ -40,12 +40,21 @@ struct UpscalerCtxDeleter { } }; +struct SDAudioDeleter { + void operator()(sd_audio_t* audio) const { + if (audio != nullptr) { + free_sd_audio(audio); + } + } +}; + template using FreeUniquePtr = std::unique_ptr; using FilePtr = std::unique_ptr; using SDCtxPtr = std::unique_ptr; using UpscalerCtxPtr = std::unique_ptr; +using SDAudioPtr = std::unique_ptr; class SDImageOwner { private: diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index 1c04367b1..bfcd909cc 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -395,6 +395,7 @@ typedef struct { int64_t seed; int video_frames; int fps; + const sd_audio_t* input_audio; float vace_strength; sd_tiling_params_t vae_tiling_params; sd_cache_params_t cache; diff --git a/src/model/vae/ltx_audio_vae.hpp b/src/model/vae/ltx_audio_vae.hpp index 997c57a5b..49847445d 100644 --- a/src/model/vae/ltx_audio_vae.hpp +++ b/src/model/vae/ltx_audio_vae.hpp @@ -1,6 +1,7 @@ #ifndef __SD_MODEL_VAE_LTX_AUDIO_VAE_HPP__ #define __SD_MODEL_VAE_LTX_AUDIO_VAE_HPP__ +#include #include #include #include @@ -21,6 +22,10 @@ namespace LTXV { int latent_channels = 8; int latent_frequency_bins = 16; int audio_channels = 2; + bool has_encoder = false; + int encoder_channels = 128; + std::vector encoder_channel_multipliers = {1, 2, 4}; + int encoder_num_res_blocks = 2; int decoder_channels = 128; std::vector decoder_channel_multipliers = {1, 2, 4}; int decoder_num_res_blocks = 2; @@ -74,6 +79,7 @@ namespace LTXV { const TensorStorage* decoder_conv_in = require("audio_vae.decoder.conv_in.conv.weight"); const TensorStorage* decoder_conv_out = require("audio_vae.decoder.conv_out.conv.weight"); + const TensorStorage* encoder_conv_in = require("audio_vae.encoder.conv_in.conv.weight"); const TensorStorage* latent_std = require("audio_vae.per_channel_statistics.std-of-means"); const TensorStorage* vocoder_conv_pre = require("vocoder.vocoder.conv_pre.weight"); const TensorStorage* vocoder_conv_post = require("vocoder.vocoder.conv_post.weight"); @@ -98,6 +104,40 @@ namespace LTXV { return config; } + if (encoder_conv_in != nullptr) { + config.has_encoder = true; + config.audio_channels = static_cast(encoder_conv_in->ne[2]); + config.encoder_channels = static_cast(encoder_conv_in->ne[3]); + + std::vector> encoder_level_channels; + for (const auto& pair : tensor_storage_map) { + const std::string& name = pair.first; + const std::string prefix = "audio_vae.encoder.down."; + const std::string suffix = ".block.0.conv1.conv.weight"; + if (!starts_with(name, prefix) || !ends_with(name, suffix)) { + continue; + } + std::string level_str = name.substr(prefix.size(), name.size() - prefix.size() - suffix.size()); + int level = std::stoi(level_str); + encoder_level_channels.push_back({level, static_cast(pair.second.ne[3])}); + } + std::sort(encoder_level_channels.begin(), encoder_level_channels.end()); + if (!encoder_level_channels.empty()) { + config.encoder_channel_multipliers.clear(); + for (const auto& level_channel : encoder_level_channels) { + config.encoder_channel_multipliers.push_back(level_channel.second / std::max(1, config.encoder_channels)); + } + } + + int encoder_block_count = 0; + while (tensor_storage_map.find("audio_vae.encoder.down.0.block." + std::to_string(encoder_block_count) + ".conv1.conv.weight") != tensor_storage_map.end()) { + ++encoder_block_count; + } + if (encoder_block_count > 0) { + config.encoder_num_res_blocks = encoder_block_count; + } + } + std::vector> level_channels; for (const auto& pair : tensor_storage_map) { const std::string& name = pair.first; @@ -171,16 +211,89 @@ namespace LTXV { if (config.audio_channels != 2 || config.latent_channels != 8 || config.mel_bins != 64) { return config; } - LOG_DEBUG("ltx_audio_vae: sample_rate = %d, mel_bins = %d, latent_channels = %d, latent_frequency_bins = %d, has_bwe = %s", + LOG_DEBUG("ltx_audio_vae: sample_rate = %d, mel_bins = %d, latent_channels = %d, latent_frequency_bins = %d, has_encoder = %s, has_bwe = %s", config.sample_rate, config.mel_bins, config.latent_channels, config.latent_frequency_bins, + config.has_encoder ? "true" : "false", config.has_bwe ? "true" : "false"); return config; } }; + static double ltx_audio_hz_to_mel(double freq) { + constexpr double min_log_hz = 1000.0; + constexpr double min_log_mel = 15.0; + constexpr double logstep = 0.06875177742094912; // log(6.4) / 27 + constexpr double f_sp = 200.0 / 3.0; + if (freq < min_log_hz) { + return freq / f_sp; + } + return min_log_mel + std::log(freq / min_log_hz) / logstep; + } + + static double ltx_audio_mel_to_hz(double mel) { + constexpr double min_log_hz = 1000.0; + constexpr double min_log_mel = 15.0; + constexpr double logstep = 0.06875177742094912; // log(6.4) / 27 + constexpr double f_sp = 200.0 / 3.0; + if (mel < min_log_mel) { + return mel * f_sp; + } + return min_log_hz * std::exp(logstep * (mel - min_log_mel)); + } + + static sd::Tensor build_encoder_stft_basis(int n_fft) { + constexpr double kPi = 3.14159265358979323846; + const int n_freqs = n_fft / 2 + 1; + sd::Tensor basis({n_fft, 1, n_freqs * 2}); + for (int k = 0; k < n_freqs; ++k) { + for (int n = 0; n < n_fft; ++n) { + double window = 0.5 - 0.5 * std::cos(2.0 * kPi * n / static_cast(n_fft)); + double phase = 2.0 * kPi * k * n / static_cast(n_fft); + basis.index(n, 0, k) = static_cast(std::cos(phase) * window); + basis.index(n, 0, k + n_freqs) = static_cast(-std::sin(phase) * window); + } + } + return basis; + } + + static sd::Tensor build_encoder_mel_basis(int sample_rate, int n_fft, int n_mels) { + const int n_freqs = n_fft / 2 + 1; + sd::Tensor basis({n_freqs, n_mels}); + std::vector fft_freqs(n_freqs); + for (int i = 0; i < n_freqs; ++i) { + fft_freqs[i] = (static_cast(sample_rate) * 0.5) * static_cast(i) / static_cast(n_freqs - 1); + } + + std::vector mel_f(n_mels + 2); + const double min_mel = ltx_audio_hz_to_mel(0.0); + const double max_mel = ltx_audio_hz_to_mel(static_cast(sample_rate) * 0.5); + for (int i = 0; i < n_mels + 2; ++i) { + double mel = min_mel + (max_mel - min_mel) * static_cast(i) / static_cast(n_mels + 1); + mel_f[i] = ltx_audio_mel_to_hz(mel); + } + + for (int m = 0; m < n_mels; ++m) { + double lower = mel_f[m]; + double center = mel_f[m + 1]; + double upper = mel_f[m + 2]; + double enorm = 2.0 / std::max(upper - lower, 1e-12); + for (int f = 0; f < n_freqs; ++f) { + double freq = fft_freqs[f]; + double value = 0.0; + if (freq > lower && freq <= center) { + value = (freq - lower) / std::max(center - lower, 1e-12); + } else if (freq > center && freq < upper) { + value = (upper - freq) / std::max(upper - center, 1e-12); + } + basis.index(f, m) = static_cast(value * enorm); + } + } + return basis; + } + static ggml_tensor* compute_log_mel_spectrogram(GGMLRunnerContext* runner_ctx, ggml_tensor* waveform, ggml_tensor* forward_basis, @@ -478,6 +591,31 @@ namespace LTXV { } }; + struct AudioDownsample2D : public GGMLBlock { + AudioDownsample2D(int64_t channels) { + blocks["conv"] = std::make_shared(channels, + channels, + std::pair{3, 3}, + std::pair{2, 2}, + std::pair{0, 0}); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + x = ggml_ext_pad_ext(ctx->ggml_ctx, + x, + 0, + 1, + 2, + 0, + 0, + 0, + 0, + 0); + return conv->forward(ctx, x); + } + }; + struct AudioResnetBlock2D : public GGMLBlock { int64_t in_channels; int64_t out_channels; @@ -514,6 +652,86 @@ namespace LTXV { } }; + struct AudioEncoder : public GGMLBlock { + LTXAudioVAEConfig config; + + explicit AudioEncoder(const LTXAudioVAEConfig& config) + : config(config) { + int block_in = config.encoder_channels; + blocks["conv_in"] = std::make_shared(config.audio_channels, block_in, std::pair{3, 3}); + + for (int level = 0; level < static_cast(config.encoder_channel_multipliers.size()); ++level) { + int block_out = config.encoder_channels * config.encoder_channel_multipliers[level]; + for (int block_idx = 0; block_idx < config.encoder_num_res_blocks; ++block_idx) { + blocks["down." + std::to_string(level) + ".block." + std::to_string(block_idx)] = + std::make_shared(block_in, block_out); + block_in = block_out; + } + if (level != static_cast(config.encoder_channel_multipliers.size()) - 1) { + blocks["down." + std::to_string(level) + ".downsample"] = std::make_shared(block_in); + } + } + + blocks["mid.block_1"] = std::make_shared(block_in, block_in); + blocks["mid.block_2"] = std::make_shared(block_in, block_in); + blocks["norm_out"] = std::make_shared(); + blocks["conv_out"] = std::make_shared(block_in, config.latent_channels * 2, std::pair{3, 3}); + } + + ggml_tensor* normalize_latent(GGMLRunnerContext* ctx, + ggml_tensor* latent, + ggml_tensor* mean, + ggml_tensor* stddev) { + latent = ggml_ext_slice(ctx->ggml_ctx, latent, 2, 0, config.latent_channels); + latent = ggml_permute(ctx->ggml_ctx, latent, 0, 2, 1, 3); + latent = ggml_cont(ctx->ggml_ctx, latent); + latent = ggml_reshape_4d(ctx->ggml_ctx, latent, config.latent_frequency_bins * config.latent_channels, latent->ne[2], 1, latent->ne[3]); + + mean = ggml_reshape_4d(ctx->ggml_ctx, mean, mean->ne[0], 1, 1, 1); + stddev = ggml_reshape_4d(ctx->ggml_ctx, stddev, stddev->ne[0], 1, 1, 1); + latent = ggml_div(ctx->ggml_ctx, ggml_sub(ctx->ggml_ctx, latent, mean), stddev); + + latent = ggml_reshape_4d(ctx->ggml_ctx, + latent, + config.latent_frequency_bins, + config.latent_channels, + latent->ne[1], + latent->ne[3]); + latent = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, latent, 0, 2, 1, 3)); + return latent; + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* spectrogram, + ggml_tensor* mean, + ggml_tensor* stddev) { + auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); + auto mid_block_1 = std::dynamic_pointer_cast(blocks["mid.block_1"]); + auto mid_block_2 = std::dynamic_pointer_cast(blocks["mid.block_2"]); + auto norm_out = std::dynamic_pointer_cast(blocks["norm_out"]); + auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); + + auto x = conv_in->forward(ctx, spectrogram); + for (int level = 0; level < static_cast(config.encoder_channel_multipliers.size()); ++level) { + for (int block_idx = 0; block_idx < config.encoder_num_res_blocks; ++block_idx) { + auto block = std::dynamic_pointer_cast(blocks["down." + std::to_string(level) + ".block." + std::to_string(block_idx)]); + x = block->forward(ctx, x); + } + if (level != static_cast(config.encoder_channel_multipliers.size()) - 1) { + auto downsample = std::dynamic_pointer_cast(blocks["down." + std::to_string(level) + ".downsample"]); + x = downsample->forward(ctx, x); + } + } + + x = mid_block_1->forward(ctx, x); + x = mid_block_2->forward(ctx, x); + x = norm_out->forward(ctx, x); + x = ggml_silu_inplace(ctx->ggml_ctx, x); + x = conv_out->forward(ctx, x); + return normalize_latent(ctx, x, mean, stddev); + } + }; + struct Conv1D : public UnaryBlock { int64_t in_channels; int64_t out_channels; @@ -914,6 +1132,9 @@ namespace LTXV { explicit LTXAudioVAE(const LTXAudioVAEConfig& config) : config(config) { + if (config.has_encoder) { + blocks["audio_vae.encoder"] = std::make_shared(config); + } blocks["audio_vae.decoder"] = std::make_shared(config); blocks["vocoder.vocoder"] = std::make_shared(config); if (config.has_bwe) { @@ -993,6 +1214,18 @@ namespace LTXV { return waveform; } + + ggml_tensor* encode(GGMLRunnerContext* ctx, + ggml_tensor* waveform, + ggml_tensor* stft_basis, + ggml_tensor* mel_basis) { + GGML_ASSERT(config.has_encoder); + auto encoder = std::dynamic_pointer_cast(blocks["audio_vae.encoder"]); + auto mean = params["audio_vae.per_channel_statistics.mean-of-means"]; + auto stddev = params["audio_vae.per_channel_statistics.std-of-means"]; + auto mel = compute_log_mel_spectrogram(ctx, waveform, stft_basis, mel_basis, config.mel_hop_length); + return encoder->forward(ctx, mel, mean, stddev); + } }; struct LTXAudioVAERunner : public GGMLRunner { @@ -1000,6 +1233,8 @@ namespace LTXV { LTXAudioVAE model; std::string weight_prefix; sd::Tensor bwe_skip_filter_tensor; + sd::Tensor encoder_stft_basis_tensor; + sd::Tensor encoder_mel_basis_tensor; LTXAudioVAERunner(ggml_backend_t backend, const String2TensorStorage& tensor_storage_map, @@ -1014,6 +1249,10 @@ namespace LTXV { const int bwe_ratio = config.bwe_output_sample_rate / config.bwe_input_sample_rate; bwe_skip_filter_tensor = sd::Tensor::from_vector(build_hann_resample_filter(bwe_ratio)); } + if (config.has_encoder) { + encoder_stft_basis_tensor = build_encoder_stft_basis(config.n_fft); + encoder_mel_basis_tensor = build_encoder_mel_basis(config.sample_rate, config.n_fft, config.mel_bins); + } } void get_param_tensors(std::map& tensors) { @@ -1046,6 +1285,29 @@ namespace LTXV { return result; } + sd::Tensor encode(int n_threads, + const sd::Tensor& waveform_tensor) { + if (!config.has_encoder || waveform_tensor.empty() || + encoder_stft_basis_tensor.empty() || encoder_mel_basis_tensor.empty()) { + return {}; + } + int64_t t0 = ggml_time_ms(); + auto get_graph = [&]() -> ggml_cgraph* { + auto waveform = make_input(waveform_tensor); + auto stft_basis = make_input(encoder_stft_basis_tensor); + auto mel_basis = make_input(encoder_mel_basis_tensor); + ggml_cgraph* gf = new_graph_custom(655360); + auto runner_ctx = GGMLRunner::get_context(); + auto latent = model.encode(&runner_ctx, waveform, stft_basis, mel_basis); + ggml_build_forward_expand(gf, latent); + return gf; + }; + auto result = restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false, false, false), 4); + int64_t t1 = ggml_time_ms(); + LOG_INFO("ltx audio vae encode completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + return result; + } + void test(const std::string& input_path) { auto z = sd::load_tensor_from_file_as_tensor(input_path); GGML_ASSERT(!z.empty()); diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 4df047dd5..b3092cc39 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -1101,9 +1102,6 @@ class StableDiffusionGGML { ignore_tensors.insert("model.diffusion_model.__32x32__"); ignore_tensors.insert("model.diffusion_model.__index_timestep_zero__"); - if (audio_vae_model) { - ignore_tensors.insert("audio_vae.encoder"); - } if (version == VERSION_OVIS_IMAGE) { ignore_tensors.insert("text_encoders.llm.vision_model."); ignore_tensors.insert("text_encoders.llm.visual_tokenizer."); @@ -2928,6 +2926,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { sd_vid_gen_params->seed = -1; sd_vid_gen_params->video_frames = 6; sd_vid_gen_params->fps = 16; + sd_vid_gen_params->input_audio = nullptr; sd_vid_gen_params->moe_boundary = 0.875f; sd_vid_gen_params->vace_strength = 1.f; sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr}; @@ -3027,6 +3026,77 @@ static sd_audio_t* waveform_to_sd_audio(const StableDiffusionGGML* sd, return audio; } +static sd_audio_t* clone_sd_audio(const sd_audio_t* src) { + if (src == nullptr || src->data == nullptr || src->sample_rate == 0 || src->channels == 0 || src->sample_count == 0) { + return nullptr; + } + + sd_audio_t* audio = (sd_audio_t*)malloc(sizeof(sd_audio_t)); + if (audio == nullptr) { + return nullptr; + } + + audio->sample_rate = src->sample_rate; + audio->channels = src->channels; + audio->sample_count = src->sample_count; + size_t sample_bytes = static_cast(src->sample_count) * static_cast(src->channels) * sizeof(float); + audio->data = (float*)malloc(sample_bytes); + if (audio->data == nullptr) { + free(audio); + return nullptr; + } + + std::memcpy(audio->data, src->data, sample_bytes); + return audio; +} + +static sd::Tensor sd_audio_to_ltx_waveform_tensor(const sd_audio_t* audio, + int target_sample_rate, + int target_channels) { + if (audio == nullptr || audio->data == nullptr || audio->sample_rate == 0 || + audio->channels == 0 || audio->sample_count == 0 || target_sample_rate <= 0 || + target_channels <= 0) { + return {}; + } + + uint64_t out_samples_u64 = (audio->sample_count * static_cast(target_sample_rate) + + static_cast(audio->sample_rate) - 1) / + static_cast(audio->sample_rate); + if (out_samples_u64 == 0 || out_samples_u64 > static_cast(std::numeric_limits::max())) { + return {}; + } + + int64_t out_samples = static_cast(out_samples_u64); + sd::Tensor waveform({out_samples, target_channels, 1, 1}); + const double src_rate = static_cast(audio->sample_rate); + const double dst_rate = static_cast(target_sample_rate); + const int src_channels = static_cast(audio->channels); + + auto src_value = [&](uint64_t sample, int channel) -> float { + int src_channel = channel; + if (src_channels == 1) { + src_channel = 0; + } else if (channel >= src_channels) { + src_channel = src_channels - 1; + } + return audio->data[static_cast(sample) * static_cast(src_channels) + static_cast(src_channel)]; + }; + + for (int64_t t = 0; t < out_samples; ++t) { + double src_pos = static_cast(t) * src_rate / dst_rate; + uint64_t i0 = static_cast(std::floor(src_pos)); + uint64_t i1 = std::min(i0 + 1, audio->sample_count - 1); + float frac = static_cast(src_pos - static_cast(i0)); + for (int ch = 0; ch < target_channels; ++ch) { + float v0 = src_value(i0, ch); + float v1 = src_value(i1, ch); + waveform.index(t, ch, 0, 0) = v0 + (v1 - v0) * frac; + } + } + + return waveform; +} + void free_sd_audio(sd_audio_t* audio) { if (audio == nullptr) { return; @@ -3595,7 +3665,8 @@ static sd::Tensor pack_ltxav_audio_and_video_latents(const sd::Tensor pack_ltxav_audio_and_video_denoise_mask(const sd::Tensor& video_mask, const sd::Tensor& video_latent, - const sd::Tensor& audio_latent) { + const sd::Tensor& audio_latent, + float audio_mask_value = 1.f) { if (video_mask.empty() || audio_latent.empty()) { return video_mask; } @@ -3638,7 +3709,7 @@ static sd::Tensor pack_ltxav_audio_and_video_denoise_mask(const sd::Tenso std::vector audio_mask_shape = video_latent.shape(); audio_mask_shape[3] = extra_ch; - auto audio_mask = sd::Tensor::ones(audio_mask_shape); + auto audio_mask = sd::full(audio_mask_shape, audio_mask_value); return sd::ops::concat(video_mask_full, audio_mask, 3); } @@ -4680,6 +4751,44 @@ static std::optional prepare_video_generation_latents(sd if (sd_version_is_ltxav(sd_ctx->sd->version)) { latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps); latents.audio_latent = make_ltxav_empty_audio_latent(latents.audio_length); + if (sd_vid_gen_params->input_audio != nullptr && + sd_vid_gen_params->input_audio->data != nullptr && + sd_vid_gen_params->input_audio->sample_count > 0) { + if (sd_ctx->sd->audio_vae_model == nullptr || !sd_ctx->sd->audio_vae_model->config.has_encoder) { + LOG_ERROR("LTX A2V requires an audio VAE with encoder weights"); + return std::nullopt; + } + + int64_t audio_encode_start = ggml_time_ms(); + auto waveform = sd_audio_to_ltx_waveform_tensor(sd_vid_gen_params->input_audio, + sd_ctx->sd->audio_vae_model->config.sample_rate, + sd_ctx->sd->audio_vae_model->config.audio_channels); + if (waveform.empty()) { + LOG_ERROR("failed to convert source audio for LTX A2V encoding"); + return std::nullopt; + } + + auto encoded_audio_latent = sd_ctx->sd->audio_vae_model->encode(sd_ctx->sd->n_threads, waveform); + if (encoded_audio_latent.empty()) { + LOG_ERROR("LTX A2V audio latent encoding failed"); + return std::nullopt; + } + + latents.audio_latent = resize_ltxav_audio_latent(encoded_audio_latent, latents.audio_length); + if (latents.audio_latent.empty()) { + LOG_ERROR("failed to resize encoded LTX A2V audio latent"); + return std::nullopt; + } + + int64_t audio_encode_end = ggml_time_ms(); + LOG_INFO("encoded LTX A2V source audio latent %dx%dx%dx%d -> length %d, taking %.2fs", + (int)encoded_audio_latent.shape()[0], + (int)encoded_audio_latent.shape()[1], + (int)encoded_audio_latent.shape()[2], + (int)encoded_audio_latent.shape()[3], + latents.audio_length, + (audio_encode_end - audio_encode_start) * 1.0f / 1000); + } } if (sd_version_is_ltxav(sd_ctx->sd->version)) { @@ -4957,10 +5066,17 @@ static std::optional prepare_video_generation_latents(sd } if (sd_version_is_ltxav(sd_ctx->sd->version) && !latents.audio_latent.empty()) { + bool has_input_audio = sd_vid_gen_params->input_audio != nullptr && + sd_vid_gen_params->input_audio->data != nullptr && + sd_vid_gen_params->input_audio->sample_count > 0; + if (has_input_audio && latents.denoise_mask.empty()) { + latents.denoise_mask = make_ltxav_video_denoise_mask(latents.init_latent, 1.f); + } if (!latents.denoise_mask.empty()) { latents.denoise_mask = pack_ltxav_audio_and_video_denoise_mask(latents.denoise_mask, latents.init_latent, - latents.audio_latent); + latents.audio_latent, + has_input_audio ? 0.f : 1.f); } latents.init_latent = pack_ltxav_audio_and_video_latents(latents.init_latent, latents.audio_latent); } @@ -5232,8 +5348,14 @@ static bool apply_ltxv_refine_image_conditioning(sd_ctx_t* sd_ctx, } if (!audio_latent.empty()) { + bool has_input_audio = sd_vid_gen_params->input_audio != nullptr && + sd_vid_gen_params->input_audio->data != nullptr && + sd_vid_gen_params->input_audio->sample_count > 0; *latent = pack_ltxav_audio_and_video_latents(video_latent, audio_latent); - *denoise_mask = pack_ltxav_audio_and_video_denoise_mask(video_mask, video_latent, audio_latent); + *denoise_mask = pack_ltxav_audio_and_video_denoise_mask(video_mask, + video_latent, + audio_latent, + has_input_audio ? 0.f : 1.f); } else { *latent = std::move(video_latent); *denoise_mask = std::move(video_mask); @@ -5265,6 +5387,9 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, int64_t t0 = ggml_time_ms(); sd_ctx->sd->vae_tiling_params = sd_vid_gen_params->vae_tiling_params; GenerationRequest request(sd_ctx, sd_vid_gen_params); + bool has_input_audio = sd_vid_gen_params->input_audio != nullptr && + sd_vid_gen_params->input_audio->data != nullptr && + sd_vid_gen_params->input_audio->sample_count > 0; bool latent_upscale_enabled = request.hires.enabled; GenerationRequest hires_request = request; if (latent_upscale_enabled) { @@ -5482,6 +5607,17 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, &hires_video_positions)) { return false; } + if (has_input_audio && hires_denoise_mask.empty() && x_t.shape()[3] > sd_ctx->sd->get_latent_channel()) { + int latent_channels = sd_ctx->sd->get_latent_channel(); + auto video_latent = sd::ops::slice(x_t, 3, 0, latent_channels); + auto audio_latent = unpack_ltxav_audio_latent(x_t, latents.audio_length, latent_channels); + if (!audio_latent.empty()) { + hires_denoise_mask = pack_ltxav_audio_and_video_denoise_mask(make_ltxav_video_denoise_mask(video_latent, 1.f), + video_latent, + audio_latent, + 0.f); + } + } noise = sd::Tensor::randn_like(x_t, sd_ctx->sd->rng); W = hires_request.width / hires_request.vae_scale_factor; @@ -5550,6 +5686,9 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, sd_audio_t* generated_audio = nullptr; if (sd_version_is_ltxav(sd_ctx->sd->version) && + has_input_audio) { + generated_audio = clone_sd_audio(sd_vid_gen_params->input_audio); + } else if (sd_version_is_ltxav(sd_ctx->sd->version) && latents.audio_length > 0 && sd_ctx->sd->audio_vae_model != nullptr) { if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {