diff --git a/README.md b/README.md index 0ff8186cf..b5a9e63f0 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,7 @@ API and command-line option may change frequently.*** - [Chroma1-Radiance](./docs/chroma_radiance.md) - [Qwen Image](./docs/qwen_image.md) - [Z-Image](./docs/z_image.md) + - [Ovis-Image](./docs/ovis_image.md) - Image Edit Models - [FLUX.1-Kontext-dev](./docs/kontext.md) - [Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md) @@ -134,6 +135,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe - [🔥Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md) - [🔥Wan2.1/Wan2.2](./docs/wan.md) - [🔥Z-Image](./docs/z_image.md) +- [Ovis-Image](./docs/ovis_image.md) - [LoRA](./docs/lora.md) - [LCM/LCM-LoRA](./docs/lcm.md) - [Using PhotoMaker to personalize image generation](./docs/photo_maker.md) diff --git a/assets/ovis_image/example.png b/assets/ovis_image/example.png new file mode 100644 index 000000000..ea7f6e126 Binary files /dev/null and b/assets/ovis_image/example.png differ diff --git a/conditioner.hpp b/conditioner.hpp index 403120d9b..2e5972c1b 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1638,7 +1638,7 @@ struct LLMEmbedder : public Conditioner { LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL; if (sd_version_is_flux2(version)) { arch = LLM::LLMArch::MISTRAL_SMALL_3_2; - } else if (sd_version_is_z_image(version)) { + } else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) { arch = LLM::LLMArch::QWEN3; } if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) { @@ -1728,6 +1728,7 @@ struct LLMEmbedder : public Conditioner { std::vector> image_embeds; std::pair prompt_attn_range; int prompt_template_encode_start_idx = 34; + int max_length = 0; std::set out_layers; if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { LOG_INFO("QwenImageEditPlusPipeline"); @@ -1825,6 +1826,17 @@ struct LLMEmbedder : public Conditioner { prompt_attn_range.second = prompt.size(); prompt += "[/INST]"; + } else if (version == VERSION_OVIS_IMAGE) { + prompt_template_encode_start_idx = 28; + max_length = prompt_template_encode_start_idx + 256; + + prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += " " + conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; } else { prompt_template_encode_start_idx = 34; @@ -1837,7 +1849,7 @@ struct LLMEmbedder : public Conditioner { prompt += "<|im_end|>\n<|im_start|>assistant\n"; } - auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false); + auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0); auto& tokens = std::get<0>(tokens_and_weights); auto& weights = std::get<1>(tokens_and_weights); @@ -1870,9 +1882,13 @@ struct LLMEmbedder : public Conditioner { GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx); - int64_t zero_pad_len = 0; + int64_t min_length = 0; if (sd_version_is_flux2(version)) { - int64_t min_length = 512; + min_length = 512; + } + + int64_t zero_pad_len = 0; + if (min_length > 0) { if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) { zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx; } @@ -1892,6 +1908,8 @@ struct LLMEmbedder : public Conditioner { ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3); }); + // print_ggml_tensor(new_hidden_states); + int64_t t1 = ggml_time_ms(); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); return {new_hidden_states, nullptr, nullptr}; diff --git a/docs/ovis_image.md b/docs/ovis_image.md new file mode 100644 index 000000000..20e641a82 --- /dev/null +++ b/docs/ovis_image.md @@ -0,0 +1,19 @@ +# How to Use + +## Download weights + +- Download Ovis-Image-7B + - safetensors: https://huggingface.co/Comfy-Org/Ovis-Image/tree/main/split_files/diffusion_models + - gguf: https://huggingface.co/leejet/Ovis-Image-7B-GGUF +- Download vae + - safetensors: https://huggingface.co/black-forest-labs/FLUX.1-schnell/tree/main +- Download Ovis 2.5 + - safetensors: https://huggingface.co/Comfy-Org/Ovis-Image/tree/main/split_files/text_encoders + +## Examples + +``` +.\bin\Release\sd.exe --diffusion-model ovis_image-Q4_0.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\ovis_2.5.safetensors -p "a lovely cat" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa +``` + +ovis image example \ No newline at end of file diff --git a/flux.hpp b/flux.hpp index f0c65e3d7..1df2874ae 100644 --- a/flux.hpp +++ b/flux.hpp @@ -134,6 +134,54 @@ namespace Flux { } }; + struct MLP : public UnaryBlock { + bool use_mlp_silu_act; + + public: + MLP(int64_t hidden_size, int64_t intermediate_size, bool use_mlp_silu_act = false, bool bias = false) + : use_mlp_silu_act(use_mlp_silu_act) { + int64_t mlp_mult_factor = use_mlp_silu_act ? 2 : 1; + blocks["0"] = std::make_shared(hidden_size, intermediate_size * mlp_mult_factor, bias); + blocks["2"] = std::make_shared(intermediate_size, hidden_size, bias); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + auto mlp_0 = std::dynamic_pointer_cast(blocks["0"]); + auto mlp_2 = std::dynamic_pointer_cast(blocks["2"]); + + x = mlp_0->forward(ctx, x); + if (use_mlp_silu_act) { + x = ggml_ext_silu_act(ctx->ggml_ctx, x); + } else { + x = ggml_gelu_inplace(ctx->ggml_ctx, x); + } + x = mlp_2->forward(ctx, x); + return x; + } + }; + + struct YakMLP : public UnaryBlock { + public: + YakMLP(int64_t hidden_size, int64_t intermediate_size, bool bias = true) { + blocks["gate_proj"] = std::make_shared(hidden_size, intermediate_size, bias); + blocks["up_proj"] = std::make_shared(hidden_size, intermediate_size, bias); + blocks["down_proj"] = std::make_shared(intermediate_size, hidden_size, bias); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + auto gate_proj = std::dynamic_pointer_cast(blocks["gate_proj"]); + auto up_proj = std::dynamic_pointer_cast(blocks["up_proj"]); + auto down_proj = std::dynamic_pointer_cast(blocks["down_proj"]); + + auto gate = gate_proj->forward(ctx, x); + gate = ggml_silu_inplace(ctx->ggml_ctx, gate); + x = up_proj->forward(ctx, x); + x = ggml_mul(ctx->ggml_ctx, x, gate); + x = down_proj->forward(ctx, x); + return x; + } + }; + struct ModulationOut { ggml_tensor* shift = nullptr; ggml_tensor* scale = nullptr; @@ -199,7 +247,6 @@ namespace Flux { struct DoubleStreamBlock : public GGMLBlock { bool prune_mod; int idx = 0; - bool use_mlp_silu_act; public: DoubleStreamBlock(int64_t hidden_size, @@ -210,10 +257,10 @@ namespace Flux { bool prune_mod = false, bool share_modulation = false, bool mlp_proj_bias = true, + bool use_yak_mlp = false, bool use_mlp_silu_act = false) - : idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) { - int64_t mlp_hidden_dim = hidden_size * mlp_ratio; - int64_t mlp_mult_factor = use_mlp_silu_act ? 2 : 1; + : idx(idx), prune_mod(prune_mod) { + int64_t mlp_hidden_dim = hidden_size * mlp_ratio; if (!prune_mod && !share_modulation) { blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); @@ -222,9 +269,11 @@ namespace Flux { blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias)); - // img_mlp.1 is nn.GELU(approximate="tanh") - blocks["img_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size, mlp_proj_bias)); + if (use_yak_mlp) { + blocks["img_mlp"] = std::shared_ptr(new YakMLP(hidden_size, mlp_hidden_dim, mlp_proj_bias)); + } else { + blocks["img_mlp"] = std::shared_ptr(new MLP(hidden_size, mlp_hidden_dim, use_mlp_silu_act, mlp_proj_bias)); + } if (!prune_mod && !share_modulation) { blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); @@ -233,9 +282,11 @@ namespace Flux { blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias)); blocks["txt_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["txt_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias)); - // img_mlp.1 is nn.GELU(approximate="tanh") - blocks["txt_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size, mlp_proj_bias)); + if (use_yak_mlp) { + blocks["txt_mlp"] = std::shared_ptr(new YakMLP(hidden_size, mlp_hidden_dim, mlp_proj_bias)); + } else { + blocks["txt_mlp"] = std::shared_ptr(new MLP(hidden_size, mlp_hidden_dim, use_mlp_silu_act, mlp_proj_bias)); + } } std::vector get_distil_img_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) { @@ -272,15 +323,13 @@ namespace Flux { auto img_attn = std::dynamic_pointer_cast(blocks["img_attn"]); auto img_norm2 = std::dynamic_pointer_cast(blocks["img_norm2"]); - auto img_mlp_0 = std::dynamic_pointer_cast(blocks["img_mlp.0"]); - auto img_mlp_2 = std::dynamic_pointer_cast(blocks["img_mlp.2"]); + auto img_mlp = std::dynamic_pointer_cast(blocks["img_mlp"]); auto txt_norm1 = std::dynamic_pointer_cast(blocks["txt_norm1"]); auto txt_attn = std::dynamic_pointer_cast(blocks["txt_attn"]); auto txt_norm2 = std::dynamic_pointer_cast(blocks["txt_norm2"]); - auto txt_mlp_0 = std::dynamic_pointer_cast(blocks["txt_mlp.0"]); - auto txt_mlp_2 = std::dynamic_pointer_cast(blocks["txt_mlp.2"]); + auto txt_mlp = std::dynamic_pointer_cast(blocks["txt_mlp"]); if (img_mods.empty()) { if (prune_mod) { @@ -348,27 +397,15 @@ namespace Flux { // calculate the img bloks img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate)); - auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale)); - if (use_mlp_silu_act) { - img_mlp_out = ggml_ext_silu_act(ctx->ggml_ctx, img_mlp_out); - } else { - img_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, img_mlp_out); - } - img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out); + auto img_mlp_out = img_mlp->forward(ctx, Flux::modulate(ctx->ggml_ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale)); img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_mlp_out, img_mod2.gate)); // calculate the txt bloks txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate)); - auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale)); - if (use_mlp_silu_act) { - txt_mlp_out = ggml_ext_silu_act(ctx->ggml_ctx, txt_mlp_out); - } else { - txt_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, txt_mlp_out); - } - txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out); - txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_mod2.gate)); + auto txt_mlp_out = txt_mlp->forward(ctx, Flux::modulate(ctx->ggml_ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale)); + txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_mod2.gate)); return {img, txt}; } @@ -381,6 +418,7 @@ namespace Flux { int64_t mlp_hidden_dim; bool prune_mod; int idx = 0; + bool use_yak_mlp; bool use_mlp_silu_act; int64_t mlp_mult_factor; @@ -393,8 +431,9 @@ namespace Flux { bool prune_mod = false, bool share_modulation = false, bool mlp_proj_bias = true, + bool use_yak_mlp = false, bool use_mlp_silu_act = false) - : hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) { + : hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) { int64_t head_dim = hidden_size / num_heads; float scale = qk_scale; if (scale <= 0.f) { @@ -402,7 +441,7 @@ namespace Flux { } mlp_hidden_dim = hidden_size * mlp_ratio; mlp_mult_factor = 1; - if (use_mlp_silu_act) { + if (use_yak_mlp || use_mlp_silu_act) { mlp_mult_factor = 2; } @@ -481,7 +520,9 @@ namespace Flux { k = norm->key_norm(ctx, k); auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size] - if (use_mlp_silu_act) { + if (use_yak_mlp) { + mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false); + } else if (use_mlp_silu_act) { mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp); } else { mlp = ggml_gelu_inplace(ctx->ggml_ctx, mlp); @@ -726,6 +767,8 @@ namespace Flux { int64_t in_dim = 64; bool disable_bias = false; bool share_modulation = false; + bool semantic_txt_norm = false; + bool use_yak_mlp = false; bool use_mlp_silu_act = false; float ref_index_scale = 1.f; ChromaRadianceParams chroma_radiance_params; @@ -759,6 +802,9 @@ namespace Flux { blocks["guidance_in"] = std::make_shared(256, params.hidden_size, !params.disable_bias); } } + if (params.semantic_txt_norm) { + blocks["txt_norm"] = std::make_shared(params.context_in_dim); + } blocks["txt_in"] = std::make_shared(params.context_in_dim, params.hidden_size, !params.disable_bias); for (int i = 0; i < params.depth; i++) { @@ -770,6 +816,7 @@ namespace Flux { params.is_chroma, params.share_modulation, !params.disable_bias, + params.use_yak_mlp, params.use_mlp_silu_act); } @@ -782,6 +829,7 @@ namespace Flux { params.is_chroma, params.share_modulation, !params.disable_bias, + params.use_yak_mlp, params.use_mlp_silu_act); } @@ -948,6 +996,12 @@ namespace Flux { ss_mods = single_stream_modulation->forward(ctx, vec); } + if (params.semantic_txt_norm) { + auto semantic_txt_norm = std::dynamic_pointer_cast(blocks["txt_norm"]); + + txt = semantic_txt_norm->forward(ctx, txt); + } + txt = txt_in->forward(ctx, txt); for (int i = 0; i < params.depth; i++) { @@ -1206,6 +1260,11 @@ namespace Flux { } else if (version == VERSION_CHROMA_RADIANCE) { flux_params.in_channels = 3; flux_params.patch_size = 16; + } else if (version == VERSION_OVIS_IMAGE) { + flux_params.semantic_txt_norm = true; + flux_params.use_yak_mlp = true; + flux_params.context_in_dim = 2048; + flux_params.vec_in_dim = 0; } else if (sd_version_is_flux2(version)) { flux_params.context_in_dim = 15360; flux_params.in_channels = 128; @@ -1364,13 +1423,22 @@ namespace Flux { ref_latents[i] = to_backend(ref_latents[i]); } + std::set txt_arange_dims; + if (sd_version_is_flux2(version)) { + txt_arange_dims = {3}; + increase_ref_index = true; + } else if (version == VERSION_OVIS_IMAGE) { + txt_arange_dims = {1, 2}; + } + pe_vec = Rope::gen_flux_pe(x->ne[1], x->ne[0], flux_params.patch_size, x->ne[3], context->ne[1], + txt_arange_dims, ref_latents, - sd_version_is_flux2(version) ? true : increase_ref_index, + increase_ref_index, flux_params.ref_index_scale, flux_params.theta, flux_params.axes_dim); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index f74fadb6e..2b4ce5d85 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -760,17 +760,23 @@ __STATIC_INLINE__ std::vector ggml_ext_chunk(struct ggml_co return chunks; } -__STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor* x) { +__STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor* x, bool gate_first = true) { // x: [ne3, ne2, ne1, ne0] // return: [ne3, ne2, ne1, ne0/2] auto x_vec = ggml_ext_chunk(ctx, x, 2, 0); - auto x1 = x_vec[0]; // [ne3, ne2, ne1, ne0/2] - auto x2 = x_vec[1]; // [ne3, ne2, ne1, ne0/2] + ggml_tensor* gate; + if (gate_first) { + gate = x_vec[0]; + x = x_vec[1]; + } else { + x = x_vec[0]; + gate = x_vec[1]; + } - x1 = ggml_silu_inplace(ctx, x1); + gate = ggml_silu_inplace(ctx, gate); - x = ggml_mul(ctx, x1, x2); // [ne3, ne2, ne1, ne0/2] + x = ggml_mul(ctx, x, gate); // [ne3, ne2, ne1, ne0/2] return x; } diff --git a/llm.hpp b/llm.hpp index c42c56499..dc04c84cc 100644 --- a/llm.hpp +++ b/llm.hpp @@ -356,6 +356,10 @@ namespace LLM { "<|fim_pad|>", "<|repo_name|>", "<|file_sep|>", + "", + "", + "", + "", }; if (merges_utf8_str.size() > 0) { @@ -859,11 +863,11 @@ namespace LLM { } if (arch == LLMArch::MISTRAL_SMALL_3_2) { - q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 131072, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); - k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 131072, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); } else if (arch == LLMArch::QWEN3) { - q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 151936, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); - k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 151936, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); } else { int sections[4] = {16, 24, 24, 0}; q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); @@ -1073,29 +1077,22 @@ namespace LLM { : GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) { params.arch = arch; if (arch == LLMArch::MISTRAL_SMALL_3_2) { - params.num_layers = 40; - params.hidden_size = 5120; - params.intermediate_size = 32768; - params.head_dim = 128; - params.num_heads = 32; - params.num_kv_heads = 8; - params.qkv_bias = false; - params.vocab_size = 131072; - params.rms_norm_eps = 1e-5f; + params.head_dim = 128; + params.num_heads = 32; + params.num_kv_heads = 8; + params.qkv_bias = false; + params.rms_norm_eps = 1e-5f; } else if (arch == LLMArch::QWEN3) { - params.num_layers = 36; - params.hidden_size = 2560; - params.intermediate_size = 9728; - params.head_dim = 128; - params.num_heads = 32; - params.num_kv_heads = 8; - params.qkv_bias = false; - params.qk_norm = true; - params.vocab_size = 151936; - params.rms_norm_eps = 1e-6f; + params.head_dim = 128; + params.num_heads = 32; + params.num_kv_heads = 8; + params.qkv_bias = false; + params.qk_norm = true; + params.rms_norm_eps = 1e-6f; } bool have_vision_weight = false; bool llama_cpp_style = false; + params.num_layers = 0; for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; if (tensor_name.find(prefix) == std::string::npos) @@ -1105,10 +1102,36 @@ namespace LLM { have_vision_weight = true; if (contains(tensor_name, "attn.q_proj")) { llama_cpp_style = true; - break; } + continue; } + pos = tensor_name.find("layers."); + if (pos != std::string::npos) { + tensor_name = tensor_name.substr(pos); // remove prefix + auto items = split_string(tensor_name, '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + if (block_index + 1 > params.num_layers) { + params.num_layers = block_index + 1; + } + } + } + if (contains(tensor_name, "embed_tokens.weight")) { + params.hidden_size = pair.second.ne[0]; + params.vocab_size = pair.second.ne[1]; + } + if (contains(tensor_name, "layers.0.mlp.gate_proj.weight")) { + params.intermediate_size = pair.second.ne[1]; + } + } + if (arch == LLMArch::QWEN3 && params.num_layers == 28) { // Qwen3 2B + params.num_heads = 16; } + LOG_DEBUG("llm: num_layers = %" PRId64 ", vocab_size = %" PRId64 ", hidden_size = %" PRId64 ", intermediate_size = %" PRId64, + params.num_layers, + params.vocab_size, + params.hidden_size, + params.intermediate_size); if (enable_vision && !have_vision_weight) { LOG_WARN("no vision weights detected, vision disabled"); enable_vision = false; diff --git a/model.cpp b/model.cpp index b314139c2..0480efefb 100644 --- a/model.cpp +++ b/model.cpp @@ -1056,6 +1056,9 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) { return VERSION_FLUX2; } + if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) { + return VERSION_OVIS_IMAGE; + } if (tensor_storage.name.find("model.diffusion_model.cap_embedder.0.weight") != std::string::npos) { return VERSION_Z_IMAGE; } diff --git a/model.h b/model.h index 71a22a8f9..d38aee1c1 100644 --- a/model.h +++ b/model.h @@ -45,6 +45,7 @@ enum SDVersion { VERSION_QWEN_IMAGE, VERSION_FLUX2, VERSION_Z_IMAGE, + VERSION_OVIS_IMAGE, VERSION_COUNT, }; @@ -90,6 +91,7 @@ static inline bool sd_version_is_flux(SDVersion version) { version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2 || + version == VERSION_OVIS_IMAGE || version == VERSION_CHROMA_RADIANCE) { return true; } diff --git a/rope.hpp b/rope.hpp index 7a35926eb..4abc51469 100644 --- a/rope.hpp +++ b/rope.hpp @@ -72,11 +72,13 @@ namespace Rope { } // Generate IDs for image patches and text - __STATIC_INLINE__ std::vector> gen_flux_txt_ids(int bs, int context_len, int axes_dim_num) { + __STATIC_INLINE__ std::vector> gen_flux_txt_ids(int bs, int context_len, int axes_dim_num, std::set arange_dims) { auto txt_ids = std::vector>(bs * context_len, std::vector(axes_dim_num, 0.0f)); - if (axes_dim_num == 4) { - for (int i = 0; i < bs * context_len; i++) { - txt_ids[i][3] = (i % context_len); + for (int dim = 0; dim < axes_dim_num; dim++) { + if (arange_dims.find(dim) != arange_dims.end()) { + for (int i = 0; i < bs * context_len; i++) { + txt_ids[i][dim] = (i % context_len); + } } } return txt_ids; @@ -211,10 +213,11 @@ namespace Rope { int bs, int axes_dim_num, int context_len, + std::set txt_arange_dims, const std::vector& ref_latents, bool increase_ref_index, float ref_index_scale) { - auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num); + auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num); auto ids = concat_ids(txt_ids, img_ids, bs); @@ -231,6 +234,7 @@ namespace Rope { int patch_size, int bs, int context_len, + std::set txt_arange_dims, const std::vector& ref_latents, bool increase_ref_index, float ref_index_scale, @@ -242,6 +246,7 @@ namespace Rope { bs, static_cast(axes_dim.size()), context_len, + txt_arange_dims, ref_latents, increase_ref_index, ref_index_scale); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index fe2a26ca3..98c0c84de 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -46,6 +46,7 @@ const char* model_version_to_str[] = { "Qwen Image", "Flux.2", "Z-Image", + "Ovis Image", }; const char* sampling_methods_str[] = { @@ -424,6 +425,13 @@ class StableDiffusionGGML { tensor_storage_map, sd_ctx_params->chroma_use_t5_mask, sd_ctx_params->chroma_t5_mask_pad); + } else if (version == VERSION_OVIS_IMAGE) { + cond_stage_model = std::make_shared(clip_backend, + offload_params_to_cpu, + tensor_storage_map, + version, + "", + false); } else { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, @@ -690,6 +698,11 @@ class StableDiffusionGGML { ignore_tensors.insert("first_stage_model.quant"); ignore_tensors.insert("text_encoders.llm.visual."); } + if (version == VERSION_OVIS_IMAGE) { + ignore_tensors.insert("text_encoders.llm.vision_model."); + ignore_tensors.insert("text_encoders.llm.visual_tokenizer."); + ignore_tensors.insert("text_encoders.llm.vte."); + } if (version == VERSION_SVD) { ignore_tensors.insert("conditioner.embedders.3"); }