@@ -528,15 +528,12 @@ def compute_subgraph_latency(
528528 # matmul_compute_per_step is the total compute when ALL MatMuls are active (step 0 onward).
529529 matmul_compute_per_step = 0.0
530530 pointwise_compute = 0.0
531- # Also collect per-op info for mixed-K phase calculation.
532- matmul_phase_info : list [tuple [int , float ]] = [] # (k_full, base_cost)
533531 for op_idx in subgraph_ops :
534532 op = problem .ops [op_idx ]
535533 if op .op_type == "MatMul" :
536534 k_full_op = _k_full_for_op (op , problem )
537535 cost_per_step = op .base_cost * (k / k_full_op )
538536 matmul_compute_per_step += cost_per_step
539- matmul_phase_info .append ((k_full_op , op .base_cost ))
540537 else :
541538 pointwise_compute += op .base_cost
542539
@@ -580,6 +577,45 @@ def compute_subgraph_latency(
580577 # Total k_strip load per step (k_strip LHS + both RHS types)
581578 k_strip_total_per_step = k_strip_lhs_per_step + rhs_std_per_step + rhs_eph_per_step
582579
580+ # Per-MatMul k_strip contribution and per-tensor active-step-count for mixed-K.
581+ # Each MatMul op contributes k_strip from its boundary inputs that are not retained:
582+ # - non-ephemeral LHS (in k_strip_lhs): h * k / bw
583+ # - non-ephemeral RHS (in rhs_standard): k * w / bw
584+ # - ephemeral RHS (in rhs_ephemeral): rhs.height * k / bw
585+ # Deduplication mirrors _categorize_inputs (a tensor counted once for its first op).
586+ #
587+ # matmul_phase_info: list of (k_full, base_cost, k_strip_contribution_per_step)
588+ # k_strip_tensor_active_steps: tensor_id -> step count of its owning MatMul.
589+ # Used in the simulation path to load k_strip inputs only while their op is active.
590+ matmul_phase_info : list [tuple [int , float , float ]] = []
591+ k_strip_tensor_active_steps : dict [int , int ] = {}
592+ _seen_k_strip : set [int ] = set ()
593+ for op_idx in subgraph_ops :
594+ op = problem .ops [op_idx ]
595+ if op .op_type != "MatMul" :
596+ continue
597+ k_full_op = _k_full_for_op (op , problem )
598+ op_steps = math .ceil (k_full_op / k )
599+ lhs_idx = op .inputs [0 ]
600+ rhs_idx = op .inputs [1 ]
601+ op_k_strip = 0.0
602+ if lhs_idx in k_strip_lhs and lhs_idx not in _seen_k_strip :
603+ _seen_k_strip .add (lhs_idx )
604+ if lhs_idx not in retained_tensors :
605+ op_k_strip += (h * k ) / bw
606+ k_strip_tensor_active_steps [lhs_idx ] = op_steps
607+ if rhs_idx in rhs_standard and rhs_idx not in _seen_k_strip :
608+ _seen_k_strip .add (rhs_idx )
609+ if rhs_idx not in retained_tensors :
610+ op_k_strip += (k * w ) / bw
611+ k_strip_tensor_active_steps [rhs_idx ] = op_steps
612+ if rhs_idx in rhs_ephemeral and rhs_idx not in _seen_k_strip :
613+ _seen_k_strip .add (rhs_idx )
614+ if rhs_idx not in retained_tensors :
615+ op_k_strip += (problem .tensors [rhs_idx ].height * k ) / bw
616+ k_strip_tensor_active_steps [rhs_idx ] = op_steps
617+ matmul_phase_info .append ((k_full_op , op .base_cost , op_k_strip ))
618+
583619 # Pointwise inputs: w * h per spatial tile (first k-step).
584620 pw_load_per_tile = sum (
585621 (w * h ) / bw
@@ -602,7 +638,7 @@ def compute_subgraph_latency(
602638 # Split-K mode: all spatial tiles are identical (no row-reuse).
603639 #
604640 # Check if all MatMuls have the same K_full (fast path) or mixed-K (phase path).
605- unique_k_fulls = set (kf for kf , _ in matmul_phase_info )
641+ unique_k_fulls = set (kf for kf , _ , _ in matmul_phase_info )
606642 all_same_k_full = len (unique_k_fulls ) <= 1
607643
608644 if all_same_k_full :
@@ -627,44 +663,71 @@ def compute_subgraph_latency(
627663 # Mixed-K path: compute phase-by-phase.
628664 # Each MatMul is active for ceil(its_K_full / k) steps.
629665 # Phases are defined by sorted unique step-end boundaries.
630- step_ends = sorted (set (math .ceil (kf / k ) for kf , _ in matmul_phase_info ))
666+ step_ends = sorted (set (math .ceil (kf / k ) for kf , _ , _ in matmul_phase_info ))
631667 # step_ends[-1] == num_k_steps
632668
633669 per_tile_lat = 0.0
634670 prev_end = 0
671+ total_k_steps = num_k_steps
635672
636673 for phase_idx , phase_end in enumerate (step_ends ):
637- # Active MatMuls: those whose step count >= phase_end
674+ # Active MatMuls: those whose step count >= phase_end.
638675 active_compute = sum (
639676 bc * (k / kf )
640- for kf , bc in matmul_phase_info
677+ for kf , bc , _ in matmul_phase_info
641678 if math .ceil (kf / k ) >= phase_end
642679 )
643680
644- # Active k_strip: scale total by ratio of active compute to total compute.
645- if matmul_compute_per_step > 0.0 :
646- active_k_strip = k_strip_total_per_step * (active_compute / matmul_compute_per_step )
647- else :
648- active_k_strip = k_strip_total_per_step
681+ # Active k_strip: sum per-op contributions for active MatMuls only.
682+ # This is exact (no proxy ratio) because each op's contribution is
683+ # precomputed from its actual tensor dimensions.
684+ active_k_strip = sum (
685+ ks
686+ for kf , _ , ks in matmul_phase_info
687+ if math .ceil (kf / k ) >= phase_end
688+ )
649689
650690 phase_steps = phase_end - prev_end
651691 is_last_phase = (phase_idx == len (step_ends ) - 1 )
652692
653- for step_offset in range (phase_steps ):
654- global_step = prev_end + step_offset
655- is_first_step = (global_step == 0 )
656- is_last_step = is_last_phase and (step_offset == phase_steps - 1 )
657-
658- mem = (
659- (full_load_lhs_time + pw_load_per_tile + active_k_strip )
660- if is_first_step
661- else active_k_strip
662- )
663- if is_last_step :
664- mem += out_evict_per_tile
665-
666- compute_this_step = active_compute + (pointwise_compute if is_last_step else 0.0 )
667- per_tile_lat += max (compute_this_step , mem )
693+ # O(1) per phase: classify steps as first, interior, or last.
694+ # Special steps: global step 0 (loads full_load + pw_load) and
695+ # global last step (evicts output, adds PW compute).
696+ has_first = (prev_end == 0 )
697+ has_last = is_last_phase # last phase always contains the last step
698+
699+ # Interior steps: all steps in this phase that are neither first nor last.
700+ interior_count = phase_steps - (1 if has_first else 0 ) - (1 if has_last else 0 )
701+ # interior_count can be negative when a single step is both first and last.
702+ interior_count = max (0 , interior_count )
703+
704+ if has_first :
705+ mem = full_load_lhs_time + pw_load_per_tile + active_k_strip
706+ # The first step is also the last only when total_k_steps == 1,
707+ # but that case is handled by the all_same_k_full branch above.
708+ per_tile_lat += max (active_compute , mem )
709+
710+ if interior_count > 0 :
711+ interior_lat = max (active_compute , active_k_strip )
712+ per_tile_lat += interior_count * interior_lat
713+
714+ if has_last :
715+ # If phase has only one step and it's also the first step,
716+ # we already added the first step above; skip duplicate.
717+ is_also_first = has_first and (phase_steps == 1 )
718+ if not is_also_first :
719+ mem_last = active_k_strip + out_evict_per_tile
720+ compute_last = active_compute + pointwise_compute
721+ per_tile_lat += max (compute_last , mem_last )
722+ else :
723+ # Single-step phase that is both first and last: adjust the
724+ # first-step cost to include eviction and PW compute.
725+ mem_last = full_load_lhs_time + pw_load_per_tile + active_k_strip + out_evict_per_tile
726+ compute_last = active_compute + pointwise_compute
727+ # Undo the first-step contribution already added and replace it.
728+ first_mem = full_load_lhs_time + pw_load_per_tile + active_k_strip
729+ per_tile_lat -= max (active_compute , first_mem )
730+ per_tile_lat += max (compute_last , mem_last )
668731
669732 prev_end = phase_end
670733
@@ -726,13 +789,16 @@ def compute_subgraph_latency(
726789 resident_full_lhs [t_idx ] = tile_row
727790
728791 # ------- k_strip LHS (non-ephemeral-output MatMul) -------
729- # In split-K: loaded every k-step (no reuse).
792+ # In split-K: loaded every k-step (no reuse), but only while the
793+ # owning MatMul is still active (k_step < ceil(K_full_op / k)).
730794 # In spatial-only: row-reusable (same as full_load).
731795 for t_idx in k_strip_lhs :
732796 if t_idx in retained_tensors :
733797 continue
734798 if is_split_k :
735- mem_in += (h * k ) / bw
799+ # Only load if the owning MatMul is still active this step.
800+ if k_step < k_strip_tensor_active_steps .get (t_idx , num_k_steps ):
801+ mem_in += (h * k ) / bw
736802 elif is_first_k and resident_k_strip_lhs [t_idx ] != tile_row :
737803 mem_in += (h * k ) / bw
738804 resident_k_strip_lhs [t_idx ] = tile_row
@@ -742,6 +808,9 @@ def compute_subgraph_latency(
742808 for t_idx in all_rhs :
743809 if t_idx in retained_tensors :
744810 continue
811+ # Only load if the owning MatMul is still active this step.
812+ if k_step >= k_strip_tensor_active_steps .get (t_idx , num_k_steps ):
813+ continue
745814 if t_idx in rhs_ephemeral :
746815 mem_in += (problem .tensors [t_idx ].height * k ) / bw
747816 else :
@@ -774,7 +843,7 @@ def compute_subgraph_latency(
774843 # For mixed-K: only MatMuls that haven't finished yet contribute compute.
775844 active_matmul_compute = sum (
776845 bc * (k / kf )
777- for kf , bc in matmul_phase_info
846+ for kf , bc , _ in matmul_phase_info
778847 if k_step < math .ceil (kf / k )
779848 )
780849 compute_this_step = active_matmul_compute + (pointwise_compute if is_last_k else 0.0 )
0 commit comments