Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1825,6 +1825,17 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = prompt.size();

prompt += "[/INST]";
} else if (sd_version_is_longcat(version)) {
prompt_template_encode_start_idx = 36;
// prompt_template_encode_end_idx = 5;

prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n";

prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());

prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else {
prompt_template_encode_start_idx = 34;

Expand Down
71 changes: 53 additions & 18 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,19 @@ namespace Flux {

public:
SelfAttention(int64_t dim,
int64_t num_heads = 8,
bool qkv_bias = false,
bool proj_bias = true)
int64_t num_heads = 8,
bool qkv_bias = false,
bool proj_bias = true,
bool diffusers_style = false)
: num_heads(num_heads) {
int64_t head_dim = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim, proj_bias));
if (diffusers_style) {
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new SplitLinear(dim, {dim, dim, dim}, qkv_bias));
} else {
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
}
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim, proj_bias));
}

std::vector<struct ggml_tensor*> pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
Expand Down Expand Up @@ -210,7 +215,8 @@ namespace Flux {
bool prune_mod = false,
bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_mlp_silu_act = false)
bool use_mlp_silu_act = false,
bool diffusers_style = false)
: idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) {
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
int64_t mlp_mult_factor = use_mlp_silu_act ? 2 : 1;
Expand All @@ -219,7 +225,7 @@ namespace Flux {
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
}
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style));

blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
Expand All @@ -230,7 +236,7 @@ namespace Flux {
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
}
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style));

blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
Expand Down Expand Up @@ -383,6 +389,7 @@ namespace Flux {
int idx = 0;
bool use_mlp_silu_act;
int64_t mlp_mult_factor;
bool diffusers_style = false;

public:
SingleStreamBlock(int64_t hidden_size,
Expand All @@ -393,7 +400,8 @@ namespace Flux {
bool prune_mod = false,
bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_mlp_silu_act = false)
bool use_mlp_silu_act = false,
bool diffusers_style = false)
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) {
int64_t head_dim = hidden_size / num_heads;
float scale = qk_scale;
Expand All @@ -405,8 +413,11 @@ namespace Flux {
if (use_mlp_silu_act) {
mlp_mult_factor = 2;
}

blocks["linear1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
if (diffusers_style) {
blocks["linear1"] = std::shared_ptr<GGMLBlock>(new SplitLinear(hidden_size, {hidden_size, hidden_size, hidden_size, mlp_hidden_dim * mlp_mult_factor}, mlp_proj_bias));
} else {
blocks["linear1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
}
blocks["linear2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size + mlp_hidden_dim, hidden_size, mlp_proj_bias));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["pre_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
Expand Down Expand Up @@ -728,6 +739,7 @@ namespace Flux {
bool share_modulation = false;
bool use_mlp_silu_act = false;
float ref_index_scale = 1.f;
bool diffusers_style = false;
ChromaRadianceParams chroma_radiance_params;
};

Expand Down Expand Up @@ -770,7 +782,8 @@ namespace Flux {
params.is_chroma,
params.share_modulation,
!params.disable_bias,
params.use_mlp_silu_act);
params.use_mlp_silu_act,
params.diffusers_style);
}

for (int i = 0; i < params.depth_single_blocks; i++) {
Expand All @@ -782,7 +795,8 @@ namespace Flux {
params.is_chroma,
params.share_modulation,
!params.disable_bias,
params.use_mlp_silu_act);
params.use_mlp_silu_act,
params.diffusers_style);
}

if (params.version == VERSION_CHROMA_RADIANCE) {
Expand Down Expand Up @@ -829,6 +843,11 @@ namespace Flux {
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
if (params.patch_size == 1) {
x = ggml_reshape_3d(ctx, x, H * W, C, N); // [N, C, H*W]
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, H*W, C]
return x;
}
int64_t p = params.patch_size;
int64_t h = H / params.patch_size;
int64_t w = W / params.patch_size;
Expand Down Expand Up @@ -863,6 +882,12 @@ namespace Flux {
int64_t W = w * params.patch_size;
int64_t p = params.patch_size;

if (params.patch_size == 1) {
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, C, H*W]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, H, W]
return x;
}

GGML_ASSERT(C * p * p == x->ne[0]);

x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]
Expand Down Expand Up @@ -1222,6 +1247,10 @@ namespace Flux {
flux_params.share_modulation = true;
flux_params.ref_index_scale = 10.f;
flux_params.use_mlp_silu_act = true;
} else if (sd_version_is_longcat(version)) {
flux_params.context_in_dim = 3584;
flux_params.vec_in_dim = 0;
flux_params.patch_size = 1;
}
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
Expand All @@ -1231,6 +1260,9 @@ namespace Flux {
// not schnell
flux_params.guidance_embed = true;
}
if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") != std::string::npos) {
flux_params.diffusers_style = true;
}
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
// Chroma
flux_params.is_chroma = true;
Expand Down Expand Up @@ -1260,6 +1292,10 @@ namespace Flux {
LOG_INFO("Flux guidance is disabled (Schnell mode)");
}

