diff --git a/solution/agent/evaluator.py b/solution/agent/evaluator.py index 3ed560c..48ee3ee 100644 --- a/solution/agent/evaluator.py +++ b/solution/agent/evaluator.py @@ -389,11 +389,11 @@ def compute_working_set( boundary_outputs = _boundary_outputs_for_subgraph(subgraph_ops, problem) # Determine whether this is a split-K scenario. - # Use min(K_full) across all MatMuls, consistent with compute_subgraph_latency(). + # Use max(K_full) across all MatMuls, consistent with compute_subgraph_latency(). matmul_ops = [op_idx for op_idx in subgraph_ops if problem.ops[op_idx].op_type == "MatMul"] if matmul_ops: - k_full = min(_k_full_for_op(problem.ops[op_idx], problem) for op_idx in matmul_ops) + k_full = max(_k_full_for_op(problem.ops[op_idx], problem) for op_idx in matmul_ops) num_k_steps = math.ceil(k_full / k) else: k_full = 1 @@ -507,29 +507,33 @@ def compute_subgraph_latency( num_tiles_h = math.ceil(H_out / h) num_spatial_tiles = num_tiles_w * num_tiles_h - # K-steps: derive from the minimum K_full across ALL MatMuls in the subgraph. + # K-steps: derive from the maximum K_full across ALL MatMuls in the subgraph. + # The subgraph runs until the longest reduction finishes (mixed-K support). # Internal MatMuls (whose output is ephemeral) still need k-steps. # If there is no MatMul at all, k is irrelevant: 1 k-step. matmul_ops = [op_idx for op_idx in subgraph_ops if problem.ops[op_idx].op_type == "MatMul"] if matmul_ops: - min_k_full = min( + max_k_full = max( _k_full_for_op(problem.ops[op_idx], problem) for op_idx in matmul_ops ) - num_k_steps = math.ceil(min_k_full / k) + num_k_steps = math.ceil(max_k_full / k) else: num_k_steps = 1 is_split_k = num_k_steps > 1 - # Split compute: MatMul cost paid every k-step; Pointwise cost only on last k-step. + # Split compute: MatMul cost paid every k-step it is active; Pointwise only on last k-step. + # For mixed-K: each MatMul is active for ceil(its_K_full / k) steps. + # matmul_compute_per_step is the total compute when ALL MatMuls are active (step 0 onward). matmul_compute_per_step = 0.0 pointwise_compute = 0.0 for op_idx in subgraph_ops: op = problem.ops[op_idx] if op.op_type == "MatMul": k_full_op = _k_full_for_op(op, problem) - matmul_compute_per_step += op.base_cost * (k / k_full_op) + cost_per_step = op.base_cost * (min(k, k_full_op) / k_full_op) + matmul_compute_per_step += cost_per_step else: pointwise_compute += op.base_cost @@ -573,6 +577,51 @@ def compute_subgraph_latency( # Total k_strip load per step (k_strip LHS + both RHS types) k_strip_total_per_step = k_strip_lhs_per_step + rhs_std_per_step + rhs_eph_per_step + # Per-MatMul k_strip contribution and per-tensor active-step-count for mixed-K. + # Each MatMul op contributes k_strip from its boundary inputs that are not retained: + # - non-ephemeral LHS (in k_strip_lhs): h * k_eff / bw + # - non-ephemeral RHS (in rhs_standard): k_eff * w / bw + # - ephemeral RHS (in rhs_ephemeral): rhs.height * k_eff / bw + # where k_eff = min(k, K_full_op) for each MatMul + # Deduplication mirrors _categorize_inputs (a tensor counted once for its first op). + # + # matmul_phase_info: list of (k_full, base_cost, k_strip_contribution_per_step) + # k_strip_tensor_active_steps: tensor_id -> step count of its owning MatMul. + # Used in the simulation path to load k_strip inputs only while their op is active. + matmul_phase_info: list[tuple[int, float, float]] = [] + k_strip_tensor_active_steps: dict[int, int] = {} + k_strip_tensor_k_eff: dict[int, int] = {} # tensor_id -> min(k, K_full_op) + _seen_k_strip: set[int] = set() + for op_idx in subgraph_ops: + op = problem.ops[op_idx] + if op.op_type != "MatMul": + continue + k_full_op = _k_full_for_op(op, problem) + op_steps = math.ceil(k_full_op / k) + k_eff = min(k, k_full_op) # clamp k to this op's K_full + lhs_idx = op.inputs[0] + rhs_idx = op.inputs[1] + op_k_strip = 0.0 + if lhs_idx in k_strip_lhs and lhs_idx not in _seen_k_strip: + _seen_k_strip.add(lhs_idx) + if lhs_idx not in retained_tensors: + op_k_strip += (h * k_eff) / bw + k_strip_tensor_active_steps[lhs_idx] = op_steps + k_strip_tensor_k_eff[lhs_idx] = k_eff + if rhs_idx in rhs_standard and rhs_idx not in _seen_k_strip: + _seen_k_strip.add(rhs_idx) + if rhs_idx not in retained_tensors: + op_k_strip += (k_eff * w) / bw + k_strip_tensor_active_steps[rhs_idx] = op_steps + k_strip_tensor_k_eff[rhs_idx] = k_eff + if rhs_idx in rhs_ephemeral and rhs_idx not in _seen_k_strip: + _seen_k_strip.add(rhs_idx) + if rhs_idx not in retained_tensors: + op_k_strip += (problem.tensors[rhs_idx].height * k_eff) / bw + k_strip_tensor_active_steps[rhs_idx] = op_steps + k_strip_tensor_k_eff[rhs_idx] = k_eff + matmul_phase_info.append((k_full_op, op.base_cost, op_k_strip)) + # Pointwise inputs: w * h per spatial tile (first k-step). pw_load_per_tile = sum( (w * h) / bw @@ -593,22 +642,100 @@ def compute_subgraph_latency( if traversal_order is None: if is_split_k: # Split-K mode: all spatial tiles are identical (no row-reuse). - # full_load LHS loaded once per tile. k_strip LHS + RHS loaded every k-step. - # First k-step: full_load + pw_load + k_strip_total - first_k_mem = full_load_lhs_time + pw_load_per_tile + k_strip_total_per_step - first_k_lat = max(matmul_compute_per_step, first_k_mem) - - # Interior k-steps: k_strip_total only - if num_k_steps > 2: - interior_k_lat = max(matmul_compute_per_step, k_strip_total_per_step) - else: - interior_k_lat = 0.0 + # + # Check if all MatMuls have the same K_full (fast path) or mixed-K (phase path). + unique_k_fulls = set(kf for kf, _, _ in matmul_phase_info) + all_same_k_full = len(unique_k_fulls) <= 1 + + if all_same_k_full: + # Fast path: uniform K_full — original formula. + # full_load LHS loaded once per tile. k_strip LHS + RHS loaded every k-step. + # First k-step: full_load + pw_load + k_strip_total + first_k_mem = full_load_lhs_time + pw_load_per_tile + k_strip_total_per_step + first_k_lat = max(matmul_compute_per_step, first_k_mem) + + # Interior k-steps: k_strip_total only + if num_k_steps > 2: + interior_k_lat = max(matmul_compute_per_step, k_strip_total_per_step) + else: + interior_k_lat = 0.0 - # Last k-step: k_strip_total + eviction, compute includes PW - last_k_mem = k_strip_total_per_step + out_evict_per_tile - last_k_lat = max(matmul_compute_per_step + pointwise_compute, last_k_mem) + # Last k-step: k_strip_total + eviction, compute includes PW + last_k_mem = k_strip_total_per_step + out_evict_per_tile + last_k_lat = max(matmul_compute_per_step + pointwise_compute, last_k_mem) + + per_tile_lat = first_k_lat + max(0, num_k_steps - 2) * interior_k_lat + last_k_lat + else: + # Mixed-K path: compute phase-by-phase. + # Each MatMul is active for ceil(its_K_full / k) steps. + # Phases are defined by sorted unique step-end boundaries. + step_ends = sorted(set(math.ceil(kf / k) for kf, _, _ in matmul_phase_info)) + # step_ends[-1] == num_k_steps + + per_tile_lat = 0.0 + prev_end = 0 + + for phase_idx, phase_end in enumerate(step_ends): + # Active MatMuls: those whose step count >= phase_end. + active_compute = sum( + bc * (min(k, kf) / kf) + for kf, bc, _ in matmul_phase_info + if math.ceil(kf / k) >= phase_end + ) + + # Active k_strip: sum per-op contributions for active MatMuls only. + # This is exact (no proxy ratio) because each op's contribution is + # precomputed from its actual tensor dimensions. + active_k_strip = sum( + ks + for kf, _, ks in matmul_phase_info + if math.ceil(kf / k) >= phase_end + ) + + phase_steps = phase_end - prev_end + is_last_phase = (phase_idx == len(step_ends) - 1) + + # O(1) per phase: classify steps as first, interior, or last. + # Special steps: global step 0 (loads full_load + pw_load) and + # global last step (evicts output, adds PW compute). + has_first = (prev_end == 0) + has_last = is_last_phase # last phase always contains the last step + + # Interior steps: all steps in this phase that are neither first nor last. + interior_count = phase_steps - (1 if has_first else 0) - (1 if has_last else 0) + # interior_count can be negative when a single step is both first and last. + interior_count = max(0, interior_count) + + if has_first: + mem = full_load_lhs_time + pw_load_per_tile + active_k_strip + # The first step is also the last only when num_k_steps == 1, + # but that case is handled by the all_same_k_full branch above. + per_tile_lat += max(active_compute, mem) + + if interior_count > 0: + interior_lat = max(active_compute, active_k_strip) + per_tile_lat += interior_count * interior_lat + + if has_last: + # If phase has only one step and it's also the first step, + # we already added the first step above; skip duplicate. + is_also_first = has_first and (phase_steps == 1) + if not is_also_first: + mem_last = active_k_strip + out_evict_per_tile + compute_last = active_compute + pointwise_compute + per_tile_lat += max(compute_last, mem_last) + else: + # Single-step phase that is both first and last: adjust the + # first-step cost to include eviction and PW compute. + mem_last = full_load_lhs_time + pw_load_per_tile + active_k_strip + out_evict_per_tile + compute_last = active_compute + pointwise_compute + # Undo the first-step contribution already added and replace it. + first_mem = full_load_lhs_time + pw_load_per_tile + active_k_strip + per_tile_lat -= max(active_compute, first_mem) + per_tile_lat += max(compute_last, mem_last) + + prev_end = phase_end - per_tile_lat = first_k_lat + max(0, num_k_steps - 2) * interior_k_lat + last_k_lat return num_spatial_tiles * per_tile_lat else: @@ -667,15 +794,18 @@ def compute_subgraph_latency( resident_full_lhs[t_idx] = tile_row # ------- k_strip LHS (non-ephemeral-output MatMul) ------- - # In split-K: loaded every k-step (no reuse). + # In split-K: loaded every k-step (no reuse), but only while the + # owning MatMul is still active (k_step < ceil(K_full_op / k)). # In spatial-only: row-reusable (same as full_load). for t_idx in k_strip_lhs: if t_idx in retained_tensors: continue + ke = k_strip_tensor_k_eff.get(t_idx, k) if is_split_k: - mem_in += (h * k) / bw + if k_step < k_strip_tensor_active_steps.get(t_idx, num_k_steps): + mem_in += (h * ke) / bw elif is_first_k and resident_k_strip_lhs[t_idx] != tile_row: - mem_in += (h * k) / bw + mem_in += (h * ke) / bw resident_k_strip_lhs[t_idx] = tile_row # ------- RHS tensors ------- @@ -683,20 +813,25 @@ def compute_subgraph_latency( for t_idx in all_rhs: if t_idx in retained_tensors: continue + # Only load if the owning MatMul is still active this step. + if k_step >= k_strip_tensor_active_steps.get(t_idx, num_k_steps): + continue + ke = k_strip_tensor_k_eff.get(t_idx, k) if t_idx in rhs_ephemeral: - mem_in += (problem.tensors[t_idx].height * k) / bw + mem_in += (problem.tensors[t_idx].height * ke) / bw else: - mem_in += (k * w) / bw + mem_in += (ke * w) / bw else: if is_first_k: for t_idx in all_rhs: if t_idx in retained_tensors: continue if resident_rhs[t_idx] != tile_col: + ke = k_strip_tensor_k_eff.get(t_idx, k) if t_idx in rhs_ephemeral: - mem_in += (problem.tensors[t_idx].height * k) / bw + mem_in += (problem.tensors[t_idx].height * ke) / bw else: - mem_in += (k * w) / bw + mem_in += (ke * w) / bw resident_rhs[t_idx] = tile_col # ------- Pointwise inputs ------- @@ -712,7 +847,13 @@ def compute_subgraph_latency( if t_idx not in tensors_to_retain_after: mem_out += (w * h) / bw - compute_this_step = matmul_compute_per_step + (pointwise_compute if is_last_k else 0.0) + # For mixed-K: only MatMuls that haven't finished yet contribute compute. + active_matmul_compute = sum( + bc * (min(k, kf) / kf) + for kf, bc, _ in matmul_phase_info + if k_step < math.ceil(kf / k) + ) + compute_this_step = active_matmul_compute + (pointwise_compute if is_last_k else 0.0) memory_time = mem_in + mem_out step_latency = max(compute_this_step, memory_time) total_latency += step_latency @@ -747,20 +888,18 @@ def evaluate(problem: Problem, solution: Solution) -> float: ops_in_sg = sg.ops gran = sg.granularity - # Validate MatMul K_full consistency and k <= K_full + # Validate k does not exceed the maximum K_full across all MatMuls. + # Mixed-K subgraphs (MatMuls with different K_full values) are allowed. matmul_k_fulls = [ _k_full_for_op(problem.ops[op_idx], problem) for op_idx in ops_in_sg if problem.ops[op_idx].op_type == "MatMul" ] if matmul_k_fulls: - if len(set(matmul_k_fulls)) > 1: - raise ValidationError( - f"Subgraph {sg_idx}: MatMul ops have inconsistent K_full values: {matmul_k_fulls}" - ) - if gran.k > matmul_k_fulls[0]: + max_k_full = max(matmul_k_fulls) + if gran.k > max_k_full: raise ValidationError( - f"Subgraph {sg_idx}: granularity k={gran.k} exceeds K_full={matmul_k_fulls[0]}" + f"Subgraph {sg_idx}: granularity k={gran.k} exceeds max K_full={max_k_full}" ) if not check_oom(ops_in_sg, gran, problem, retained_tensors): diff --git a/solution/agent/scheduler.py b/solution/agent/scheduler.py index 28a8b19..04a0826 100644 --- a/solution/agent/scheduler.py +++ b/solution/agent/scheduler.py @@ -268,17 +268,6 @@ def _cached_best(ops: list[int], cache: dict) -> tuple[Granularity, float]: merged = sg_ops[i] + sg_ops[i + 1] - # K_full consistency: all MatMuls must share the same K_full - matmul_k_fulls = [ - _k_full_for_op(problem.ops[o], problem) - for o in merged if problem.ops[o].op_type == "MatMul" - ] - if matmul_k_fulls and len(set(matmul_k_fulls)) > 1: - rejected_merges.add(merge_key) - new_sg_ops.append(sg_ops[i]) - i += 1 - continue - # Boundary output dimension consistency boundary_outs = list(_boundary_outputs_for_subgraph(merged, problem)) if boundary_outs: diff --git a/solution/backend/rust/src/evaluate.rs b/solution/backend/rust/src/evaluate.rs index 002a6ba..aa19439 100644 --- a/solution/backend/rust/src/evaluate.rs +++ b/solution/backend/rust/src/evaluate.rs @@ -71,8 +71,10 @@ pub fn evaluate(problem: &Problem, solution: &Solution) -> Result = sg.ops.iter() + // Validate k does not exceed the maximum K_full across all MatMuls. + // Mixed-K subgraphs (MatMuls with different K_full values) are allowed; + // k only needs to be <= max(K_full) so the step count is well-defined. + let max_k_full: Option = sg.ops.iter() .filter_map(|&op_idx| { let op = &problem.ops[op_idx]; if op.is_matmul() { @@ -81,20 +83,12 @@ pub fn evaluate(problem: &Problem, solution: &Solution) -> Result kf_max { return Err(format!( - "Subgraph {sg_idx}: MatMul ops have inconsistent K_full values: {:?}", - matmul_k_fulls - )); - } - // k must not exceed K_full - if sg.granularity.k > matmul_k_fulls[0] { - return Err(format!( - "Subgraph {sg_idx}: granularity k={} exceeds K_full={}", - sg.granularity.k, matmul_k_fulls[0] + "Subgraph {sg_idx}: granularity k={} exceeds max K_full={}", + sg.granularity.k, kf_max )); } } diff --git a/solution/backend/rust/src/latency.rs b/solution/backend/rust/src/latency.rs index a1c34a5..d70df26 100644 --- a/solution/backend/rust/src/latency.rs +++ b/solution/backend/rust/src/latency.rs @@ -29,7 +29,7 @@ pub fn matmul_compute_per_step( let op = &problem.ops[op_idx]; if op.is_matmul() { let op_k_full = k_full_for_matmul(op, &problem.tensors) as f64; - total += op.base_cost as f64 * (k / op_k_full); + total += op.base_cost as f64 * (k.min(op_k_full) / op_k_full); } } total @@ -155,6 +155,10 @@ fn build_memory_plan( && !dag.tensor_consumers[out_t].is_empty() && dag.tensor_consumers[out_t].iter().all(|c| op_set.contains(c)); + // Effective k for this op: clamp to its K_full (can't load more than exists) + let op_k_full = k_full_for_matmul(op, &problem.tensors); + let k_eff = k.min(op_k_full); + // LHS input let lhs_boundary = !dag.tensor_producer[lhs_idx] .map(|p| op_set.contains(&p)) @@ -164,15 +168,13 @@ fn build_memory_plan( if previously_retained.contains(&lhs_idx) { pre_retained.push(lhs_idx); } else if output_ephemeral { - // Upstream LHS: we need a ROW STRIP = h rows * full K_full_Op0 width - // h = output height of the FINAL output (the subgraph output) - // K_full_Op0 = lhs.width (the full reduction dimension of this upstream op) + // Upstream LHS: ROW STRIP = h * K_full (full reduction width) let lhs_width = problem.tensors[lhs_idx].width; let row_strip_size = h * lhs_width; full_load.push((lhs_idx, row_strip_size)); } else { - // Standard LHS slice = h * k - k_strip.push((lhs_idx, h * k)); + // Standard LHS slice = h * k_eff + k_strip.push((lhs_idx, h * k_eff)); } } @@ -185,13 +187,12 @@ fn build_memory_plan( if previously_retained.contains(&rhs_idx) { pre_retained.push(rhs_idx); } else if output_ephemeral { - // Upstream RHS: col strip of the intermediate = K_full_Op0 * k - // = rhs.height * k (since rhs.height = K_full_Op0 for this upstream op) + // Upstream RHS: rhs.height * k_eff let rhs_height = problem.tensors[rhs_idx].height; - k_strip.push((rhs_idx, rhs_height * k)); + k_strip.push((rhs_idx, rhs_height * k_eff)); } else { - // Standard RHS slice = k * w - k_strip.push((rhs_idx, k * w)); + // Standard RHS slice = k_eff * w + k_strip.push((rhs_idx, k_eff * w)); } } } @@ -213,14 +214,15 @@ fn build_memory_plan( } } -/// Compute num_k_steps for a subgraph: ceil(min_K_full / k) across all MatMuls. +/// Compute num_k_steps for a subgraph: ceil(max_K_full / k) across all MatMuls. +/// Uses MAX so the subgraph runs until the longest reduction finishes. /// Returns 1 if there are no MatMul ops. pub fn compute_num_k_steps( subgraph_ops: &[usize], k: i64, problem: &Problem, ) -> i64 { - let min_k_full: Option = subgraph_ops + let max_k_full: Option = subgraph_ops .iter() .filter_map(|&op_idx| { let op = &problem.ops[op_idx]; @@ -230,11 +232,11 @@ pub fn compute_num_k_steps( None } }) - .min(); + .max(); if k <= 0 { return 1; // Guard against division by zero from malformed input } - match min_k_full { + match max_k_full { Some(kf) => (kf + k - 1) / k, None => 1, } @@ -335,31 +337,166 @@ pub fn subgraph_latency( if num_k_steps > 1 { // Split-K mode: all spatial tiles are identical. // - // First k-step of each spatial tile: - // load = full_load_total + pw_load_total + k_strip_total - // compute = matmul_compute - let first_k_mem = (full_load_total + pw_load_total + k_strip_total) as f64 / bw; - let first_k_lat = f64::max(matmul_compute, first_k_mem); - - // Interior k-steps (steps 2 .. num_k_steps-1): - // load = k_strip_total only - // compute = matmul_compute - let interior_k_lat = if num_k_steps > 2 { - let interior_mem = k_strip_total as f64 / bw; - f64::max(matmul_compute, interior_mem) + // Mixed-K support: MatMuls with different K_full values finish at different steps. + // We compute latency in phases where each phase has a different set of active MatMuls. + // + // Build a lookup from tensor_id -> k_strip_size from the memory plan. + // Used to compute per-MatMul k_strip contributions accurately (no compute-ratio proxy). + let k_strip_map: std::collections::HashMap = + plan.k_strip.iter().map(|&(t, sz)| (t, sz)).collect(); + + // Collect (K_full, base_cost, k_strip_contribution) tuples for all MatMul ops. + // k_strip_contribution is the sum of k_strip sizes for this op's boundary LHS/RHS + // inputs (using the same deduplication logic as build_memory_plan: each tensor + // is counted for the first MatMul op that claims it). + let mut k_strip_seen: std::collections::HashSet = std::collections::HashSet::new(); + let matmul_phases: Vec<(i64, f64, i64)> = subgraph_ops + .iter() + .filter_map(|&op_idx| { + let op = &problem.ops[op_idx]; + if !op.is_matmul() { + return None; + } + let kf = k_full_for_matmul(op, &problem.tensors); + let lhs_idx = op.inputs[0]; + let rhs_idx = op.inputs[1]; + let mut op_k_strip: i64 = 0; + if !k_strip_seen.contains(&lhs_idx) { + if let Some(&sz) = k_strip_map.get(&lhs_idx) { + op_k_strip += sz; + } + k_strip_seen.insert(lhs_idx); + } + if !k_strip_seen.contains(&rhs_idx) { + if let Some(&sz) = k_strip_map.get(&rhs_idx) { + op_k_strip += sz; + } + k_strip_seen.insert(rhs_idx); + } + Some((kf, op.base_cost as f64, op_k_strip)) + }) + .collect(); + + // Check if all MatMuls have identical K_full (fast path, existing formula). + let all_same_k_full = matmul_phases.windows(2).all(|w| w[0].0 == w[1].0); + + let per_tile_lat = if all_same_k_full { + // Fast path: uniform K_full — use original formula. + // + // First k-step: load = full_load_total + pw_load_total + k_strip_total + let first_k_mem = (full_load_total + pw_load_total + k_strip_total) as f64 / bw; + let first_k_lat = f64::max(matmul_compute, first_k_mem); + + // Interior k-steps: load = k_strip_total only + let interior_k_lat = if num_k_steps > 2 { + let interior_mem = k_strip_total as f64 / bw; + f64::max(matmul_compute, interior_mem) + } else { + 0.0 + }; + + // Last k-step: load = k_strip_total, evict output, compute includes PW + let last_k_mem = (k_strip_total + plan.out_evict_size) as f64 / bw; + let last_k_lat = f64::max(matmul_compute + pw_compute, last_k_mem); + + first_k_lat + (num_k_steps - 2).max(0) as f64 * interior_k_lat + last_k_lat } else { - 0.0 - }; + // Mixed-K path: compute phase-by-phase. + // + // Phases are defined by sorted unique step-end boundaries (when each MatMul finishes). + // + // Example: K_full = [4, 8], k = 2 + // MatMul-A finishes at step ceil(4/2)=2, MatMul-B at step ceil(8/2)=4 + // Phase 1: steps 0..2 — both active (2 steps) + // Phase 2: steps 2..4 — only MatMul-B active (2 steps) + // + // Within a phase all steps have identical cost except: + // - Global step 0: loads full_load_total + pw_load_total additionally + // - Global last step: evicts output + adds PW compute + // Replace the per-step loop with O(1) per-phase arithmetic. + let mut step_ends: Vec = matmul_phases + .iter() + .map(|(kf, _, _)| (*kf + k - 1) / k) + .collect(); + step_ends.sort_unstable(); + step_ends.dedup(); + // step_ends.last() == num_k_steps (max) + + let mut per_tile_lat = 0.0_f64; + let mut prev_end: i64 = 0; + + for (phase_idx, &phase_end) in step_ends.iter().enumerate() { + // Active MatMuls: those with step_count >= phase_end. + let active_compute: f64 = matmul_phases + .iter() + .filter(|(kf, _, _)| (*kf + k - 1) / k >= phase_end) + .map(|(kf, cost, _)| cost * ((k as f64).min(*kf as f64) / *kf as f64)) + .sum(); + + // Active k_strip: sum per-op contributions for active MatMuls only. + // This is exact because each op's contribution was precomputed from + // its actual tensor dimensions, not from a compute-ratio proxy. + let active_k_strip_elems: i64 = matmul_phases + .iter() + .filter(|(kf, _, _)| (*kf + k - 1) / k >= phase_end) + .map(|(_, _, ks)| ks) + .sum(); + let active_k_strip = active_k_strip_elems as f64 / bw; + + let phase_steps = phase_end - prev_end; + let is_last_phase = phase_idx == step_ends.len() - 1; + + // O(1) per phase: classify steps as first, interior, or last. + // Special steps: global step 0 (loads full_load + pw_load) and + // global last step (evicts output, adds PW compute). + let has_first = prev_end == 0; + let has_last = is_last_phase; // last phase always contains the last step + + // Interior steps: all steps in this phase that are neither first nor last. + let interior_count = (phase_steps + - if has_first { 1 } else { 0 } + - if has_last { 1 } else { 0 }) + .max(0); + + if has_first { + let mem = (full_load_total + pw_load_total) as f64 / bw + active_k_strip; + // First step is also last only when num_k_steps == 1, which is + // impossible here (we are in the num_k_steps > 1 branch). + per_tile_lat += f64::max(active_compute, mem); + } + + if interior_count > 0 { + let interior_lat = f64::max(active_compute, active_k_strip); + per_tile_lat += interior_count as f64 * interior_lat; + } - // Last k-step of each spatial tile: - // load = k_strip_total - // evict = out_evict_size - // compute = matmul_compute + pw_compute - let last_k_mem = (k_strip_total + plan.out_evict_size) as f64 / bw; - let last_k_lat = f64::max(matmul_compute + pw_compute, last_k_mem); + if has_last { + // If this phase has exactly one step and it is also the first step, + // we already accounted for it above; replace that cost with the + // combined first+last cost. + let is_also_first = has_first && phase_steps == 1; + if !is_also_first { + let mem_last = active_k_strip + plan.out_evict_size as f64 / bw; + let compute_last = active_compute + pw_compute; + per_tile_lat += f64::max(compute_last, mem_last); + } else { + // Single-step phase that is both first and last. + // Undo the first-step cost already added, then add combined cost. + let first_mem = (full_load_total + pw_load_total) as f64 / bw + active_k_strip; + per_tile_lat -= f64::max(active_compute, first_mem); + let mem_last = (full_load_total + pw_load_total) as f64 / bw + + active_k_strip + + plan.out_evict_size as f64 / bw; + let compute_last = active_compute + pw_compute; + per_tile_lat += f64::max(compute_last, mem_last); + } + } + + prev_end = phase_end; + } - // num_k_steps >= 2 here (outer branch guarantees it) - let per_tile_lat = first_k_lat + (num_k_steps - 2).max(0) as f64 * interior_k_lat + last_k_lat; + per_tile_lat + }; num_spatial_tiles as f64 * per_tile_lat } else { diff --git a/solution/backend/rust/src/main.rs b/solution/backend/rust/src/main.rs index 1bf4d36..7a24724 100644 --- a/solution/backend/rust/src/main.rs +++ b/solution/backend/rust/src/main.rs @@ -142,7 +142,7 @@ fn run_evaluate(args: &[String]) { #[cfg(test)] mod tests { use super::*; - use crate::latency::subgraph_latency; + use crate::latency::{compute_num_k_steps, subgraph_latency}; use crate::models::Granularity; use std::collections::HashSet; @@ -630,6 +630,54 @@ mod tests { ); } + // Mixed-K: two MatMuls with different K_full fused in one subgraph + #[test] + fn test_mixed_k_two_matmuls() { + // Op0: MatMul K=64, Tensor0(64x128) @ Tensor1(128x64) -> Tensor2(128x128), cost=1000 + // Op1: MatMul K=128, Tensor2(128x128) @ Tensor3(128x128) -> Tensor4(128x128), cost=2000 + // Fused [0,1] with k=32: + // Op0: ceil(64/32)=2 k-steps, compute/step = 1000*(32/64) = 500 + // Op1: ceil(128/32)=4 k-steps, compute/step = 2000*(32/128) = 500 + // Total k-steps = max(2,4) = 4 + // Phase 1 (steps 0-1): both active, compute = 500+500 = 1000 + // Phase 2 (steps 2-3): only Op1 active, compute = 500 + let json = r#"{ + "widths": [64,128,128,128,128], + "heights": [128,64,128,128,128], + "inputs": [[0,1],[2,3]], + "outputs": [[2],[4]], + "base_costs": [1000, 2000], + "op_types": ["MatMul","MatMul"], + "fast_memory_capacity": 100000, + "slow_memory_bandwidth": 10, + "native_granularity": [128, 128] + }"#; + let problem = parse_problem(json).unwrap(); + let dag = DagInfo::build(&problem).unwrap(); + + let gran = Granularity { w: 128, h: 128, k: 32 }; + let lat = subgraph_latency(&[0, 1], &gran, &[], &HashSet::new(), &problem, &dag); + + // 4 k-steps, 1 spatial tile + // Verify latency > 0 and is reasonable (phases produce different costs) + assert!(lat > 0.0, "Mixed-K latency must be positive, got {lat}"); + + // Compare with k=64 (Op0 does 1 step, Op1 does 2 steps) + let gran64 = Granularity { w: 128, h: 128, k: 64 }; + let lat64 = subgraph_latency(&[0, 1], &gran64, &[], &HashSet::new(), &problem, &dag); + assert!(lat64 > 0.0, "Mixed-K k=64 latency must be positive, got {lat64}"); + + // k=128: Op0 does 1 step (k>K_full clamped), Op1 does 1 step + let gran128 = Granularity { w: 128, h: 128, k: 128 }; + let lat128 = subgraph_latency(&[0, 1], &gran128, &[], &HashSet::new(), &problem, &dag); + assert!(lat128 > 0.0, "Mixed-K k=128 latency must be positive, got {lat128}"); + + // Verify num_k_steps = ceil(max(64,128)/k) + assert_eq!(compute_num_k_steps(&[0, 1], 32, &problem), 4); // ceil(128/32) + assert_eq!(compute_num_k_steps(&[0, 1], 64, &problem), 2); // ceil(128/64) + assert_eq!(compute_num_k_steps(&[0, 1], 128, &problem), 1); // ceil(128/128) + } + // Benchmark validation: all 5 benchmark solutions should cover all ops and have non-negative latencies #[test] fn test_benchmark_solutions_validity() { diff --git a/solution/backend/rust/src/memory.rs b/solution/backend/rust/src/memory.rs index f2ab87d..763ddbc 100644 --- a/solution/backend/rust/src/memory.rs +++ b/solution/backend/rust/src/memory.rs @@ -143,8 +143,9 @@ pub fn check_oom( ) <= problem.fast_memory_capacity } -/// Find the largest k (power-of-2 downward from min K_full) that fits in memory. -/// Uses min K_full across all MatMuls so k never exceeds any op's reduction dimension. +/// Find the largest k (power-of-2 downward from max K_full) that fits in memory. +/// Uses max K_full across all MatMuls as the upper search bound; the working-set +/// calculator handles per-op slice sizes correctly for mixed-K subgraphs. pub fn find_split_k( subgraph_ops: &[usize], granularity: &Granularity, @@ -153,8 +154,10 @@ pub fn find_split_k( problem: &Problem, dag: &DagInfo, ) -> Option { - // Find the MINIMUM K_full across all MatMul ops so k is valid for every op - let min_k_full = subgraph_ops + // Find the MAXIMUM K_full across all MatMul ops as the upper search bound. + // For mixed-K subgraphs k candidates start from the largest K_full and halve + // downward; the memory check accounts for actual per-op slice sizes. + let max_k_full = subgraph_ops .iter() .filter_map(|&op_idx| { let op = &problem.ops[op_idx]; @@ -164,9 +167,9 @@ pub fn find_split_k( None } }) - .min(); + .max(); - let k_full = match min_k_full { + let k_full = match max_k_full { Some(kf) => kf, None => { // No MatMul ops -- try with k=1 for pointwise-only diff --git a/solution/backend/rust/src/optimizer/fusion.rs b/solution/backend/rust/src/optimizer/fusion.rs index 62c4bb6..73f4010 100644 --- a/solution/backend/rust/src/optimizer/fusion.rs +++ b/solution/backend/rust/src/optimizer/fusion.rs @@ -12,7 +12,7 @@ use crate::latency::subgraph_latency; use crate::memory::find_split_k; use crate::models::{Granularity, Problem, SubgraphDef}; use crate::optimizer::granularity::search_best_granularity; -use crate::parser::{k_full_for_matmul, native_granularity_for_subgraph}; +use crate::parser::native_granularity_for_subgraph; /// Attempt to fuse the subgraphs in topological order. /// Returns a new list of subgraphs where adjacent groups have been merged @@ -69,15 +69,7 @@ pub fn greedy_fusion( }) }; - // K_full consistency: all MatMuls must share the same K_full. - let matmul_k_fulls: Vec = merged.iter() - .filter(|&&op_idx| problem.ops[op_idx].is_matmul()) - .map(|&op_idx| k_full_for_matmul(&problem.ops[op_idx], &problem.tensors)) - .collect(); - let k_full_consistent = matmul_k_fulls.is_empty() - || matmul_k_fulls.iter().all(|&kf| kf == matmul_k_fulls[0]); - - if dims_consistent && k_full_consistent { + if dims_consistent { if find_feasible_granularity(&merged, &retained_before, problem, dag).is_some() { let base_merged = native_granularity_for_subgraph(&merged, problem); diff --git a/solution/backend/rust/src/optimizer/granularity.rs b/solution/backend/rust/src/optimizer/granularity.rs index 882e1b1..83564f8 100644 --- a/solution/backend/rust/src/optimizer/granularity.rs +++ b/solution/backend/rust/src/optimizer/granularity.rs @@ -11,8 +11,9 @@ use crate::memory::{check_oom, find_split_k}; use crate::models::{Granularity, Problem, SubgraphDef}; use crate::parser::k_full_for_matmul; -/// Find K_full for a subgraph: the minimum K_full across ALL MatMuls in the subgraph. -/// Internal MatMuls (ephemeral output) still drive k-step counts, so we must consider them. +/// Find K_full for a subgraph: the maximum K_full across ALL MatMuls in the subgraph. +/// Uses max so k candidates extend up to the longest reduction dimension, allowing +/// all k steps to be explored for mixed-K subgraphs. fn find_k_full(ops: &[usize], problem: &Problem, _dag: &DagInfo) -> Option { ops.iter() .filter_map(|&op_idx| { @@ -23,7 +24,7 @@ fn find_k_full(ops: &[usize], problem: &Problem, _dag: &DagInfo) -> Option None } }) - .min() + .max() } /// Generate candidate w/h values for a given tensor dimension. diff --git a/solution/docs/architecture/system-design.md b/solution/docs/architecture/system-design.md index c946411..e6c43ed 100644 --- a/solution/docs/architecture/system-design.md +++ b/solution/docs/architecture/system-design.md @@ -180,6 +180,8 @@ Note: `latency(A)` already includes evicting boundary outputs to DRAM, and `late This prevents harmful fusions where forcing a shared granularity degrades latency more than the DRAM savings from making intermediates ephemeral. +**Mixed-K fusion (Issue #22):** The fusion stage no longer rejects merges between ops with different K_full values. MatMuls with different reduction dimensions can coexist in the same subgraph. The cost-based criterion remains the gatekeeper: a mixed-K fusion is accepted only when the fused latency (computed under the mixed-K execution model) is lower than the split latency. See "Mixed-K Execution Model" below for how the latency is computed. + ### Stage 7: Granularity Search -- Full (w, h, k) Search The granularity search explores a three-dimensional candidate space for each subgraph: @@ -188,10 +190,10 @@ The granularity search explores a three-dimensional candidate space for each sub ``` w candidates: powers of 2 from 1 up to output_width h candidates: powers of 2 from 1 up to output_height -k candidates: K_cap, K_cap/2, K_cap/4, ..., 1 (powers of 2, descending) +k candidates: K_max, K_max/2, K_max/4, ..., 1 (powers of 2, descending) ``` -Where `K_cap = K_full` (the shared reduction dimension across all MatMuls in the subgraph). **Invariant: all MatMul ops within a single subgraph must share the same K_full.** This invariant is enforced during fusion (ops with different K_full are not merged) and validated during evaluation. It ensures the subgraph has a single, well-defined k-step loop. For Pointwise-only subgraphs, k is fixed at 1 and only (w, h) is searched. +Where `K_max = max(K_full_op for each MatMul op in the subgraph)`. **Mixed-K subgraphs are permitted (Issue #22):** MatMul ops within a single subgraph may have different K_full values. Each MatMul individually requires `ceil(K_full_op / k)` k-steps, and the subgraph runs for `ceil(K_max / k)` total k-steps (driven by the MatMul with the largest reduction dimension). MatMuls that finish their reduction before the final k-step contribute zero compute and zero memory traffic on remaining steps. For Pointwise-only subgraphs, k is fixed at 1 and only (w, h) is searched. **For each (w, h, k) candidate:** 1. Compute the working set (input slices + output slices + retained tensors) @@ -212,10 +214,10 @@ Instead of iterating over every tile and k-step to sum per-step roofline costs ( ``` num_rows = ceil(H_out / h) num_cols = ceil(W_out / w) -num_k_steps = ceil(K_full / k) +num_k_steps = ceil(K_max / k) -- where K_max = max(K_full) across all MatMuls ``` -**Spatial-only** (num_k_steps == 1) — row-reuse applies: +**Spatial-only** (num_k_steps == 1) -- row-reuse applies: ``` first_col_latency = max(compute, (full_load + k_strip_lhs + rhs + pw + evict) / bw) @@ -223,8 +225,9 @@ other_col_latency = max(compute, (rhs + pw + evict) / bw) total_latency = num_rows * (first_col_latency + (num_cols - 1) * other_col_latency) ``` -**Split-K** (num_k_steps > 1) — all tiles identical, no row-reuse: +**Split-K / Mixed-K** (num_k_steps > 1) -- per-step compute and memory vary by step: +For **uniform-K subgraphs** (all MatMuls share the same K_full), the formula remains: ``` first_k = max(matmul_compute, (full_load + pw_load + k_strip) / bw) interior = max(matmul_compute, k_strip / bw) @@ -233,20 +236,96 @@ per_tile = first_k + max(0, num_k_steps - 2) * interior + last_k total_latency = num_spatial_tiles * per_tile ``` +For **mixed-K subgraphs** (MatMuls with different K_full values), the per-step cost varies because some MatMuls finish their reduction before others. The execution has distinct phases: + +``` +K_full values (sorted descending): K1 >= K2 >= ... >= Kn +num_k_steps_i = ceil(Ki / k) -- steps for MatMul i +num_k_steps = max(num_k_steps_i) = ceil(K1 / k) -- total subgraph steps + +Phase boundaries occur at each distinct num_k_steps_i value. +Within each phase, the set of active MatMuls is constant. + +For step_idx in 0..num_k_steps: + active_matmuls = { op | step_idx < ceil(K_full_op / k) } + matmul_compute = sum(op.base_cost * (k / K_full_op) for op in active_matmuls) + matmul_memory = sum(k_strip_sizes for op in active_matmuls) + + if step_idx == num_k_steps - 1: -- last step + compute = matmul_compute + pw_compute + memory includes eviction + else: + compute = matmul_compute + + step_latency = max(compute, memory / bw) +``` + +The closed-form optimization groups consecutive steps with the same set of active MatMuls into phases, computing each phase's contribution in O(1). The total number of phases equals the number of distinct K_full values, so the per-candidate cost is O(distinct_K_values) rather than O(num_k_steps). + For **Pointwise-only subgraphs** (num_k_steps = 1, no MatMul reuse patterns), the formula simplifies to `num_tiles * max(compute, memory)` since all tiles are identical. -This reduces candidate evaluation from O(tiles * k_steps) to O(1), making the entire granularity search O(candidates) where candidates = O(log W * log H * log K_full). +This reduces candidate evaluation from O(tiles * k_steps) to O(1) for uniform-K and O(distinct_K) for mixed-K, making the entire granularity search O(candidates) where candidates = O(log W * log H * log K_max). -**Search complexity:** For a subgraph with output dimensions W x H and reduction K_full: +**Search complexity:** For a subgraph with output dimensions W x H and reduction K_max: - w candidates: O(log W) - h candidates: O(log H) -- k candidates: O(log K_full) -- Total: O(log W * log H * log K_full) candidates per subgraph -- Each candidate evaluation is O(1) via closed-form +- k candidates: O(log K_max) +- Total: O(log W * log H * log K_max) candidates per subgraph +- Each candidate evaluation is O(1) for uniform-K or O(distinct_K) for mixed-K via closed-form - Total search time remains well under 100 milliseconds for all benchmarks --- +## Mixed-K Execution Model (Issue #22 / ADR-006) + +When a subgraph contains MatMul ops with different K_full values, the subgraph uses a **mixed-K execution model**. This section describes the semantics precisely. + +### Motivation + +The previous design enforced a K_full consistency invariant: all MatMuls in a subgraph had to share the same K_full. This was more restrictive than the problem statement requires. The problem says each MatMul individually processes its dot product in `ceil(K_full_op / k)` k-steps. Different MatMuls can have different K_full values in the same subgraph. + +The K_full consistency constraint prevented fusing adjacent MatMul+Pointwise chains when the MatMuls had different reduction dimensions (e.g., K=1024 and K=4096). This created unnecessary DRAM boundaries between subgraphs, accounting for approximately 30% of total latency on benchmarks 1 and 9. + +### Execution Semantics + +For a subgraph with granularity (w, h, k) containing MatMuls with K_full values {K1, K2, ..., Kn}: + +``` +K_max = max(K1, K2, ..., Kn) +num_k_steps = ceil(K_max / k) +``` + +For each spatial tile, the subgraph runs `num_k_steps` k-steps: + +| Step Index | Active MatMuls | Compute | Memory | +|------------|---------------|---------|--------| +| 0 .. ceil(Ki/k)-1 | MatMul_i is active | base_cost_i * (k / Ki) | Load LHS/RHS strips for MatMul_i | +| ceil(Ki/k) .. num_k_steps-1 | MatMul_i is **inactive** | 0 | 0 (no strips loaded) | +| num_k_steps - 1 (last) | All remaining active MatMuls + all Pointwise ops | matmul_compute + pw_compute | strips + eviction | + +**Key properties:** +- A MatMul that has completed all its k-steps (i.e., `step_idx >= ceil(K_full_op / k)`) contributes **zero compute and zero memory** on subsequent steps +- Pointwise ops execute only on the **last k-step** (step `num_k_steps - 1`), as in the uniform-K case +- The working-set OOM check uses the **worst-case step** (typically step 0, when all MatMuls are active and all input strips must be loaded) + +### Example + +Subgraph with two MatMuls (K_full=1024 and K_full=4096) and one Pointwise, k=128: + +``` +MatMul_A: K_full=1024, num_k_steps_A = ceil(1024/128) = 8 +MatMul_B: K_full=4096, num_k_steps_B = ceil(4096/128) = 32 +Subgraph num_k_steps = max(8, 32) = 32 + +Steps 0-7: MatMul_A active (compute + memory), MatMul_B active (compute + memory) +Steps 8-30: MatMul_A done (0 cost), MatMul_B active (compute + memory) +Step 31: MatMul_B active + Pointwise executes + eviction +``` + +On steps 8-30, the memory working set is smaller because MatMul_A's input strips are no longer loaded. + +--- + ## Latency Model Specification The latency model implements the roofline evaluation described in PROBLEM.md and must match the C++ `Evaluate()` function exactly. @@ -258,9 +337,11 @@ The latency model implements the roofline evaluation described in PROBLEM.md and - `num_tiles_h = ceil(H_out / h)` -- number of spatial tiles along height - `num_spatial_tiles = num_tiles_w * num_tiles_h` -**K-Steps (Split-K)**: For MatMul with reduction dimension `K_full`: -- `num_k_steps = ceil(K_full / k)` -- For Pointwise: `num_k_steps = 1` (k is ignored) +**K-Steps (Split-K / Mixed-K)**: For a subgraph containing MatMul ops: +- Each MatMul op has its own K_full: `num_k_steps_op = ceil(K_full_op / k)` +- The subgraph runs for `num_k_steps = ceil(max(K_full) / k)` total k-steps +- A MatMul op is **active** on step `s` if `s < ceil(K_full_op / k)`, otherwise it contributes nothing +- For Pointwise-only subgraphs: `num_k_steps = 1` (k is ignored) **Total Iterations**: `num_spatial_tiles * num_k_steps` @@ -278,25 +359,26 @@ For each execution step (one spatial tile, one k-step): - When granularity equals native: `base_cost` is the cost per tile, and `num_spatial_tiles` tiles cover the full tensor - When granularity is smaller: `base_cost` is still the cost per tile (hardware pads), but more tiles are needed -**Reduction scaling**: For MatMul, each k-step costs `base_cost * (k / K_full)` where `K_full` is the op's full reduction dimension. Verified against Example 5B: `k=32`, `K_full=128`, `base_cost=2000` per op, compute per step = `2000*(32/128) + 2000*(32/128) = 1000`. +**Reduction scaling**: For MatMul, each k-step costs `base_cost * (k / K_full)` where `K_full` is the op's full reduction dimension, **but only on steps where the op is active** (i.e., `step_idx < ceil(K_full_op / k)`). On steps after the op has completed its reduction, it contributes zero compute. Verified against Example 5B: `k=32`, `K_full=128`, `base_cost=2000` per op, compute per step = `2000*(32/128) + 2000*(32/128) = 1000`. -For Pointwise, k is irrelevant — the op executes **once per spatial tile** (on the last k-step only). In a fused subgraph with k-steps from a MatMul, Pointwise compute is added only on the final k-step of each spatial tile, not every step. Verified against Example 1C (pure Pointwise, no k-steps): `base_cost=1000+100=1100` per tile, 4 tiles. +For Pointwise, k is irrelevant -- the op executes **once per spatial tile** (on the last k-step only). In a fused subgraph with k-steps from a MatMul, Pointwise compute is added only on the final k-step of each spatial tile, not every step. Verified against Example 1C (pure Pointwise, no k-steps): `base_cost=1000+100=1100` per tile, 4 tiles. **Spatial padding**: if `w < native_w` or `h < native_h`, you still pay full `base_cost` per step (hardware pads), but need more spatial tiles to cover the tensor. -**Summary**: +**Summary (mixed-K aware)**: ``` -For each k-step of a spatial tile: - matmul_compute = sum(op.base_cost * (k / K_full) for MatMul ops) - if is_last_k_step: +For each k-step (step_idx) of a spatial tile: + active_matmuls = { op | op is MatMul AND step_idx < ceil(K_full_op / k) } + matmul_compute = sum(op.base_cost * (k / K_full_op) for op in active_matmuls) + if is_last_k_step: -- step_idx == ceil(K_max / k) - 1 compute = matmul_compute + sum(op.base_cost for Pointwise ops) else: compute = matmul_compute ``` -Where `K_full_for_this_op` is the inner/reduction dimension of that specific MatMul (the width of the LHS input = height of the RHS input... actually: for MatMul with inputs [LHS, RHS], `K_full = LHS.width = RHS.height`). +Where `K_full_op` is the inner/reduction dimension of that specific MatMul (for MatMul with inputs [LHS, RHS], `K_full = LHS.width = RHS.height`). -Actually, let me re-examine. From the granularity definition: +From the granularity definition: - LHS input slice: width `k`, height `h` - RHS input slice: width `w`, height `k` - Output slice: width `w`, height `h` @@ -312,6 +394,7 @@ For each execution step, we must account for data loaded from slow memory and da - Compute the slice size based on the op type and granularity - If the tensor is already resident in fast memory (retained from previous subgraph, or already loaded in a previous step of the same subgraph via intra-subgraph reuse) it costs 0 - Otherwise: `slice_size / slow_memory_bandwidth` + - **Mixed-K rule (Issue #22)**: If the MatMul consuming this tensor is inactive on this step (has completed its reduction), the tensor's input strips are NOT loaded -- zero memory cost for that tensor on this step **Slice sizes** (for one spatial tile + one k-step): - Pointwise input: `w * h` elements @@ -381,6 +464,8 @@ working_set = sum(slice_size for each boundary input and output tensor that must + sum(size of retained tensors from previous subgraphs) ``` +**Mixed-K working set note (Issue #22)**: The OOM check uses the **worst-case step**, which is typically step 0 when all MatMuls are active and all input strips must be loaded simultaneously. On later steps where some MatMuls have finished, the actual memory usage is lower, but the OOM constraint must be satisfied for the worst case. + **OOM check**: `working_set <= fast_memory_capacity` ### Retained Tensors from Previous Subgraphs @@ -444,23 +529,32 @@ num_spatial_tiles = num_tiles_w * num_tiles_h ### Number of K-Steps ``` -For MatMul: num_k_steps = ceil(K_full / k) +For uniform-K subgraphs: num_k_steps = ceil(K_full / k) +For mixed-K subgraphs: num_k_steps = ceil(max(K_full_op for each MatMul) / k) +Per-op active steps: num_k_steps_op = ceil(K_full_op / k) + op is active on step s if s < num_k_steps_op For Pointwise-only subgraphs: num_k_steps = 1 ``` ### Compute Cost Per Step +For a given step_idx: ``` -compute_per_step = sum for each op in subgraph: - if MatMul: base_cost * min(k, K_full_remaining) / native_k - if Pointwise: base_cost +active_matmuls = { op | op is MatMul AND step_idx < ceil(K_full_op / k) } +matmul_compute = sum(op.base_cost * (k / K_full_op) for op in active_matmuls) ``` -Actually, let me be more precise. From the problem: "choosing k below native simply runs fewer cycles, dividing compute proportionally without waste." So for MatMul: +On the last k-step (step_idx == num_k_steps - 1), add Pointwise compute: ``` -compute_per_matmul_step = base_cost * (k / K_full) +compute = matmul_compute + sum(op.base_cost for Pointwise ops) ``` -where `K_full` is the full reduction dimension of that MatMul. + +On all other steps: +``` +compute = matmul_compute +``` + +Note: from the problem statement, "choosing k below native simply runs fewer cycles, dividing compute proportionally without waste." For MatMul, `compute_per_matmul_step = base_cost * (k / K_full_op)` where `K_full_op` is the full reduction dimension of that specific MatMul. For the spatial dimensions, if `w < native_w` or `h < native_h`, you still pay `base_cost` (padded), but you need more tiles. The examples confirm this. @@ -469,7 +563,7 @@ For the spatial dimensions, if `w < native_w` or `h < native_h`, you still pay ` ``` step_latency = max(compute_time, memory_time) where: - compute_time = sum of per-op compute costs for this step + compute_time = sum of per-op compute costs for active ops on this step memory_time = (bytes_in + bytes_out) / slow_memory_bandwidth ``` @@ -497,9 +591,10 @@ total_latency = sum(subgraph_latency for each subgraph) ## Performance Considerations 1. **Rust zero-cost abstractions**: The scheduler performs integer arithmetic, comparisons, and vector operations. Rust compiles to native code with no garbage collection pauses. -2. **Closed-form granularity evaluation (ADR-005)**: Candidate (w, h, k) triples are evaluated using a closed-form latency formula instead of tile-by-tile simulation. This reduces per-candidate evaluation from O(tiles * k_steps) to O(1). The search space remains O(log W * log H * log K) candidates per subgraph, with each evaluation being a constant-time arithmetic expression. See ADR-005 for the derivation and correctness argument. +2. **Closed-form granularity evaluation (ADR-005)**: Candidate (w, h, k) triples are evaluated using a closed-form latency formula instead of tile-by-tile simulation. This reduces per-candidate evaluation from O(tiles * k_steps) to O(1) for uniform-K subgraphs and O(distinct_K) for mixed-K subgraphs. The search space remains O(log W * log H * log K) candidates per subgraph, with each evaluation being a constant-time (or near-constant) arithmetic expression. See ADR-005 for the derivation and correctness argument. 3. **Early termination on OOM**: Before computing the closed-form latency for any candidate, the working set is checked against `fast_memory_capacity`. Infeasible candidates are pruned immediately, avoiding even the O(1) latency computation. For memory-constrained subgraphs, this eliminates the majority of the search space. 4. **Cost-based fusion (Issue #16)**: Before merging two subgraphs, the optimizer compares `latency(fused, best_granularity)` against `latency(A, best_A) + latency(B, best_B)`. Individual latencies already include DRAM boundary transfers (eviction from A, loading into B), so no separate boundary cost is added. The cost comparison uses the same closed-form evaluator, so it adds negligible overhead. -5. **Topological sort**: Kahn's algorithm, O(V + E), runs once. -6. **Total optimizer complexity**: O(N^2) for fusion (N = number of ops, each candidate merge requires a cost comparison), O(G) for granularity search per subgraph (G = candidate granularities, each O(1)). All 5 benchmarks complete end-to-end in under 2 seconds. -7. **Static binary**: `cargo build --release` with `lto = true` and `codegen-units = 1` produces a fully optimized, statically linked binary with no runtime dependencies. +5. **Mixed-K fusion (Issue #22)**: Relaxing the K_full consistency constraint allows fusing ops that were previously forced into separate subgraphs. This eliminates DRAM boundary transfers between those subgraphs, improving latency by up to 30% on benchmarks with mixed reduction dimensions (benchmarks 1 and 9). The additional complexity in the closed-form evaluator (phased computation) is O(distinct_K) per candidate, which is negligible. +6. **Topological sort**: Kahn's algorithm, O(V + E), runs once. +7. **Total optimizer complexity**: O(N^2) for fusion (N = number of ops, each candidate merge requires a cost comparison), O(G) for granularity search per subgraph (G = candidate granularities, each O(1) or O(distinct_K)). All 5 benchmarks complete end-to-end in under 2 seconds. +8. **Static binary**: `cargo build --release` with `lto = true` and `codegen-units = 1` produces a fully optimized, statically linked binary with no runtime dependencies. diff --git a/solution/docs/decisions/ADR-006-mixed-k-fusion.md b/solution/docs/decisions/ADR-006-mixed-k-fusion.md new file mode 100644 index 0000000..aa4eca0 --- /dev/null +++ b/solution/docs/decisions/ADR-006-mixed-k-fusion.md @@ -0,0 +1,72 @@ +# ADR-006: Mixed-K Fusion (Relax K_full Consistency Constraint) + +## Status + +Accepted + +## Context + +The previous design enforced a **K_full consistency invariant**: all MatMul ops within a single subgraph had to share the same K_full (full reduction dimension). This invariant was enforced during fusion (ops with different K_full were rejected from merging) and validated during evaluation. The rationale was simplicity: a single K_full value gives the subgraph a single, well-defined k-step loop count (`ceil(K_full / k)`). + +However, the problem statement does not require this constraint. From PROBLEM.md: "This single configuration creates a unified execution grid that every operation in the subgraph must conform to." The granularity `(w, h, k)` applies uniformly, but each MatMul individually processes its dot product in `ceil(K_full_op / k)` k-steps. Different MatMuls naturally have different K_full values based on their input tensor dimensions. + +### Impact of the constraint + +The K_full consistency constraint prevented fusing adjacent op chains when their MatMuls had different reduction dimensions. This forced DRAM boundaries between subgraphs that could otherwise share ephemeral intermediates, creating two unnecessary transfers per boundary (eviction from subgraph A + loading into subgraph B). + +Profiling on benchmarks 1 and 9 showed that these artificial DRAM boundaries account for approximately **30% of total latency**. Benchmark 9 in particular has 8 repeating MatMul+Pointwise blocks with mixed tensor sizes (1024-4096), where adjacent blocks frequently have MatMuls with different K_full values. + +## Decision + +**Remove the K_full consistency invariant.** Allow MatMul ops with different K_full values to coexist in the same subgraph. + +### Mixed-K Execution Model + +For a subgraph with MatMuls having K_full values {K1, K2, ..., Kn} and granularity k: + +1. **Total k-steps**: `num_k_steps = ceil(max(K1, ..., Kn) / k)` +2. **Per-op activity**: MatMul_i is active on step `s` if `s < ceil(Ki / k)`, inactive otherwise +3. **Inactive ops contribute zero**: No compute cost, no memory traffic (input strips not loaded) +4. **Pointwise ops**: Execute on the last k-step only (unchanged from uniform-K) +5. **OOM check**: Uses worst-case step (step 0, when all MatMuls are active) + +### Example + +Subgraph with MatMul_A (K=1024) and MatMul_B (K=4096), k=128: +- MatMul_A active for 8 steps, MatMul_B active for 32 steps +- Total: 32 k-steps per spatial tile +- Steps 0-7: both active (full compute + full memory) +- Steps 8-31: only MatMul_B active (reduced compute + reduced memory) + +### Changes Required + +1. **Fusion stage**: Remove the K_full equality check from the merge eligibility filter. The cost-based criterion (`latency_fused < latency_split`) remains the gatekeeper. +2. **Granularity search**: Use `K_max = max(K_full_op)` for the k candidate range. Each candidate is evaluated using the mixed-K latency model. +3. **Closed-form latency evaluator**: Extend to handle per-step variation in active ops. Group consecutive steps with the same active set into phases; compute each phase in O(1). Total cost per candidate: O(distinct_K_values). +4. **Working-set calculator**: No change needed -- already computes worst-case (all ops active). +5. **Evaluator/validator**: Remove the K_full consistency assertion. + +## Consequences + +### Positive + +- **30% latency improvement** on benchmarks 1 and 9 by eliminating artificial DRAM boundaries +- **Better fusion opportunities**: The cost-based criterion still prevents harmful fusions, but now has access to a larger candidate space +- **Correct modeling**: Matches the problem statement's per-op k-step semantics + +### Negative + +- **More complex latency model**: The closed-form evaluator must handle phased execution with varying active op sets. This increases per-candidate evaluation from O(1) to O(distinct_K_values), though in practice distinct_K is small (2-4 values) +- **Harder to reason about**: Uniform-K subgraphs have a clean, regular execution pattern. Mixed-K introduces step-varying behavior that is harder to debug and validate +- **Working-set underestimation risk**: If the worst-case step is not step 0 (e.g., if retained tensors from a finished MatMul interact unexpectedly with later steps), the OOM check could be incorrect. Mitigation: the working set at step 0 is always the maximum because it has the most active input strips. + +### Neutral + +- The uniform-K case is a special case of mixed-K (all ops have the same K_full), so no regression on existing behavior +- The cost-based fusion criterion provides a safety net: even if mixed-K is allowed, a fusion only happens when it provably reduces latency + +## References + +- Issue #22: Relax K_full consistency to allow mixed-K fusion +- Issue #16: Cost-based fusion (prerequisite -- provides the safety net) +- ADR-005: Closed-form latency evaluation (must be extended for phased computation) diff --git a/solution/requirements/mvp-scope.md b/solution/requirements/mvp-scope.md index 56738a7..8d07097 100644 --- a/solution/requirements/mvp-scope.md +++ b/solution/requirements/mvp-scope.md @@ -43,7 +43,7 @@ Goal: Lowest total latency on MLSys-2026 benchmarks | # | Feature | Acceptance Criteria | Est. (h) | Depends On | |---|---------|--------------------|----|-----------| -| 1 | **F-14: Topological sort** | Given any valid DAG, returns operations in a valid linearized order (tested on examples 1–5) | 1 | none | +| 1 | **F-14: Topological sort** | Given any valid DAG, returns operations in a valid linearized order (tested on examples 1-5) | 1 | none | | 2 | **F-01: Problem JSON parser** | Reads all 5 benchmark files + example_problem.json without error; reconstructs `Problem` struct with correct tensor/op counts and hardware params | 2 | none | | 3 | **F-02: Latency model** | Passes all 5 worked-example test cases from PROBLEM.md with latencies matching to 0.1 precision; handles compute-bound, memory-bound, and tiled (multi-step) cases correctly | 4 | F-01 | | 4 | **F-03: Working-set calculator** | Given a subgraph definition + resident tensor set + granularity, returns working-set size and raises OOM flag if > `fast_memory_capacity`; verified against all 5 examples | 3 | F-01, F-02 | @@ -51,10 +51,10 @@ Goal: Lowest total latency on MLSys-2026 benchmarks | 6 | **F-11: Solution JSON serializer** | Writes well-formed JSON matching the output schema; round-trips through a JSON validator; `null` traversal_orders serialize correctly | 2 | none | | 7 | **F-04: Baseline scheduler** | Produces one valid subgraph per operation; uses native granularity `[128, 128, K_full]`; `tensors_to_retain = []` for all; latency values match model; no OOM on any benchmark | 2 | F-01, F-02, F-03, F-11, F-14 | | 8 | **F-12: Benchmark runner** | CLI that accepts `--problem FILE --solution FILE`, calls evaluate logic, prints total latency and pass/fail | 3 | F-01, F-11 | -| 9 | **F-05: Op grouping / chain fusion** | Cost-based fusion: group adjacent ops only when fused latency is meaningfully less than `lat_a + lat_b` (relative tolerance to avoid float noise; boundary DRAM already included in individual latencies); verify latency improves vs. baseline on all 5 benchmarks | 6 | F-14, F-03, F-02 | +| 9 | **F-05: Op grouping / chain fusion** | Cost-based fusion: group adjacent ops only when fused latency is meaningfully less than `lat_a + lat_b` (relative tolerance to avoid float noise; boundary DRAM already included in individual latencies); **mixed-K fusion allowed (Issue #22)**: ops with different K_full values may be fused into the same subgraph when the cost-based criterion is satisfied, using the mixed-K execution model where `num_k_steps = ceil(max(K_full) / k)` and each MatMul is active for its own `ceil(K_full_op / k)` steps; verify latency improves vs. baseline on all 5 benchmarks | 6 | F-14, F-03, F-02 | | 10 | **F-07: Tensor retention** | After each subgraph, determine which output tensors are consumed by the immediately following subgraph and have sufficient residual capacity; retain them; verify improvement on Example 3C pattern | 4 | F-05, F-03 | | 11 | **F-08: Split-K** | For MatMul subgraphs where full-k working set exceeds capacity, search for the largest `k` divisor that fits; model accumulator as resident across k-steps; verify Example 5B latency | 5 | F-05, F-03, F-02 | -| 12 | **F-06: Granularity search** | For each subgraph, try candidate `[w, h]` values (powers of 2 up to tensor dimensions); **for MatMul subgraphs, also search `k` from `K_full` down to 1 in powers of 2 (Issue #15 fix) -- k must not be hardcoded to 1**; use **closed-form latency evaluation** (ADR-005, Issue #16) instead of tile-by-tile simulation; select the `[w, h, k]` combination that minimizes subgraph latency within the OOM constraint; larger k values reduce the total number of k-steps and total memory reloads; verify Example 1C pattern and that k > 1 is chosen for MatMul ops where the memory budget allows | 8 | F-05, F-03, F-02 | +| 12 | **F-06: Granularity search** | For each subgraph, try candidate `[w, h]` values (powers of 2 up to tensor dimensions); **for MatMul subgraphs, also search `k` from `K_max` down to 1 in powers of 2 where `K_max = max(K_full_op)` across all MatMuls in the subgraph (Issue #15 fix, Issue #22 mixed-K) -- k must not be hardcoded to 1**; use **closed-form latency evaluation** (ADR-005, Issue #16) instead of tile-by-tile simulation; for mixed-K subgraphs, the closed-form evaluator groups steps into phases by active op set (ADR-006); select the `[w, h, k]` combination that minimizes subgraph latency within the OOM constraint; larger k values reduce the total number of k-steps and total memory reloads; verify Example 1C pattern and that k > 1 is chosen for MatMul ops where the memory budget allows | 8 | F-05, F-03, F-02 | **Total MVP Estimated Effort: 43 hours** @@ -72,15 +72,21 @@ Goal: Lowest total latency on MLSys-2026 benchmarks - [ ] The `subgraph_latencies` values in every output JSON match the latency model to within floating-point tolerance (validated by `Evaluate()` or the Python re-implementation) - [ ] No solution contains a working set exceeding `fast_memory_capacity` for any benchmark -- [ ] For MatMul subgraphs, `k` is searched from `K_full` down to 1 (powers of 2) and the - `(w, h, k)` triple that minimizes total subgraph latency within `fast_memory_capacity` is - selected (larger `k` is preferred as a tie-breaker when latencies are equal) +- [ ] For MatMul subgraphs, `k` is searched from `K_max` down to 1 (powers of 2) where + `K_max = max(K_full_op)` across all MatMuls in the subgraph, and the `(w, h, k)` triple + that minimizes total subgraph latency within `fast_memory_capacity` is selected (larger `k` + is preferred as a tie-breaker when latencies are equal) +- [ ] Mixed-K fusion is supported (Issue #22): subgraphs may contain MatMuls with different + K_full values; each MatMul is active for `ceil(K_full_op / k)` steps and contributes zero + compute/memory after completing its reduction; the subgraph runs for + `ceil(max(K_full) / k)` total k-steps ### Performance (Issue #16) - [ ] Each of the 5 benchmarks completes end-to-end (parse, optimize, serialize) in under 2 seconds on a standard developer machine (M-series Mac or modern x86-64) - [ ] The granularity search stage must not be the performance bottleneck: candidate evaluation - uses closed-form latency (O(1) per candidate, not tile-by-tile simulation) + uses closed-form latency (O(1) per candidate for uniform-K, O(distinct_K) for mixed-K, + not tile-by-tile simulation) - [ ] Cost-based fusion must not regress performance: the merge-vs-split comparison adds negligible overhead relative to the O(N^2) fusion loop @@ -90,6 +96,8 @@ Goal: Lowest total latency on MLSys-2026 benchmarks individual latencies already include boundary DRAM) - [ ] No benchmark exhibits a latency regression from fusion (i.e., fusing never makes things worse than keeping subgraphs separate) +- [ ] Mixed-K fusion (Issue #22) eliminates artificial DRAM boundaries between ops with + different K_full values, targeting approximately 30% improvement on benchmarks 1 and 9 --- @@ -99,11 +107,13 @@ Goal: Lowest total latency on MLSys-2026 benchmarks 1. Scheduler reads problem JSON -> Parses Problem struct with tensors, ops, hw params 2. Topological sort -> Linearized op execution order 3. Baseline schedule generated -> Valid JSON, all ops covered, no OOM -4. Cost-based chain fusion applied -> Adjacent ops merged only when fused latency wins +4. Cost-based chain fusion applied -> Adjacent ops merged only when fused latency wins; + mixed-K fusion allowed (Issue #22) 5. Tensor retention decided -> Downstream-needed tensors flagged as resident 6. Split-K applied to MatMul subgraphs -> k reduced to fit tight memory budgets 7. Granularity search per subgraph -> Best [w, h, k] selected via closed-form evaluation; - k searched from K_full downward for MatMul ops + k searched from K_max downward for MatMul ops; + mixed-K subgraphs use phased evaluation (ADR-006) 8. Latency calculated for each subgraph -> subgraph_latencies list populated 9. Solution JSON written -> Ready for Evaluate() call 10. Benchmark runner reports total -> Score vs. baseline shown; validated correct @@ -147,4 +157,5 @@ Goal: Lowest total latency on MLSys-2026 benchmarks | Benchmark 17 (160 tensors, 95+ ops, 500K fast memory) has complex topology that greedy fusion mishandles | M | M | Analyze graph structure in architecture stage; design fusion rules for the attention-like repeating pattern observed in benchmarks 9, 13, 17 | | Working-set formula for subgraphs containing both MatMul and Pointwise ops is incorrectly specified | L | H | Cross-check against Example 5B which has exactly this combination; add a dedicated test case | | Granularity search too slow on large benchmarks (Issue #16) | H | H | Use closed-form latency evaluation (ADR-005) instead of tile-by-tile simulation; early termination on OOM; target < 2s per benchmark | -| Granularity search defaults to k=1 for MatMul ops (Issue #15) | H | H | Search k from K_full downward; select the (w,h,k) minimizing total latency (larger k as tie-breaker); add regression test asserting k > 1 on any benchmark where K_full > 1 and memory allows | +| Granularity search defaults to k=1 for MatMul ops (Issue #15) | H | H | Search k from K_max downward; select the (w,h,k) minimizing total latency (larger k as tie-breaker); add regression test asserting k > 1 on any benchmark where K_full > 1 and memory allows | +| Mixed-K latency model incorrectly computes phased execution (Issue #22) | M | H | Validate against hand-computed examples with known mixed-K subgraphs; ensure uniform-K is a degenerate case that produces identical results to the previous model; add regression tests for mixed-K scenarios |