Skip to content

Commit 6fba91e

Browse files
committed
fix: clamp k to K_full_op in compute, remove dead total_k_steps
- Rust + Python: compute uses min(k, K_full_op) / K_full_op to prevent base_cost scaling above 1.0 when k > K_full_op in mixed-K subgraphs - Applies to matmul_compute_per_step, mixed-K active_compute, and simulation path active_matmul_compute - Removed dead total_k_steps variable from Python mixed-K path Refs #22
1 parent 80221bb commit 6fba91e

2 files changed

Lines changed: 6 additions & 7 deletions

File tree

solution/agent/evaluator.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def compute_subgraph_latency(
532532
op = problem.ops[op_idx]
533533
if op.op_type == "MatMul":
534534
k_full_op = _k_full_for_op(op, problem)
535-
cost_per_step = op.base_cost * (k / k_full_op)
535+
cost_per_step = op.base_cost * (min(k, k_full_op) / k_full_op)
536536
matmul_compute_per_step += cost_per_step
537537
else:
538538
pointwise_compute += op.base_cost
@@ -668,12 +668,11 @@ def compute_subgraph_latency(
668668

669669
per_tile_lat = 0.0
670670
prev_end = 0
671-
total_k_steps = num_k_steps
672671

673672
for phase_idx, phase_end in enumerate(step_ends):
674673
# Active MatMuls: those whose step count >= phase_end.
675674
active_compute = sum(
676-
bc * (k / kf)
675+
bc * (min(k, kf) / kf)
677676
for kf, bc, _ in matmul_phase_info
678677
if math.ceil(kf / k) >= phase_end
679678
)
@@ -703,7 +702,7 @@ def compute_subgraph_latency(
703702

704703
if has_first:
705704
mem = full_load_lhs_time + pw_load_per_tile + active_k_strip
706-
# The first step is also the last only when total_k_steps == 1,
705+
# The first step is also the last only when num_k_steps == 1,
707706
# but that case is handled by the all_same_k_full branch above.
708707
per_tile_lat += max(active_compute, mem)
709708

@@ -842,7 +841,7 @@ def compute_subgraph_latency(
842841

843842
# For mixed-K: only MatMuls that haven't finished yet contribute compute.
844843
active_matmul_compute = sum(
845-
bc * (k / kf)
844+
bc * (min(k, kf) / kf)
846845
for kf, bc, _ in matmul_phase_info
847846
if k_step < math.ceil(kf / k)
848847
)

solution/backend/rust/src/latency.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pub fn matmul_compute_per_step(
2929
let op = &problem.ops[op_idx];
3030
if op.is_matmul() {
3131
let op_k_full = k_full_for_matmul(op, &problem.tensors) as f64;
32-
total += op.base_cost as f64 * (k / op_k_full);
32+
total += op.base_cost as f64 * (k.min(op_k_full) / op_k_full);
3333
}
3434
}
3535
total
@@ -429,7 +429,7 @@ pub fn subgraph_latency(
429429
let active_compute: f64 = matmul_phases
430430
.iter()
431431
.filter(|(kf, _, _)| (*kf + k - 1) / k >= phase_end)
432-
.map(|(kf, cost, _)| cost * (k as f64 / *kf as f64))
432+
.map(|(kf, cost, _)| cost * ((k as f64).min(*kf as f64) / *kf as f64))
433433
.sum();
434434

435435
// Active k_strip: sum per-op contributions for active MatMuls only.

0 commit comments

Comments
 (0)