Skip to content

Commit 66e27de

Browse files
committed
add flux2 support
1 parent 2034588 commit 66e27de

File tree

12 files changed

+489569
-521
lines changed

12 files changed

+489569
-521
lines changed

conditioner.hpp

Lines changed: 91 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,61 +1623,72 @@ struct T5CLIPEmbedder : public Conditioner {
16231623
}
16241624
};
16251625

1626-
struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
1627-
Qwen::Qwen2Tokenizer tokenizer;
1628-
std::shared_ptr<Qwen::Qwen2_5_VLRunner> qwenvl;
1629-
1630-
Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend,
1631-
bool offload_params_to_cpu,
1632-
const String2TensorStorage& tensor_storage_map = {},
1633-
const std::string prefix = "",
1634-
bool enable_vision = false) {
1635-
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend,
1636-
offload_params_to_cpu,
1637-
tensor_storage_map,
1638-
"text_encoders.qwen2vl",
1639-
enable_vision);
1626+
struct LLMEmbedder : public Conditioner {
1627+
SDVersion version;
1628+
std::shared_ptr<LLM::BPETokenizer> tokenizer;
1629+
std::shared_ptr<LLM::LLMRunner> llm;
1630+
1631+
LLMEmbedder(ggml_backend_t backend,
1632+
bool offload_params_to_cpu,
1633+
const String2TensorStorage& tensor_storage_map = {},
1634+
SDVersion version = VERSION_QWEN_IMAGE,
1635+
const std::string prefix = "",
1636+
bool enable_vision = false)
1637+
: version(version) {
1638+
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
1639+
if (sd_version_is_flux2(version)) {
1640+
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
1641+
}
1642+
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
1643+
tokenizer = std::make_shared<LLM::MistralTokenizer>();
1644+
} else {
1645+
tokenizer = std::make_shared<LLM::Qwen2Tokenizer>();
1646+
}
1647+
llm = std::make_shared<LLM::LLMRunner>(arch,
1648+
backend,
1649+
offload_params_to_cpu,
1650+
tensor_storage_map,
1651+
"text_encoders.qwen2vl",
1652+
enable_vision);
16401653
}
16411654

16421655
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) override {
1643-
qwenvl->get_param_tensors(tensors, "text_encoders.qwen2vl");
1656+
llm->get_param_tensors(tensors, "text_encoders.qwen2vl");
16441657
}
16451658

16461659
void alloc_params_buffer() override {
1647-
qwenvl->alloc_params_buffer();
1660+
llm->alloc_params_buffer();
16481661
}
16491662

16501663
void free_params_buffer() override {
1651-
qwenvl->free_params_buffer();
1664+
llm->free_params_buffer();
16521665
}
16531666

16541667
size_t get_params_buffer_size() override {
16551668
size_t buffer_size = 0;
1656-
buffer_size += qwenvl->get_params_buffer_size();
1669+
buffer_size += llm->get_params_buffer_size();
16571670
return buffer_size;
16581671
}
16591672

16601673
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
1661-
if (qwenvl) {
1662-
qwenvl->set_weight_adapter(adapter);
1674+
if (llm) {
1675+
llm->set_weight_adapter(adapter);
16631676
}
16641677
}
16651678

