Skip to content

Commit 6068423

Browse files
committed
support Flux.2 patched latents for proj preview
1 parent 1888672 commit 6068423

File tree

1 file changed

+123
-2
lines changed

1 file changed

+123
-2
lines changed

stable-diffusion.cpp

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,8 +1310,72 @@ class StableDiffusionGGML {
13101310
uint32_t dim = latents->ne[ggml_n_dims(latents) - 1];
13111311

13121312
if (preview_mode == PREVIEW_PROJ) {
1313-
const float(*latent_rgb_proj)[channel] = nullptr;
1314-
float* latent_rgb_bias = nullptr;
1313+
int64_t patch_sz = 1;
1314+
if (sd_version_is_flux2(version)) {
1315+
patch_sz = 2;
1316+
}
1317+
if (patch_sz != 1) {
1318+
// unshuffle latents
1319+
const int64_t N = latents->ne[3];
1320+
const int64_t C_in = latents->ne[2];
1321+
const int64_t H_in = latents->ne[1];
1322+
const int64_t W_in = latents->ne[0];
1323+
1324+
const int64_t C_out = C_in / (patch_sz * patch_sz);
1325+
const int64_t H_out = H_in * patch_sz;
1326+
const int64_t W_out = W_in * patch_sz;
1327+
1328+
const char* src_ptr = (char*)latents->data;
1329+
size_t elem_size = latents->nb[0];
1330+
1331+
std::vector<char> dst_buffer(N * C_out * H_out * W_out * elem_size);
1332+
char* dst_base = dst_buffer.data();
1333+
1334+
size_t dst_stride_w = elem_size;
1335+
size_t dst_stride_h = dst_stride_w * W_out;
1336+
size_t dst_stride_c = dst_stride_h * H_out;
1337+
size_t dst_stride_n = dst_stride_c * C_out;
1338+
1339+
size_t dst_step_w = dst_stride_w * patch_sz;
1340+
size_t dst_step_h = dst_stride_h * patch_sz;
1341+
1342+
for (int64_t n = 0; n < N; ++n) {
1343+
for (int64_t c = 0; c < C_in; ++c) {
1344+
int64_t c_out = c / (patch_sz * patch_sz);
1345+
int64_t rem = c % (patch_sz * patch_sz);
1346+
int64_t py = rem / patch_sz;
1347+
int64_t px = rem % patch_sz;
1348+
1349+
char* dst_layer = dst_base + n * dst_stride_n + c_out * dst_stride_c + py * dst_stride_h + px * dst_stride_w;
1350+
1351+
for (int64_t y = 0; y < H_in; ++y) {
1352+
char* dst_row = dst_layer + y * dst_step_h;
1353+
1354+
for (int64_t x = 0; x < W_in; ++x) {
1355+
memcpy(dst_row + x * dst_step_w, src_ptr, elem_size);
1356+
src_ptr += elem_size;
1357+
}
1358+
}
1359+
}
1360+
}
1361+
1362+
memcpy(latents->data, dst_buffer.data(), dst_buffer.size());
1363+
1364+
latents->ne[0] = W_out;
1365+
latents->ne[1] = H_out;
1366+
latents->ne[2] = C_out;
1367+
1368+
latents->nb[0] = dst_stride_w;
1369+
latents->nb[1] = dst_stride_h;
1370+
latents->nb[2] = dst_stride_c;
1371+
latents->nb[3] = dst_stride_n;
1372+
1373+
width = W_out;
1374+
height = H_out;
1375+
dim = C_out;
1376+
}
1377+
const float (*latent_rgb_proj)[channel] = nullptr;
1378+
float* latent_rgb_bias = nullptr;
13151379

13161380
if (dim == 48) {
13171381
if (sd_version_is_wan(version)) {
@@ -1381,6 +1445,63 @@ class StableDiffusionGGML {
13811445
step_callback(step, frames, images, is_noisy);
13821446
free(data);
13831447
free(images);
1448+
1449+
if (patch_sz != 1) {
1450+
// restore shuffled latents
1451+
const int64_t N = latents->ne[3];
1452+
const int64_t C_in = latents->ne[2];
1453+
const int64_t H_in = latents->ne[1];
1454+
const int64_t W_in = latents->ne[0];
1455+
1456+
const int64_t C_out = C_in * patch_sz * patch_sz;
1457+
const int64_t H_out = H_in / patch_sz;
1458+
const int64_t W_out = W_in / patch_sz;
1459+
1460+
const char* src_base = (char*)latents->data;
1461+
const size_t elem_size = latents->nb[0];
1462+
1463+
const size_t src_stride_w = latents->nb[0];
1464+
const size_t src_stride_h = latents->nb[1];
1465+
const size_t src_stride_c = latents->nb[2];
1466+
const size_t src_stride_n = latents->nb[3];
1467+
1468+
std::vector<char> dst_buffer(N * C_out * H_out * W_out * elem_size);
1469+
char* dst_ptr = dst_buffer.data();
1470+
1471+
const size_t src_step_h = src_stride_h * patch_sz;
1472+
const size_t src_step_w = src_stride_w * patch_sz;
1473+
1474+
for (int64_t n = 0; n < N; ++n) {
1475+
for (int64_t c = 0; c < C_out; ++c) {
1476+
int64_t c_rem = c % (patch_sz * patch_sz);
1477+
int64_t c_in = c / (patch_sz * patch_sz);
1478+
int64_t py = c_rem / patch_sz;
1479+
int64_t px = c_rem % patch_sz;
1480+
1481+
const char* src_layer = src_base + n * src_stride_n + c_in * src_stride_c + py * src_stride_h + px * src_stride_w;
1482+
1483+
for (int64_t y = 0; y < H_out; ++y) {
1484+
const char* src_row = src_layer + y * src_step_h;
1485+
1486+
for (int64_t x = 0; x < W_out; ++x) {
1487+
memcpy(dst_ptr, src_row + x * src_step_w, elem_size);
1488+
dst_ptr += elem_size;
1489+
}
1490+
}
1491+
}
1492+
}
1493+
1494+
memcpy(latents->data, dst_buffer.data(), dst_buffer.size());
1495+
1496+
latents->ne[0] = W_out;
1497+
latents->ne[1] = H_out;
1498+
latents->ne[2] = C_out;
1499+
1500+
latents->nb[0] = elem_size;
1501+
latents->nb[1] = latents->nb[0] * W_out;
1502+
latents->nb[2] = latents->nb[1] * H_out;
1503+
latents->nb[3] = latents->nb[2] * C_out;
1504+
}
13841505
} else {
13851506
if (preview_mode == PREVIEW_VAE) {
13861507
process_latent_out(latents);

0 commit comments

Comments
 (0)