Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions olive/common/onnx_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,22 @@ def get_kv_info(io_config: dict) -> dict | None:
if kv_format is None:
return None

# find the number of layers
num_layers = 0
# find the actual layer indices (may be non-contiguous after pruning)
layer_indices = []
for i_name in io_config["input_names"]:
num_layers += int(re.match(kv_format, i_name) is not None)
m = re.match(kv_format, i_name)
if m:
idx = int(m.group(1))
if idx not in layer_indices:
layer_indices.append(idx)
layer_indices.sort()

past_names = []
present_to_past = {}
for k in ["key", "value"]:
past_names.extend([kv_options[kv_format][f"past_{k}"] % i for i in range(num_layers)])
past_names.extend([kv_options[kv_format][f"past_{k}"] % i for i in layer_indices])
present_to_past.update(
{
kv_options[kv_format][f"present_{k}"] % i: kv_options[kv_format][f"past_{k}"] % i
for i in range(num_layers)
}
{kv_options[kv_format][f"present_{k}"] % i: kv_options[kv_format][f"past_{k}"] % i for i in layer_indices}
)

past_shape = io_config["input_shapes"][io_config["input_names"].index(past_names[0])]
Expand Down
72 changes: 65 additions & 7 deletions olive/evaluator/lmeval_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,28 @@ def __init__(self, model_path: str, ep: str | None = None, ep_options: dict | No
if self.kv_info is None:
raise ValueError("Invalid io_config: kv_info not found")

# detect position_ids rank (e.g. 3 for mRoPE models like Qwen3.5)
self.position_ids_rank = 2
if "position_ids" in self.io_config["input_names"]:
idx = self.io_config["input_names"].index("position_ids")
self.position_ids_rank = len(self.io_config["input_shapes"][idx])

# detect hybrid state inputs (conv_state, recurrent_state for linear attention layers)
self.hybrid_states = {}
for idx, name in enumerate(self.io_config["input_names"]):
if "conv_state" in name or "recurrent_state" in name:
shape = self.io_config["input_shapes"][idx]
dtype = self.io_config["input_types"][idx]
self.hybrid_states[name] = {"shape": shape, "dtype": dtype}

# detect hybrid state outputs
self.hybrid_state_outputs = {}
for idx, name in enumerate(self.io_config["output_names"]):
if "conv_state" in name or "recurrent_state" in name:
shape = self.io_config["output_shapes"][idx]
dtype = self.io_config["output_types"][idx]
self.hybrid_state_outputs[name] = {"shape": shape, "dtype": dtype}

self._session = None
self._batch_size = None
self._buffers = None
Expand Down Expand Up @@ -331,17 +353,29 @@ def run(self, input_ids: torch.Tensor) -> torch.Tensor:
inputs_to_bind[name] = (self._buffers["inputs"][name], self.io_dtypes[name], shape)
if "position_ids" in self._buffers["inputs"]:
# need to reallocate since the position_ids tensor may be sliced
inputs_to_bind["position_ids"] = (
self._buffers["inputs"]["position_ids"][:batch_size, :seqlen].contiguous(),
self.io_dtypes["position_ids"],
(batch_size, seqlen),
)
if self.position_ids_rank == 3:
inputs_to_bind["position_ids"] = (
self._buffers["inputs"]["position_ids"][:, :batch_size, :seqlen].contiguous(),
self.io_dtypes["position_ids"],
(self._buffers["inputs"]["position_ids"].shape[0], batch_size, seqlen),
)
else:
inputs_to_bind["position_ids"] = (
self._buffers["inputs"]["position_ids"][:batch_size, :seqlen].contiguous(),
self.io_dtypes["position_ids"],
(batch_size, seqlen),
)
for name in self._buffers["kv_inputs"]:
inputs_to_bind[name] = (
self._buffers["kv_inputs"][name],
self.kv_info["dtype"],
(batch_size, self.kv_info["num_kv_heads"], 0, self.kv_info["head_size"]),
)
# hybrid state inputs (conv_state, recurrent_state)
for name, buf in self._buffers["hybrid_inputs"].items():
shape = list(buf.shape)
shape[0] = batch_size
inputs_to_bind[name] = (buf, self.hybrid_states[name]["dtype"], tuple(shape))
for name, (buffer, dtype, shape) in inputs_to_bind.items():
io_binding.bind_input(
name,
Expand All @@ -363,6 +397,11 @@ def run(self, input_ids: torch.Tensor) -> torch.Tensor:
self.kv_info["dtype"],
(batch_size, self.kv_info["num_kv_heads"], seqlen, self.kv_info["head_size"]),
)
# hybrid state outputs (conv_state, recurrent_state)
for name, buf in self._buffers["hybrid_outputs"].items():
shape = list(buf.shape)
shape[0] = batch_size
outputs_to_bind[name] = (buf, self.hybrid_state_outputs[name]["dtype"], tuple(shape))
for name, (buffer, dtype, shape) in outputs_to_bind.items():
io_binding.bind_output(
name,
Expand Down Expand Up @@ -418,11 +457,16 @@ def initialize_buffers(self, batch_size: int, max_length: int):
)
}
if self.io_dtypes.get("position_ids") is not None:
inputs["position_ids"] = (
pos_ids = (
torch.arange(max_length, dtype=getattr(torch, self.io_dtypes["position_ids"]), device=self.device)
.unsqueeze(0)
.expand(batch_size, -1)
)
if self.position_ids_rank == 3:
# mRoPE: expand to [mrope_sections, batch_size, seq_len]
mrope_sections = self.io_config["input_shapes"][self.io_config["input_names"].index("position_ids")][0]
pos_ids = pos_ids.unsqueeze(0).expand(mrope_sections, -1, -1)
inputs["position_ids"] = pos_ids
if self.io_dtypes.get("past_seq_len") is not None:
inputs["past_seq_len"] = (
torch.tensor(max_length - 1, dtype=getattr(torch, self.io_dtypes["past_seq_len"]), device=self.device)
Expand Down Expand Up @@ -457,6 +501,20 @@ def initialize_buffers(self, batch_size: int, max_length: int):
}

self._buffers = {"inputs": inputs, "outputs": outputs, "kv_inputs": kv_inputs, "kv_outputs": kv_outputs}

# hybrid state buffers (conv_state, recurrent_state) - zero-initialized
hybrid_inputs = {}
for name, info in self.hybrid_states.items():
# Replace symbolic 'batch_size' with actual batch_size
shape = [batch_size if s == "batch_size" else s for s in info["shape"]]
hybrid_inputs[name] = torch.zeros(shape, dtype=getattr(torch, info["dtype"]), device=self.device)
hybrid_outputs = {}
for name, info in self.hybrid_state_outputs.items():
shape = [batch_size if s == "batch_size" else s for s in info["shape"]]
hybrid_outputs[name] = torch.zeros(shape, dtype=getattr(torch, info["dtype"]), device=self.device)
self._buffers["hybrid_inputs"] = hybrid_inputs
self._buffers["hybrid_outputs"] = hybrid_outputs

self._batch_size = batch_size


Expand Down Expand Up @@ -539,7 +597,7 @@ def _detect_full_logits(self) -> bool:
def eot_token_id(self):
return self._eot_token_id

def tok_encode(self, string: str, **kwargs) -> list[int]:
def tok_encode(self, string: str, add_special_tokens: bool | None = None, **kwargs) -> list[int]:
"""Tokenize a string using the model's tokenizer and return a list of token IDs."""
return self.tokenizer.encode(string).tolist()

Expand Down
12 changes: 9 additions & 3 deletions olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,10 +1612,16 @@ def evaluate(

task_metrics = {}
for mf, v in metric_items:
if mf != "alias":
if mf == "alias":
continue
if not isinstance(v, (int, float)):
continue
if "," in mf:
m, _ = mf.split(",", 1)
if not m.endswith("_stderr"):
task_metrics[m] = SubMetricResult(value=v, priority=-1, higher_is_better=True)
else:
m = mf
if not m.endswith("_stderr"):
task_metrics[m] = SubMetricResult(value=v, priority=-1, higher_is_better=True)

metrics[task_name] = MetricResult.model_validate(task_metrics)

Expand Down
Loading
Loading