if (flux_params.diffusers_style) {
LOG_INFO("Using diffusers-style attention blocks");
}

flux = Flux(flux_params);
flux.init(params_ctx, tensor_storage_map, prefix);
}
Expand Down Expand Up @@ -1363,7 +1399,6 @@ namespace Flux {
for (int i = 0; i < ref_latents.size(); i++) {
ref_latents[i] = to_backend(ref_latents[i]);
}

pe_vec = Rope::gen_flux_pe(x->ne[1],
x->ne[0],
flux_params.patch_size,
Expand All @@ -1373,10 +1408,10 @@ namespace Flux {
sd_version_is_flux2(version) ? true : increase_ref_index,
flux_params.ref_index_scale,
flux_params.theta,
flux_params.axes_dim);
flux_params.axes_dim,
sd_version_is_longcat(version));
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);
// pe->data = nullptr;
Expand Down
77 changes: 77 additions & 0 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,83 @@ class Linear : public UnaryBlock {
}
};

class SplitLinear : public Linear {
protected:
int64_t in_features;
std::vector<int64_t> out_features_vec;
bool bias;
bool force_f32;
bool force_prec_f32;
float scale;
std::string prefix;

void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
this->prefix = prefix;
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
wtype = GGML_TYPE_F32;
}
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[0]);
for (int i = 1; i < out_features_vec.size(); i++) {
// most likely same type as the first weight
params["weight." + std::to_string(i)] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[i]);
}
if (bias) {
enum ggml_type wtype = GGML_TYPE_F32;
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[0]);
for (int i = 1; i < out_features_vec.size(); i++) {
params["bias." + std::to_string(i)] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[i]);
}
}
}

public:
SplitLinear(int64_t in_features,
std::vector<int64_t> out_features_vec,
bool bias = true,
bool force_f32 = false,
bool force_prec_f32 = false,
float scale = 1.f)
: Linear(in_features, out_features_vec[0], bias, force_f32, force_prec_f32, scale),
in_features(in_features),
out_features_vec(out_features_vec),
bias(bias),
force_f32(force_f32),
force_prec_f32(force_prec_f32),
scale(scale) {}

struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = nullptr;
if (bias) {
b = params["bias"];
}
if (ctx->weight_adapter) {
// concat all weights and biases together so it runs in one linear layer
for (int i = 1; i < out_features_vec.size(); i++) {
w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1);
if (bias) {
b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0);
}
}
WeightAdapter::ForwardParams forward_params;
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR;
forward_params.linear.force_prec_f32 = force_prec_f32;
forward_params.linear.scale = scale;
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
}
auto x0 = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
for (int i = 1; i < out_features_vec.size(); i++) {
auto wi = params["weight." + std::to_string(i)];
auto bi = bias ? params["bias." + std::to_string(i)] : nullptr;
auto xi = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale);
x0 = ggml_concat(ctx->ggml_ctx, x0, xi, 0);
}

return x0;
}
};

__STATIC_INLINE__ bool support_get_rows(ggml_type wtype) {
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
if (allow_types.find(wtype) != allow_types.end()) {
Expand Down
Loading
Loading