From 74f0b39b862f36a416befcccdd19e7a724b8d5ca Mon Sep 17 00:00:00 2001 From: Akif Faruk Nane Date: Sun, 21 Sep 2025 20:58:22 +0200 Subject: [PATCH 1/9] Update linear_fns.py --- inference_lib/src/fp_quant/module/linear_fns.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/inference_lib/src/fp_quant/module/linear_fns.py b/inference_lib/src/fp_quant/module/linear_fns.py index 20679a1..a7b1c4f 100644 --- a/inference_lib/src/fp_quant/module/linear_fns.py +++ b/inference_lib/src/fp_quant/module/linear_fns.py @@ -40,7 +40,7 @@ def _(x_flat, hadamard_matrix, forward_method): @torch.library.custom_op("fp_quant::matmul_mxf4_bf16_tn_op", mutates_args=()) def matmul_mxf4_bf16_tn_op( - x: torch.Tensor, w: torch.Tensor, xs: torch.Tensor, ws: torch.Tensor, alpha: float + x: torch.Tensor, w: torch.Tensor, xs: torch.Tensor, ws: torch.Tensor, alpha: torch.Tensor ) -> torch.Tensor: return matmul_mxf4_bf16_tn( x, w, to_blocked(xs), to_blocked(ws).view(torch.float8_e8m0fnu), alpha @@ -54,7 +54,7 @@ def _(x, w, xs, ws, alpha): @torch.library.custom_op("fp_quant::matmul_ada_mxf4_bf16_tn_op", mutates_args=()) def matmul_ada_mxf4_bf16_tn_op( - x: torch.Tensor, w: torch.Tensor, xs: torch.Tensor, ws: torch.Tensor, alpha: float + x: torch.Tensor, w: torch.Tensor, xs: torch.Tensor, ws: torch.Tensor, alpha: torch.Tensor ) -> torch.Tensor: return matmul_ada_mxf4_bf16_tn(x, w, xs, ws.view(torch.float8_e8m0fnu), alpha) @@ -248,7 +248,7 @@ def forward( x_flat, forward_hadamard_matrix, dtype, forward_method ) - y = forward_gemm(x_flat_q, weight_q, x_flat_scales, weight_scales, 1.0 / 9.0) + y = forward_gemm(x_flat_q, weight_q, x_flat_scales, weight_scales, torch.tensor([1.0 / 9.0], device=x.device)) y = y.unflatten(dim=0, sizes=x.shape[:-1]) if bias is not None: From a1bdd9f6ebd93fca797b2b2864a237eede8a94b4 Mon Sep 17 00:00:00 2001 From: Akif Faruk Nane Date: Sun, 21 Sep 2025 20:59:38 +0200 Subject: [PATCH 2/9] Add files via upload --- .../src/fp_quant/module/layer_analytics.py | 131 ++++++++++++++++++ inference_lib/src/fp_quant/module/linear.py | 50 ++++++- 2 files changed, 174 insertions(+), 7 deletions(-) create mode 100644 inference_lib/src/fp_quant/module/layer_analytics.py diff --git a/inference_lib/src/fp_quant/module/layer_analytics.py b/inference_lib/src/fp_quant/module/layer_analytics.py new file mode 100644 index 0000000..c51e29f --- /dev/null +++ b/inference_lib/src/fp_quant/module/layer_analytics.py @@ -0,0 +1,131 @@ +import torch +import json + +from ..utils.config import FPQuantConfig, FPQuantDtype + + +all_added_layer_names = [] + +layer_list = { + +} +config = FPQuantConfig(forward_dtype=FPQuantDtype.MXFP4, forward_method="abs_max", + backward_dtype=FPQuantDtype.BF16, hadamard_group_size=32) + + +def bench_ms(fn, warmup=10, iters=100): + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + for _ in range(warmup): + _ = fn() + torch.cuda.synchronize() + times = [] + for _ in range(iters): + start_evt.record() + _ = fn() + end_evt.record() + torch.cuda.synchronize() + times.append(start_evt.elapsed_time(end_evt)) # ms + t = torch.tensor(times, device="cpu") + return float(t.mean().item()), float(t.std(unbiased=False).item()) + +def add_layer(layer: torch.nn.Module, layer_name: str, input_shape: torch.Size, in_features: int, out_features: int, device, dtype): + + global layer_list + + if layer_name in layer_list: + return + + layer_info = { + "bias": layer.bias is not None, + "layer_name": layer_name, + "input_shape": list(input_shape), + "in_features": in_features, + "out_features": out_features, + "device": str(device), + "dtype": str(dtype), + } + + layer_list[layer_name] = layer_info + + print(f"Layer name: {layer_name}, input shape: {input_shape}, in_features: {in_features}, out_features: {out_features}") + + if len(all_added_layer_names) == len(layer_list): + analyze_layers() + print("All layers have been analyzed.") + +def analyze_layers(): + + # First find all unique pairs of (input_shape, in_features, out_features, device, dtype) + + global layer_list + + unique_layers = {} + + for name in layer_list: + bias = layer_list[name]["bias"] + input_shape = tuple(layer_list[name]["input_shape"]) + in_features = layer_list[name]["in_features"] + out_features = layer_list[name]["out_features"] + device = layer_list[name]["device"] + dtype = layer_list[name]["dtype"] + + key = (input_shape, in_features, out_features, device, dtype, bias) + + # Only keep the first layer encountered for each unique key + if key not in unique_layers: + unique_layers[key] = { + } + + # Now for each unique layer + for key in unique_layers: + input_shape, in_features, out_features, device, dtype, bias = key + + device = torch.device(device) + dtype = torch.__dict__[dtype.split('.')[-1]] + nn_layer = torch.nn.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) + from .linear import FPQuantLinear + quantized_layer = FPQuantLinear(in_features, out_features, bias=bias, config=config, device='cuda', dtype=torch.bfloat16) + quantized_layer.pre_forward() + + with torch.no_grad(): + sample_input = torch.randn(*input_shape, device=device, dtype=dtype) + + def b1(): + return nn_layer(sample_input) + + def b2(): + return quantized_layer(sample_input) + + if input_shape[0] == 12288 or input_shape[1] == 12288: + print(f"AAAAAAAAAAAAAAAA") + + quantized_layer_time, _ = bench_ms(b2, warmup=100, iters=200) + nn_layer_time, _ = bench_ms(b1, warmup=100, iters=200) + + if input_shape[0] == 12288 or input_shape[1] == 12288: + print(f"AAAAAAAAAAAAAAAA") + + print(f"Layer: {key}, nn_layer_time: {nn_layer_time:.3f} ms, quantized_layer_time: {quantized_layer_time:.3f} ms, ratio: {quantized_layer_time / nn_layer_time:.3f}") + + unique_layers[key]["quantized_layer_time"] = quantized_layer_time + unique_layers[key]["nn_layer_time"] = nn_layer_time + + del quantized_layer + del nn_layer + torch.cuda.empty_cache() + + for name in layer_list: + input_shape = tuple(layer_list[name]["input_shape"]) + in_features = layer_list[name]["in_features"] + out_features = layer_list[name]["out_features"] + device = layer_list[name]["device"] + dtype = layer_list[name]["dtype"] + key = (input_shape, in_features, out_features, device, dtype, bias) + layer_list[name]["quantized_layer_time"] = unique_layers[key]["quantized_layer_time"] + layer_list[name]["nn_layer_time"] = unique_layers[key]["nn_layer_time"] + + # Save to file + with open("layer_analytics.json", "w") as f: + json.dump(layer_list, f, indent=4) + diff --git a/inference_lib/src/fp_quant/module/linear.py b/inference_lib/src/fp_quant/module/linear.py index ce8e3d8..ef0d05c 100644 --- a/inference_lib/src/fp_quant/module/linear.py +++ b/inference_lib/src/fp_quant/module/linear.py @@ -5,6 +5,9 @@ from scipy.linalg import hadamard from ..utils import FPQuantConfig, FPQuantDtype + +from . import layer_analytics + from .linear_fns import ( HAS_QUTLASS, FPQuant4x16MasterFn, @@ -23,6 +26,17 @@ def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.devic hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device ) +def free_param(param: torch.nn.Parameter): + if param is not None: + # Move off GPU to free CUDA memory + if param.device.type == "cuda": + param.data = param.data.cpu() + + # Break the computational graph and drop storage reference + param.detach_() + + # Remove the parameter reference + del param class FPQuantLinear(nn.Module): def __init__( @@ -33,6 +47,8 @@ def __init__( bias: bool = True, device: torch.device = None, dtype: torch.dtype = None, + name: str = None, + enable_analytics: bool = False, ): super().__init__() @@ -44,12 +60,14 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} self.in_features = in_features self.out_features = out_features + self.name = name + self.name_analyzed = not enable_analytics + self.weight = nn.Parameter( torch.empty((out_features, in_features), **factory_kwargs) ) - self.dqweight = nn.Parameter( - torch.empty((out_features, in_features), **factory_kwargs) - ) + self.dqweight = None + if bias: self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) else: @@ -159,10 +177,25 @@ def pre_forward(self): self.scales = nn.Parameter( scales.view(dtype=torch.uint8), requires_grad=False ) + + if self.weight is not None: + free_param(self.weight) + self.register_parameter("weight", None) + torch.cuda.empty_cache() + del self.weight + self.weight = None self.dqweight = None def forward(self, x) -> torch.Tensor: + + if self.name is not None and not self.name_analyzed: + self.name_analyzed = True + layer_analytics.add_layer(self, self.name, x.shape, self.in_features, self.out_features, x.device, x.dtype) + + + result = None + match ( self.config.forward_dtype, self.config.backward_dtype, @@ -170,7 +203,7 @@ def forward(self, x) -> torch.Tensor: self.config.pseudoquantization, ): case (FPQuantDtype.MXFP4, FPQuantDtype.BF16, True, False): - return FPQuant4x16MasterFn.apply( + result = FPQuant4x16MasterFn.apply( x, self.weight, self.bias, @@ -179,7 +212,7 @@ def forward(self, x) -> torch.Tensor: self.config.forward_method, ) case (FPQuantDtype.MXFP4, FPQuantDtype.BF16, False, False): - return FPQuant4x16NoMasterFn.apply( + result = FPQuant4x16NoMasterFn.apply( x, self.qweight, self.scales, @@ -189,7 +222,7 @@ def forward(self, x) -> torch.Tensor: self.config.forward_method, ) case (FPQuantDtype.MXFP4, FPQuantDtype.BF16, True, True): - return PseudoQuant4x16MasterFn.apply( + result = PseudoQuant4x16MasterFn.apply( x, self.dqweight, self.bias, @@ -198,7 +231,7 @@ def forward(self, x) -> torch.Tensor: self.config.forward_method, ) case (FPQuantDtype.MXFP4, FPQuantDtype.BF16, False, True): - return PseudoQuant4x16NoMasterFn.apply( + result = PseudoQuant4x16NoMasterFn.apply( x, self.dqweight, self.bias, @@ -210,3 +243,6 @@ def forward(self, x) -> torch.Tensor: raise ValueError( f"Forward dtype: {self.config.forward_dtype}, backward dtype: {self.config.backward_dtype}, store_master_weights: {self.config.store_master_weights}, pseudoquantization: {self.config.pseudoquantization} isn't supported yet." ) + + return result + \ No newline at end of file From 474bbbe1fce37d145029c26a5cdc00e78b929047 Mon Sep 17 00:00:00 2001 From: Akif Faruk Nane Date: Sun, 21 Sep 2025 21:00:27 +0200 Subject: [PATCH 3/9] Add files via upload --- inference_lib/src/fp_quant/utils/replace.py | 70 ++++++++++++++++----- 1 file changed, 56 insertions(+), 14 deletions(-) diff --git a/inference_lib/src/fp_quant/utils/replace.py b/inference_lib/src/fp_quant/utils/replace.py index cceb026..218ee64 100644 --- a/inference_lib/src/fp_quant/utils/replace.py +++ b/inference_lib/src/fp_quant/utils/replace.py @@ -1,14 +1,18 @@ +import gc import torch from torch import nn from .config import FPQuantConfig +from ..module import layer_analytics def replace_with_fp_quant_linear( model, fp_quant_linear_config: FPQuantConfig, current_key_name=None, has_been_replaced=False, + apply_pre_forward=False, + enable_analytics=False, ): from ..module import FPQuantLinear @@ -30,8 +34,10 @@ def replace_with_fp_quant_linear( """ from accelerate import init_empty_weights + num_of_params = 0 for name, module in model.named_children(): + if current_key_name is None: current_key_name = [] current_key_name.append(name) @@ -40,32 +46,68 @@ def replace_with_fp_quant_linear( # Check if the current key is not in the `quantization_config.modules_to_not_convert` current_key_name_str = ".".join(current_key_name) if not any( - current_key_name_str.endswith(key) + current_key_name_str.startswith(key) for key in fp_quant_linear_config.modules_to_not_convert ): - with init_empty_weights(): - in_features = module.in_features - out_features = module.out_features - - model._modules[name] = FPQuantLinear( - in_features, - out_features, - config=fp_quant_linear_config, - bias=module.bias is not None, - ) + in_features = module.in_features + out_features = module.out_features + + layer_analytics.all_added_layer_names.append(current_key_name_str) + + new = FPQuantLinear(in_features, out_features, config=fp_quant_linear_config, bias=module.bias is not None, device=module.weight.device, dtype=module.weight.dtype, name=current_key_name_str, enable_analytics=enable_analytics) + with torch.no_grad(): + if hasattr(new, "load_from_linear"): + new.load_from_linear(module) # hypothetical helper + else: + # fallback if FPQuantLinear stores real-valued weights + if hasattr(new, "weight") and hasattr(module, "weight"): + new.weight.copy_(module.weight) + if hasattr(new, "bias") and hasattr(module, "bias") and module.bias is not None: + new.bias.copy_(module.bias) + + model._modules[name] = new + + module.weight.to(device='cpu') + if module.bias is not None: + module.bias.to(device='cpu') + has_been_replaced = True # Store the module class in case we need to transpose the weight later model._modules[name].source_cls = type(module) # Force requires grad to False to avoid unexpected errors model._modules[name].requires_grad_(False) - if len(list(module.children())) > 0: - _, has_been_replaced = replace_with_fp_quant_linear( + + # Force delete the tensors here + if hasattr(module, "weight") and module.weight is not None: + num_of_params += module.weight.numel() + del module.weight + module._parameters.pop("weight", None) + module.register_parameter("weight", None) + + if hasattr(module, "bias") and module.bias is not None: + del module.bias + module._parameters.pop("bias", None) + module.register_parameter("bias", None) + + del module + + if apply_pre_forward: + model._modules[name].pre_forward() + + torch.cuda.empty_cache() + + elif len(list(module.children())) > 0: + _, has_been_replaced, num_of_params_temp = replace_with_fp_quant_linear( module, fp_quant_linear_config=fp_quant_linear_config, current_key_name=current_key_name, has_been_replaced=has_been_replaced, + apply_pre_forward=apply_pre_forward, + enable_analytics=enable_analytics ) + num_of_params += num_of_params_temp + # Remove the last key for recursion current_key_name.pop(-1) - return model, has_been_replaced + return model, has_been_replaced, num_of_params From a67b3c343f3f906eb56f98400513c34a0522785b Mon Sep 17 00:00:00 2001 From: Akif Faruk Nane Date: Sun, 21 Sep 2025 21:20:21 +0200 Subject: [PATCH 4/9] Add files via upload --- inference_lib/src/fp_quant/module/layer_analytics.py | 10 ++++++---- inference_lib/src/fp_quant/module/linear.py | 3 +++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/inference_lib/src/fp_quant/module/layer_analytics.py b/inference_lib/src/fp_quant/module/layer_analytics.py index c51e29f..c023e45 100644 --- a/inference_lib/src/fp_quant/module/layer_analytics.py +++ b/inference_lib/src/fp_quant/module/layer_analytics.py @@ -9,9 +9,6 @@ layer_list = { } -config = FPQuantConfig(forward_dtype=FPQuantDtype.MXFP4, forward_method="abs_max", - backward_dtype=FPQuantDtype.BF16, hadamard_group_size=32) - def bench_ms(fn, warmup=10, iters=100): start_evt = torch.cuda.Event(enable_timing=True) @@ -37,6 +34,7 @@ def add_layer(layer: torch.nn.Module, layer_name: str, input_shape: torch.Size, return layer_info = { + "config": layer.config, "bias": layer.bias is not None, "layer_name": layer_name, "input_shape": list(input_shape), @@ -69,23 +67,26 @@ def analyze_layers(): out_features = layer_list[name]["out_features"] device = layer_list[name]["device"] dtype = layer_list[name]["dtype"] + config = layer_list[name]["config"] key = (input_shape, in_features, out_features, device, dtype, bias) # Only keep the first layer encountered for each unique key if key not in unique_layers: unique_layers[key] = { + "config": config } # Now for each unique layer for key in unique_layers: input_shape, in_features, out_features, device, dtype, bias = key + config = unique_layers[key]["config"] device = torch.device(device) dtype = torch.__dict__[dtype.split('.')[-1]] nn_layer = torch.nn.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) from .linear import FPQuantLinear - quantized_layer = FPQuantLinear(in_features, out_features, bias=bias, config=config, device='cuda', dtype=torch.bfloat16) + quantized_layer = FPQuantLinear(in_features, out_features, bias=bias, config=config, device=device, dtype=dtype) quantized_layer.pre_forward() with torch.no_grad(): @@ -124,6 +125,7 @@ def b2(): key = (input_shape, in_features, out_features, device, dtype, bias) layer_list[name]["quantized_layer_time"] = unique_layers[key]["quantized_layer_time"] layer_list[name]["nn_layer_time"] = unique_layers[key]["nn_layer_time"] + del layer_list[name]["config"] # Save to file with open("layer_analytics.json", "w") as f: diff --git a/inference_lib/src/fp_quant/module/linear.py b/inference_lib/src/fp_quant/module/linear.py index ef0d05c..2d27f77 100644 --- a/inference_lib/src/fp_quant/module/linear.py +++ b/inference_lib/src/fp_quant/module/linear.py @@ -58,6 +58,9 @@ def __init__( ) factory_kwargs = {"device": device, "dtype": dtype} + self.device = device + self.dtype = dtype + self.in_features = in_features self.out_features = out_features self.name = name From 16c74810f95c9b28598ccd12a9ce100c6fb92774 Mon Sep 17 00:00:00 2001 From: Akif Faruk Nane Date: Sun, 21 Sep 2025 21:20:39 +0200 Subject: [PATCH 5/9] Add files via upload From 35b655fd107172893b95109f539f2480366f78f6 Mon Sep 17 00:00:00 2001 From: Akif Faruk Nane Date: Sun, 21 Sep 2025 21:21:37 +0200 Subject: [PATCH 6/9] Add files via upload From 3a6db49f258b65a8d0fb85ecc34a86aed042719b Mon Sep 17 00:00:00 2001 From: Akif Faruk Nane Date: Sun, 21 Sep 2025 21:28:19 +0200 Subject: [PATCH 7/9] Update README.md --- inference_lib/README.md | 83 ++++++++++++++++++++++++++++++++++------- 1 file changed, 69 insertions(+), 14 deletions(-) diff --git a/inference_lib/README.md b/inference_lib/README.md index 427a20b..5241ea1 100644 --- a/inference_lib/README.md +++ b/inference_lib/README.md @@ -1,21 +1,76 @@ -# fp_quant +# Quantizing Flux Kontext -A library that wraps [`qutlass`](https://github.com/IST-DASLab/qutlass) kernels with linear layer wrappers for integrations into training and inference engines. +- Here is a quick example for how to use FP-Quant to quantize the models on the fly. -## Installation +~~~python -```bash -pip install . -``` +pipe = FluxKontextPipeline.from_pretrained("/home/cropy/flux_kontext", + local_files_only=True, + quantization_config=pipeline_quant_config, + torch_dtype=torch.bfloat16) +pipe.to("cuda") -## Usage +# Apply Qutlass quantization to the transformer +# Read the layer analytics (if present) to compare each layer’s quantized runtime with the normal runtime. -```python -from fp_quant import replace_with_fp_quant_linear, FPQuantConfig +try: + with open("layer_analytics.json", "r") as f: + layer_analytics_list = json.load(f) + layer_analytics_list = [key for key in layer_analytics_list if layer_analytics_list[key]["quantized_layer_time"]/layer_analytics_list[key]["nn_layer_time"] > 0.95] + enable_analytics = False +except: + layer_analytics_list = [] + print("No layer_analytics.json found, or error in reading it.") + enable_analytics = True -# Replace nn.Linear layers with fp_quant.FPQuantLinear -replace_with_fp_quant_linear( - model, - fp_quant_linear_config=FPQuantConfig(), +from fp_quant.inference_lib.src.fp_quant import FPQuantLinear, FPQuantConfig, FPQuantDtype, replace_with_fp_quant_linear +fp_quant_config = FPQuantConfig(forward_dtype=FPQuantDtype.MXFP4, forward_method="abs_max", + backward_dtype=FPQuantDtype.BF16, hadamard_group_size=32, + modules_to_not_convert=[ + "x_embedder", # we should not quantize x_embedder. Otherwise the resulting image looks like noise. + *layer_analytics_list + ], ) -``` \ No newline at end of file + +_, result, num_of_params=replace_with_fp_quant_linear(pipe.transformer, fp_quant_config, apply_pre_forward=True, enable_analytics=enable_analytics) +print("Transformer Replaced:", result, "Num of params:", num_of_params) + +pipe.transformer = torch.compile( + pipe.transformer, mode="max-autotune", fullgraph=False +) +pipe.vae.decode = torch.compile( + pipe.vae.decode, mode="max-autotune", fullgraph=True +) + +file_name = "images/1.png" + +input_image = load_image(file_name) + +height=1024 +width=1024 +input_image.resize((width, height)) +num_images_per_prompt = 1 + + +for _ in range(5): + images = pipe( + image=input_image, + prompt="Your prompt", + + + negative_prompt="blurry, low quality, bad quality, worst quality, deformed, distorted, worst quality", + guidance_scale=2.5, + height=height, + width=width, + max_area=height*width, + num_inference_steps=25, + generator=torch.manual_seed(441), + num_images_per_prompt=num_images_per_prompt + ).images + + + for i in range(num_images_per_prompt): + file_name_i = file_name.replace(".jpg", f"_{i}.jpg").replace(".png", f"_{i}.png") + images[i].save(file_name_i) + +~~~ From 45f1192bcd1afa997670bcee39290eb30696acd5 Mon Sep 17 00:00:00 2001 From: Akif Faruk Nane Date: Sun, 21 Sep 2025 21:54:09 +0200 Subject: [PATCH 8/9] Update README.md --- inference_lib/README.md | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/inference_lib/README.md b/inference_lib/README.md index 5241ea1..65f7912 100644 --- a/inference_lib/README.md +++ b/inference_lib/README.md @@ -8,12 +8,13 @@ pipe = FluxKontextPipeline.from_pretrained("/home/cropy/flux_kontext", local_files_only=True, quantization_config=pipeline_quant_config, torch_dtype=torch.bfloat16) + pipe.to("cuda") # Apply Qutlass quantization to the transformer -# Read the layer analytics (if present) to compare each layer’s quantized runtime with the normal runtime. - try: + + # read layer_analytics.json with open("layer_analytics.json", "r") as f: layer_analytics_list = json.load(f) layer_analytics_list = [key for key in layer_analytics_list if layer_analytics_list[key]["quantized_layer_time"]/layer_analytics_list[key]["nn_layer_time"] > 0.95] @@ -27,7 +28,15 @@ from fp_quant.inference_lib.src.fp_quant import FPQuantLinear, FPQuantConfig, FP fp_quant_config = FPQuantConfig(forward_dtype=FPQuantDtype.MXFP4, forward_method="abs_max", backward_dtype=FPQuantDtype.BF16, hadamard_group_size=32, modules_to_not_convert=[ - "x_embedder", # we should not quantize x_embedder. Otherwise the resulting image looks like noise. + # "pos_embed", + # "text_time_guidance_cls", + # "time_text_embed", + # "transformer_blocks", + # "single_transformer_blocks", + # "proj_out", + # "norm_out", + "x_embedder", + # "context_embedder", *layer_analytics_list ], ) @@ -55,11 +64,7 @@ num_images_per_prompt = 1 for _ in range(5): images = pipe( image=input_image, - prompt="Your prompt", - - - negative_prompt="blurry, low quality, bad quality, worst quality, deformed, distorted, worst quality", - guidance_scale=2.5, + prompt="Your prompt!", height=height, width=width, max_area=height*width, @@ -73,4 +78,5 @@ for _ in range(5): file_name_i = file_name.replace(".jpg", f"_{i}.jpg").replace(".png", f"_{i}.png") images[i].save(file_name_i) + ~~~ From 7cff8f56617835b6ac149c5df70e7995be2f3b2d Mon Sep 17 00:00:00 2001 From: Akif Faruk Nane Date: Mon, 22 Sep 2025 15:54:48 +0200 Subject: [PATCH 9/9] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index da777f3..706c0c0 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +See: [https://github.com/IST-DASLab/FP-Quant/pull/6](https://github.com/IST-DASLab/FP-Quant/pull/6) + # FP Format Quantization Harness This is a harness for efficient and accurate weight-and-activation quantization for low-bit FP/INT formats, with and without microscaling, including FP4, NVFP4, and MXFP. These formats are compatible with the NVIDIA Blackwell GPU architecture.