Skip to content
Open
3 changes: 2 additions & 1 deletion examples/diffusers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
check_conv_and_mha,
check_lora,
filter_func_default,
filter_func_flux_dev,
filter_func_ltx_video,
filter_func_wan_video,
load_calib_prompts,
Expand Down Expand Up @@ -138,7 +139,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
A filter function appropriate for the model type
"""
filter_func_map = {
ModelType.FLUX_DEV: filter_func_default,
ModelType.FLUX_DEV: filter_func_flux_dev,
ModelType.FLUX_SCHNELL: filter_func_default,
ModelType.SDXL_BASE: filter_func_default,
ModelType.SDXL_TURBO: filter_func_default,
Expand Down
6 changes: 6 additions & 0 deletions examples/diffusers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ def filter_func_ltx_video(name: str) -> bool:
return pattern.match(name) is not None


def filter_func_flux_dev(name: str) -> bool:
"""Filter function specifically for Flux-dev models."""
pattern = re.compile(r"(proj_out.*|.*(time_text_embed|context_embedder|x_embedder|norm_out).*)")
return pattern.match(name) is not None


def filter_func_wan_video(name: str) -> bool:
"""Filter function specifically for LTX-Video models."""
pattern = re.compile(r".*(patch_embedding|condition_embedder).*")
Expand Down
Loading
Loading