Skip to content

Comments

Add factor-based and independent per-mesh-dim ILP sharding optimizers#320

Draft
fmassa wants to merge 20 commits intomainfrom
fmassa/factor_graph
Draft

Add factor-based and independent per-mesh-dim ILP sharding optimizers#320
fmassa wants to merge 20 commits intomainfrom
fmassa/factor_graph

Conversation

@fmassa
Copy link
Contributor

@fmassa fmassa commented Feb 18, 2026

Summary

Adds two new sharding optimizers as alternatives to the original enumeration-based ILP (ShardingOptimizer):

  • FactorShardingOptimizer in optimize_sharding_new.py
  • IndependentShardingOptimizer in optimize_sharding_independent.py

Both target the scalability bottleneck of the original solver, whose strategy count per node is O((d+1)^k) where d = tensor dims and k = mesh dims. On a 2D mesh with 3D tensors, that's 16 strategies per node; on a 4D mesh it would be 256+. The ILP variable count is even worse because each variable encodes a (producer output, consumer input) placement pair.

Original ILP recap

The original ShardingOptimizer enumerates all multi-dimensional placement combinations per node and creates binary decision variables x[node, arg, output_strat, input_strat]. Constraints enforce one-hot selection, cross-argument consistency, and dataflow consistency (producer output = consumer input). The cost of each variable directly encodes the exact redistribution cost between the producer's output placement and the consumer's expected input placement.

Strengths:

  • Exact redistribution costs — every (src_placement, dst_placement) pair has a precise communication cost computed from estimate_strategy_comms_cost.
  • Tight LP relaxation — the one-hot encoding means the LP relaxation is the convex hull of integer solutions (each constraint group forms a simplex). Branch-and-bound typically finds the optimum at the root node.
  • Correct joint solutions — considers all mesh dimensions together, so cross-dim interactions (e.g., "TP on dim 1 is only worthwhile if DP handles the batch on dim 0") are captured naturally.

Weaknesses:

  • Strategy explosion — O((d+1)^k) strategies per node, O((d+1)^(2k)) variables per node-argument pair. On a (16, 8) mesh with the Block model: 23,594 variables, 5,113 constraints, 3.26s total time. This becomes prohibitive for larger meshes or models.

Factor-based ILP (FactorShardingOptimizer)

Inspired by Shardy's factor-based propagation, this reformulates the ILP using "factors" — logical computation dimensions (M, N, K for a matmul). Instead of enumerating all placement combinations, it creates binary variables y[factor_gid, mesh_dim] indicating whether a given factor is sharded on a given mesh dimension.

How it works:

  1. Builds multi-dim strategies using existing DTensor op rules (same as original).
  2. Extracts factors generically from OpStrategy objects by inspecting 1D placement atoms at mesh dim 0. Each unique non-replicate atom corresponds to one factor.
  3. Builds a factor graph with spatial edges (producer result dim = consumer operand dim) and reduction edges (Partial propagation through data-preserving ops).
  4. Constructs an ILP with:
    - Factor uniqueness: each factor assigned to at most one mesh dim (≤ 1).
    - Tensor exclusion: per tensor per mesh dim, at most one factor can be sharded.
    - Reduction propagation: consumer Partial ≤ producer Partial.
    - Objective: compute benefit (negative cost for sharding) minus edge disagreement costs (all-gather/reduce-scatter/all-reduce when producer and consumer disagree) minus uncovered reduction exit costs.

Strengths vs original:

  • Dramatically fewer ILP variables: O(F × k) instead of O(A × (d+1)^(2k)). On the Block model: 634 y-vars + 678 z-vars = 1,312 total (vs 23,594). This is a ~18x reduction.
  • Captures cross-dim interactions — factors are assigned to mesh dims globally, so the solver naturally discovers TP+DP splits.
  • Solution quality is very close to original — on the Block model, factor and original agree on most nodes. The few disagreements are in how Partial outputs are resolved (reduce-scatter vs all-reduce), which is a consequence of the cost linearization.

Weaknesses vs original:

  • Edge disagreement costs are linearized using auxiliary z-variables (z ≥ y_prod - y_cons), which is an LP relaxation of the exact redistribution cost. The original ILP encodes exact costs per (src, dst) pair. This means:
    • The factor ILP cannot distinguish between different redistribution patterns that have different costs (e.g., Shard(0) → Shard(1) vs Shard(0) → Replicate).
    • The LP relaxation is weaker — fractional z-values can underestimate disagreement costs, potentially requiring more branch-and-bound work.
  • Strategy building still uses the full multi-dim mesh (O((d+1)^k) per node) even though only 1D atoms are inspected. This is a POC limitation; a production version would build strategies on a 1D mesh only.
  • Factor extraction is a heuristic — it assumes OpStrategies are Cartesian products of per-mesh-dim atoms (which is true for ops that go through expand_to_full_mesh_op_strategy, but may not hold for custom ops or local_map).

Solve time: 1.32s total (0.24s solve + 1.05s strategy building). The strategy building dominates — with 1D-only strategy building this would be much faster.

Independent per-mesh-dim ILP (IndependentShardingOptimizer)

This takes the simplest possible decomposition: run k completely independent ShardingOptimizer instances on 1D sub-meshes (one per mesh dimension), then stitch the 1D solutions together into multi-dim placements.

How it works:

  1. For each mesh dim m, extract a 1D sub-mesh via mesh[dim_name].
  2. Create a ShardingOptimizer on the 1D sub-mesh — this builds O(d+1) strategies per node (not O((d+1)^k)).
  3. Project user constraints from multi-dim to 1D: (Shard(0), Replicate()) → dim 0 sees (Shard(0),), dim 1 sees (Replicate(),).
  4. Solve each 1D ILP independently.
  5. Combine: for each node, the multi-dim placement is (p_0, p_1, ..., p_{k-1}) where p_m is the 1D solution for mesh dim m.

Strengths vs original:

  • Fastest solver: 0.48s total (0.16s build per dim + 0.07s solve per dim). This is 6.8x faster than the original and 2.75x faster than the factor-based solver.
  • O(d+1) strategies per node on each 1D mesh — no exponential blowup.
  • Tight LP relaxation — inherits the one-hot simplex structure of the original ILP on each 1D sub-problem.
  • Exact redistribution costs within each mesh dimension — no z-variable linearization.
  • Reuses ShardingOptimizer unmodified — no new ILP formulation, no factor extraction, minimal new code.

Strengths vs factor-based:

  • Simpler implementation (~300 lines vs ~1,500 lines).
  • Exact 1D redistribution costs (no z-variable approximation).
  • Tighter LP relaxation per sub-problem.
  • Faster due to both smaller ILP size and no factor extraction overhead.

Weaknesses vs both original and factor-based:

  • Mesh dimensions are fully decoupled — cross-dim interactions are lost entirely. Each 1D solver independently optimizes without knowing what the other dims chose. This causes a systematic problem: both the dp and tp solvers independently discover that batch-dim sharding is cheapest (zero communication cost), so the combined solution is effectively all-data-parallelism (Shard(0) on every mesh dim). The original and factor-based solvers correctly find DP+TP splits because they consider mesh dimensions jointly.
  • Memory constraints are approximate — the true parameter memory ratio is a product over mesh dims (nonlinear), but each 1D solver applies the constraint independently.
  • The solution quality gap is significant in practice: on the Block model with a (16, 8) mesh, nearly every node differs from the original solution. The independent solver chooses (Shard(0), Shard(0)) (128-way DP) where the original chooses (Shard(0), Shard(1)) (16-way DP + 8-way TP).

Summary table

  ┌──────────────────────────────┬─────────────────┬─────────────────────┬────────────────────┐
  │           Property           │    Original     │    Factor-based     │    Independent     │
  ├──────────────────────────────┼─────────────────┼─────────────────────┼────────────────────┤
  │ ILP variables (Block model)  │ 23,594          │ 1,312               │ 3,724 (2 × 1,862)  │
  ├──────────────────────────────┼─────────────────┼─────────────────────┼────────────────────┤
  │ ILP constraints              │ 5,113           │ 1,695               │ 1,782 (2 × 891)    │
  ├──────────────────────────────┼─────────────────┼─────────────────────┼────────────────────┤
  │ Total time                   │ 3.26s           │ 1.32s               │ 0.48s              │
  ├──────────────────────────────┼─────────────────┼─────────────────────┼────────────────────┤
  │ Strategies per node          │ O((d+1)^k)      │ O((d+1)^k)*         │ O(d+1) per dim     │
  ├──────────────────────────────┼─────────────────┼─────────────────────┼────────────────────┤
  │ Redistribution costs         │ Exact           │ Linearized (z-vars) │ Exact per dim      │
  ├──────────────────────────────┼─────────────────┼─────────────────────┼────────────────────┤
  │ LP relaxation tightness      │ Tight (simplex) │ Weaker (z-vars)     │ Tight per dim      │
  ├──────────────────────────────┼─────────────────┼─────────────────────┼────────────────────┤
  │ Cross-dim interactions       │ Full            │ Full                │ None               │
  ├──────────────────────────────┼─────────────────┼─────────────────────┼────────────────────┤
  │ Solution quality vs original │ —               │ Very close          │ Poor (all-DP bias) │
  └──────────────────────────────┴─────────────────┴─────────────────────┴────────────────────┘
  * Factor-based currently builds full multi-dim strategies for factor extraction;
    could be reduced to O(d+1) with 1D-only building.

Test plan

  • Run python examples/example_autoparallel_factor.py — compares all three solvers side-by-side
  • Verify factor-based and independent solvers produce valid OpSpec dictionaries that can be passed to apply_sharding
  • Test with different mesh shapes (1D, 2D, 3D) to verify correctness
  • Test with input/output/grad-param constraints

The representation now allows it — nothing in the constraint system prevents y[root, 0] = 1 AND y[root, 1] = 1. The issue is in the cost model. Let me trace why.

  The DIFF nodes are weight transposes (t.default). Their dim 0 factor merges with the K (reduction/contraction) factor of the downstream mm. In the current cost model:

  if fac.is_reduction:
      cost += allreduce_cost      # penalty only
  else:
      cost -= compute_benefit     # benefit only

  Reduction factors get only an allreduce penalty, with no compute benefit subtracted. But sharding the K dimension does reduce compute — each device computes a smaller
  matmul. The original enumeration-based optimizer captures this because its per-strategy costs include both comm and compute for each choice. The factor cost model misses
  the compute savings for reduction factors, so the optimizer avoids assigning K to any mesh dim.
Three places fixed in optimize_sharding_new.py:

  1. _infer_factor_size — now uses _get_primary_tensor() to extract the first tensor from tuple outputs for shape inference
  2. _placeholder_factor_rule — same, handles tuple val metadata
  3. get_solution — for tuple outputs, produces a tuple of DTensorSpecs, filtering out Shard(d) placements where d exceeds a given output tensor's ndim (since auxiliary
  outputs like logsumexp may have fewer dimensions than the primary output)
Here's a summary of the changes:

  1. Partial placement output — Reduction factors (like K in matmul) now produce Partial() instead of being silently skipped. This matches what the original optimizer outputs
   for nodes like mm_3.

  2. FLOP-based compute cost estimator — Replaced the broken estimate_strategy_runtime_cost(node, None) (which was silently returning 0 for everything) with a direct FLOP
  counter for mm, addmm, bmm, and SDPA. This is why the optimizer wasn't incentivized to shard the batch dimension on both mesh dims — it saw zero compute benefit.

  3. All-gather cost at exit edges — For each spatial factor root, we pre-compute "exit edges" where a producer has the factor but its consumer doesn't (via union-find). When
   the factor is assigned to a mesh dim, an all-gather is needed at each such edge. This cost is added as a linear term in the objective, penalizing factors that would
  require redistribution.
Here's a summary of the three changes:

  1. Reduction factor propagation in _build_factor_graph: When a consumer has a reduction factor with operand_dims[arg_pos] = None (a Partial pass-through atom from
  view/permute/alias strategies), it now merges with the producer's reduction factor. This makes Partial propagate through data-preserving ops — so view_11 will correctly
  show Partial(sum) instead of Replicate().
  2. Reduce-scatter cost replaces allreduce: The old code added allreduce cost (2B·(n-1)/n) unconditionally at every node where a factor was a reduction. Now the cost is only
   added at "exit edges" where the Partial doesn't propagate to the consumer, using reduce-scatter cost (B·(n-1)/n — half of allreduce). Where the reduction factor propagates
   (e.g., mm → view), there's zero cost.
  3. Unified _compute_redistribution_bytes: Returns both ag_bytes (spatial exit edges → all-gather) and rs_bytes (reduction exit edges → reduce-scatter) in a single pass over
   the graph.
Here's a summary of the three changes made:

  1. Tensor exclusion extended to include reduction factors (line 556-560)

  Previously, only spatial factors were included in the per-node per-mesh-dim exclusion constraint. A reduction factor (Partial) and a spatial factor (Shard) could end up
  assigned to the same mesh dim for the same output tensor, which is invalid in DTensor. Now both are included in the sum <= 1 constraint.

  2. Linearized Partial → Replicate cost (lines 649-694)

  For each reduction exit edge (root r, consumer u) on mesh dim m:
  - Base cost = B·(n-1)/n · y[r,m] — reduce-scatter, always incurred
  - Extra cost = B·(n-1)/n · z — upgrades to all-reduce when consumer is Replicate

  The auxiliary continuous variable z satisfies:
  - z ≥ y[r,m] − Σ y[s,m] for consumer's spatial roots s
  - z ≥ 0

  Since z has a positive coefficient and we minimize, the solver naturally sets z = max(0, y[r,m] − Σ y[s,m]):
  - Consumer has spatial factor on m → z=0, total = B (reduce-scatter)
  - Consumer fully replicated on m → z=y[r,m], total = 2B (all-reduce)

  3. New helper methods (lines 798-838)

  - _compute_reduction_exit_info() — returns (reduction_root, consumer_nidx) → bytes for each exit edge
  - _get_spatial_roots_at_node(nidx) — returns the set of spatial factor roots at a given node

  The auxiliary z variables are continuous (not binary), so they add negligible solve cost. The stats now report the breakdown (e.g., "104 y + 42 z").
Here's what was changed:

  Fix 1: Partial propagation validation (_build_factor_graph + _merge_reduction_factors)

  Problem: For add(o, x) in the transformer, o is Partial (from mm's K factor) but x is Shard(0). The old code merged reduction factors one operand at a time — it saw o has Partial and immediately merged, without checking that x ALSO has
  Partial. The strategy add(Partial, Shard) is invalid in DTensor.

  Fix: Spatial and reduction factor merging are now separate passes. The new _merge_reduction_factors method (line 470) collects ALL operands that need Partial (operand_dims[arg_pos] = None) and only merges when every one of them can provide a
  reduction factor from its producer. For unary ops (view, permute, alias) with one operand, nothing changes. For binary ops (add, mul), Partial only propagates when both operands produce Partial.

  Fix 2: Compute cost via estimate_strategy_runtime_cost (line 769)

  Problem: The old _compute_cost was a manual FLOP estimator covering only mm/addmm/bmm/SDPA. It returned 0 for pointwise ops (add, relu, etc.), giving the optimizer no incentive to shard them. It also returned FLOPs while communication costs
  were in microseconds — the units were inconsistent, making the compute-vs-communication tradeoff arbitrary.

  Fix: Now calls estimate_strategy_runtime_cost(node, None) from compute_estimation.py, which:
  - Counts FLOPs via FlopCounterMode (handles all ops, not just matmuls)
  - Computes memory read/write bytes (captures memory-bound ops like pointwise add/relu)
  - Returns max(compute_time, memory_time) in microseconds
  - Returns 0 for view ops (via _has_zero_cost)

  Results are cached per node. The communication costs in _add_objective are already in microseconds (bytes / BW * 1e6), so now both terms use consistent units.
Let me summarize all the changes made:

  Changes summary

  1. import operator added (line 48) — needed to check node.target is operator.getitem in the factor graph builder.

  2. _infer_factor_size (lines 191-204) — now handles multi-output nodes correctly. For tuple/list outputs, it indexes result_dims[ri] against val[ri] to find the correct output tensor's shape, rather than always using the first tensor.

  3. extract_factors_from_strategy (lines 239-299) — the core fix. Previously it extracted only the first output's placement (out_specs[0].placements[0]) and built a single-element result_dims. Now it:
  - Determines num_results from the output specs structure
  - Extracts placements for all output tensors into out_ps
  - Builds result_dims with one entry per output tensor (e.g., [1, 1, 1] for SDPA backward with 3 outputs that all shard on dim 1)
  - Sets num_results correctly on the returned FactorRule

  4. _build_factor_graph (lines 473-502) — for getitem nodes, determines the getitem index from node.args[1] and uses p_fac.result_dims[result_idx] instead of always result_dims[0]. This correctly connects e.g. getitem_4 = sdpa_bwd[0] to the
  first output's factors of sdpa_bwd, and getitem_5 = sdpa_bwd[1] to the second output's factors.

  5. _add_tensor_exclusion (lines 650-672) — now iterates over each result tensor separately (for ri in range(rule.num_results)), generating per-result exclusion constraints. Different output tensors of the same op can independently have
  different shardings.

  6. _compute_redistribution_bytes (line 863) — changed from fac.result_dims[0] is not None to any(d is not None for d in fac.result_dims).

  7. _get_spatial_roots_at_node (line 905) — same pattern, checks any result dim.

  8. add_node_constraint (lines 929, 941) — Shard matching checks any(d == p.dim for d in fac.result_dims), Replicate checking uses any(d is not None for d in fac.result_dims).

  9. get_solution (lines 1063-1091) — multi-output reconstruction now builds per-output-tensor placements using fac.result_dims[ri] for each output index ri, rather than copying a single set of placements and filtering by shape.

  10. Removed the commented-out IPython embed debug line.

  How it works with your example

  For _scaled_dot_product_efficient_attention_backward with 3 gradient outputs (grad_q, grad_k, grad_v), the factor rule now produces:
  - Factor "nheads": result_dims=[1, 1, 1] (dim 1 of all 3 outputs)
  - Factor "batch": result_dims=[0, 0, 0] (dim 0 of all 3 outputs)
  - num_results=3

  When getitem_4 = sdpa_bwd[0] is processed in the factor graph, it uses result_idx=0, so it matches against result_dims[0] of each producer factor — correctly connecting to the first output's dimensions. Similarly, getitem_5 = sdpa_bwd[1] uses
  result_idx=1, connecting to the second output.
  Imports (lines 63, 65):
  - Added OpSpec to the _op_schema import
  - Added tree_flatten to the _pytree import

  get_solution (lines 1012-1103):
  - Return type changed from dict[torch.fx.Node, DTensorSpec] to dict[torch.fx.Node, OpSpec]
  - Output spec construction is unchanged (same logic, now stored in output_specs local var)
  - After building output_specs, calls the new _build_input_specs helper
  - Returns OpSpec(output_specs=..., input_specs=...) per node
  - For nodes with no operands (placeholders, get_attr), input_specs is None

  _build_input_specs (lines 1105-1176) — new method:
  - For each operand oi in range(rule.num_operands):
    - Starts with [Replicate()] * mesh.ndim
    - For each factor assigned to mesh dim m:
        - operand_dims[oi] not None → Shard(dim) (spatial sharding)
      - operand_dims[oi] is None + is_reduction → Partial (reduction pass-through)
      - Otherwise → stays Replicate
    - Gets TensorMeta from the corresponding input node's meta["val"]
    - For getitem nodes consuming a multi-output producer, uses the getitem index to pick the correct tensor from the tuple for TensorMeta

  This matches the contract that apply_sharding.py expects: sharding_placement[node].output_specs for the node's output, and sharding_placement[node].input_specs[c] for the c-th tensor input (in tree_flatten(node.args) order, filtering for
  Nodes).

  Note that example_autoparallel_factor.py will need its comparison code updated — it currently accesses the factor solution directly as DTensorSpec/tuples, but now it should access .output_specs on the returned OpSpec, consistent with how it
  already handles the original optimizer's solution.
  1. Iterates over all (param, grad) pairs from get_param_and_grad_nodes
  2. For each spatial factor of the param (matched by result_dims[0]), finds the corresponding factor in the grad with the same dimension index
  3. If they have different roots (not already unified by the factor graph), adds equality constraints y[param_root, m] == y[grad_root, m] for all mesh dims — ensuring the same sharding decision applies to both
Problem: A parameter's memory ratio is 1 / product(mesh_shape[m] for sharded mesh dims m) — a product of binary decisions, which is nonlinear in the y variables.

  Linearization: For each parameter, introduce a binary indicator variable b_S for each possible subset S of mesh dims (2^k variables where k = mesh.ndim). The constraints enforce:

  1. Exactly one subset active (sum_S b_S = 1): the parameter is sharded on exactly one combination of mesh dims
  2. Consistency with factor assignments (sum_{S: m∈S} b_S = s_m): subset S is active iff the parameter is sharded on exactly those mesh dims, where s_m = sum_fi y[root_fi, m] (already 0 or 1 due to tensor exclusion)
  3. Memory contribution: each b_S contributes a precomputed ratio 1 / prod(mesh_shape[m] for m in S) — fully linear

  The final constraint matches the original: low * N <= sum_p ratio_p <= high * N, where N is the number of eligible parameters (those large enough to be fully sharded).

  For a 2D mesh, this adds 4 binary variables and 3 constraints per parameter. For a 3D mesh, 8 variables and 4 constraints. With ~10-20 parameters in a typical model, the overhead is negligible.
Summary of Changes

  File: autoparallel/optimize_sharding_new.py

  Added

  - FactorEdge dataclass — records producer→consumer factor relationships with producer_gid, consumer_gid, node indices, producer node reference, and kind
  ("spatial"/"reduction")
  - _add_reduction_propagation_constraints() — for each reduction edge (pk → ck): y[ck, m] <= y[pk, m], preventing Partial from being created from non-Partial input
  - _add_disabled_reduction_constraints() — forces disabled reduction gids to 0 across all mesh dims

  Modified

  - __init__ — replaced self.uf, self.factor_ops, _collect_factor_metadata() with self._factor_edges and self._disabled_reduction_gids
  - _build_factor_graph() — removed UF make_set/union calls; spatial matches now append FactorEdge objects; calls _record_reduction_edges() instead of
  _merge_reduction_factors()
  - _merge_reduction_factors() → _record_reduction_edges() — appends FactorEdge(kind="reduction") instead of UF union; tracks _disabled_reduction_gids when all_valid fails
  with null operand dims
  - _build_ilp() — uses all_gids (all factor keys) instead of UF roots; adds reduction propagation and disabled reduction constraints
  - _add_tensor_exclusion() — uses gid directly instead of self.uf.find(gid)
  - _add_objective() — rewritten with three components: (A) per-node compute benefit, (B) per-edge disagreement costs with linearized z-variables, (C) uncovered reduction
  exit costs
  - _get_spatial_roots_at_node() → _get_spatial_gids_at_node() — returns gids directly
  - add_node_constraint() — uses gid directly instead of UF roots
  - add_grad_param_constraints() — uses pk/gk directly; pk != gk always true without UF
  - add_parameter_memory_constraint() — uses gid directly
  - get_solution() — assignment is now dict[int, list[int]] (gid → list of mesh dims), supporting multi-dim assignments like (Shard(0), Shard(0))
  - _build_input_specs() — iterates assignment.get(gid, []) instead of single mesh dim
  - get_stats() — reports num_edges; uses len(set(self.factor_keys.values())) for unique factors
  - get_log() — reports edge count; builds reverse lookup for verbose mode instead of using self.factor_ops

  Removed

  - _collect_factor_metadata(), _unique_roots(), _compute_redistribution_bytes(), _compute_reduction_exit_info()

  Kept (unused but harmless)

  - UnionFind class definition
  - The factor uniqueness constraint is the real fix for the original view error ((S(0), S(0)) can no longer be chosen by the solver).
  - The _build_input_specs must use the consumer's own factor assignments, not the producer's output_specs. This is because apply_sharding.py relies on input_specs to tell it
   what placement the op needs its inputs in, then redistributes from curr_spec (producer output) to tgt_spec (consumer input). When I was copying producer specs,
  redistribution was skipped, leaving inputs with incompatible shardings for ops like mul in RMS norm.
Here's a summary of the changes made:

  1. Added imports from torch.distributed.tensor._collective_utils: MeshTopoInfo, allgather_cost, allreduce_cost, reduce_scatter_cost
  2. Added self.mesh_topo in __init__: MeshTopoInfo.build_from_mesh(mesh) — built once, reused throughout
  3. Replaced all comm_unit = bytes_val * (mesh_size - 1) / mesh_size / self._BW * 1e6 with the appropriate collective cost function:
    - Spatial edges: allgather_cost(bytes_gb, mesh_topo, m)
    - Reduction edges: reduce_scatter_cost(bytes_gb, mesh_topo, m) as base, with allreduce_cost(...) - reduce_scatter_cost(...) as the upgrade delta
    - Uncovered reduction exits: same pattern
  4. Removed _BW: float = 50e9 class attribute — no longer needed

  The key improvement: instead of a flat bandwidth constant across all mesh dims, the cost model now uses per-mesh-dim bandwidth and latency from MeshTopoInfo. This means the
   solver can properly distinguish between e.g. NVLink within a node (high bandwidth, low latency) vs. InfiniBand across nodes (lower bandwidth, higher latency), and make
  better sharding decisions accordingly.
  autoparallel/optimize_sharding.py — get_stats:
  - Removed print statements
  - Returns a dict with keys: num_graph_nodes, num_ilp_variables, num_ilp_constraints, mesh_shape

  autoparallel/optimize_sharding_new.py — get_stats:
  - Removed print statements
  - Removed estimated_original_ilp_variables and variable_reduction_ratio (these belong in comparison code, not in the optimizer itself)
  - Renamed keys to align with the old optimizer: num_ilp_variables, num_ilp_constraints (shared), plus factor-specific: num_unique_factors, num_edges, num_y_variables, num_z_variables
  - Updated get_log to use the renamed keys

  examples/example_autoparallel_factor.py:
  - Updated to use orig_opt.get_stats() instead of accessing orig_opt.ds and orig_opt.prob.constraints directly
  - Updated key references to match the new naming
At least 10x faster. HiGHS has a much better branch-and-bound implementation than CBC, and it's also better at exploiting LP relaxation tightening (presolve, cuts, etc.), which is exactly where the z-variable formulation was struggling
Performance: The independent optimizer is the fastest at 0.48s (vs 3.26s original, 1.32s factor-based) — a 6.8x speedup over the original.

  ILP size: 3,724 variables / 1,782 constraints (vs 23,594 / 5,113 original). Each 1D sub-problem has ~1,862 vars and 891 constraints.

  Placement differences: The independent solver overwhelmingly chooses Shard(dim=0) on both mesh dimensions (i.e., 128-way data parallelism) rather than the expected DP+TP split. This is the expected behavior of the independent decomposition:
  each 1D solver independently sees that batch-dim sharding is cheapest (zero communication), without knowing the other solver already handles batch parallelism. From a pure per-dim cost perspective, Shard(0) on tp is strictly cheaper than
  Shard(1) or Shard(2) because it avoids all-gather/reduce-scatter at matmul boundaries.

  This is the fundamental tradeoff of the independent approach — cross-dim interactions (like "dp already handles batch, so tp should handle hidden dims") are lost. Two files were created/modified:

  - New: autoparallel/optimize_sharding_independent.py — IndependentShardingOptimizer class
  - Modified: examples/example_autoparallel_factor.py — added third optimizer comparison
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant