Skip to content

Commit 80221bb

Browse files
committed
fix: per-op k_strip tracking, O(1) phases, inactive MatMul filtering
Rust + Python: - Per-op k_strip: track actual tensor-dimension-based k_strip per MatMul instead of compute-ratio proxy (correct for mixed input dimensions) - O(1) per phase: replaced step loop with first/interior/last arithmetic matching documented O(distinct_K) complexity - Simulation path: skip k_strip loads for inactive MatMuls (finished their reduction) using tensor-to-active-steps mapping Refs #22
1 parent f0d86b3 commit 80221bb

2 files changed

Lines changed: 189 additions & 102 deletions

File tree

solution/agent/evaluator.py

Lines changed: 99 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -528,15 +528,12 @@ def compute_subgraph_latency(
528528
# matmul_compute_per_step is the total compute when ALL MatMuls are active (step 0 onward).
529529
matmul_compute_per_step = 0.0
530530
pointwise_compute = 0.0
531-
# Also collect per-op info for mixed-K phase calculation.
532-
matmul_phase_info: list[tuple[int, float]] = [] # (k_full, base_cost)
533531
for op_idx in subgraph_ops:
534532
op = problem.ops[op_idx]
535533
if op.op_type == "MatMul":
536534
k_full_op = _k_full_for_op(op, problem)
537535
cost_per_step = op.base_cost * (k / k_full_op)
538536
matmul_compute_per_step += cost_per_step
539-
matmul_phase_info.append((k_full_op, op.base_cost))
540537
else:
541538
pointwise_compute += op.base_cost
542539

@@ -580,6 +577,45 @@ def compute_subgraph_latency(
580577
# Total k_strip load per step (k_strip LHS + both RHS types)
581578
k_strip_total_per_step = k_strip_lhs_per_step + rhs_std_per_step + rhs_eph_per_step
582579

580+
# Per-MatMul k_strip contribution and per-tensor active-step-count for mixed-K.
581+
# Each MatMul op contributes k_strip from its boundary inputs that are not retained:
582+
# - non-ephemeral LHS (in k_strip_lhs): h * k / bw
583+
# - non-ephemeral RHS (in rhs_standard): k * w / bw
584+
# - ephemeral RHS (in rhs_ephemeral): rhs.height * k / bw
585+
# Deduplication mirrors _categorize_inputs (a tensor counted once for its first op).
586+
#
587+
# matmul_phase_info: list of (k_full, base_cost, k_strip_contribution_per_step)
588+
# k_strip_tensor_active_steps: tensor_id -> step count of its owning MatMul.
589+
# Used in the simulation path to load k_strip inputs only while their op is active.
590+
matmul_phase_info: list[tuple[int, float, float]] = []
591+
k_strip_tensor_active_steps: dict[int, int] = {}
592+
_seen_k_strip: set[int] = set()
593+
for op_idx in subgraph_ops:
594+
op = problem.ops[op_idx]
595+
if op.op_type != "MatMul":
596+
continue
597+
k_full_op = _k_full_for_op(op, problem)
598+
op_steps = math.ceil(k_full_op / k)
599+
lhs_idx = op.inputs[0]
600+
rhs_idx = op.inputs[1]
601+
op_k_strip = 0.0
602+
if lhs_idx in k_strip_lhs and lhs_idx not in _seen_k_strip:
603+
_seen_k_strip.add(lhs_idx)
604+
if lhs_idx not in retained_tensors:
605+
op_k_strip += (h * k) / bw
606+
k_strip_tensor_active_steps[lhs_idx] = op_steps
607+
if rhs_idx in rhs_standard and rhs_idx not in _seen_k_strip:
608+
_seen_k_strip.add(rhs_idx)
609+
if rhs_idx not in retained_tensors:
610+
op_k_strip += (k * w) / bw
611+
k_strip_tensor_active_steps[rhs_idx] = op_steps
612+
if rhs_idx in rhs_ephemeral and rhs_idx not in _seen_k_strip:
613+
_seen_k_strip.add(rhs_idx)
614+
if rhs_idx not in retained_tensors:
615+
op_k_strip += (problem.tensors[rhs_idx].height * k) / bw
616+
k_strip_tensor_active_steps[rhs_idx] = op_steps
617+
matmul_phase_info.append((k_full_op, op.base_cost, op_k_strip))
618+
583619
# Pointwise inputs: w * h per spatial tile (first k-step).
584620
pw_load_per_tile = sum(
585621
(w * h) / bw
@@ -602,7 +638,7 @@ def compute_subgraph_latency(
602638
# Split-K mode: all spatial tiles are identical (no row-reuse).
603639
#
604640
# Check if all MatMuls have the same K_full (fast path) or mixed-K (phase path).
605-
unique_k_fulls = set(kf for kf, _ in matmul_phase_info)
641+
unique_k_fulls = set(kf for kf, _, _ in matmul_phase_info)
606642
all_same_k_full = len(unique_k_fulls) <= 1
607643

608644
if all_same_k_full:
@@ -627,44 +663,71 @@ def compute_subgraph_latency(
627663
# Mixed-K path: compute phase-by-phase.
628664
# Each MatMul is active for ceil(its_K_full / k) steps.
629665
# Phases are defined by sorted unique step-end boundaries.
630-
step_ends = sorted(set(math.ceil(kf / k) for kf, _ in matmul_phase_info))
666+
step_ends = sorted(set(math.ceil(kf / k) for kf, _, _ in matmul_phase_info))
631667
# step_ends[-1] == num_k_steps
632668

633669
per_tile_lat = 0.0
634670
prev_end = 0
671+
total_k_steps = num_k_steps
635672

636673
for phase_idx, phase_end in enumerate(step_ends):
637-
# Active MatMuls: those whose step count >= phase_end
674+
# Active MatMuls: those whose step count >= phase_end.
638675
active_compute = sum(
639676
bc * (k / kf)
640-
for kf, bc in matmul_phase_info
677+
for kf, bc, _ in matmul_phase_info
641678
if math.ceil(kf / k) >= phase_end
642679
)
643680

644-
# Active k_strip: scale total by ratio of active compute to total compute.
645-
if matmul_compute_per_step > 0.0:
646-
active_k_strip = k_strip_total_per_step * (active_compute / matmul_compute_per_step)
647-
else:
648-
active_k_strip = k_strip_total_per_step
681+
# Active k_strip: sum per-op contributions for active MatMuls only.
682+
# This is exact (no proxy ratio) because each op's contribution is
683+
# precomputed from its actual tensor dimensions.
684+
active_k_strip = sum(
685+
ks
686+
for kf, _, ks in matmul_phase_info
687+
if math.ceil(kf / k) >= phase_end
688+
)
649689

650690
phase_steps = phase_end - prev_end
651691
is_last_phase = (phase_idx == len(step_ends) - 1)
652692

653-
for step_offset in range(phase_steps):
654-
global_step = prev_end + step_offset
655-
is_first_step = (global_step == 0)
656-
is_last_step = is_last_phase and (step_offset == phase_steps - 1)
657-
658-
mem = (
659-
(full_load_lhs_time + pw_load_per_tile + active_k_strip)
660-
if is_first_step
661-
else active_k_strip
662-
)
663-
if is_last_step:
664-
mem += out_evict_per_tile
665-
666-
compute_this_step = active_compute + (pointwise_compute if is_last_step else 0.0)
667-
per_tile_lat += max(compute_this_step, mem)
693+
# O(1) per phase: classify steps as first, interior, or last.
694+
# Special steps: global step 0 (loads full_load + pw_load) and
695+
# global last step (evicts output, adds PW compute).
696+
has_first = (prev_end == 0)
697+
has_last = is_last_phase # last phase always contains the last step
698+
699+
# Interior steps: all steps in this phase that are neither first nor last.
700+
interior_count = phase_steps - (1 if has_first else 0) - (1 if has_last else 0)
701+
# interior_count can be negative when a single step is both first and last.
702+
interior_count = max(0, interior_count)
703+
704+
if has_first:
705+
mem = full_load_lhs_time + pw_load_per_tile + active_k_strip
706+
# The first step is also the last only when total_k_steps == 1,
707+
# but that case is handled by the all_same_k_full branch above.
708+
per_tile_lat += max(active_compute, mem)
709+
710+
if interior_count > 0:
711+
interior_lat = max(active_compute, active_k_strip)
712+
per_tile_lat += interior_count * interior_lat
713+
714+
if has_last:
715+
# If phase has only one step and it's also the first step,
716+
# we already added the first step above; skip duplicate.
717+
is_also_first = has_first and (phase_steps == 1)
718+
if not is_also_first:
719+
mem_last = active_k_strip + out_evict_per_tile
720+
compute_last = active_compute + pointwise_compute
721+
per_tile_lat += max(compute_last, mem_last)
722+
else:
723+
# Single-step phase that is both first and last: adjust the
724+
# first-step cost to include eviction and PW compute.
725+
mem_last = full_load_lhs_time + pw_load_per_tile + active_k_strip + out_evict_per_tile
726+
compute_last = active_compute + pointwise_compute
727+
# Undo the first-step contribution already added and replace it.
728+
first_mem = full_load_lhs_time + pw_load_per_tile + active_k_strip
729+
per_tile_lat -= max(active_compute, first_mem)
730+
per_tile_lat += max(compute_last, mem_last)
668731

669732
prev_end = phase_end
670733

@@ -726,13 +789,16 @@ def compute_subgraph_latency(
726789
resident_full_lhs[t_idx] = tile_row
727790

728791
# ------- k_strip LHS (non-ephemeral-output MatMul) -------
729-
# In split-K: loaded every k-step (no reuse).
792+
# In split-K: loaded every k-step (no reuse), but only while the
793+
# owning MatMul is still active (k_step < ceil(K_full_op / k)).
730794
# In spatial-only: row-reusable (same as full_load).
731795
for t_idx in k_strip_lhs:
732796
if t_idx in retained_tensors:
733797
continue
734798
if is_split_k:
735-
mem_in += (h * k) / bw
799+
# Only load if the owning MatMul is still active this step.
800+
if k_step < k_strip_tensor_active_steps.get(t_idx, num_k_steps):
801+
mem_in += (h * k) / bw
736802
elif is_first_k and resident_k_strip_lhs[t_idx] != tile_row:
737803
mem_in += (h * k) / bw
738804
resident_k_strip_lhs[t_idx] = tile_row
@@ -742,6 +808,9 @@ def compute_subgraph_latency(
742808
for t_idx in all_rhs:
743809
if t_idx in retained_tensors:
744810
continue
811+
# Only load if the owning MatMul is still active this step.
812+
if k_step >= k_strip_tensor_active_steps.get(t_idx, num_k_steps):
813+
continue
745814
if t_idx in rhs_ephemeral:
746815
mem_in += (problem.tensors[t_idx].height * k) / bw
747816
else:
@@ -774,7 +843,7 @@ def compute_subgraph_latency(
774843
# For mixed-K: only MatMuls that haven't finished yet contribute compute.
775844
active_matmul_compute = sum(
776845
bc * (k / kf)
777-
for kf, bc in matmul_phase_info
846+
for kf, bc, _ in matmul_phase_info
778847
if k_step < math.ceil(kf / k)
779848
)
780849
compute_this_step = active_matmul_compute + (pointwise_compute if is_last_k else 0.0)

0 commit comments

Comments
 (0)