16661679
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
1667-
size_t max_length = 0,
1668-
size_t system_prompt_length = 0,
1669-
bool padding = false) {
1680+
std::pair<int, int> attn_range,
1681+
size_t max_length = 0,
1682+
bool padding = false) {
16701683
std::vector<std::pair<std::string, float>> parsed_attention;
1671-
if (system_prompt_length > 0) {
1672-
parsed_attention.emplace_back(text.substr(0, system_prompt_length), 1.f);
1673-
auto new_parsed_attention = parse_prompt_attention(text.substr(system_prompt_length, text.size() - system_prompt_length));
1684+
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
1685+
if (attn_range.second - attn_range.first > 0) {
1686+
auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first));
16741687
parsed_attention.insert(parsed_attention.end(),
16751688
new_parsed_attention.begin(),
16761689
new_parsed_attention.end());
1677-
} else {
1678-
parsed_attention = parse_prompt_attention(text);
16791690
}
1680-
1691+
parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
16811692
{
16821693
std::stringstream ss;
16831694
ss << "[";
@@ -1693,12 +1704,12 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
16931704
for (const auto& item : parsed_attention) {
16941705
const std::string& curr_text = item.first;
16951706
float curr_weight = item.second;
1696-
std::vector<int> curr_tokens = tokenizer.tokenize(curr_text, nullptr);
1707+
std::vector<int> curr_tokens = tokenizer->tokenize(curr_text, nullptr);
16971708
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
16981709
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
16991710
}
17001711

1701-
tokenizer.pad_tokens(tokens, weights, max_length, padding);
1712+
tokenizer->pad_tokens(tokens, weights, max_length, padding);
17021713

17031714
// for (int i = 0; i < tokens.size(); i++) {
17041715
// std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl;
@@ -1713,9 +1724,10 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
17131724
const ConditionerParams& conditioner_params) override {
17141725
std::string prompt;
17151726
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
1716-
size_t system_prompt_length = 0;
1727+
std::pair<int, int> prompt_attn_range;
17171728
int prompt_template_encode_start_idx = 34;
1718-
if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) {
1729+
std::set<int> out_layers;
1730+
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
17191731
LOG_INFO("QwenImageEditPlusPipeline");
17201732
prompt_template_encode_start_idx = 64;
17211733
int image_embed_idx = 64 + 6;
@@ -1727,7 +1739,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
17271739

17281740
for (int i = 0; i < conditioner_params.ref_images.size(); i++) {
17291741
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]);
1730-
double factor = qwenvl->params.vision.patch_size * qwenvl->params.vision.spatial_merge_size;
1742+
double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size;
17311743
int height = image.height;
17321744
int width = image.width;
17331745
int h_bar = static_cast<int>(std::round(height / factor)) * factor;
@@ -1757,7 +1769,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
17571769
resized_image.data = nullptr;
17581770

17591771
ggml_tensor* image_embed = nullptr;
1760-
qwenvl->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
1772+
llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
17611773
image_embeds.emplace_back(image_embed_idx, image_embed);
17621774
image_embed_idx += 1 + image_embed->ne[1] + 6;
17631775

@@ -1771,17 +1783,37 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
17711783
}
17721784

17731785
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
1774-
1775-
system_prompt_length = prompt.size();
1776-
17771786
prompt += img_prompt;
1787+
1788+
prompt_attn_range.first = prompt.size();
17781789
prompt += conditioner_params.text;
1790+
prompt_attn_range.second = prompt.size();
1791+
17791792
prompt += "<|im_end|>\n<|im_start|>assistant\n";
1793+
} else if (sd_version_is_flux2(version)) {
1794+
prompt_template_encode_start_idx = 0;
1795+
out_layers = {10, 20, 30};
1796+
1797+
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
1798+
1799+
prompt_attn_range.first = prompt.size();
1800+
prompt += conditioner_params.text;
1801+
prompt_attn_range.second = prompt.size();
1802+
1803+
prompt += "[/INST]";
17801804
} else {
1781-
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n";
1805+
prompt_template_encode_start_idx = 34;
1806+
1807+
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
1808+
1809+
prompt_attn_range.first = prompt.size();
1810+
prompt += conditioner_params.text;
1811+
prompt_attn_range.second = prompt.size();
1812+
1813+
prompt += "<|im_end|>\n<|im_start|>assistant\n";
17821814
}
17831815

1784-
auto tokens_and_weights = tokenize(prompt, 0, system_prompt_length, false);
1816+
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
17851817
auto& tokens = std::get<0>(tokens_and_weights);
17861818
auto& weights = std::get<1>(tokens_and_weights);
17871819

@@ -1790,11 +1822,12 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
17901822

17911823
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
17921824

1793-
qwenvl->compute(n_threads,
1794-
input_ids,
1795-
image_embeds,
1796-
&hidden_states,
1797-
work_ctx);
1825+
llm->compute(n_threads,
1826+
input_ids,
1827+
image_embeds,
1828+
out_layers,
1829+
&hidden_states,
1830+
work_ctx);
17981831
{
17991832
auto tensor = hidden_states;
18001833
float original_mean = ggml_ext_tensor_mean(tensor);
@@ -1813,14 +1846,25 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
18131846

18141847
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
18151848

1849+
int64_t zero_pad_len = 0;
1850+
if (sd_version_is_flux2(version)) {
1851+
int64_t min_length = 512;
1852+
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
1853+
zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx;
1854+
}
1855+
}
1856+
18161857
ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
18171858
GGML_TYPE_F32,
18181859
hidden_states->ne[0],
1819-
hidden_states->ne[1] - prompt_template_encode_start_idx,
1860+
hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len,
18201861
hidden_states->ne[2]);
18211862

18221863
ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
1823-
float value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
1864+
float value = 0.f;
1865+
if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) {
1866+
value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
1867+
}
18241868
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
18251869
});
18261870

0 commit comments

Comments
 (0)