@@ -389,11 +389,11 @@ def compute_working_set(
389389 boundary_outputs = _boundary_outputs_for_subgraph (subgraph_ops , problem )
390390
391391 # Determine whether this is a split-K scenario.
392- # Use min (K_full) across all MatMuls, consistent with compute_subgraph_latency().
392+ # Use max (K_full) across all MatMuls, consistent with compute_subgraph_latency().
393393 matmul_ops = [op_idx for op_idx in subgraph_ops
394394 if problem .ops [op_idx ].op_type == "MatMul" ]
395395 if matmul_ops :
396- k_full = min (_k_full_for_op (problem .ops [op_idx ], problem ) for op_idx in matmul_ops )
396+ k_full = max (_k_full_for_op (problem .ops [op_idx ], problem ) for op_idx in matmul_ops )
397397 num_k_steps = math .ceil (k_full / k )
398398 else :
399399 k_full = 1
@@ -507,29 +507,36 @@ def compute_subgraph_latency(
507507 num_tiles_h = math .ceil (H_out / h )
508508 num_spatial_tiles = num_tiles_w * num_tiles_h
509509
510- # K-steps: derive from the minimum K_full across ALL MatMuls in the subgraph.
510+ # K-steps: derive from the maximum K_full across ALL MatMuls in the subgraph.
511+ # The subgraph runs until the longest reduction finishes (mixed-K support).
511512 # Internal MatMuls (whose output is ephemeral) still need k-steps.
512513 # If there is no MatMul at all, k is irrelevant: 1 k-step.
513514 matmul_ops = [op_idx for op_idx in subgraph_ops
514515 if problem .ops [op_idx ].op_type == "MatMul" ]
515516 if matmul_ops :
516- min_k_full = min (
517+ max_k_full = max (
517518 _k_full_for_op (problem .ops [op_idx ], problem ) for op_idx in matmul_ops
518519 )
519- num_k_steps = math .ceil (min_k_full / k )
520+ num_k_steps = math .ceil (max_k_full / k )
520521 else :
521522 num_k_steps = 1
522523
523524 is_split_k = num_k_steps > 1
524525
525- # Split compute: MatMul cost paid every k-step; Pointwise cost only on last k-step.
526+ # Split compute: MatMul cost paid every k-step it is active; Pointwise only on last k-step.
527+ # For mixed-K: each MatMul is active for ceil(its_K_full / k) steps.
528+ # matmul_compute_per_step is the total compute when ALL MatMuls are active (step 0 onward).
526529 matmul_compute_per_step = 0.0
527530 pointwise_compute = 0.0
531+ # Also collect per-op info for mixed-K phase calculation.
532+ matmul_phase_info : list [tuple [int , float ]] = [] # (k_full, base_cost)
528533 for op_idx in subgraph_ops :
529534 op = problem .ops [op_idx ]
530535 if op .op_type == "MatMul" :
531536 k_full_op = _k_full_for_op (op , problem )
532- matmul_compute_per_step += op .base_cost * (k / k_full_op )
537+ cost_per_step = op .base_cost * (k / k_full_op )
538+ matmul_compute_per_step += cost_per_step
539+ matmul_phase_info .append ((k_full_op , op .base_cost ))
533540 else :
534541 pointwise_compute += op .base_cost
535542
@@ -593,22 +600,74 @@ def compute_subgraph_latency(
593600 if traversal_order is None :
594601 if is_split_k :
595602 # Split-K mode: all spatial tiles are identical (no row-reuse).
596- # full_load LHS loaded once per tile. k_strip LHS + RHS loaded every k-step.
597- # First k-step: full_load + pw_load + k_strip_total
598- first_k_mem = full_load_lhs_time + pw_load_per_tile + k_strip_total_per_step
599- first_k_lat = max (matmul_compute_per_step , first_k_mem )
600-
601- # Interior k-steps: k_strip_total only
602- if num_k_steps > 2 :
603- interior_k_lat = max (matmul_compute_per_step , k_strip_total_per_step )
603+ #
604+ # Check if all MatMuls have the same K_full (fast path) or mixed-K (phase path).
605+ unique_k_fulls = set (kf for kf , _ in matmul_phase_info )
606+ all_same_k_full = len (unique_k_fulls ) <= 1
607+
608+ if all_same_k_full :
609+ # Fast path: uniform K_full — original formula.
610+ # full_load LHS loaded once per tile. k_strip LHS + RHS loaded every k-step.
611+ # First k-step: full_load + pw_load + k_strip_total
612+ first_k_mem = full_load_lhs_time + pw_load_per_tile + k_strip_total_per_step
613+ first_k_lat = max (matmul_compute_per_step , first_k_mem )
614+
615+ # Interior k-steps: k_strip_total only
616+ if num_k_steps > 2 :
617+ interior_k_lat = max (matmul_compute_per_step , k_strip_total_per_step )
618+ else :
619+ interior_k_lat = 0.0
620+
621+ # Last k-step: k_strip_total + eviction, compute includes PW
622+ last_k_mem = k_strip_total_per_step + out_evict_per_tile
623+ last_k_lat = max (matmul_compute_per_step + pointwise_compute , last_k_mem )
624+
625+ per_tile_lat = first_k_lat + max (0 , num_k_steps - 2 ) * interior_k_lat + last_k_lat
604626 else :
605- interior_k_lat = 0.0
627+ # Mixed-K path: compute phase-by-phase.
628+ # Each MatMul is active for ceil(its_K_full / k) steps.
629+ # Phases are defined by sorted unique step-end boundaries.
630+ step_ends = sorted (set (math .ceil (kf / k ) for kf , _ in matmul_phase_info ))
631+ # step_ends[-1] == num_k_steps
632+
633+ per_tile_lat = 0.0
634+ prev_end = 0
635+
636+ for phase_idx , phase_end in enumerate (step_ends ):
637+ # Active MatMuls: those whose step count >= phase_end
638+ active_compute = sum (
639+ bc * (k / kf )
640+ for kf , bc in matmul_phase_info
641+ if math .ceil (kf / k ) >= phase_end
642+ )
643+
644+ # Active k_strip: scale total by ratio of active compute to total compute.
645+ if matmul_compute_per_step > 0.0 :
646+ active_k_strip = k_strip_total_per_step * (active_compute / matmul_compute_per_step )
647+ else :
648+ active_k_strip = k_strip_total_per_step
649+
650+ phase_steps = phase_end - prev_end
651+ is_last_phase = (phase_idx == len (step_ends ) - 1 )
652+
653+ for step_offset in range (phase_steps ):
654+ global_step = prev_end + step_offset
655+ is_first_step = (global_step == 0 )
656+ is_last_step = is_last_phase and (step_offset == phase_steps - 1 )
606657
607- # Last k-step: k_strip_total + eviction, compute includes PW
608- last_k_mem = k_strip_total_per_step + out_evict_per_tile
609- last_k_lat = max (matmul_compute_per_step + pointwise_compute , last_k_mem )
658+ mem = (
659+ (full_load_lhs_time + pw_load_per_tile + active_k_strip )
660+ if is_first_step
661+ else active_k_strip
662+ )
663+ if is_last_step :
664+ mem += out_evict_per_tile
665+
666+ compute_this_step = active_compute + (pointwise_compute if is_last_step else 0.0 )
667+ per_tile_lat += max (compute_this_step , mem )
668+
669+ prev_end = phase_end
610670
611- per_tile_lat = first_k_lat + max (0 , num_k_steps - 2 ) * interior_k_lat + last_k_lat
612671 return num_spatial_tiles * per_tile_lat
613672
614673 else :
@@ -712,7 +771,13 @@ def compute_subgraph_latency(
712771 if t_idx not in tensors_to_retain_after :
713772 mem_out += (w * h ) / bw
714773
715- compute_this_step = matmul_compute_per_step + (pointwise_compute if is_last_k else 0.0 )
774+ # For mixed-K: only MatMuls that haven't finished yet contribute compute.
775+ active_matmul_compute = sum (
776+ bc * (k / kf )
777+ for kf , bc in matmul_phase_info
778+ if k_step < math .ceil (kf / k )
779+ )
780+ compute_this_step = active_matmul_compute + (pointwise_compute if is_last_k else 0.0 )
716781 memory_time = mem_in + mem_out
717782 step_latency = max (compute_this_step , memory_time )
718783 total_latency += step_latency
@@ -747,20 +812,18 @@ def evaluate(problem: Problem, solution: Solution) -> float:
747812 ops_in_sg = sg .ops
748813 gran = sg .granularity
749814
750- # Validate MatMul K_full consistency and k <= K_full
815+ # Validate k does not exceed the maximum K_full across all MatMuls.
816+ # Mixed-K subgraphs (MatMuls with different K_full values) are allowed.
751817 matmul_k_fulls = [
752818 _k_full_for_op (problem .ops [op_idx ], problem )
753819 for op_idx in ops_in_sg
754820 if problem .ops [op_idx ].op_type == "MatMul"
755821 ]
756822 if matmul_k_fulls :
757- if len (set (matmul_k_fulls )) > 1 :
758- raise ValidationError (
759- f"Subgraph { sg_idx } : MatMul ops have inconsistent K_full values: { matmul_k_fulls } "
760- )
761- if gran .k > matmul_k_fulls [0 ]:
823+ max_k_full = max (matmul_k_fulls )
824+ if gran .k > max_k_full :
762825 raise ValidationError (
763- f"Subgraph { sg_idx } : granularity k={ gran .k } exceeds K_full={ matmul_k_fulls [ 0 ] } "
826+ f"Subgraph { sg_idx } : granularity k={ gran .k } exceeds max K_full={ max_k_full } "
764827 )
765828
766829 if not check_oom (ops_in_sg , gran , problem , retained_tensors ):
0 commit comments