Skip to content

Commit 625ac6f

Browse files
authored
Merge pull request #14 from Aatricks/cond-dynamic-width-scaling-451456543888025473
Dynamic Width Scaling in Condition Encoding
2 parents 74476ce + d6b7d6c commit 625ac6f

2 files changed

Lines changed: 14 additions & 4 deletions

File tree

src/Utilities/Latent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ class LatentFormat:
1515

1616
scale_factor: float = 1.0
1717
latent_channels: int = 4
18-
18+
downscale_factor: int = 8
19+
1920
def process_in(self, latent: torch.Tensor) -> torch.Tensor:
2021
"""#### Process the latent input, by multiplying it by the scale factor.
2122

src/cond/cond.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -588,10 +588,19 @@ def encode_model_conds(
588588
params["device"] = device
589589
params["noise"] = noise
590590
default_width = None
591-
if len(noise.shape) >= 4: # TODO: 8 multiple should be set by the model
592-
default_width = noise.shape[3] * 8
591+
592+
downscale_factor = 8
593+
if hasattr(model_function, "__self__"):
594+
model = model_function.__self__
595+
if hasattr(model, "latent_format") and hasattr(
596+
model.latent_format, "downscale_factor"
597+
):
598+
downscale_factor = model.latent_format.downscale_factor
599+
600+
if len(noise.shape) >= 4:
601+
default_width = noise.shape[3] * downscale_factor
593602
params["width"] = params.get("width", default_width)
594-
params["height"] = params.get("height", noise.shape[2] * 8)
603+
params["height"] = params.get("height", noise.shape[2] * downscale_factor)
595604
params["prompt_type"] = params.get("prompt_type", prompt_type)
596605
for k in kwargs:
597606
if k not in params:

0 commit comments

Comments
 (0)