Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 175 additions & 36 deletions solution/agent/evaluator.py

Large diffs are not rendered by default.

11 changes: 0 additions & 11 deletions solution/agent/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,17 +268,6 @@ def _cached_best(ops: list[int], cache: dict) -> tuple[Granularity, float]:

merged = sg_ops[i] + sg_ops[i + 1]

# K_full consistency: all MatMuls must share the same K_full
matmul_k_fulls = [
_k_full_for_op(problem.ops[o], problem)
for o in merged if problem.ops[o].op_type == "MatMul"
]
if matmul_k_fulls and len(set(matmul_k_fulls)) > 1:
rejected_merges.add(merge_key)
new_sg_ops.append(sg_ops[i])
i += 1
continue

# Boundary output dimension consistency
boundary_outs = list(_boundary_outputs_for_subgraph(merged, problem))
if boundary_outs:
Expand Down
24 changes: 9 additions & 15 deletions solution/backend/rust/src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ pub fn evaluate(problem: &Problem, solution: &Solution) -> Result<EvaluateResult
}
}

// Validate MatMul K_full consistency and k <= K_full
let matmul_k_fulls: Vec<i64> = sg.ops.iter()
// Validate k does not exceed the maximum K_full across all MatMuls.
// Mixed-K subgraphs (MatMuls with different K_full values) are allowed;
// k only needs to be <= max(K_full) so the step count is well-defined.
let max_k_full: Option<i64> = sg.ops.iter()
.filter_map(|&op_idx| {
let op = &problem.ops[op_idx];
if op.is_matmul() {
Expand All @@ -81,20 +83,12 @@ pub fn evaluate(problem: &Problem, solution: &Solution) -> Result<EvaluateResult
None
}
})
.collect();
if !matmul_k_fulls.is_empty() {
// All MatMuls in a subgraph must share the same K_full
if !matmul_k_fulls.iter().all(|&kf| kf == matmul_k_fulls[0]) {
.max();
if let Some(kf_max) = max_k_full {
if sg.granularity.k > kf_max {
return Err(format!(
"Subgraph {sg_idx}: MatMul ops have inconsistent K_full values: {:?}",
matmul_k_fulls
));
}
// k must not exceed K_full
if sg.granularity.k > matmul_k_fulls[0] {
return Err(format!(
"Subgraph {sg_idx}: granularity k={} exceeds K_full={}",
sg.granularity.k, matmul_k_fulls[0]
"Subgraph {sg_idx}: granularity k={} exceeds max K_full={}",
sg.granularity.k, kf_max
));
}
}
Expand Down
211 changes: 174 additions & 37 deletions solution/backend/rust/src/latency.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub fn matmul_compute_per_step(
let op = &problem.ops[op_idx];
if op.is_matmul() {
let op_k_full = k_full_for_matmul(op, &problem.tensors) as f64;
total += op.base_cost as f64 * (k / op_k_full);
total += op.base_cost as f64 * (k.min(op_k_full) / op_k_full);
}
}
total
Expand Down Expand Up @@ -155,6 +155,10 @@ fn build_memory_plan(
&& !dag.tensor_consumers[out_t].is_empty()
&& dag.tensor_consumers[out_t].iter().all(|c| op_set.contains(c));

// Effective k for this op: clamp to its K_full (can't load more than exists)
let op_k_full = k_full_for_matmul(op, &problem.tensors);
let k_eff = k.min(op_k_full);

// LHS input
let lhs_boundary = !dag.tensor_producer[lhs_idx]
.map(|p| op_set.contains(&p))
Expand All @@ -164,15 +168,13 @@ fn build_memory_plan(
if previously_retained.contains(&lhs_idx) {
pre_retained.push(lhs_idx);
} else if output_ephemeral {
// Upstream LHS: we need a ROW STRIP = h rows * full K_full_Op0 width
// h = output height of the FINAL output (the subgraph output)
// K_full_Op0 = lhs.width (the full reduction dimension of this upstream op)
// Upstream LHS: ROW STRIP = h * K_full (full reduction width)
let lhs_width = problem.tensors[lhs_idx].width;
let row_strip_size = h * lhs_width;
full_load.push((lhs_idx, row_strip_size));
} else {
// Standard LHS slice = h * k
k_strip.push((lhs_idx, h * k));
// Standard LHS slice = h * k_eff
k_strip.push((lhs_idx, h * k_eff));
}
}

Expand All @@ -185,13 +187,12 @@ fn build_memory_plan(
if previously_retained.contains(&rhs_idx) {
pre_retained.push(rhs_idx);
} else if output_ephemeral {
// Upstream RHS: col strip of the intermediate = K_full_Op0 * k
// = rhs.height * k (since rhs.height = K_full_Op0 for this upstream op)
// Upstream RHS: rhs.height * k_eff
let rhs_height = problem.tensors[rhs_idx].height;
k_strip.push((rhs_idx, rhs_height * k));
k_strip.push((rhs_idx, rhs_height * k_eff));
} else {
// Standard RHS slice = k * w
k_strip.push((rhs_idx, k * w));
// Standard RHS slice = k_eff * w
k_strip.push((rhs_idx, k_eff * w));
}
}
}
Expand All @@ -213,14 +214,15 @@ fn build_memory_plan(
}
}

/// Compute num_k_steps for a subgraph: ceil(min_K_full / k) across all MatMuls.
/// Compute num_k_steps for a subgraph: ceil(max_K_full / k) across all MatMuls.
/// Uses MAX so the subgraph runs until the longest reduction finishes.
/// Returns 1 if there are no MatMul ops.
pub fn compute_num_k_steps(
subgraph_ops: &[usize],
k: i64,
problem: &Problem,
) -> i64 {
let min_k_full: Option<i64> = subgraph_ops
let max_k_full: Option<i64> = subgraph_ops
.iter()
.filter_map(|&op_idx| {
let op = &problem.ops[op_idx];
Expand All @@ -230,11 +232,11 @@ pub fn compute_num_k_steps(
None
}
})
.min();
.max();
if k <= 0 {
return 1; // Guard against division by zero from malformed input
}
match min_k_full {
match max_k_full {
Some(kf) => (kf + k - 1) / k,
None => 1,
}
Expand Down Expand Up @@ -335,31 +337,166 @@ pub fn subgraph_latency(
if num_k_steps > 1 {
// Split-K mode: all spatial tiles are identical.
//
// First k-step of each spatial tile:
// load = full_load_total + pw_load_total + k_strip_total
// compute = matmul_compute
let first_k_mem = (full_load_total + pw_load_total + k_strip_total) as f64 / bw;
let first_k_lat = f64::max(matmul_compute, first_k_mem);

// Interior k-steps (steps 2 .. num_k_steps-1):
// load = k_strip_total only
// compute = matmul_compute
let interior_k_lat = if num_k_steps > 2 {
let interior_mem = k_strip_total as f64 / bw;
f64::max(matmul_compute, interior_mem)
// Mixed-K support: MatMuls with different K_full values finish at different steps.
// We compute latency in phases where each phase has a different set of active MatMuls.
//
// Build a lookup from tensor_id -> k_strip_size from the memory plan.
// Used to compute per-MatMul k_strip contributions accurately (no compute-ratio proxy).
let k_strip_map: std::collections::HashMap<usize, i64> =
plan.k_strip.iter().map(|&(t, sz)| (t, sz)).collect();

// Collect (K_full, base_cost, k_strip_contribution) tuples for all MatMul ops.
// k_strip_contribution is the sum of k_strip sizes for this op's boundary LHS/RHS
// inputs (using the same deduplication logic as build_memory_plan: each tensor
// is counted for the first MatMul op that claims it).
let mut k_strip_seen: std::collections::HashSet<usize> = std::collections::HashSet::new();
let matmul_phases: Vec<(i64, f64, i64)> = subgraph_ops
.iter()
.filter_map(|&op_idx| {
let op = &problem.ops[op_idx];
if !op.is_matmul() {
return None;
}
let kf = k_full_for_matmul(op, &problem.tensors);
let lhs_idx = op.inputs[0];
let rhs_idx = op.inputs[1];
let mut op_k_strip: i64 = 0;
if !k_strip_seen.contains(&lhs_idx) {
if let Some(&sz) = k_strip_map.get(&lhs_idx) {
op_k_strip += sz;
}
k_strip_seen.insert(lhs_idx);
}
if !k_strip_seen.contains(&rhs_idx) {
if let Some(&sz) = k_strip_map.get(&rhs_idx) {
op_k_strip += sz;
}
k_strip_seen.insert(rhs_idx);
}
Some((kf, op.base_cost as f64, op_k_strip))
})
.collect();

// Check if all MatMuls have identical K_full (fast path, existing formula).
let all_same_k_full = matmul_phases.windows(2).all(|w| w[0].0 == w[1].0);

let per_tile_lat = if all_same_k_full {
// Fast path: uniform K_full — use original formula.
//
// First k-step: load = full_load_total + pw_load_total + k_strip_total
let first_k_mem = (full_load_total + pw_load_total + k_strip_total) as f64 / bw;
let first_k_lat = f64::max(matmul_compute, first_k_mem);

// Interior k-steps: load = k_strip_total only
let interior_k_lat = if num_k_steps > 2 {
let interior_mem = k_strip_total as f64 / bw;
f64::max(matmul_compute, interior_mem)
} else {
0.0
};

// Last k-step: load = k_strip_total, evict output, compute includes PW
let last_k_mem = (k_strip_total + plan.out_evict_size) as f64 / bw;
let last_k_lat = f64::max(matmul_compute + pw_compute, last_k_mem);

first_k_lat + (num_k_steps - 2).max(0) as f64 * interior_k_lat + last_k_lat
} else {
0.0
};
// Mixed-K path: compute phase-by-phase.
//
// Phases are defined by sorted unique step-end boundaries (when each MatMul finishes).
//
// Example: K_full = [4, 8], k = 2
// MatMul-A finishes at step ceil(4/2)=2, MatMul-B at step ceil(8/2)=4
// Phase 1: steps 0..2 — both active (2 steps)
// Phase 2: steps 2..4 — only MatMul-B active (2 steps)
//
// Within a phase all steps have identical cost except:
// - Global step 0: loads full_load_total + pw_load_total additionally
// - Global last step: evicts output + adds PW compute
// Replace the per-step loop with O(1) per-phase arithmetic.
let mut step_ends: Vec<i64> = matmul_phases
.iter()
.map(|(kf, _, _)| (*kf + k - 1) / k)
.collect();
step_ends.sort_unstable();
step_ends.dedup();
// step_ends.last() == num_k_steps (max)

let mut per_tile_lat = 0.0_f64;
let mut prev_end: i64 = 0;

for (phase_idx, &phase_end) in step_ends.iter().enumerate() {
// Active MatMuls: those with step_count >= phase_end.
let active_compute: f64 = matmul_phases
.iter()
.filter(|(kf, _, _)| (*kf + k - 1) / k >= phase_end)
.map(|(kf, cost, _)| cost * ((k as f64).min(*kf as f64) / *kf as f64))
.sum();

// Active k_strip: sum per-op contributions for active MatMuls only.
// This is exact because each op's contribution was precomputed from
// its actual tensor dimensions, not from a compute-ratio proxy.
let active_k_strip_elems: i64 = matmul_phases
.iter()
.filter(|(kf, _, _)| (*kf + k - 1) / k >= phase_end)
.map(|(_, _, ks)| ks)
.sum();
let active_k_strip = active_k_strip_elems as f64 / bw;

let phase_steps = phase_end - prev_end;
let is_last_phase = phase_idx == step_ends.len() - 1;

// O(1) per phase: classify steps as first, interior, or last.
// Special steps: global step 0 (loads full_load + pw_load) and
// global last step (evicts output, adds PW compute).
let has_first = prev_end == 0;
let has_last = is_last_phase; // last phase always contains the last step

// Interior steps: all steps in this phase that are neither first nor last.
let interior_count = (phase_steps
- if has_first { 1 } else { 0 }
- if has_last { 1 } else { 0 })
.max(0);

if has_first {
let mem = (full_load_total + pw_load_total) as f64 / bw + active_k_strip;
// First step is also last only when num_k_steps == 1, which is
// impossible here (we are in the num_k_steps > 1 branch).
per_tile_lat += f64::max(active_compute, mem);
}

if interior_count > 0 {
let interior_lat = f64::max(active_compute, active_k_strip);
per_tile_lat += interior_count as f64 * interior_lat;
}

// Last k-step of each spatial tile:
// load = k_strip_total
// evict = out_evict_size
// compute = matmul_compute + pw_compute
let last_k_mem = (k_strip_total + plan.out_evict_size) as f64 / bw;
let last_k_lat = f64::max(matmul_compute + pw_compute, last_k_mem);
if has_last {
// If this phase has exactly one step and it is also the first step,
// we already accounted for it above; replace that cost with the
// combined first+last cost.
let is_also_first = has_first && phase_steps == 1;
if !is_also_first {
let mem_last = active_k_strip + plan.out_evict_size as f64 / bw;
let compute_last = active_compute + pw_compute;
per_tile_lat += f64::max(compute_last, mem_last);
} else {
// Single-step phase that is both first and last.
// Undo the first-step cost already added, then add combined cost.
let first_mem = (full_load_total + pw_load_total) as f64 / bw + active_k_strip;
per_tile_lat -= f64::max(active_compute, first_mem);
let mem_last = (full_load_total + pw_load_total) as f64 / bw
+ active_k_strip
+ plan.out_evict_size as f64 / bw;
let compute_last = active_compute + pw_compute;
per_tile_lat += f64::max(compute_last, mem_last);
}
}

prev_end = phase_end;
}

// num_k_steps >= 2 here (outer branch guarantees it)
let per_tile_lat = first_k_lat + (num_k_steps - 2).max(0) as f64 * interior_k_lat + last_k_lat;
per_tile_lat
};

num_spatial_tiles as f64 * per_tile_lat
} else {
Expand Down
50 changes: 49 additions & 1 deletion solution/backend/rust/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ fn run_evaluate(args: &[String]) {
#[cfg(test)]
mod tests {
use super::*;
use crate::latency::subgraph_latency;
use crate::latency::{compute_num_k_steps, subgraph_latency};
use crate::models::Granularity;
use std::collections::HashSet;

Expand Down Expand Up @@ -630,6 +630,54 @@ mod tests {
);
}

// Mixed-K: two MatMuls with different K_full fused in one subgraph
#[test]
fn test_mixed_k_two_matmuls() {
// Op0: MatMul K=64, Tensor0(64x128) @ Tensor1(128x64) -> Tensor2(128x128), cost=1000
// Op1: MatMul K=128, Tensor2(128x128) @ Tensor3(128x128) -> Tensor4(128x128), cost=2000
// Fused [0,1] with k=32:
// Op0: ceil(64/32)=2 k-steps, compute/step = 1000*(32/64) = 500
// Op1: ceil(128/32)=4 k-steps, compute/step = 2000*(32/128) = 500
// Total k-steps = max(2,4) = 4
// Phase 1 (steps 0-1): both active, compute = 500+500 = 1000
// Phase 2 (steps 2-3): only Op1 active, compute = 500
let json = r#"{
"widths": [64,128,128,128,128],
"heights": [128,64,128,128,128],
"inputs": [[0,1],[2,3]],
"outputs": [[2],[4]],
"base_costs": [1000, 2000],
"op_types": ["MatMul","MatMul"],
"fast_memory_capacity": 100000,
"slow_memory_bandwidth": 10,
"native_granularity": [128, 128]
}"#;
let problem = parse_problem(json).unwrap();
let dag = DagInfo::build(&problem).unwrap();

let gran = Granularity { w: 128, h: 128, k: 32 };
let lat = subgraph_latency(&[0, 1], &gran, &[], &HashSet::new(), &problem, &dag);

// 4 k-steps, 1 spatial tile
// Verify latency > 0 and is reasonable (phases produce different costs)
assert!(lat > 0.0, "Mixed-K latency must be positive, got {lat}");

// Compare with k=64 (Op0 does 1 step, Op1 does 2 steps)
let gran64 = Granularity { w: 128, h: 128, k: 64 };
let lat64 = subgraph_latency(&[0, 1], &gran64, &[], &HashSet::new(), &problem, &dag);
assert!(lat64 > 0.0, "Mixed-K k=64 latency must be positive, got {lat64}");

// k=128: Op0 does 1 step (k>K_full clamped), Op1 does 1 step
let gran128 = Granularity { w: 128, h: 128, k: 128 };
let lat128 = subgraph_latency(&[0, 1], &gran128, &[], &HashSet::new(), &problem, &dag);
assert!(lat128 > 0.0, "Mixed-K k=128 latency must be positive, got {lat128}");

// Verify num_k_steps = ceil(max(64,128)/k)
assert_eq!(compute_num_k_steps(&[0, 1], 32, &problem), 4); // ceil(128/32)
assert_eq!(compute_num_k_steps(&[0, 1], 64, &problem), 2); // ceil(128/64)
assert_eq!(compute_num_k_steps(&[0, 1], 128, &problem), 1); // ceil(128/128)
}

// Benchmark validation: all 5 benchmark solutions should cover all ops and have non-negative latencies
#[test]
fn test_benchmark_solutions_validity() {
Expand Down
Loading
Loading