@@ -1310,8 +1310,72 @@ class StableDiffusionGGML {
13101310 uint32_t dim = latents->ne [ggml_n_dims (latents) - 1 ];
13111311
13121312 if (preview_mode == PREVIEW_PROJ) {
1313- const float (*latent_rgb_proj)[channel] = nullptr ;
1314- float * latent_rgb_bias = nullptr ;
1313+ int64_t patch_sz = 1 ;
1314+ if (sd_version_is_flux2 (version)) {
1315+ patch_sz = 2 ;
1316+ }
1317+ if (patch_sz != 1 ) {
1318+ // unshuffle latents
1319+ const int64_t N = latents->ne [3 ];
1320+ const int64_t C_in = latents->ne [2 ];
1321+ const int64_t H_in = latents->ne [1 ];
1322+ const int64_t W_in = latents->ne [0 ];
1323+
1324+ const int64_t C_out = C_in / (patch_sz * patch_sz);
1325+ const int64_t H_out = H_in * patch_sz;
1326+ const int64_t W_out = W_in * patch_sz;
1327+
1328+ const char * src_ptr = (char *)latents->data ;
1329+ size_t elem_size = latents->nb [0 ];
1330+
1331+ std::vector<char > dst_buffer (N * C_out * H_out * W_out * elem_size);
1332+ char * dst_base = dst_buffer.data ();
1333+
1334+ size_t dst_stride_w = elem_size;
1335+ size_t dst_stride_h = dst_stride_w * W_out;
1336+ size_t dst_stride_c = dst_stride_h * H_out;
1337+ size_t dst_stride_n = dst_stride_c * C_out;
1338+
1339+ size_t dst_step_w = dst_stride_w * patch_sz;
1340+ size_t dst_step_h = dst_stride_h * patch_sz;
1341+
1342+ for (int64_t n = 0 ; n < N; ++n) {
1343+ for (int64_t c = 0 ; c < C_in; ++c) {
1344+ int64_t c_out = c / (patch_sz * patch_sz);
1345+ int64_t rem = c % (patch_sz * patch_sz);
1346+ int64_t py = rem / patch_sz;
1347+ int64_t px = rem % patch_sz;
1348+
1349+ char * dst_layer = dst_base + n * dst_stride_n + c_out * dst_stride_c + py * dst_stride_h + px * dst_stride_w;
1350+
1351+ for (int64_t y = 0 ; y < H_in; ++y) {
1352+ char * dst_row = dst_layer + y * dst_step_h;
1353+
1354+ for (int64_t x = 0 ; x < W_in; ++x) {
1355+ memcpy (dst_row + x * dst_step_w, src_ptr, elem_size);
1356+ src_ptr += elem_size;
1357+ }
1358+ }
1359+ }
1360+ }
1361+
1362+ memcpy (latents->data , dst_buffer.data (), dst_buffer.size ());
1363+
1364+ latents->ne [0 ] = W_out;
1365+ latents->ne [1 ] = H_out;
1366+ latents->ne [2 ] = C_out;
1367+
1368+ latents->nb [0 ] = dst_stride_w;
1369+ latents->nb [1 ] = dst_stride_h;
1370+ latents->nb [2 ] = dst_stride_c;
1371+ latents->nb [3 ] = dst_stride_n;
1372+
1373+ width = W_out;
1374+ height = H_out;
1375+ dim = C_out;
1376+ }
1377+ const float (*latent_rgb_proj)[channel] = nullptr ;
1378+ float * latent_rgb_bias = nullptr ;
13151379
13161380 if (dim == 48 ) {
13171381 if (sd_version_is_wan (version)) {
@@ -1381,6 +1445,63 @@ class StableDiffusionGGML {
13811445 step_callback (step, frames, images, is_noisy);
13821446 free (data);
13831447 free (images);
1448+
1449+ if (patch_sz != 1 ) {
1450+ // restore shuffled latents
1451+ const int64_t N = latents->ne [3 ];
1452+ const int64_t C_in = latents->ne [2 ];
1453+ const int64_t H_in = latents->ne [1 ];
1454+ const int64_t W_in = latents->ne [0 ];
1455+
1456+ const int64_t C_out = C_in * patch_sz * patch_sz;
1457+ const int64_t H_out = H_in / patch_sz;
1458+ const int64_t W_out = W_in / patch_sz;
1459+
1460+ const char * src_base = (char *)latents->data ;
1461+ const size_t elem_size = latents->nb [0 ];
1462+
1463+ const size_t src_stride_w = latents->nb [0 ];
1464+ const size_t src_stride_h = latents->nb [1 ];
1465+ const size_t src_stride_c = latents->nb [2 ];
1466+ const size_t src_stride_n = latents->nb [3 ];
1467+
1468+ std::vector<char > dst_buffer (N * C_out * H_out * W_out * elem_size);
1469+ char * dst_ptr = dst_buffer.data ();
1470+
1471+ const size_t src_step_h = src_stride_h * patch_sz;
1472+ const size_t src_step_w = src_stride_w * patch_sz;
1473+
1474+ for (int64_t n = 0 ; n < N; ++n) {
1475+ for (int64_t c = 0 ; c < C_out; ++c) {
1476+ int64_t c_rem = c % (patch_sz * patch_sz);
1477+ int64_t c_in = c / (patch_sz * patch_sz);
1478+ int64_t py = c_rem / patch_sz;
1479+ int64_t px = c_rem % patch_sz;
1480+
1481+ const char * src_layer = src_base + n * src_stride_n + c_in * src_stride_c + py * src_stride_h + px * src_stride_w;
1482+
1483+ for (int64_t y = 0 ; y < H_out; ++y) {
1484+ const char * src_row = src_layer + y * src_step_h;
1485+
1486+ for (int64_t x = 0 ; x < W_out; ++x) {
1487+ memcpy (dst_ptr, src_row + x * src_step_w, elem_size);
1488+ dst_ptr += elem_size;
1489+ }
1490+ }
1491+ }
1492+ }
1493+
1494+ memcpy (latents->data , dst_buffer.data (), dst_buffer.size ());
1495+
1496+ latents->ne [0 ] = W_out;
1497+ latents->ne [1 ] = H_out;
1498+ latents->ne [2 ] = C_out;
1499+
1500+ latents->nb [0 ] = elem_size;
1501+ latents->nb [1 ] = latents->nb [0 ] * W_out;
1502+ latents->nb [2 ] = latents->nb [1 ] * H_out;
1503+ latents->nb [3 ] = latents->nb [2 ] * C_out;
1504+ }
13841505 } else {
13851506 if (preview_mode == PREVIEW_VAE) {
13861507 process_latent_out (latents);
0 commit comments