Skip to content

Commit f0d86b3

Browse files
committed
feat: allow mixed-K fusion (remove K_full consistency constraint)
Rust + Python changes: - compute_num_k_steps: max(K_full) instead of min(K_full) - find_k_full/find_split_k: max(K_full) for search range - Removed K_full consistency check from fusion (both tracks) - Evaluator: k <= max(K_full) instead of K_full consistency - Latency model: mixed-K phase-based computation for split-K (each MatMul active for ceil(its_K_full/k) steps only) - Uniform-K case uses existing formula (no regression) 16/16 tests pass. Benchmark latencies unchanged because cost-based fusion correctly determines that mixed-K fused latency is not better than split for these benchmarks' tensor dimensions. Closes #22
1 parent 409a128 commit f0d86b3

7 files changed

Lines changed: 259 additions & 99 deletions

File tree

solution/agent/evaluator.py

Lines changed: 91 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -389,11 +389,11 @@ def compute_working_set(
389389
boundary_outputs = _boundary_outputs_for_subgraph(subgraph_ops, problem)
390390

391391
# Determine whether this is a split-K scenario.
392-
# Use min(K_full) across all MatMuls, consistent with compute_subgraph_latency().
392+
# Use max(K_full) across all MatMuls, consistent with compute_subgraph_latency().
393393
matmul_ops = [op_idx for op_idx in subgraph_ops
394394
if problem.ops[op_idx].op_type == "MatMul"]
395395
if matmul_ops:
396-
k_full = min(_k_full_for_op(problem.ops[op_idx], problem) for op_idx in matmul_ops)
396+
k_full = max(_k_full_for_op(problem.ops[op_idx], problem) for op_idx in matmul_ops)
397397
num_k_steps = math.ceil(k_full / k)
398398
else:
399399
k_full = 1
@@ -507,29 +507,36 @@ def compute_subgraph_latency(
507507
num_tiles_h = math.ceil(H_out / h)
508508
num_spatial_tiles = num_tiles_w * num_tiles_h
509509

510-
# K-steps: derive from the minimum K_full across ALL MatMuls in the subgraph.
510+
# K-steps: derive from the maximum K_full across ALL MatMuls in the subgraph.
511+
# The subgraph runs until the longest reduction finishes (mixed-K support).
511512
# Internal MatMuls (whose output is ephemeral) still need k-steps.
512513
# If there is no MatMul at all, k is irrelevant: 1 k-step.
513514
matmul_ops = [op_idx for op_idx in subgraph_ops
514515
if problem.ops[op_idx].op_type == "MatMul"]
515516
if matmul_ops:
516-
min_k_full = min(
517+
max_k_full = max(
517518
_k_full_for_op(problem.ops[op_idx], problem) for op_idx in matmul_ops
518519
)
519-
num_k_steps = math.ceil(min_k_full / k)
520+
num_k_steps = math.ceil(max_k_full / k)
520521
else:
521522
num_k_steps = 1
522523

523524
is_split_k = num_k_steps > 1
524525

525-
# Split compute: MatMul cost paid every k-step; Pointwise cost only on last k-step.
526+
# Split compute: MatMul cost paid every k-step it is active; Pointwise only on last k-step.
527+
# For mixed-K: each MatMul is active for ceil(its_K_full / k) steps.
528+
# matmul_compute_per_step is the total compute when ALL MatMuls are active (step 0 onward).
526529
matmul_compute_per_step = 0.0
527530
pointwise_compute = 0.0
531+
# Also collect per-op info for mixed-K phase calculation.
532+
matmul_phase_info: list[tuple[int, float]] = [] # (k_full, base_cost)
528533
for op_idx in subgraph_ops:
529534
op = problem.ops[op_idx]
530535
if op.op_type == "MatMul":
531536
k_full_op = _k_full_for_op(op, problem)
532-
matmul_compute_per_step += op.base_cost * (k / k_full_op)
537+
cost_per_step = op.base_cost * (k / k_full_op)
538+
matmul_compute_per_step += cost_per_step
539+
matmul_phase_info.append((k_full_op, op.base_cost))
533540
else:
534541
pointwise_compute += op.base_cost
535542

@@ -593,22 +600,74 @@ def compute_subgraph_latency(
593600
if traversal_order is None:
594601
if is_split_k:
595602
# Split-K mode: all spatial tiles are identical (no row-reuse).
596-
# full_load LHS loaded once per tile. k_strip LHS + RHS loaded every k-step.
597-
# First k-step: full_load + pw_load + k_strip_total
598-
first_k_mem = full_load_lhs_time + pw_load_per_tile + k_strip_total_per_step
599-
first_k_lat = max(matmul_compute_per_step, first_k_mem)
600-
601-
# Interior k-steps: k_strip_total only
602-
if num_k_steps > 2:
603-
interior_k_lat = max(matmul_compute_per_step, k_strip_total_per_step)
603+
#
604+
# Check if all MatMuls have the same K_full (fast path) or mixed-K (phase path).
605+
unique_k_fulls = set(kf for kf, _ in matmul_phase_info)
606+
all_same_k_full = len(unique_k_fulls) <= 1
607+
608+
if all_same_k_full:
609+
# Fast path: uniform K_full — original formula.
610+
# full_load LHS loaded once per tile. k_strip LHS + RHS loaded every k-step.
611+
# First k-step: full_load + pw_load + k_strip_total
612+
first_k_mem = full_load_lhs_time + pw_load_per_tile + k_strip_total_per_step
613+
first_k_lat = max(matmul_compute_per_step, first_k_mem)
614+
615+
# Interior k-steps: k_strip_total only
616+
if num_k_steps > 2:
617+
interior_k_lat = max(matmul_compute_per_step, k_strip_total_per_step)
618+
else:
619+
interior_k_lat = 0.0
620+
621+
# Last k-step: k_strip_total + eviction, compute includes PW
622+
last_k_mem = k_strip_total_per_step + out_evict_per_tile
623+
last_k_lat = max(matmul_compute_per_step + pointwise_compute, last_k_mem)
624+
625+
per_tile_lat = first_k_lat + max(0, num_k_steps - 2) * interior_k_lat + last_k_lat
604626
else:
605-
interior_k_lat = 0.0
627+
# Mixed-K path: compute phase-by-phase.
628+
# Each MatMul is active for ceil(its_K_full / k) steps.
629+
# Phases are defined by sorted unique step-end boundaries.
630+
step_ends = sorted(set(math.ceil(kf / k) for kf, _ in matmul_phase_info))
631+
# step_ends[-1] == num_k_steps
632+
633+
per_tile_lat = 0.0
634+
prev_end = 0
635+
636+
for phase_idx, phase_end in enumerate(step_ends):
637+
# Active MatMuls: those whose step count >= phase_end
638+
active_compute = sum(
639+
bc * (k / kf)
640+
for kf, bc in matmul_phase_info
641+
if math.ceil(kf / k) >= phase_end
642+
)
643+
644+
# Active k_strip: scale total by ratio of active compute to total compute.
645+
if matmul_compute_per_step > 0.0:
646+
active_k_strip = k_strip_total_per_step * (active_compute / matmul_compute_per_step)
647+
else:
648+
active_k_strip = k_strip_total_per_step
649+
650+
phase_steps = phase_end - prev_end
651+
is_last_phase = (phase_idx == len(step_ends) - 1)
652+
653+
for step_offset in range(phase_steps):
654+
global_step = prev_end + step_offset
655+
is_first_step = (global_step == 0)
656+
is_last_step = is_last_phase and (step_offset == phase_steps - 1)
606657

607-
# Last k-step: k_strip_total + eviction, compute includes PW
608-
last_k_mem = k_strip_total_per_step + out_evict_per_tile
609-
last_k_lat = max(matmul_compute_per_step + pointwise_compute, last_k_mem)
658+
mem = (
659+
(full_load_lhs_time + pw_load_per_tile + active_k_strip)
660+
if is_first_step
661+
else active_k_strip
662+
)
663+
if is_last_step:
664+
mem += out_evict_per_tile
665+
666+
compute_this_step = active_compute + (pointwise_compute if is_last_step else 0.0)
667+
per_tile_lat += max(compute_this_step, mem)
668+
669+
prev_end = phase_end
610670

611-
per_tile_lat = first_k_lat + max(0, num_k_steps - 2) * interior_k_lat + last_k_lat
612671
return num_spatial_tiles * per_tile_lat
613672

614673
else:
@@ -712,7 +771,13 @@ def compute_subgraph_latency(
712771
if t_idx not in tensors_to_retain_after:
713772
mem_out += (w * h) / bw
714773

715-
compute_this_step = matmul_compute_per_step + (pointwise_compute if is_last_k else 0.0)
774+
# For mixed-K: only MatMuls that haven't finished yet contribute compute.
775+
active_matmul_compute = sum(
776+
bc * (k / kf)
777+
for kf, bc in matmul_phase_info
778+
if k_step < math.ceil(kf / k)
779+
)
780+
compute_this_step = active_matmul_compute + (pointwise_compute if is_last_k else 0.0)
716781
memory_time = mem_in + mem_out
717782
step_latency = max(compute_this_step, memory_time)
718783
total_latency += step_latency
@@ -747,20 +812,18 @@ def evaluate(problem: Problem, solution: Solution) -> float:
747812
ops_in_sg = sg.ops
748813
gran = sg.granularity
749814

750-
# Validate MatMul K_full consistency and k <= K_full
815+
# Validate k does not exceed the maximum K_full across all MatMuls.
816+
# Mixed-K subgraphs (MatMuls with different K_full values) are allowed.
751817
matmul_k_fulls = [
752818
_k_full_for_op(problem.ops[op_idx], problem)
753819
for op_idx in ops_in_sg
754820
if problem.ops[op_idx].op_type == "MatMul"
755821
]
756822
if matmul_k_fulls:
757-
if len(set(matmul_k_fulls)) > 1:
758-
raise ValidationError(
759-
f"Subgraph {sg_idx}: MatMul ops have inconsistent K_full values: {matmul_k_fulls}"
760-
)
761-
if gran.k > matmul_k_fulls[0]:
823+
max_k_full = max(matmul_k_fulls)
824+
if gran.k > max_k_full:
762825
raise ValidationError(
763-
f"Subgraph {sg_idx}: granularity k={gran.k} exceeds K_full={matmul_k_fulls[0]}"
826+
f"Subgraph {sg_idx}: granularity k={gran.k} exceeds max K_full={max_k_full}"
764827
)
765828

766829
if not check_oom(ops_in_sg, gran, problem, retained_tensors):

solution/agent/scheduler.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -268,17 +268,6 @@ def _cached_best(ops: list[int], cache: dict) -> tuple[Granularity, float]:
268268

269269
merged = sg_ops[i] + sg_ops[i + 1]
270270

271-
# K_full consistency: all MatMuls must share the same K_full
272-
matmul_k_fulls = [
273-
_k_full_for_op(problem.ops[o], problem)
274-
for o in merged if problem.ops[o].op_type == "MatMul"
275-
]
276-
if matmul_k_fulls and len(set(matmul_k_fulls)) > 1:
277-
rejected_merges.add(merge_key)
278-
new_sg_ops.append(sg_ops[i])
279-
i += 1
280-
continue
281-
282271
# Boundary output dimension consistency
283272
boundary_outs = list(_boundary_outputs_for_subgraph(merged, problem))
284273
if boundary_outs:

solution/backend/rust/src/evaluate.rs

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ pub fn evaluate(problem: &Problem, solution: &Solution) -> Result<EvaluateResult
7171
}
7272
}
7373

74-
// Validate MatMul K_full consistency and k <= K_full
75-
let matmul_k_fulls: Vec<i64> = sg.ops.iter()
74+
// Validate k does not exceed the maximum K_full across all MatMuls.
75+
// Mixed-K subgraphs (MatMuls with different K_full values) are allowed;
76+
// k only needs to be <= max(K_full) so the step count is well-defined.
77+
let max_k_full: Option<i64> = sg.ops.iter()
7678
.filter_map(|&op_idx| {
7779
let op = &problem.ops[op_idx];
7880
if op.is_matmul() {
@@ -81,20 +83,12 @@ pub fn evaluate(problem: &Problem, solution: &Solution) -> Result<EvaluateResult
8183
None
8284
}
8385
})
84-
.collect();
85-
if !matmul_k_fulls.is_empty() {
86-
// All MatMuls in a subgraph must share the same K_full
87-
if !matmul_k_fulls.iter().all(|&kf| kf == matmul_k_fulls[0]) {
86+
.max();
87+
if let Some(kf_max) = max_k_full {
88+
if sg.granularity.k > kf_max {
8889
return Err(format!(
89-
"Subgraph {sg_idx}: MatMul ops have inconsistent K_full values: {:?}",
90-
matmul_k_fulls
91-
));
92-
}
93-
// k must not exceed K_full
94-
if sg.granularity.k > matmul_k_fulls[0] {
95-
return Err(format!(
96-
"Subgraph {sg_idx}: granularity k={} exceeds K_full={}",
97-
sg.granularity.k, matmul_k_fulls[0]
90+
"Subgraph {sg_idx}: granularity k={} exceeds max K_full={}",
91+
sg.granularity.k, kf_max
9892
));
9993
}
10094
}

0 commit comments

Comments
 (0)