diff --git a/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts b/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts index 59821f0c..d043e4e4 100644 --- a/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts +++ b/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts @@ -12,6 +12,7 @@ import { type SolvedTracePath, } from "../SchematicTraceLinesSolver/SchematicTraceLinesSolver" import { TraceOverlapShiftSolver } from "../TraceOverlapShiftSolver/TraceOverlapShiftSolver" +import { TraceCombineSolver } from "../TraceCombineSolver/TraceCombineSolver" import { NetLabelPlacementSolver } from "../NetLabelPlacementSolver/NetLabelPlacementSolver" import { colorAvailableNetOrientationLabels } from "./colorAvailableNetOrientationLabels" import { visualizeInputProblem } from "./visualizeInputProblem" @@ -71,6 +72,7 @@ export class SchematicTracePipelineSolver extends BaseSolver { schematicTraceLinesSolver?: SchematicTraceLinesSolver longDistancePairSolver?: LongDistancePairSolver traceOverlapShiftSolver?: TraceOverlapShiftSolver + traceCombineSolver?: TraceCombineSolver netLabelPlacementSolver?: NetLabelPlacementSolver labelMergingSolver?: MergedNetLabelObstacleSolver traceLabelOverlapAvoidanceSolver?: TraceLabelOverlapAvoidanceSolver @@ -154,19 +156,21 @@ export class SchematicTracePipelineSolver extends BaseSolver { onSolved: (_solver) => {}, }, ), + definePipelineStep("traceCombineSolver", TraceCombineSolver, () => [ + { + inputProblem: this.inputProblem, + inputTracePaths: Object.values( + this.traceOverlapShiftSolver!.correctedTraceMap, + ), + }, + ]), definePipelineStep( "netLabelPlacementSolver", NetLabelPlacementSolver, () => [ { inputProblem: this.inputProblem, - inputTraceMap: - this.traceOverlapShiftSolver?.correctedTraceMap ?? - Object.fromEntries( - this.longDistancePairSolver!.getOutput().allTracesMerged.map( - (p) => [p.mspPairId, p], - ), - ), + inputTraceMap: this.getRoutedTraceMap(), }, ], { @@ -179,13 +183,7 @@ export class SchematicTracePipelineSolver extends BaseSolver { "traceLabelOverlapAvoidanceSolver", TraceLabelOverlapAvoidanceSolver, (instance) => { - const traceMap = - instance.traceOverlapShiftSolver?.correctedTraceMap ?? - Object.fromEntries( - instance - .longDistancePairSolver!.getOutput() - .allTracesMerged.map((p) => [p.mspPairId, p]), - ) + const traceMap = instance.getRoutedTraceMap() const traces = Object.values(traceMap) const netLabelPlacements = instance.netLabelPlacementSolver!.netLabelPlacements @@ -320,6 +318,21 @@ export class SchematicTracePipelineSolver extends BaseSolver { currentPipelineStepIndex = 0 + getRoutedTraceMap(): Record { + if (this.traceCombineSolver) { + return this.traceCombineSolver.correctedTraceMap + } + if (this.traceOverlapShiftSolver) { + return this.traceOverlapShiftSolver.correctedTraceMap + } + return Object.fromEntries( + this.longDistancePairSolver!.getOutput().allTracesMerged.map((p) => [ + p.mspPairId, + p, + ]), + ) + } + private cloneAndCorrectInputProblem(original: InputProblem): InputProblem { const cloned: InputProblem = structuredClone({ ...original, diff --git a/lib/solvers/TraceCombineSolver/TraceCombineSolver.ts b/lib/solvers/TraceCombineSolver/TraceCombineSolver.ts new file mode 100644 index 00000000..1ac5a85b --- /dev/null +++ b/lib/solvers/TraceCombineSolver/TraceCombineSolver.ts @@ -0,0 +1,76 @@ +import { BaseSolver } from "lib/solvers/BaseSolver/BaseSolver" +import type { SolvedTracePath } from "lib/solvers/SchematicTraceLinesSolver/SchematicTraceLinesSolver" +import type { MspConnectionPairId } from "lib/solvers/MspConnectionPairSolver/MspConnectionPairSolver" +import { visualizeInputProblem } from "lib/solvers/SchematicTracePipelineSolver/visualizeInputProblem" +import type { InputProblem } from "lib/types/InputProblem" +import { + combineSameNetTraceSegments, + DEFAULT_COMBINE_DISTANCE, +} from "./combineSameNetTraceSegments" + +/** + * Pipeline phase that combines nearby parallel trace segments on the same net + * by snapping them to a shared X or Y coordinate. + */ +export class TraceCombineSolver extends BaseSolver { + inputProblem: InputProblem + inputTracePaths: SolvedTracePath[] + combineDistance: number + + correctedTraceMap: Record = {} + + constructor(params: { + inputProblem: InputProblem + inputTracePaths: SolvedTracePath[] + combineDistance?: number + }) { + super() + this.inputProblem = params.inputProblem + this.inputTracePaths = params.inputTracePaths + this.combineDistance = params.combineDistance ?? DEFAULT_COMBINE_DISTANCE + + for (const tracePath of this.inputTracePaths) { + this.correctedTraceMap[tracePath.mspPairId] = tracePath + } + } + + override getConstructorParams(): ConstructorParameters< + typeof TraceCombineSolver + >[0] { + return { + inputProblem: this.inputProblem, + inputTracePaths: this.inputTracePaths, + combineDistance: this.combineDistance, + } + } + + override _step() { + const combined = combineSameNetTraceSegments( + this.inputTracePaths, + this.combineDistance, + ) + + this.correctedTraceMap = Object.fromEntries( + combined.map((trace) => [trace.mspPairId, trace]), + ) + this.solved = true + } + + getOutput(): { traces: SolvedTracePath[] } { + return { traces: Object.values(this.correctedTraceMap) } + } + + override visualize() { + const graphics = visualizeInputProblem(this.inputProblem) + graphics.lines = graphics.lines || [] + + for (const trace of Object.values(this.correctedTraceMap)) { + graphics.lines.push({ + points: trace.tracePath, + strokeColor: "teal", + }) + } + + return graphics + } +} diff --git a/lib/solvers/TraceCombineSolver/combineSameNetTraceSegments.ts b/lib/solvers/TraceCombineSolver/combineSameNetTraceSegments.ts new file mode 100644 index 00000000..ddaa24f4 --- /dev/null +++ b/lib/solvers/TraceCombineSolver/combineSameNetTraceSegments.ts @@ -0,0 +1,316 @@ +import type { Point } from "@tscircuit/math-utils" +import type { SolvedTracePath } from "lib/solvers/SchematicTraceLinesSolver/SchematicTraceLinesSolver" +import { simplifyPath } from "lib/solvers/TraceCleanupSolver/simplifyPath" + +const EPS = 1e-6 +export const DEFAULT_COMBINE_DISTANCE = 0.15 + +type Orientation = "horizontal" | "vertical" + +interface SegmentRef { + traceIndex: number + segmentIndex: number + orientation: Orientation + fixedCoordinate: number + rangeStart: number + rangeEnd: number +} + +const getSegmentRef = ( + traceIndex: number, + segmentIndex: number, + start: Point, + end: Point, +): SegmentRef | null => { + if (Math.abs(start.y - end.y) < EPS) { + return { + traceIndex, + segmentIndex, + orientation: "horizontal", + fixedCoordinate: start.y, + rangeStart: Math.min(start.x, end.x), + rangeEnd: Math.max(start.x, end.x), + } + } + + if (Math.abs(start.x - end.x) < EPS) { + return { + traceIndex, + segmentIndex, + orientation: "vertical", + fixedCoordinate: start.x, + rangeStart: Math.min(start.y, end.y), + rangeEnd: Math.max(start.y, end.y), + } + } + + return null +} + +const getSegments = ( + traces: SolvedTracePath[], + traceIndex: number, + options: { includeTerminals: boolean }, +): SegmentRef[] => { + const trace = traces[traceIndex]! + const refs: SegmentRef[] = [] + const startIndex = options.includeTerminals ? 0 : 1 + const endIndex = options.includeTerminals + ? trace.tracePath.length - 1 + : trace.tracePath.length - 2 + + for (let segmentIndex = startIndex; segmentIndex < endIndex; segmentIndex++) { + const start = trace.tracePath[segmentIndex]! + const end = trace.tracePath[segmentIndex + 1]! + const ref = getSegmentRef(traceIndex, segmentIndex, start, end) + if (ref) refs.push(ref) + } + + return refs +} + +const rangesOverlap = (a: SegmentRef, b: SegmentRef) => + Math.min(a.rangeEnd, b.rangeEnd) - Math.max(a.rangeStart, b.rangeStart) > EPS + +const wouldOverlapDifferentNet = ( + traces: SolvedTracePath[], + source: SegmentRef, + fixedCoordinate: number, +) => { + for (let traceIndex = 0; traceIndex < traces.length; traceIndex++) { + const trace = traces[traceIndex]! + if (trace.globalConnNetId === traces[source.traceIndex]!.globalConnNetId) { + continue + } + + for ( + let segmentIndex = 0; + segmentIndex < trace.tracePath.length - 1; + segmentIndex++ + ) { + const ref = getSegmentRef( + traceIndex, + segmentIndex, + trace.tracePath[segmentIndex]!, + trace.tracePath[segmentIndex + 1]!, + ) + if (!ref) continue + if (ref.orientation !== source.orientation) continue + if (Math.abs(ref.fixedCoordinate - fixedCoordinate) > EPS) continue + if (rangesOverlap(source, ref)) return true + } + } + + return false +} + +const findConsolidatableSegmentPair = ( + traces: SolvedTracePath[], + indexA: number, + indexB: number, + combineDistance: number, +): SegmentRef | null => { + const traceA = traces[indexA]! + const traceB = traces[indexB]! + if (traceA.tracePath.length !== 2 || traceB.tracePath.length !== 2) { + return null + } + + const segmentA = getSegmentRef( + indexA, + 0, + traceA.tracePath[0]!, + traceA.tracePath[1]!, + ) + const segmentB = getSegmentRef( + indexB, + 0, + traceB.tracePath[0]!, + traceB.tracePath[1]!, + ) + if (!segmentA || !segmentB) return null + if (segmentA.orientation !== segmentB.orientation) return null + if ( + Math.abs(segmentA.fixedCoordinate - segmentB.fixedCoordinate) > + combineDistance + ) { + return null + } + if (!rangesOverlap(segmentA, segmentB)) return null + + return segmentA +} + +const mergeTracePair = ( + kept: SolvedTracePath, + removed: SolvedTracePath, + canonical: SegmentRef, +): SolvedTracePath => { + const tracePath = kept.tracePath.map((point) => { + const p = { ...point } + if (canonical.orientation === "horizontal") { + p.y = canonical.fixedCoordinate + } else { + p.x = canonical.fixedCoordinate + } + return p + }) + + return { + ...kept, + tracePath: simplifyPath(tracePath), + mspConnectionPairIds: [ + ...new Set([ + ...kept.mspConnectionPairIds, + ...removed.mspConnectionPairIds, + ]), + ], + pinIds: [...new Set([...kept.pinIds, ...removed.pinIds])], + } +} + +const consolidateRedundantParallelTraces = ( + traces: SolvedTracePath[], + combineDistance: number, +): SolvedTracePath[] => { + const result = traces.map((trace) => ({ + ...trace, + tracePath: trace.tracePath.map((point) => ({ ...point })), + })) + + let changed = true + while (changed) { + changed = false + + outer: for (let indexA = 0; indexA < result.length; indexA++) { + for (let indexB = indexA + 1; indexB < result.length; indexB++) { + if ( + result[indexA]!.globalConnNetId !== result[indexB]!.globalConnNetId + ) { + continue + } + + const canonical = findConsolidatableSegmentPair( + result, + indexA, + indexB, + combineDistance, + ) + if (!canonical) continue + if ( + wouldOverlapDifferentNet(result, canonical, canonical.fixedCoordinate) + ) { + continue + } + + result[indexA] = mergeTracePair( + result[indexA]!, + result[indexB]!, + canonical, + ) + result.splice(indexB, 1) + changed = true + break outer + } + } + } + + return result +} + +const snapSegmentFixedCoordinate = ( + trace: SolvedTracePath, + segmentIndex: number, + orientation: Orientation, + fixedCoordinate: number, +) => { + const tracePath = trace.tracePath.map((point) => ({ ...point })) + const start = tracePath[segmentIndex]! + const end = tracePath[segmentIndex + 1]! + + if (orientation === "horizontal") { + start.y = fixedCoordinate + end.y = fixedCoordinate + } else { + start.x = fixedCoordinate + end.x = fixedCoordinate + } + + return { + ...trace, + tracePath: simplifyPath(tracePath), + } +} + +/** Aligns and merges nearby parallel same-net trace segments onto a shared axis. */ +export const combineSameNetTraceSegments = ( + traces: SolvedTracePath[], + combineDistance = DEFAULT_COMBINE_DISTANCE, +): SolvedTracePath[] => { + const combinedTraces = traces.map((trace) => ({ + ...trace, + tracePath: trace.tracePath.map((point) => ({ ...point })), + })) + + const traceIndexesByNet = new Map() + for (let traceIndex = 0; traceIndex < combinedTraces.length; traceIndex++) { + const netId = combinedTraces[traceIndex]!.globalConnNetId + const traceIndexes = traceIndexesByNet.get(netId) ?? [] + traceIndexes.push(traceIndex) + traceIndexesByNet.set(netId, traceIndexes) + } + + for (const traceIndexes of traceIndexesByNet.values()) { + if (traceIndexes.length < 2) continue + + let changed = true + while (changed) { + changed = false + + for (const traceIndex of traceIndexes.slice(1)) { + const candidates = getSegments(combinedTraces, traceIndex, { + includeTerminals: false, + }) + + for (const candidate of candidates) { + const target = traceIndexes + .filter((targetTraceIndex) => targetTraceIndex !== traceIndex) + .flatMap((targetTraceIndex) => + getSegments(combinedTraces, targetTraceIndex, { + includeTerminals: true, + }), + ) + .find( + (other) => + other.orientation === candidate.orientation && + Math.abs(other.fixedCoordinate - candidate.fixedCoordinate) <= + combineDistance && + Math.abs(other.fixedCoordinate - candidate.fixedCoordinate) > + EPS && + rangesOverlap(candidate, other) && + !wouldOverlapDifferentNet( + combinedTraces, + candidate, + other.fixedCoordinate, + ), + ) + + if (!target) continue + + combinedTraces[traceIndex] = snapSegmentFixedCoordinate( + combinedTraces[traceIndex]!, + candidate.segmentIndex, + candidate.orientation, + target.fixedCoordinate, + ) + changed = true + break + } + + if (changed) break + } + } + } + + return consolidateRedundantParallelTraces(combinedTraces, combineDistance) +} diff --git a/tests/assets/TraceCombineSolver_repro29.input.json b/tests/assets/TraceCombineSolver_repro29.input.json new file mode 100644 index 00000000..3680f349 --- /dev/null +++ b/tests/assets/TraceCombineSolver_repro29.input.json @@ -0,0 +1,51 @@ +{ + "inputProblem": { + "chips": [ + { + "chipId": "U1", + "center": { "x": 0, "y": 0 }, + "width": 2, + "height": 2, + "pins": [ + { "pinId": "U1.1", "x": -1, "y": 0.5 }, + { "pinId": "U1.2", "x": 1, "y": 0.5 } + ] + } + ], + "directConnections": [ + { + "pinIds": ["U1.1", "U1.2"], + "netId": "SIG" + } + ], + "netConnections": [] + }, + "inputTracePaths": [ + { + "mspPairId": "trace-a", + "dcConnNetId": "net-sig", + "globalConnNetId": "net-sig", + "userNetId": "SIG", + "pins": [ + { "pinId": "U1.1", "chipId": "U1", "x": -1, "y": 0.5 }, + { "pinId": "U1.2", "chipId": "U1", "x": 1, "y": 0.5 } + ], + "tracePath": [{ "x": 0, "y": 1 }, { "x": 4, "y": 1 }], + "mspConnectionPairIds": ["trace-a"], + "pinIds": ["U1.1", "U1.2"] + }, + { + "mspPairId": "trace-b", + "dcConnNetId": "net-sig", + "globalConnNetId": "net-sig", + "userNetId": "SIG", + "pins": [ + { "pinId": "U1.1", "chipId": "U1", "x": -1, "y": 0.5 }, + { "pinId": "U1.2", "chipId": "U1", "x": 1, "y": 0.5 } + ], + "tracePath": [{ "x": 0, "y": 1.1 }, { "x": 4, "y": 1.1 }], + "mspConnectionPairIds": ["trace-b"], + "pinIds": ["U1.1", "U1.2"] + } + ] +} diff --git a/tests/solvers/TraceCombineSolver/TraceCombineSolver_repro29.test.ts b/tests/solvers/TraceCombineSolver/TraceCombineSolver_repro29.test.ts new file mode 100644 index 00000000..ccc77d86 --- /dev/null +++ b/tests/solvers/TraceCombineSolver/TraceCombineSolver_repro29.test.ts @@ -0,0 +1,22 @@ +import { expect, test } from "bun:test" +import { TraceCombineSolver } from "lib/solvers/TraceCombineSolver/TraceCombineSolver" +import input from "../../assets/TraceCombineSolver_repro29.input.json" +import "tests/fixtures/matcher" + +test("TraceCombineSolver_repro29 combines close parallel same-net traces", () => { + const solver = new TraceCombineSolver({ + inputProblem: input.inputProblem as any, + inputTracePaths: input.inputTracePaths as any, + }) + solver.solve() + + const traces = solver.getOutput().traces + expect(traces).toHaveLength(1) + expect(traces[0]!.tracePath).toEqual([ + { x: 0, y: 1 }, + { x: 4, y: 1 }, + ]) + expect(traces[0]!.mspConnectionPairIds).toEqual(["trace-a", "trace-b"]) + + expect(solver).toMatchSolverSnapshot(import.meta.path) +}) diff --git a/tests/solvers/TraceCombineSolver/__snapshots__/TraceCombineSolver_repro29.snap.svg b/tests/solvers/TraceCombineSolver/__snapshots__/TraceCombineSolver_repro29.snap.svg new file mode 100644 index 00000000..f4ed3633 --- /dev/null +++ b/tests/solvers/TraceCombineSolver/__snapshots__/TraceCombineSolver_repro29.snap.svg @@ -0,0 +1,77 @@ + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/solvers/TraceCombineSolver/combineSameNetTraceSegments.test.ts b/tests/solvers/TraceCombineSolver/combineSameNetTraceSegments.test.ts new file mode 100644 index 00000000..5bf678e7 --- /dev/null +++ b/tests/solvers/TraceCombineSolver/combineSameNetTraceSegments.test.ts @@ -0,0 +1,88 @@ +import { expect, test } from "bun:test" +import type { SolvedTracePath } from "lib/solvers/SchematicTraceLinesSolver/SchematicTraceLinesSolver" +import { combineSameNetTraceSegments } from "lib/solvers/TraceCombineSolver/combineSameNetTraceSegments" + +const makeTrace = ( + mspPairId: string, + globalConnNetId: string, + tracePath: Array<{ x: number; y: number }>, +): SolvedTracePath => + ({ + mspPairId, + dcConnNetId: globalConnNetId, + globalConnNetId, + pins: [ + { pinId: `${mspPairId}-a`, chipId: "U1", ...tracePath[0]! }, + { + pinId: `${mspPairId}-b`, + chipId: "U1", + ...tracePath[tracePath.length - 1]!, + }, + ], + tracePath, + mspConnectionPairIds: [mspPairId], + pinIds: [`${mspPairId}-a`, `${mspPairId}-b`], + }) as SolvedTracePath + +test("snaps nearby internal same-net horizontal segments onto the same y", () => { + const [first, second] = combineSameNetTraceSegments([ + makeTrace("trace-a", "net-1", [ + { x: 0, y: 0 }, + { x: 0, y: 1 }, + { x: 4, y: 1 }, + { x: 4, y: 0 }, + ]), + makeTrace("trace-b", "net-1", [ + { x: 0, y: 0.12 }, + { x: 0, y: 1.12 }, + { x: 4, y: 1.12 }, + { x: 4, y: 0.12 }, + ]), + ]) + + expect(first!.tracePath[1]!.y).toBe(1) + expect(first!.tracePath[2]!.y).toBe(1) + expect(second!.tracePath[1]!.y).toBe(1) + expect(second!.tracePath[2]!.y).toBe(1) + expect(second!.tracePath[0]!.y).toBe(0.12) + expect(second!.tracePath[3]!.y).toBe(0.12) +}) + +test("merges two close parallel same-net traces into one", () => { + const result = combineSameNetTraceSegments([ + makeTrace("trace-a", "net-1", [ + { x: 0, y: 1 }, + { x: 4, y: 1 }, + ]), + makeTrace("trace-b", "net-1", [ + { x: 0, y: 1.1 }, + { x: 4, y: 1.1 }, + ]), + ]) + + expect(result).toHaveLength(1) + expect(result[0]!.tracePath).toEqual([ + { x: 0, y: 1 }, + { x: 4, y: 1 }, + ]) +}) + +test("does not combine nearby segments from different nets", () => { + const [first, second] = combineSameNetTraceSegments([ + makeTrace("trace-a", "net-1", [ + { x: 0, y: 0 }, + { x: 0, y: 1 }, + { x: 4, y: 1 }, + { x: 4, y: 0 }, + ]), + makeTrace("trace-b", "net-2", [ + { x: 0, y: 0.12 }, + { x: 0, y: 1.12 }, + { x: 4, y: 1.12 }, + { x: 4, y: 0.12 }, + ]), + ]) + + expect(first!.tracePath[1]!.y).toBe(1) + expect(second!.tracePath[1]!.y).toBe(1.12) +})