Skip to content

Enable NextStepDiffusion and support multi-device tuning for diffusion#1640

Open
xin3he wants to merge 18 commits intomainfrom
xinhe/3-30a
Open

Enable NextStepDiffusion and support multi-device tuning for diffusion#1640
xin3he wants to merge 18 commits intomainfrom
xinhe/3-30a

Conversation

@xin3he
Copy link
Copy Markdown
Contributor

@xin3he xin3he commented Mar 30, 2026

Description

fix nextstep loading issue

example_prompt = "A REALISTIC PHOTOGRAPH OF A WALL WITH \"TOWARD AUTOREGRESSIVE IMAGE GENERATION WITH CONTINUOUS TOKENS AT SCALE\" PROMINENTLY DISPLAYED"

Raw model output:

image

W4A16 model output with torch backend on CPU:

image

W4A16 model output with gptqmodel:marlin backend on CUDA:

image

Type of Change

  • Bug fix
  • New feature
  • Documentation update
  • Performance improvement
  • Code refactoring
  • Other (please specify):

Related Issues

Fixes or relates to #

Checklist Before Submitting

  • My code has been tested locally.
  • Documentation has been updated as needed.
  • New or updated tests are included where applicable.

Signed-off-by: Xin He <xin3.he@intel.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes model loading for the “nextstep” model type by selecting an appropriate AutoModel loader, and adjusts multimodal key detection to recognize “image”-named components.

Changes:

  • Force AutoModel for model_type == "nextstep" during MLLM model loading.
  • Add "image" to MM_KEYS to broaden multimodal component detection.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
auto_round/utils/model.py Adds a NextStep-specific loader class override to resolve loading failures.
auto_round/utils/common.py Extends multimodal key matching to include "image" for downstream detection/mapping.

Signed-off-by: Xin He <xin3.he@intel.com>
@xin3he
Copy link
Copy Markdown
Contributor Author

xin3he commented Mar 30, 2026

The exllama backend has accuracy issue for nextstep generation.
The marlin backend requires main branch so fix it in this PR.
cc @wenhuach21

@xin3he xin3he requested a review from wenhuach21 March 30, 2026 13:57
@wenhuach21
Copy link
Copy Markdown
Contributor

better add next_step to mllm support matrix

@xin3he
Copy link
Copy Markdown
Contributor Author

xin3he commented Mar 31, 2026

I need to upstream a model before updating the support matrix (requires model link).

@wenhuach21
Copy link
Copy Markdown
Contributor

I need to upstream a model before updating the support matrix (requires model link).

If the model’s license allows upstreaming, we can upload it. Otherwise, we can leave the link blank.

@xin3he xin3he marked this pull request as draft April 3, 2026 01:52
@xin3he
Copy link
Copy Markdown
Contributor Author

xin3he commented Apr 3, 2026

The status has been reverted to "Draft," as only RTN is currently supported; upstream adaptation and optimization work is currently underway.

xin3he added 2 commits April 7, 2026 12:23
Signed-off-by: Xin He <xin3.he@intel.com>
Signed-off-by: Xin He <xin3.he@intel.com>
xin3he and others added 5 commits April 8, 2026 02:46
… gptqmodel fix

Signed-off-by: Xin He <xin3.he@intel.com>
Signed-off-by: Xin He <xin3.he@intel.com>
…imports

Signed-off-by: Xin He <xin3.he@intel.com>
@xin3he xin3he changed the title fix nextstep loading issue Enable NextStepDiffusion and support multi-device tuning for diffusion Apr 8, 2026
@xin3he xin3he requested a review from changwangss April 8, 2026 07:51
@xin3he xin3he marked this pull request as ready for review April 8, 2026 07:54
@xin3he
Copy link
Copy Markdown
Contributor Author

xin3he commented Apr 8, 2026

/azp run Unit-Test-CUDA-AutoRound

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 1 pipeline(s).

xin3he and others added 7 commits April 9, 2026 07:42
Signed-off-by: Xin He <xin3.he@intel.com>
Signed-off-by: Xin He <xin3.he@intel.com>
Signed-off-by: Xin He <xin3.he@intel.com>
Signed-off-by: Xin He <xin3.he@intel.com>
**kwargs,
):
logger.warning("Diffusion model quantization is experimental and is only validated on Flux models.")
if dataset == "NeelNanda/pile-10k":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not very robust, I guess all of our supported llm datasets are not suitable for this

"""
# Replace special characters to make the folder name filesystem-safe
sanitized_format = format.get_backend_name().replace(":", "-").replace("_", "-")
if hasattr(self.model, "config") and getattr(self.model.config, "model_type", None) == "nextstep":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very tricky, It would be better to handle this in special model

return super().save_quantized(output_dir, format=format, inplace=inplace, **kwargs)

compressed_model = None
if hasattr(self.model, "config") and getattr(self.model.config, "model_type", None) == "nextstep":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same tricky issue. We do not handle model specific issue in the common code. Better name it as a specific behavior and code a function/class to handle this for all models with the same behavior


if isinstance(model, DiffusionPipeline):
pipe = model
_device_map = 0 if device_map is None else device_map
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only try the code that may throw exceptions, I guess it's from diffusers.pipelines.pipeline_utils import DiffusionPipeline here


# This function is designed for Auto Scheme and Diffusion Pipeline,
# which requires dispatching the whole model on all available devices.
def dispatch_model_by_all_available_devices(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we consolidate with the other function in auto-scheme

try:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trust_remode_code should follow the AR's setting. We have disable_trust

config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
model_type = getattr(config, "model_type", "")
# A special case for NextStep
if model_type == "nextstep":
Copy link
Copy Markdown
Contributor

@wenhuach21 wenhuach21 Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same issue, you could register model type in handle_special_model.py or in diffuser folder



def load_next_step_diffusion(pretrained_model_name_or_path, device_str):
from models.gen_pipeline import NextStepPipeline # pylint: disable=E0401
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better create a new file or new folder to handle the special model loading, you could use register or something else. And the other developer only needs to call load_mllm_model to load all our supported models

assert device in environ_mapping, f"Device {device} not supported for vllm tensor parallelism."
environ_name = environ_mapping[device]
assert device in DEVICE_ENVIRON_VARIABLE_MAPPING, f"Device {device} not supported for vllm tensor parallelism."
environ_name = DEVICE_ENVIRON_VARIABLE_MAPPING[device]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

env_name

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants