diff --git a/olive/common/onnx_io.py b/olive/common/onnx_io.py index 42db080a3..7be04ce10 100644 --- a/olive/common/onnx_io.py +++ b/olive/common/onnx_io.py @@ -89,20 +89,22 @@ def get_kv_info(io_config: dict) -> dict | None: if kv_format is None: return None - # find the number of layers - num_layers = 0 + # find the actual layer indices (may be non-contiguous after pruning) + layer_indices = [] for i_name in io_config["input_names"]: - num_layers += int(re.match(kv_format, i_name) is not None) + m = re.match(kv_format, i_name) + if m: + idx = int(m.group(1)) + if idx not in layer_indices: + layer_indices.append(idx) + layer_indices.sort() past_names = [] present_to_past = {} for k in ["key", "value"]: - past_names.extend([kv_options[kv_format][f"past_{k}"] % i for i in range(num_layers)]) + past_names.extend([kv_options[kv_format][f"past_{k}"] % i for i in layer_indices]) present_to_past.update( - { - kv_options[kv_format][f"present_{k}"] % i: kv_options[kv_format][f"past_{k}"] % i - for i in range(num_layers) - } + {kv_options[kv_format][f"present_{k}"] % i: kv_options[kv_format][f"past_{k}"] % i for i in layer_indices} ) past_shape = io_config["input_shapes"][io_config["input_names"].index(past_names[0])] diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index c4a158533..2faa20fc6 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -293,6 +293,28 @@ def __init__(self, model_path: str, ep: str | None = None, ep_options: dict | No if self.kv_info is None: raise ValueError("Invalid io_config: kv_info not found") + # detect position_ids rank (e.g. 3 for mRoPE models like Qwen3.5) + self.position_ids_rank = 2 + if "position_ids" in self.io_config["input_names"]: + idx = self.io_config["input_names"].index("position_ids") + self.position_ids_rank = len(self.io_config["input_shapes"][idx]) + + # detect hybrid state inputs (conv_state, recurrent_state for linear attention layers) + self.hybrid_states = {} + for idx, name in enumerate(self.io_config["input_names"]): + if "conv_state" in name or "recurrent_state" in name: + shape = self.io_config["input_shapes"][idx] + dtype = self.io_config["input_types"][idx] + self.hybrid_states[name] = {"shape": shape, "dtype": dtype} + + # detect hybrid state outputs + self.hybrid_state_outputs = {} + for idx, name in enumerate(self.io_config["output_names"]): + if "conv_state" in name or "recurrent_state" in name: + shape = self.io_config["output_shapes"][idx] + dtype = self.io_config["output_types"][idx] + self.hybrid_state_outputs[name] = {"shape": shape, "dtype": dtype} + self._session = None self._batch_size = None self._buffers = None @@ -331,17 +353,29 @@ def run(self, input_ids: torch.Tensor) -> torch.Tensor: inputs_to_bind[name] = (self._buffers["inputs"][name], self.io_dtypes[name], shape) if "position_ids" in self._buffers["inputs"]: # need to reallocate since the position_ids tensor may be sliced - inputs_to_bind["position_ids"] = ( - self._buffers["inputs"]["position_ids"][:batch_size, :seqlen].contiguous(), - self.io_dtypes["position_ids"], - (batch_size, seqlen), - ) + if self.position_ids_rank == 3: + inputs_to_bind["position_ids"] = ( + self._buffers["inputs"]["position_ids"][:, :batch_size, :seqlen].contiguous(), + self.io_dtypes["position_ids"], + (self._buffers["inputs"]["position_ids"].shape[0], batch_size, seqlen), + ) + else: + inputs_to_bind["position_ids"] = ( + self._buffers["inputs"]["position_ids"][:batch_size, :seqlen].contiguous(), + self.io_dtypes["position_ids"], + (batch_size, seqlen), + ) for name in self._buffers["kv_inputs"]: inputs_to_bind[name] = ( self._buffers["kv_inputs"][name], self.kv_info["dtype"], (batch_size, self.kv_info["num_kv_heads"], 0, self.kv_info["head_size"]), ) + # hybrid state inputs (conv_state, recurrent_state) + for name, buf in self._buffers["hybrid_inputs"].items(): + shape = list(buf.shape) + shape[0] = batch_size + inputs_to_bind[name] = (buf, self.hybrid_states[name]["dtype"], tuple(shape)) for name, (buffer, dtype, shape) in inputs_to_bind.items(): io_binding.bind_input( name, @@ -363,6 +397,11 @@ def run(self, input_ids: torch.Tensor) -> torch.Tensor: self.kv_info["dtype"], (batch_size, self.kv_info["num_kv_heads"], seqlen, self.kv_info["head_size"]), ) + # hybrid state outputs (conv_state, recurrent_state) + for name, buf in self._buffers["hybrid_outputs"].items(): + shape = list(buf.shape) + shape[0] = batch_size + outputs_to_bind[name] = (buf, self.hybrid_state_outputs[name]["dtype"], tuple(shape)) for name, (buffer, dtype, shape) in outputs_to_bind.items(): io_binding.bind_output( name, @@ -418,11 +457,16 @@ def initialize_buffers(self, batch_size: int, max_length: int): ) } if self.io_dtypes.get("position_ids") is not None: - inputs["position_ids"] = ( + pos_ids = ( torch.arange(max_length, dtype=getattr(torch, self.io_dtypes["position_ids"]), device=self.device) .unsqueeze(0) .expand(batch_size, -1) ) + if self.position_ids_rank == 3: + # mRoPE: expand to [mrope_sections, batch_size, seq_len] + mrope_sections = self.io_config["input_shapes"][self.io_config["input_names"].index("position_ids")][0] + pos_ids = pos_ids.unsqueeze(0).expand(mrope_sections, -1, -1) + inputs["position_ids"] = pos_ids if self.io_dtypes.get("past_seq_len") is not None: inputs["past_seq_len"] = ( torch.tensor(max_length - 1, dtype=getattr(torch, self.io_dtypes["past_seq_len"]), device=self.device) @@ -457,6 +501,20 @@ def initialize_buffers(self, batch_size: int, max_length: int): } self._buffers = {"inputs": inputs, "outputs": outputs, "kv_inputs": kv_inputs, "kv_outputs": kv_outputs} + + # hybrid state buffers (conv_state, recurrent_state) - zero-initialized + hybrid_inputs = {} + for name, info in self.hybrid_states.items(): + # Replace symbolic 'batch_size' with actual batch_size + shape = [batch_size if s == "batch_size" else s for s in info["shape"]] + hybrid_inputs[name] = torch.zeros(shape, dtype=getattr(torch, info["dtype"]), device=self.device) + hybrid_outputs = {} + for name, info in self.hybrid_state_outputs.items(): + shape = [batch_size if s == "batch_size" else s for s in info["shape"]] + hybrid_outputs[name] = torch.zeros(shape, dtype=getattr(torch, info["dtype"]), device=self.device) + self._buffers["hybrid_inputs"] = hybrid_inputs + self._buffers["hybrid_outputs"] = hybrid_outputs + self._batch_size = batch_size @@ -539,7 +597,7 @@ def _detect_full_logits(self) -> bool: def eot_token_id(self): return self._eot_token_id - def tok_encode(self, string: str, **kwargs) -> list[int]: + def tok_encode(self, string: str, add_special_tokens: bool | None = None, **kwargs) -> list[int]: """Tokenize a string using the model's tokenizer and return a list of token IDs.""" return self.tokenizer.encode(string).tolist() diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index 05933b8b6..eb88b1597 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -1612,10 +1612,16 @@ def evaluate( task_metrics = {} for mf, v in metric_items: - if mf != "alias": + if mf == "alias": + continue + if not isinstance(v, (int, float)): + continue + if "," in mf: m, _ = mf.split(",", 1) - if not m.endswith("_stderr"): - task_metrics[m] = SubMetricResult(value=v, priority=-1, higher_is_better=True) + else: + m = mf + if not m.endswith("_stderr"): + task_metrics[m] = SubMetricResult(value=v, priority=-1, higher_is_better=True) metrics[task_name] = MetricResult.model_validate(task_metrics) diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index 1c66c7a5e..415d5c49f 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -2172,7 +2172,7 @@ def __call__(self, model: onnx.ModelProto): ]: return dag.model - if embed_op_type == "Gather": + if embed_op_type == "Gather" and lm_head_op_type == "MatMul": return self.handle_unquantized(dag, embed_name, lm_head_name) return self.handle_quantized(dag, embed_name, lm_head_name) @@ -2371,6 +2371,262 @@ def equal_weights(self, dag: OnnxDAG, init0: str, init1: str, transpose: bool = return np.array_equal(arr0.ravel(), arr1.ravel()) +def _find_embed_node(model, op_type, label): + """Find the embed_tokens node of the given op_type and its index.""" + for i, node in enumerate(model.graph.node): + if node.op_type == op_type and "embed_tokens" in node.name: + return node, i + logger.warning("No embed_tokens %s node found, skipping %s", op_type, label) + return None, None + + +def _find_lm_head_node(model): + """Find the lm_head MatMulNBits node and its index.""" + for i, node in enumerate(model.graph.node): + if node.op_type == "MatMulNBits" and "lm_head" in node.name: + return node, i + logger.warning("No lm_head MatMulNBits found") + return None, None + + +def _find_initializer(model, name): + """Find an initializer by name.""" + for init in model.graph.initializer: + if init.name == name: + return init + return None + + +def _get_node_attrs(node, *attr_names): + """Extract integer attributes from a node by name.""" + result = {} + for attr in node.attribute: + if attr.name in attr_names: + result[attr.name] = attr.i + return result + + +def _ensure_msft_opset(model): + """Ensure com.microsoft opset import is present in the model.""" + for opset in model.opset_import: + if opset.domain == "com.microsoft": + return + model.opset_import.append(onnx.helper.make_opsetid("com.microsoft", 1)) + + +class QuantizeEmbeddingInt8(ProtoSurgeon): + """Quantize FP16 embedding to INT8 using GatherBlockQuantized. + + Replaces the Gather op for embed_tokens with a GatherBlockQuantized op + that uses per-block INT8 quantization (block_size=32). + """ + + def __call__(self, model: onnx.ModelProto): + from onnx import numpy_helper + + # Find embedding Gather node and weight + gather_node, gather_idx = _find_embed_node(model, "Gather", "QuantizeEmbeddingInt8") + if gather_node is None: + return model + + embed_init = _find_initializer(model, gather_node.input[0]) + + if embed_init is None: + logger.warning("Embedding weight initializer not found, skipping QuantizeEmbeddingInt8") + return model + + # Check if already quantized + if embed_init.data_type not in (onnx.TensorProto.FLOAT16, onnx.TensorProto.FLOAT): + logger.info("Embedding is not FP16/FP32, skipping QuantizeEmbeddingInt8") + return model + + embed = numpy_helper.to_array(embed_init).astype(np.float32) + vocab_size, hidden_size = embed.shape + block_size = 32 + + if hidden_size % block_size != 0: + logger.warning("hidden_size %d not divisible by block_size %d, skipping", hidden_size, block_size) + return model + + num_blocks = hidden_size // block_size + + logger.info( + "Quantizing embedding %s (%dx%d) from FP16 to INT8 (block_size=%d)", + embed_init.name, + vocab_size, + hidden_size, + block_size, + ) + + # Per-block INT8 quantization (asymmetric with zero_point=128 for GatherBlockQuantized) + blocked = embed.reshape(vocab_size, num_blocks, block_size) + scales = (np.abs(blocked).max(axis=2) / 127.0).astype(np.float16) + scales_f32 = scales.astype(np.float32) + # Avoid division by zero + scales_f32 = np.where(scales_f32 == 0, 1.0, scales_f32) + q = np.clip(np.round(blocked / scales_f32[:, :, None]), -128, 127).astype(np.int8) + # GatherBlockQuantized expects unsigned uint8 with zero_point offset + q_uint8 = (q.astype(np.int16) + 128).astype(np.uint8) + q_flat = q_uint8.reshape(vocab_size, hidden_size) + # Zero point tensor: 128 for all blocks (symmetric around 128) + zero_points = np.full((vocab_size, num_blocks), 128, dtype=np.uint8) + + old_size_mb = embed.nbytes / (1024 * 1024) + new_size_mb = (q_flat.nbytes + scales.nbytes + zero_points.nbytes) / (1024 * 1024) + logger.info( + "Embedding: %.0f MB -> %.0f MB (saved %.0f MB)", old_size_mb, new_size_mb, old_size_mb - new_size_mb + ) + + # Create new initializers + qweight_name = embed_init.name + "_Q8" + scales_name = embed_init.name + "_scales" + zp_name = embed_init.name + "_zp" + model.graph.initializer.append(numpy_helper.from_array(q_flat, name=qweight_name)) + model.graph.initializer.append(numpy_helper.from_array(scales, name=scales_name)) + model.graph.initializer.append(numpy_helper.from_array(zero_points, name=zp_name)) + + # Ensure com.microsoft opset is declared + _ensure_msft_opset(model) + + # Create GatherBlockQuantized node + gbq_node = onnx.helper.make_node( + "GatherBlockQuantized", + inputs=[qweight_name, gather_node.input[1], scales_name, zp_name], + outputs=gather_node.output, + name=gather_node.name.replace("Gather", "GatherBlockQuantized"), + domain="com.microsoft", + bits=8, + block_size=block_size, + gather_axis=0, + quantize_axis=1, + ) + + # Replace Gather with GatherBlockQuantized + model.graph.node.remove(gather_node) + model.graph.node.insert(gather_idx, gbq_node) + + # Remove old FP16 embedding weight + model.graph.initializer.remove(embed_init) + + logger.info("Replaced Gather with GatherBlockQuantized (INT8)") + return model + + +class ShareEmbeddingLmHead(ProtoSurgeon): + """Share INT8 embedding weight with lm_head by converting lm_head to INT8 MatMulNBits. + + Must be applied AFTER QuantizeEmbeddingInt8. Replaces the lm_head's INT4 + MatMulNBits with an INT8 MatMulNBits that references the same quantized + weight as the embedding's GatherBlockQuantized, eliminating duplicate storage. + """ + + def __call__(self, model: onnx.ModelProto): + from onnx import numpy_helper + + # Find embedding GatherBlockQuantized + gbq_node, _ = _find_embed_node(model, "GatherBlockQuantized", "ShareEmbeddingLmHead") + if gbq_node is None: + return model + + attrs = _get_node_attrs(gbq_node, "bits", "block_size") + gbq_bits = attrs.get("bits", 8) + gbq_block_size = attrs.get("block_size", 32) + + if gbq_bits != 8: + logger.warning("Embedding is not INT8, cannot share with lm_head") + return model + + # Get embedding weight, scales, zero_points names + embed_weight_name = gbq_node.input[0] + embed_scales_name = gbq_node.input[2] + embed_zp_name = gbq_node.input[3] if len(gbq_node.input) > 3 else None + + # Get embedding weight shape to determine K and N + embed_weight_init = _find_initializer(model, embed_weight_name) + if embed_weight_init is None: + logger.warning("Could not find embedding weight initializer") + return model + + embed_weight = numpy_helper.to_array(embed_weight_init) + + vocab_size, hidden_size = embed_weight.shape # [V, H] for INT8 + num_blocks = hidden_size // gbq_block_size + + # Find lm_head MatMulNBits node + lm_head_node, lm_head_idx = _find_lm_head_node(model) + if lm_head_node is None: + return model + + # Check if already shared (idempotency): lm_head weight input references embedding weight + lm_head_weight_input = lm_head_node.input[1] + if embed_weight_name in lm_head_weight_input or lm_head_node.input[2] == embed_scales_name: + logger.info("lm_head already shares weights with embedding, skipping ShareEmbeddingLmHead") + return model + + # Get old lm_head attributes + old_attrs = _get_node_attrs(lm_head_node, "K", "N", "bits", "block_size") + + logger.info( + "Sharing embedding with lm_head: lm_head INT%d (%dx%d, bs=%d) -> INT8 (shared with embedding)", + old_attrs.get("bits", 0), + old_attrs.get("N", 0), + old_attrs.get("K", 0), + old_attrs.get("block_size", 0), + ) + + # Remove old lm_head weight initializers + old_init_names = set(lm_head_node.input[1:]) # weight, scales, [zp], [g_idx] + to_remove = [init for init in model.graph.initializer if init.name in old_init_names] + for init in to_remove: + model.graph.initializer.remove(init) + + # MatMulNBits needs [N, K_blocks, block_size] but GBQ weight is [V, H]. + # Add a Reshape node to convert, referencing the SAME embedding weight. + reshape_shape_name = "lm_head.MatMulNBits.reshape_shape" + reshape_shape = np.array([vocab_size, num_blocks, gbq_block_size], dtype=np.int64) + model.graph.initializer.append(numpy_helper.from_array(reshape_shape, name=reshape_shape_name)) + + reshape_output = "lm_head.MatMulNBits.reshaped_weight" + reshape_node = onnx.helper.make_node( + "Reshape", + inputs=[embed_weight_name, reshape_shape_name], + outputs=[reshape_output], + name="lm_head/Reshape_shared_weight", + ) + model.graph.node.insert(lm_head_idx, reshape_node) + + # Scales and zp: reuse embedding's directly + inputs = [lm_head_node.input[0], reshape_output, embed_scales_name] + if embed_zp_name: + inputs.append(embed_zp_name) + + # Ensure com.microsoft opset is declared + _ensure_msft_opset(model) + + # Create new INT8 MatMulNBits node + new_lm_head = onnx.helper.make_node( + "MatMulNBits", + inputs=inputs, + outputs=lm_head_node.output, + name=lm_head_node.name, + domain="com.microsoft", + bits=8, + block_size=gbq_block_size, + K=hidden_size, + N=vocab_size, + ) + # Copy accuracy_level if present + for attr in lm_head_node.attribute: + if attr.name == "accuracy_level": + new_lm_head.attribute.append(attr) + + model.graph.node.remove(lm_head_node) + model.graph.node.insert(lm_head_idx + 1, new_lm_head) + + logger.info("lm_head now uses INT8 MatMulNBits (shared quantization with embedding)") + return model + + class ReciprocalMulToDiv(ProtoSurgeon): """Replace Reciprocal(x) * a with Div(a, x). diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index e1062e02b..1d49319de 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -575,11 +575,6 @@ def patched_make_embedding(self, embedding): import onnx_ir as ir basename = "/model/embed_tokens" - if getattr(self, "int4_tied_embeddings", False) or getattr(self, "shared_embeddings", False): - logger.debug( - "int4_tied_embedding/shared_embeddings is set to True but will be ignored. Use TieWordEmbeddings graph surgery to tie" - " embeddings." - ) if hasattr(embedding, "qweight"): qweight = "model.embed_tokens.qweight" diff --git a/test/passes/onnx/test_quantize_embedding.py b/test/passes/onnx/test_quantize_embedding.py new file mode 100644 index 000000000..db3f2e83f --- /dev/null +++ b/test/passes/onnx/test_quantize_embedding.py @@ -0,0 +1,195 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import numpy as np +from onnx import TensorProto, helper, numpy_helper + +from olive.passes.onnx.graph_surgeries import QuantizeEmbeddingInt8, ShareEmbeddingLmHead + + +def _make_model_with_fp16_embed(vocab_size=64, hidden_size=64, block_size=32): + """Create a minimal ONNX model with FP16 Gather embedding and INT4 MatMulNBits lm_head.""" + # Embedding: Gather with FP16 weight + embed_weight = np.random.randn(vocab_size, hidden_size).astype(np.float16) + embed_init = numpy_helper.from_array(embed_weight, name="model.embed_tokens.weight") + + input_ids = helper.make_tensor_value_info("input_ids", TensorProto.INT64, ["batch_size", "seq_len"]) + + gather_node = helper.make_node( + "Gather", + inputs=["model.embed_tokens.weight", "input_ids"], + outputs=["embed_output"], + name="/model/embed_tokens/Gather", + ) + + # lm_head: MatMulNBits with INT4 weight + num_blocks = hidden_size // block_size + lm_weight = np.random.randint(0, 255, (vocab_size, num_blocks, block_size // 2), dtype=np.uint8) + lm_scales = np.random.randn(vocab_size, num_blocks).astype(np.float16) * 0.01 + lm_zp = np.full((vocab_size, num_blocks), 8, dtype=np.uint8) + + lm_weight_init = numpy_helper.from_array(lm_weight, name="lm_head.MatMul_Q4.qweight") + lm_scales_init = numpy_helper.from_array(lm_scales, name="lm_head.MatMul_Q4.scales") + lm_zp_init = numpy_helper.from_array(lm_zp, name="lm_head.MatMul_Q4.zp") + + lm_head_node = helper.make_node( + "MatMulNBits", + inputs=["embed_output", "lm_head.MatMul_Q4.qweight", "lm_head.MatMul_Q4.scales", "lm_head.MatMul_Q4.zp"], + outputs=["logits"], + name="/lm_head/MatMulNBits", + domain="com.microsoft", + bits=4, + block_size=block_size, + K=hidden_size, + N=vocab_size, + ) + + graph = helper.make_graph( + [gather_node, lm_head_node], + "test", + [input_ids], + [helper.make_tensor_value_info("logits", TensorProto.FLOAT16, ["batch_size", "seq_len", vocab_size])], + initializer=[embed_init, lm_weight_init, lm_scales_init, lm_zp_init], + ) + return helper.make_model( + graph, + opset_imports=[helper.make_opsetid("", 21), helper.make_opsetid("com.microsoft", 1)], + ) + + +class TestQuantizeEmbeddingInt8: + def test_replaces_gather_with_gbq(self): + model = _make_model_with_fp16_embed() + surgery = QuantizeEmbeddingInt8() + result = surgery(model) + + # Verify Gather is replaced with GatherBlockQuantized + node_types = [n.op_type for n in result.graph.node] + assert "Gather" not in node_types or all( + "embed_tokens" not in n.name for n in result.graph.node if n.op_type == "Gather" + ) + gbq_nodes = [n for n in result.graph.node if n.op_type == "GatherBlockQuantized"] + assert len(gbq_nodes) == 1 + + gbq = gbq_nodes[0] + attrs = {a.name: a.i for a in gbq.attribute} + assert attrs["bits"] == 8 + assert attrs["block_size"] == 32 + + # Verify zero_point input exists (4 inputs: weight, input_ids, scales, zp) + assert len(gbq.input) == 4 + + def test_reduces_weight_size(self): + model = _make_model_with_fp16_embed(vocab_size=256, hidden_size=128) + surgery = QuantizeEmbeddingInt8() + + result = surgery(model) + + # FP16 weight should be removed + fp16_names = [init.name for init in result.graph.initializer if init.name == "model.embed_tokens.weight"] + assert len(fp16_names) == 0 + + # INT8 weight should exist + int8_names = [init.name for init in result.graph.initializer if "_Q8" in init.name] + assert len(int8_names) == 1 + + def test_skips_non_fp16(self): + model = _make_model_with_fp16_embed() + surgery = QuantizeEmbeddingInt8() + + # First pass: quantize to INT8 + result1 = surgery(model) + # Second pass: should skip (already quantized) + result2 = surgery(result1) + + # Should still have exactly 1 GBQ node + gbq_count = sum(1 for n in result2.graph.node if n.op_type == "GatherBlockQuantized") + assert gbq_count == 1 + + def test_skips_when_hidden_not_divisible(self): + # hidden_size=33, not divisible by block_size=32 + embed_weight = np.random.randn(64, 33).astype(np.float16) + embed_init = numpy_helper.from_array(embed_weight, name="model.embed_tokens.weight") + input_ids = helper.make_tensor_value_info("input_ids", TensorProto.INT64, [1]) + output = helper.make_tensor_value_info("out", TensorProto.FLOAT16, [1, 33]) + gather = helper.make_node( + "Gather", ["model.embed_tokens.weight", "input_ids"], ["out"], name="/model/embed_tokens/Gather" + ) + graph = helper.make_graph([gather], "test", [input_ids], [output], initializer=[embed_init]) + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 21), helper.make_opsetid("com.microsoft", 1)] + ) + + surgery = QuantizeEmbeddingInt8() + result = surgery(model) + + # Should still have Gather (not replaced) + assert any(n.op_type == "Gather" for n in result.graph.node) + + +class TestShareEmbeddingLmHead: + def test_shares_weight(self): + model = _make_model_with_fp16_embed() + + # First quantize embedding to INT8 + q_surgery = QuantizeEmbeddingInt8() + model = q_surgery(model) + + # Then share + s_surgery = ShareEmbeddingLmHead() + result = s_surgery(model) + + # lm_head should now be INT8 + lm_head = next(n for n in result.graph.node if n.op_type == "MatMulNBits" and "lm_head" in n.name) + attrs = {a.name: a.i for a in lm_head.attribute} + assert attrs["bits"] == 8 + + # Should have a Reshape node for weight sharing + reshape_nodes = [n for n in result.graph.node if "Reshape_shared" in n.name] + assert len(reshape_nodes) == 1 + + # Reshape should reference the embedding weight + reshape = reshape_nodes[0] + assert "embed_tokens" in reshape.input[0] + + # lm_head should use shared scales + assert "embed_tokens" in lm_head.input[2] # scales + + def test_idempotent(self): + model = _make_model_with_fp16_embed() + q_surgery = QuantizeEmbeddingInt8() + model = q_surgery(model) + + s_surgery = ShareEmbeddingLmHead() + result1 = s_surgery(model) + # Applying again should be a no-op + result2 = s_surgery(result1) + + # Should still have exactly 1 Reshape_shared node + reshape_count = sum(1 for n in result2.graph.node if "Reshape_shared" in n.name) + assert reshape_count == 1 + + def test_skips_without_gbq(self): + model = _make_model_with_fp16_embed() + # Don't quantize embedding first + s_surgery = ShareEmbeddingLmHead() + result = s_surgery(model) + + # Should be unchanged — still has Gather + assert any(n.op_type == "Gather" for n in result.graph.node) + + def test_removes_old_lm_head_weights(self): + model = _make_model_with_fp16_embed() + q_surgery = QuantizeEmbeddingInt8() + model = q_surgery(model) + + s_surgery = ShareEmbeddingLmHead() + result = s_surgery(model) + + new_init_names = {init.name for init in result.graph.initializer} + + # Old lm_head weights should be removed + assert "lm_head.MatMul_Q4.qweight" not in new_init_names + assert "lm_head.MatMul_Q4.scales" not in new_init_names + assert "lm_head.MatMul_Q4.zp" not in new_init_names