From 5f66048f9b9d1f12ddc59b2d2e712d4888204fdc Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Wed, 15 Apr 2026 11:50:40 -0700 Subject: [PATCH 1/3] Update tie-word embedding surgery --- olive/passes/onnx/graph_surgeries.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index 533c0c2c5..db32c99b7 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -1924,11 +1924,18 @@ class TieWordEmbeddings(ProtoSurgeon): def __call__(self, model: onnx.ModelProto): dag = OnnxDAG(model) - if not dag.is_input("input_ids") or not dag.is_output("logits"): + # support both "input_ids" and "input_embeds" as input names + input_name = None + for candidate in ("input_ids", "input_embeds"): + if candidate in dag.ios and dag.is_input(candidate): + input_name = candidate + break + + if input_name is None or "logits" not in dag.ios or not dag.is_output("logits"): return dag.model embed_name, embed_op_type = self.get_name_op_type( - dag, dag.get_consumers("input_ids"), ["Gather", "GatherBlockQuantized"], 0 + dag, dag.get_consumers(input_name), ["Gather", "GatherBlockQuantized"], 0 ) if embed_name is None: return dag.model From 14ec3283ee8894b637ac3d44d5d56d54307217ce Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Thu, 14 May 2026 00:12:56 +0000 Subject: [PATCH 2/3] Add QuantizeEmbeddingInt8 and ShareEmbeddingLmHead graph surgeries for INT8 embedding quantization --- olive/common/onnx_io.py | 15 +- olive/evaluator/lmeval_ort.py | 76 ++++++- olive/evaluator/olive_evaluator.py | 12 +- olive/passes/onnx/graph_surgeries.py | 232 +++++++++++++++++++- olive/passes/onnx/model_builder.py | 5 - test/passes/onnx/test_quantize_embedding.py | 185 ++++++++++++++++ 6 files changed, 505 insertions(+), 20 deletions(-) create mode 100644 test/passes/onnx/test_quantize_embedding.py diff --git a/olive/common/onnx_io.py b/olive/common/onnx_io.py index 42db080a3..da718f7bd 100644 --- a/olive/common/onnx_io.py +++ b/olive/common/onnx_io.py @@ -89,19 +89,24 @@ def get_kv_info(io_config: dict) -> dict | None: if kv_format is None: return None - # find the number of layers - num_layers = 0 + # find the actual layer indices (may be non-contiguous after pruning) + layer_indices = [] for i_name in io_config["input_names"]: - num_layers += int(re.match(kv_format, i_name) is not None) + m = re.match(kv_format, i_name) + if m: + idx = int(m.group(1)) + if idx not in layer_indices: + layer_indices.append(idx) + layer_indices.sort() past_names = [] present_to_past = {} for k in ["key", "value"]: - past_names.extend([kv_options[kv_format][f"past_{k}"] % i for i in range(num_layers)]) + past_names.extend([kv_options[kv_format][f"past_{k}"] % i for i in layer_indices]) present_to_past.update( { kv_options[kv_format][f"present_{k}"] % i: kv_options[kv_format][f"past_{k}"] % i - for i in range(num_layers) + for i in layer_indices } ) diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index fd69b066e..c874ded3e 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -293,6 +293,28 @@ def __init__(self, model_path: str, ep: str | None = None, ep_options: dict | No if self.kv_info is None: raise ValueError("Invalid io_config: kv_info not found") + # detect position_ids rank (e.g. 3 for mRoPE models like Qwen3.5) + self.position_ids_rank = 2 + if "position_ids" in self.io_config["input_names"]: + idx = self.io_config["input_names"].index("position_ids") + self.position_ids_rank = len(self.io_config["input_shapes"][idx]) + + # detect hybrid state inputs (conv_state, recurrent_state for linear attention layers) + self.hybrid_states = {} + for idx, name in enumerate(self.io_config["input_names"]): + if "conv_state" in name or "recurrent_state" in name: + shape = self.io_config["input_shapes"][idx] + dtype = self.io_config["input_types"][idx] + self.hybrid_states[name] = {"shape": shape, "dtype": dtype} + + # detect hybrid state outputs + self.hybrid_state_outputs = {} + for idx, name in enumerate(self.io_config["output_names"]): + if "conv_state" in name or "recurrent_state" in name: + shape = self.io_config["output_shapes"][idx] + dtype = self.io_config["output_types"][idx] + self.hybrid_state_outputs[name] = {"shape": shape, "dtype": dtype} + self._session = None self._batch_size = None self._buffers = None @@ -331,17 +353,29 @@ def run(self, input_ids: torch.Tensor) -> torch.Tensor: inputs_to_bind[name] = (self._buffers["inputs"][name], self.io_dtypes[name], shape) if "position_ids" in self._buffers["inputs"]: # need to reallocate since the position_ids tensor may be sliced - inputs_to_bind["position_ids"] = ( - self._buffers["inputs"]["position_ids"][:batch_size, :seqlen].contiguous(), - self.io_dtypes["position_ids"], - (batch_size, seqlen), - ) + if self.position_ids_rank == 3: + inputs_to_bind["position_ids"] = ( + self._buffers["inputs"]["position_ids"][:, :batch_size, :seqlen].contiguous(), + self.io_dtypes["position_ids"], + (self._buffers["inputs"]["position_ids"].shape[0], batch_size, seqlen), + ) + else: + inputs_to_bind["position_ids"] = ( + self._buffers["inputs"]["position_ids"][:batch_size, :seqlen].contiguous(), + self.io_dtypes["position_ids"], + (batch_size, seqlen), + ) for name in self._buffers["kv_inputs"]: inputs_to_bind[name] = ( self._buffers["kv_inputs"][name], self.kv_info["dtype"], (batch_size, self.kv_info["num_kv_heads"], 0, self.kv_info["head_size"]), ) + # hybrid state inputs (conv_state, recurrent_state) + for name, buf in self._buffers["hybrid_inputs"].items(): + shape = list(buf.shape) + shape[0] = batch_size + inputs_to_bind[name] = (buf, self.hybrid_states[name]["dtype"], tuple(shape)) for name, (buffer, dtype, shape) in inputs_to_bind.items(): io_binding.bind_input( name, @@ -363,6 +397,11 @@ def run(self, input_ids: torch.Tensor) -> torch.Tensor: self.kv_info["dtype"], (batch_size, self.kv_info["num_kv_heads"], seqlen, self.kv_info["head_size"]), ) + # hybrid state outputs (conv_state, recurrent_state) + for name, buf in self._buffers["hybrid_outputs"].items(): + shape = list(buf.shape) + shape[0] = batch_size + outputs_to_bind[name] = (buf, self.hybrid_state_outputs[name]["dtype"], tuple(shape)) for name, (buffer, dtype, shape) in outputs_to_bind.items(): io_binding.bind_output( name, @@ -418,11 +457,18 @@ def initialize_buffers(self, batch_size: int, max_length: int): ) } if self.io_dtypes.get("position_ids") is not None: - inputs["position_ids"] = ( + pos_ids = ( torch.arange(max_length, dtype=getattr(torch, self.io_dtypes["position_ids"]), device=self.device) .unsqueeze(0) .expand(batch_size, -1) ) + if self.position_ids_rank == 3: + # mRoPE: expand to [mrope_sections, batch_size, seq_len] + mrope_sections = self.io_config["input_shapes"][ + self.io_config["input_names"].index("position_ids") + ][0] + pos_ids = pos_ids.unsqueeze(0).expand(mrope_sections, -1, -1) + inputs["position_ids"] = pos_ids if self.io_dtypes.get("past_seq_len") is not None: inputs["past_seq_len"] = ( torch.tensor(max_length - 1, dtype=getattr(torch, self.io_dtypes["past_seq_len"]), device=self.device) @@ -457,6 +503,24 @@ def initialize_buffers(self, batch_size: int, max_length: int): } self._buffers = {"inputs": inputs, "outputs": outputs, "kv_inputs": kv_inputs, "kv_outputs": kv_outputs} + + # hybrid state buffers (conv_state, recurrent_state) - zero-initialized + hybrid_inputs = {} + for name, info in self.hybrid_states.items(): + # Replace symbolic 'batch_size' with actual batch_size + shape = [batch_size if s == "batch_size" else s for s in info["shape"]] + hybrid_inputs[name] = torch.zeros( + shape, dtype=getattr(torch, info["dtype"]), device=self.device + ) + hybrid_outputs = {} + for name, info in self.hybrid_state_outputs.items(): + shape = [batch_size if s == "batch_size" else s for s in info["shape"]] + hybrid_outputs[name] = torch.zeros( + shape, dtype=getattr(torch, info["dtype"]), device=self.device + ) + self._buffers["hybrid_inputs"] = hybrid_inputs + self._buffers["hybrid_outputs"] = hybrid_outputs + self._batch_size = batch_size diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index 87541ebe5..ec5c465dc 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -1115,10 +1115,16 @@ def evaluate( task_metrics = {} for mf, v in metric_items: - if mf != "alias": + if mf == "alias": + continue + if not isinstance(v, (int, float)): + continue + if "," in mf: m, _ = mf.split(",", 1) - if not m.endswith("_stderr"): - task_metrics[m] = SubMetricResult(value=v, priority=-1, higher_is_better=True) + else: + m = mf + if not m.endswith("_stderr"): + task_metrics[m] = SubMetricResult(value=v, priority=-1, higher_is_better=True) metrics[task_name] = MetricResult.model_validate(task_metrics) diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index db32c99b7..1ea738e01 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -1953,7 +1953,7 @@ def __call__(self, model: onnx.ModelProto): ]: return dag.model - if embed_op_type == "Gather": + if embed_op_type == "Gather" and lm_head_op_type == "MatMul": return self.handle_unquantized(dag, embed_name, lm_head_name) return self.handle_quantized(dag, embed_name, lm_head_name) @@ -2152,6 +2152,236 @@ def equal_weights(self, dag: OnnxDAG, init0: str, init1: str, transpose: bool = return np.array_equal(arr0.ravel(), arr1.ravel()) +def _find_embed_node(model, op_type, label): + """Find the embed_tokens node of the given op_type and its index.""" + for i, node in enumerate(model.graph.node): + if node.op_type == op_type and "embed_tokens" in node.name: + return node, i + logger.warning("No embed_tokens %s node found, skipping %s", op_type, label) + return None, None + + +def _find_lm_head_node(model): + """Find the lm_head MatMulNBits node and its index.""" + for i, node in enumerate(model.graph.node): + if node.op_type == "MatMulNBits" and "lm_head" in node.name: + return node, i + logger.warning("No lm_head MatMulNBits found") + return None, None + + +def _find_initializer(model, name): + """Find an initializer by name.""" + for init in model.graph.initializer: + if init.name == name: + return init + return None + + +def _get_node_attrs(node, *attr_names): + """Extract integer attributes from a node by name.""" + result = {} + for attr in node.attribute: + if attr.name in attr_names: + result[attr.name] = attr.i + return result + + +class QuantizeEmbeddingInt8(ProtoSurgeon): + """Quantize FP16 embedding to INT8 using GatherBlockQuantized. + + Replaces the Gather op for embed_tokens with a GatherBlockQuantized op + that uses per-block INT8 quantization (block_size=32). + """ + + def __call__(self, model: onnx.ModelProto): + import numpy as np + from onnx import numpy_helper, helper + + # Find embedding Gather node and weight + gather_node, gather_idx = _find_embed_node(model, "Gather", "QuantizeEmbeddingInt8") + if gather_node is None: + return model + + embed_init = _find_initializer(model, gather_node.input[0]) + + if embed_init is None: + logger.warning("Embedding weight initializer not found, skipping QuantizeEmbeddingInt8") + return model + + # Check if already quantized + if embed_init.data_type not in (onnx.TensorProto.FLOAT16, onnx.TensorProto.FLOAT): + logger.info("Embedding is not FP16/FP32, skipping QuantizeEmbeddingInt8") + return model + + embed = numpy_helper.to_array(embed_init).astype(np.float32) + vocab_size, hidden_size = embed.shape + block_size = 32 + + if hidden_size % block_size != 0: + logger.warning("hidden_size %d not divisible by block_size %d, skipping", hidden_size, block_size) + return model + + num_blocks = hidden_size // block_size + + logger.info( + "Quantizing embedding %s (%dx%d) from FP16 to INT8 (block_size=%d)", + embed_init.name, vocab_size, hidden_size, block_size, + ) + + # Per-block INT8 quantization (asymmetric with zero_point=128 for GatherBlockQuantized) + blocked = embed.reshape(vocab_size, num_blocks, block_size) + scales = (np.abs(blocked).max(axis=2) / 127.0).astype(np.float16) + scales_f32 = scales.astype(np.float32) + # Avoid division by zero + scales_f32 = np.where(scales_f32 == 0, 1.0, scales_f32) + q = np.clip(np.round(blocked / scales_f32[:, :, None]), -128, 127).astype(np.int8) + # GatherBlockQuantized expects unsigned uint8 with zero_point offset + q_uint8 = (q.astype(np.int16) + 128).astype(np.uint8) + q_flat = q_uint8.reshape(vocab_size, hidden_size) + # Zero point tensor: 128 for all blocks (symmetric around 128) + zero_points = np.full((vocab_size, num_blocks), 128, dtype=np.uint8) + + old_size_mb = embed.nbytes / (1024 * 1024) + new_size_mb = (q_flat.nbytes + scales.nbytes + zero_points.nbytes) / (1024 * 1024) + logger.info("Embedding: %.0f MB -> %.0f MB (saved %.0f MB)", old_size_mb, new_size_mb, old_size_mb - new_size_mb) + + # Create new initializers + qweight_name = embed_init.name + "_Q8" + scales_name = embed_init.name + "_scales" + zp_name = embed_init.name + "_zp" + model.graph.initializer.append(numpy_helper.from_array(q_flat, name=qweight_name)) + model.graph.initializer.append(numpy_helper.from_array(scales, name=scales_name)) + model.graph.initializer.append(numpy_helper.from_array(zero_points, name=zp_name)) + + # Create GatherBlockQuantized node + gbq_node = helper.make_node( + "GatherBlockQuantized", + inputs=[qweight_name, gather_node.input[1], scales_name, zp_name], + outputs=gather_node.output, + name=gather_node.name.replace("Gather", "GatherBlockQuantized"), + domain="com.microsoft", + bits=8, + block_size=block_size, + gather_axis=0, + quantize_axis=1, + ) + + # Replace Gather with GatherBlockQuantized + model.graph.node.remove(gather_node) + model.graph.node.insert(gather_idx, gbq_node) + + # Remove old FP16 embedding weight + model.graph.initializer.remove(embed_init) + + logger.info("Replaced Gather with GatherBlockQuantized (INT8)") + return model + + +class ShareEmbeddingLmHead(ProtoSurgeon): + """Share INT8 embedding weight with lm_head by converting lm_head to INT8 MatMulNBits. + + Must be applied AFTER QuantizeEmbeddingInt8. Replaces the lm_head's INT4 + MatMulNBits with an INT8 MatMulNBits that references the same quantized + weight as the embedding's GatherBlockQuantized, eliminating duplicate storage. + """ + + def __call__(self, model: onnx.ModelProto): + import numpy as np + from onnx import numpy_helper, helper + + # Find embedding GatherBlockQuantized + gbq_node, _ = _find_embed_node(model, "GatherBlockQuantized", "ShareEmbeddingLmHead") + if gbq_node is None: + return model + + attrs = _get_node_attrs(gbq_node, "bits", "block_size") + gbq_bits = attrs.get("bits", 8) + gbq_block_size = attrs.get("block_size", 32) + + if gbq_bits != 8: + logger.warning("Embedding is not INT8, cannot share with lm_head") + return model + + # Get embedding weight, scales, zero_points names + embed_weight_name = gbq_node.input[0] + embed_scales_name = gbq_node.input[2] + embed_zp_name = gbq_node.input[3] if len(gbq_node.input) > 3 else None + + # Get embedding weight shape to determine K and N + embed_weight_init = _find_initializer(model, embed_weight_name) + if embed_weight_init is None: + logger.warning("Could not find embedding weight initializer") + return model + + embed_weight = numpy_helper.to_array(embed_weight_init) + + vocab_size, hidden_size = embed_weight.shape # [V, H] for INT8 + num_blocks = hidden_size // gbq_block_size + + # Find lm_head MatMulNBits node + lm_head_node, lm_head_idx = _find_lm_head_node(model) + if lm_head_node is None: + return model + + # Get old lm_head attributes + old_attrs = _get_node_attrs(lm_head_node, "K", "N", "bits", "block_size") + + logger.info( + "Sharing embedding with lm_head: lm_head INT%d (%dx%d, bs=%d) -> INT8 (shared with embedding)", + old_attrs.get("bits", 0), old_attrs.get("N", 0), old_attrs.get("K", 0), old_attrs.get("block_size", 0), + ) + + # Remove old lm_head weight initializers + old_init_names = set(lm_head_node.input[1:]) # weight, scales, [zp], [g_idx] + to_remove = [init for init in model.graph.initializer if init.name in old_init_names] + for init in to_remove: + model.graph.initializer.remove(init) + + # MatMulNBits needs [N, K_blocks, block_size] but GBQ weight is [V, H]. + # Add a Reshape node to convert, referencing the SAME embedding weight. + reshape_shape_name = "lm_head.MatMulNBits.reshape_shape" + reshape_shape = np.array([vocab_size, num_blocks, gbq_block_size], dtype=np.int64) + model.graph.initializer.append(numpy_helper.from_array(reshape_shape, name=reshape_shape_name)) + + reshape_output = "lm_head.MatMulNBits.reshaped_weight" + reshape_node = helper.make_node( + "Reshape", + inputs=[embed_weight_name, reshape_shape_name], + outputs=[reshape_output], + name="lm_head/Reshape_shared_weight", + ) + model.graph.node.insert(lm_head_idx, reshape_node) + + # Scales and zp: reuse embedding's directly + inputs = [lm_head_node.input[0], reshape_output, embed_scales_name] + if embed_zp_name: + inputs.append(embed_zp_name) + + # Create new INT8 MatMulNBits node + new_lm_head = helper.make_node( + "MatMulNBits", + inputs=inputs, + outputs=lm_head_node.output, + name=lm_head_node.name, + domain="com.microsoft", + bits=8, + block_size=gbq_block_size, + K=hidden_size, + N=vocab_size, + ) + # Copy accuracy_level if present + for attr in lm_head_node.attribute: + if attr.name == "accuracy_level": + new_lm_head.attribute.append(attr) + + model.graph.node.remove(lm_head_node) + model.graph.node.insert(lm_head_idx + 1, new_lm_head) + + logger.info("lm_head now uses INT8 MatMulNBits (shared quantization with embedding)") + return model + + class ReciprocalMulToDiv(ProtoSurgeon): """Replace Reciprocal(x) * a with Div(a, x). diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index e2539feca..5507aefbb 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -520,11 +520,6 @@ def patched_make_embedding(self, embedding): import onnx_ir as ir basename = "/model/embed_tokens" - if getattr(self, "int4_tied_embeddings", False) or getattr(self, "shared_embeddings", False): - logger.debug( - "int4_tied_embedding/shared_embeddings is set to True but will be ignored. Use TieWordEmbeddings graph surgery to tie" - " embeddings." - ) if hasattr(embedding, "qweight"): qweight = "model.embed_tokens.qweight" diff --git a/test/passes/onnx/test_quantize_embedding.py b/test/passes/onnx/test_quantize_embedding.py new file mode 100644 index 000000000..ac0285fe5 --- /dev/null +++ b/test/passes/onnx/test_quantize_embedding.py @@ -0,0 +1,185 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import numpy as np +import onnx +import pytest +from onnx import TensorProto, helper, numpy_helper + +from olive.passes.onnx.graph_surgeries import QuantizeEmbeddingInt8, ShareEmbeddingLmHead + + +def _make_model_with_fp16_embed(vocab_size=64, hidden_size=64, block_size=32): + """Create a minimal ONNX model with FP16 Gather embedding and INT4 MatMulNBits lm_head.""" + # Embedding: Gather with FP16 weight + embed_weight = np.random.randn(vocab_size, hidden_size).astype(np.float16) + embed_init = numpy_helper.from_array(embed_weight, name="model.embed_tokens.weight") + + input_ids = helper.make_tensor_value_info("input_ids", TensorProto.INT64, ["batch_size", "seq_len"]) + gather_output = helper.make_tensor_value_info("embed_output", TensorProto.FLOAT16, ["batch_size", "seq_len", hidden_size]) + + gather_node = helper.make_node( + "Gather", inputs=["model.embed_tokens.weight", "input_ids"], + outputs=["embed_output"], name="/model/embed_tokens/Gather" + ) + + # lm_head: MatMulNBits with INT4 weight + num_blocks = hidden_size // block_size + lm_weight = np.random.randint(0, 255, (vocab_size, num_blocks, block_size // 2), dtype=np.uint8) + lm_scales = np.random.randn(vocab_size, num_blocks).astype(np.float16) * 0.01 + lm_zp = np.full((vocab_size, num_blocks), 8, dtype=np.uint8) + + lm_weight_init = numpy_helper.from_array(lm_weight, name="lm_head.MatMul_Q4.qweight") + lm_scales_init = numpy_helper.from_array(lm_scales, name="lm_head.MatMul_Q4.scales") + lm_zp_init = numpy_helper.from_array(lm_zp, name="lm_head.MatMul_Q4.zp") + + lm_head_node = helper.make_node( + "MatMulNBits", + inputs=["embed_output", "lm_head.MatMul_Q4.qweight", "lm_head.MatMul_Q4.scales", "lm_head.MatMul_Q4.zp"], + outputs=["logits"], + name="/lm_head/MatMulNBits", + domain="com.microsoft", + bits=4, block_size=block_size, K=hidden_size, N=vocab_size, + ) + + logits = helper.make_tensor_value_info("logits", TensorProto.FLOAT16, ["batch_size", "seq_len", vocab_size]) + + graph = helper.make_graph( + [gather_node, lm_head_node], "test", + [input_ids], [logits], + initializer=[embed_init, lm_weight_init, lm_scales_init, lm_zp_init], + ) + model = helper.make_model( + graph, + opset_imports=[helper.make_opsetid("", 21), helper.make_opsetid("com.microsoft", 1)], + ) + return model + + +class TestQuantizeEmbeddingInt8: + def test_replaces_gather_with_gbq(self): + model = _make_model_with_fp16_embed() + surgery = QuantizeEmbeddingInt8() + result = surgery(model) + + # Verify Gather is replaced with GatherBlockQuantized + node_types = [n.op_type for n in result.graph.node] + assert "Gather" not in node_types or all( + "embed_tokens" not in n.name for n in result.graph.node if n.op_type == "Gather" + ) + gbq_nodes = [n for n in result.graph.node if n.op_type == "GatherBlockQuantized"] + assert len(gbq_nodes) == 1 + + gbq = gbq_nodes[0] + attrs = {a.name: a.i for a in gbq.attribute} + assert attrs["bits"] == 8 + assert attrs["block_size"] == 32 + + # Verify zero_point input exists (4 inputs: weight, input_ids, scales, zp) + assert len(gbq.input) == 4 + + def test_reduces_weight_size(self): + model = _make_model_with_fp16_embed(vocab_size=256, hidden_size=128) + surgery = QuantizeEmbeddingInt8() + + # Get FP16 weight size before + fp16_size = sum( + np.prod(list(init.dims)) * 2 # FP16 = 2 bytes + for init in model.graph.initializer + if init.name == "model.embed_tokens.weight" + ) + + result = surgery(model) + + # FP16 weight should be removed + fp16_names = [init.name for init in result.graph.initializer if init.name == "model.embed_tokens.weight"] + assert len(fp16_names) == 0 + + # INT8 weight should exist + int8_names = [init.name for init in result.graph.initializer if "_Q8" in init.name] + assert len(int8_names) == 1 + + def test_skips_non_fp16(self): + model = _make_model_with_fp16_embed() + surgery = QuantizeEmbeddingInt8() + + # First pass: quantize to INT8 + result1 = surgery(model) + # Second pass: should skip (already quantized) + result2 = surgery(result1) + + # Should still have exactly 1 GBQ node + gbq_count = sum(1 for n in result2.graph.node if n.op_type == "GatherBlockQuantized") + assert gbq_count == 1 + + def test_skips_when_hidden_not_divisible(self): + # hidden_size=33, not divisible by block_size=32 + embed_weight = np.random.randn(64, 33).astype(np.float16) + embed_init = numpy_helper.from_array(embed_weight, name="model.embed_tokens.weight") + input_ids = helper.make_tensor_value_info("input_ids", TensorProto.INT64, [1]) + output = helper.make_tensor_value_info("out", TensorProto.FLOAT16, [1, 33]) + gather = helper.make_node("Gather", ["model.embed_tokens.weight", "input_ids"], ["out"], name="/model/embed_tokens/Gather") + graph = helper.make_graph([gather], "test", [input_ids], [output], initializer=[embed_init]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21), helper.make_opsetid("com.microsoft", 1)]) + + surgery = QuantizeEmbeddingInt8() + result = surgery(model) + + # Should still have Gather (not replaced) + assert any(n.op_type == "Gather" for n in result.graph.node) + + +class TestShareEmbeddingLmHead: + def test_shares_weight(self): + model = _make_model_with_fp16_embed() + + # First quantize embedding to INT8 + q_surgery = QuantizeEmbeddingInt8() + model = q_surgery(model) + + # Then share + s_surgery = ShareEmbeddingLmHead() + result = s_surgery(model) + + # lm_head should now be INT8 + lm_head = [n for n in result.graph.node if n.op_type == "MatMulNBits" and "lm_head" in n.name][0] + attrs = {a.name: a.i for a in lm_head.attribute} + assert attrs["bits"] == 8 + + # Should have a Reshape node for weight sharing + reshape_nodes = [n for n in result.graph.node if "Reshape_shared" in n.name] + assert len(reshape_nodes) == 1 + + # Reshape should reference the embedding weight + reshape = reshape_nodes[0] + assert "embed_tokens" in reshape.input[0] + + # lm_head should use shared scales + assert "embed_tokens" in lm_head.input[2] # scales + + def test_skips_without_gbq(self): + model = _make_model_with_fp16_embed() + # Don't quantize embedding first + s_surgery = ShareEmbeddingLmHead() + result = s_surgery(model) + + # Should be unchanged — still has Gather + assert any(n.op_type == "Gather" for n in result.graph.node) + + def test_removes_old_lm_head_weights(self): + model = _make_model_with_fp16_embed() + q_surgery = QuantizeEmbeddingInt8() + model = q_surgery(model) + + old_init_names = {init.name for init in model.graph.initializer} + + s_surgery = ShareEmbeddingLmHead() + result = s_surgery(model) + + new_init_names = {init.name for init in result.graph.initializer} + + # Old lm_head weights should be removed + assert "lm_head.MatMul_Q4.qweight" not in new_init_names + assert "lm_head.MatMul_Q4.scales" not in new_init_names + assert "lm_head.MatMul_Q4.zp" not in new_init_names From cbb973a3b7782dc58ca2b6af78d60755abcd6123 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Thu, 14 May 2026 03:58:04 +0000 Subject: [PATCH 3/3] Fix comments --- olive/common/onnx_io.py | 5 +- olive/evaluator/lmeval_ort.py | 14 ++--- olive/passes/onnx/graph_surgeries.py | 46 ++++++++++++---- test/passes/onnx/test_quantize_embedding.py | 58 ++++++++++++--------- 4 files changed, 75 insertions(+), 48 deletions(-) diff --git a/olive/common/onnx_io.py b/olive/common/onnx_io.py index da718f7bd..7be04ce10 100644 --- a/olive/common/onnx_io.py +++ b/olive/common/onnx_io.py @@ -104,10 +104,7 @@ def get_kv_info(io_config: dict) -> dict | None: for k in ["key", "value"]: past_names.extend([kv_options[kv_format][f"past_{k}"] % i for i in layer_indices]) present_to_past.update( - { - kv_options[kv_format][f"present_{k}"] % i: kv_options[kv_format][f"past_{k}"] % i - for i in layer_indices - } + {kv_options[kv_format][f"present_{k}"] % i: kv_options[kv_format][f"past_{k}"] % i for i in layer_indices} ) past_shape = io_config["input_shapes"][io_config["input_names"].index(past_names[0])] diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index c5f3791c7..2faa20fc6 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -464,9 +464,7 @@ def initialize_buffers(self, batch_size: int, max_length: int): ) if self.position_ids_rank == 3: # mRoPE: expand to [mrope_sections, batch_size, seq_len] - mrope_sections = self.io_config["input_shapes"][ - self.io_config["input_names"].index("position_ids") - ][0] + mrope_sections = self.io_config["input_shapes"][self.io_config["input_names"].index("position_ids")][0] pos_ids = pos_ids.unsqueeze(0).expand(mrope_sections, -1, -1) inputs["position_ids"] = pos_ids if self.io_dtypes.get("past_seq_len") is not None: @@ -509,15 +507,11 @@ def initialize_buffers(self, batch_size: int, max_length: int): for name, info in self.hybrid_states.items(): # Replace symbolic 'batch_size' with actual batch_size shape = [batch_size if s == "batch_size" else s for s in info["shape"]] - hybrid_inputs[name] = torch.zeros( - shape, dtype=getattr(torch, info["dtype"]), device=self.device - ) + hybrid_inputs[name] = torch.zeros(shape, dtype=getattr(torch, info["dtype"]), device=self.device) hybrid_outputs = {} for name, info in self.hybrid_state_outputs.items(): shape = [batch_size if s == "batch_size" else s for s in info["shape"]] - hybrid_outputs[name] = torch.zeros( - shape, dtype=getattr(torch, info["dtype"]), device=self.device - ) + hybrid_outputs[name] = torch.zeros(shape, dtype=getattr(torch, info["dtype"]), device=self.device) self._buffers["hybrid_inputs"] = hybrid_inputs self._buffers["hybrid_outputs"] = hybrid_outputs @@ -603,7 +597,7 @@ def _detect_full_logits(self) -> bool: def eot_token_id(self): return self._eot_token_id - def tok_encode(self, string: str, **kwargs) -> list[int]: + def tok_encode(self, string: str, add_special_tokens: bool | None = None, **kwargs) -> list[int]: """Tokenize a string using the model's tokenizer and return a list of token IDs.""" return self.tokenizer.encode(string).tolist() diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index 34e4606ec..415d5c49f 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -2406,6 +2406,14 @@ def _get_node_attrs(node, *attr_names): return result +def _ensure_msft_opset(model): + """Ensure com.microsoft opset import is present in the model.""" + for opset in model.opset_import: + if opset.domain == "com.microsoft": + return + model.opset_import.append(onnx.helper.make_opsetid("com.microsoft", 1)) + + class QuantizeEmbeddingInt8(ProtoSurgeon): """Quantize FP16 embedding to INT8 using GatherBlockQuantized. @@ -2414,8 +2422,7 @@ class QuantizeEmbeddingInt8(ProtoSurgeon): """ def __call__(self, model: onnx.ModelProto): - import numpy as np - from onnx import numpy_helper, helper + from onnx import numpy_helper # Find embedding Gather node and weight gather_node, gather_idx = _find_embed_node(model, "Gather", "QuantizeEmbeddingInt8") @@ -2445,7 +2452,10 @@ def __call__(self, model: onnx.ModelProto): logger.info( "Quantizing embedding %s (%dx%d) from FP16 to INT8 (block_size=%d)", - embed_init.name, vocab_size, hidden_size, block_size, + embed_init.name, + vocab_size, + hidden_size, + block_size, ) # Per-block INT8 quantization (asymmetric with zero_point=128 for GatherBlockQuantized) @@ -2463,7 +2473,9 @@ def __call__(self, model: onnx.ModelProto): old_size_mb = embed.nbytes / (1024 * 1024) new_size_mb = (q_flat.nbytes + scales.nbytes + zero_points.nbytes) / (1024 * 1024) - logger.info("Embedding: %.0f MB -> %.0f MB (saved %.0f MB)", old_size_mb, new_size_mb, old_size_mb - new_size_mb) + logger.info( + "Embedding: %.0f MB -> %.0f MB (saved %.0f MB)", old_size_mb, new_size_mb, old_size_mb - new_size_mb + ) # Create new initializers qweight_name = embed_init.name + "_Q8" @@ -2473,8 +2485,11 @@ def __call__(self, model: onnx.ModelProto): model.graph.initializer.append(numpy_helper.from_array(scales, name=scales_name)) model.graph.initializer.append(numpy_helper.from_array(zero_points, name=zp_name)) + # Ensure com.microsoft opset is declared + _ensure_msft_opset(model) + # Create GatherBlockQuantized node - gbq_node = helper.make_node( + gbq_node = onnx.helper.make_node( "GatherBlockQuantized", inputs=[qweight_name, gather_node.input[1], scales_name, zp_name], outputs=gather_node.output, @@ -2506,8 +2521,7 @@ class ShareEmbeddingLmHead(ProtoSurgeon): """ def __call__(self, model: onnx.ModelProto): - import numpy as np - from onnx import numpy_helper, helper + from onnx import numpy_helper # Find embedding GatherBlockQuantized gbq_node, _ = _find_embed_node(model, "GatherBlockQuantized", "ShareEmbeddingLmHead") @@ -2543,12 +2557,21 @@ def __call__(self, model: onnx.ModelProto): if lm_head_node is None: return model + # Check if already shared (idempotency): lm_head weight input references embedding weight + lm_head_weight_input = lm_head_node.input[1] + if embed_weight_name in lm_head_weight_input or lm_head_node.input[2] == embed_scales_name: + logger.info("lm_head already shares weights with embedding, skipping ShareEmbeddingLmHead") + return model + # Get old lm_head attributes old_attrs = _get_node_attrs(lm_head_node, "K", "N", "bits", "block_size") logger.info( "Sharing embedding with lm_head: lm_head INT%d (%dx%d, bs=%d) -> INT8 (shared with embedding)", - old_attrs.get("bits", 0), old_attrs.get("N", 0), old_attrs.get("K", 0), old_attrs.get("block_size", 0), + old_attrs.get("bits", 0), + old_attrs.get("N", 0), + old_attrs.get("K", 0), + old_attrs.get("block_size", 0), ) # Remove old lm_head weight initializers @@ -2564,7 +2587,7 @@ def __call__(self, model: onnx.ModelProto): model.graph.initializer.append(numpy_helper.from_array(reshape_shape, name=reshape_shape_name)) reshape_output = "lm_head.MatMulNBits.reshaped_weight" - reshape_node = helper.make_node( + reshape_node = onnx.helper.make_node( "Reshape", inputs=[embed_weight_name, reshape_shape_name], outputs=[reshape_output], @@ -2577,8 +2600,11 @@ def __call__(self, model: onnx.ModelProto): if embed_zp_name: inputs.append(embed_zp_name) + # Ensure com.microsoft opset is declared + _ensure_msft_opset(model) + # Create new INT8 MatMulNBits node - new_lm_head = helper.make_node( + new_lm_head = onnx.helper.make_node( "MatMulNBits", inputs=inputs, outputs=lm_head_node.output, diff --git a/test/passes/onnx/test_quantize_embedding.py b/test/passes/onnx/test_quantize_embedding.py index ac0285fe5..db3f2e83f 100644 --- a/test/passes/onnx/test_quantize_embedding.py +++ b/test/passes/onnx/test_quantize_embedding.py @@ -3,8 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import numpy as np -import onnx -import pytest from onnx import TensorProto, helper, numpy_helper from olive.passes.onnx.graph_surgeries import QuantizeEmbeddingInt8, ShareEmbeddingLmHead @@ -17,11 +15,12 @@ def _make_model_with_fp16_embed(vocab_size=64, hidden_size=64, block_size=32): embed_init = numpy_helper.from_array(embed_weight, name="model.embed_tokens.weight") input_ids = helper.make_tensor_value_info("input_ids", TensorProto.INT64, ["batch_size", "seq_len"]) - gather_output = helper.make_tensor_value_info("embed_output", TensorProto.FLOAT16, ["batch_size", "seq_len", hidden_size]) gather_node = helper.make_node( - "Gather", inputs=["model.embed_tokens.weight", "input_ids"], - outputs=["embed_output"], name="/model/embed_tokens/Gather" + "Gather", + inputs=["model.embed_tokens.weight", "input_ids"], + outputs=["embed_output"], + name="/model/embed_tokens/Gather", ) # lm_head: MatMulNBits with INT4 weight @@ -40,21 +39,23 @@ def _make_model_with_fp16_embed(vocab_size=64, hidden_size=64, block_size=32): outputs=["logits"], name="/lm_head/MatMulNBits", domain="com.microsoft", - bits=4, block_size=block_size, K=hidden_size, N=vocab_size, + bits=4, + block_size=block_size, + K=hidden_size, + N=vocab_size, ) - logits = helper.make_tensor_value_info("logits", TensorProto.FLOAT16, ["batch_size", "seq_len", vocab_size]) - graph = helper.make_graph( - [gather_node, lm_head_node], "test", - [input_ids], [logits], + [gather_node, lm_head_node], + "test", + [input_ids], + [helper.make_tensor_value_info("logits", TensorProto.FLOAT16, ["batch_size", "seq_len", vocab_size])], initializer=[embed_init, lm_weight_init, lm_scales_init, lm_zp_init], ) - model = helper.make_model( + return helper.make_model( graph, opset_imports=[helper.make_opsetid("", 21), helper.make_opsetid("com.microsoft", 1)], ) - return model class TestQuantizeEmbeddingInt8: @@ -83,13 +84,6 @@ def test_reduces_weight_size(self): model = _make_model_with_fp16_embed(vocab_size=256, hidden_size=128) surgery = QuantizeEmbeddingInt8() - # Get FP16 weight size before - fp16_size = sum( - np.prod(list(init.dims)) * 2 # FP16 = 2 bytes - for init in model.graph.initializer - if init.name == "model.embed_tokens.weight" - ) - result = surgery(model) # FP16 weight should be removed @@ -119,9 +113,13 @@ def test_skips_when_hidden_not_divisible(self): embed_init = numpy_helper.from_array(embed_weight, name="model.embed_tokens.weight") input_ids = helper.make_tensor_value_info("input_ids", TensorProto.INT64, [1]) output = helper.make_tensor_value_info("out", TensorProto.FLOAT16, [1, 33]) - gather = helper.make_node("Gather", ["model.embed_tokens.weight", "input_ids"], ["out"], name="/model/embed_tokens/Gather") + gather = helper.make_node( + "Gather", ["model.embed_tokens.weight", "input_ids"], ["out"], name="/model/embed_tokens/Gather" + ) graph = helper.make_graph([gather], "test", [input_ids], [output], initializer=[embed_init]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21), helper.make_opsetid("com.microsoft", 1)]) + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 21), helper.make_opsetid("com.microsoft", 1)] + ) surgery = QuantizeEmbeddingInt8() result = surgery(model) @@ -143,7 +141,7 @@ def test_shares_weight(self): result = s_surgery(model) # lm_head should now be INT8 - lm_head = [n for n in result.graph.node if n.op_type == "MatMulNBits" and "lm_head" in n.name][0] + lm_head = next(n for n in result.graph.node if n.op_type == "MatMulNBits" and "lm_head" in n.name) attrs = {a.name: a.i for a in lm_head.attribute} assert attrs["bits"] == 8 @@ -158,6 +156,20 @@ def test_shares_weight(self): # lm_head should use shared scales assert "embed_tokens" in lm_head.input[2] # scales + def test_idempotent(self): + model = _make_model_with_fp16_embed() + q_surgery = QuantizeEmbeddingInt8() + model = q_surgery(model) + + s_surgery = ShareEmbeddingLmHead() + result1 = s_surgery(model) + # Applying again should be a no-op + result2 = s_surgery(result1) + + # Should still have exactly 1 Reshape_shared node + reshape_count = sum(1 for n in result2.graph.node if "Reshape_shared" in n.name) + assert reshape_count == 1 + def test_skips_without_gbq(self): model = _make_model_with_fp16_embed() # Don't quantize embedding first @@ -172,8 +184,6 @@ def test_removes_old_lm_head_weights(self): q_surgery = QuantizeEmbeddingInt8() model = q_surgery(model) - old_init_names = {init.name for init in model.graph.initializer} - s_surgery = ShareEmbeddingLmHead() result = s_surgery(model)