diff --git a/lib/solvers/SameNetTraceCombiningSolver/SameNetTraceCombiningSolver.ts b/lib/solvers/SameNetTraceCombiningSolver/SameNetTraceCombiningSolver.ts new file mode 100644 index 00000000..545398cc --- /dev/null +++ b/lib/solvers/SameNetTraceCombiningSolver/SameNetTraceCombiningSolver.ts @@ -0,0 +1,286 @@ +import type { Point } from "@tscircuit/math-utils" +import type { GraphicsObject } from "graphics-debug" +import { BaseSolver } from "lib/solvers/BaseSolver/BaseSolver" +import type { SolvedTracePath } from "lib/solvers/SchematicTraceLinesSolver/SchematicTraceLinesSolver" +import { visualizeInputProblem } from "lib/solvers/SchematicTracePipelineSolver/visualizeInputProblem" +import type { InputPin, InputProblem } from "lib/types/InputProblem" +import { getColorFromString } from "lib/utils/getColorFromString" + +const AXIS_TOLERANCE = 0.01 +const MAX_MERGE_GAP = 0.15 + +type SegmentOrientation = "horizontal" | "vertical" + +type TraceSegment = { + orientation: SegmentOrientation + axis: number + start: number + end: number + sourceTrace: SolvedTracePath + sourceSegmentIndex: number +} + +type AxisGroup = { + orientation: SegmentOrientation + axis: number + segments: TraceSegment[] +} + +type MergeCluster = { + start: number + end: number + segments: TraceSegment[] +} + +export class SameNetTraceCombiningSolver extends BaseSolver { + inputProblem: InputProblem + inputTraces: SolvedTracePath[] + outputTraces: SolvedTracePath[] = [] + + constructor(params: { + inputProblem: InputProblem + inputTraces: SolvedTracePath[] + }) { + super() + this.inputProblem = params.inputProblem + this.inputTraces = params.inputTraces + this.outputTraces = params.inputTraces + } + + override getConstructorParams(): ConstructorParameters< + typeof SameNetTraceCombiningSolver + >[0] { + return { + inputProblem: this.inputProblem, + inputTraces: this.inputTraces, + } + } + + override _step() { + this.outputTraces = this.combineSameNetTraceSegments() + this.solved = true + } + + getOutput() { + return { + traces: this.outputTraces, + } + } + + private combineSameNetTraceSegments(): SolvedTracePath[] { + const tracesByNet = new Map() + + for (const trace of this.inputTraces) { + const key = trace.globalConnNetId + if (!tracesByNet.has(key)) tracesByNet.set(key, []) + tracesByNet.get(key)!.push(trace) + } + + const combinedTraces: SolvedTracePath[] = [] + + for (const [globalConnNetId, traces] of tracesByNet.entries()) { + const axisGroups: AxisGroup[] = [] + + for (const trace of traces) { + for (const segment of this.getTraceSegments(trace)) { + const axisGroup = axisGroups.find( + (group) => + group.orientation === segment.orientation && + Math.abs(group.axis - segment.axis) <= AXIS_TOLERANCE, + ) + + if (axisGroup) { + axisGroup.segments.push(segment) + axisGroup.axis = + axisGroup.segments.reduce((sum, s) => sum + s.axis, 0) / + axisGroup.segments.length + } else { + axisGroups.push({ + orientation: segment.orientation, + axis: segment.axis, + segments: [segment], + }) + } + } + } + + for (let groupIndex = 0; groupIndex < axisGroups.length; groupIndex++) { + const axisGroup = axisGroups[groupIndex]! + const clusters = this.mergeAxisGroup(axisGroup) + + for ( + let clusterIndex = 0; + clusterIndex < clusters.length; + clusterIndex++ + ) { + const cluster = clusters[clusterIndex]! + combinedTraces.push( + this.createTraceFromCluster({ + globalConnNetId, + axisGroup, + cluster, + groupIndex, + clusterIndex, + }), + ) + } + } + } + + return combinedTraces + } + + private getTraceSegments(trace: SolvedTracePath): TraceSegment[] { + const segments: TraceSegment[] = [] + + for (let i = 0; i < trace.tracePath.length - 1; i++) { + const p1 = trace.tracePath[i]! + const p2 = trace.tracePath[i + 1]! + const isHorizontal = Math.abs(p1.y - p2.y) <= AXIS_TOLERANCE + const isVertical = Math.abs(p1.x - p2.x) <= AXIS_TOLERANCE + + if (isHorizontal) { + const start = Math.min(p1.x, p2.x) + const end = Math.max(p1.x, p2.x) + if (Math.abs(end - start) <= AXIS_TOLERANCE) continue + segments.push({ + orientation: "horizontal", + axis: (p1.y + p2.y) / 2, + start, + end, + sourceTrace: trace, + sourceSegmentIndex: i, + }) + } else if (isVertical) { + const start = Math.min(p1.y, p2.y) + const end = Math.max(p1.y, p2.y) + if (Math.abs(end - start) <= AXIS_TOLERANCE) continue + segments.push({ + orientation: "vertical", + axis: (p1.x + p2.x) / 2, + start, + end, + sourceTrace: trace, + sourceSegmentIndex: i, + }) + } + } + + return segments + } + + private mergeAxisGroup(axisGroup: AxisGroup): MergeCluster[] { + const sortedSegments = [...axisGroup.segments].sort((a, b) => { + if (Math.abs(a.start - b.start) > AXIS_TOLERANCE) { + return a.start - b.start + } + return a.end - b.end + }) + const clusters: MergeCluster[] = [] + + for (const segment of sortedSegments) { + const lastCluster = clusters[clusters.length - 1] + + if (!lastCluster || segment.start - lastCluster.end > MAX_MERGE_GAP) { + clusters.push({ + start: segment.start, + end: segment.end, + segments: [segment], + }) + continue + } + + lastCluster.end = Math.max(lastCluster.end, segment.end) + lastCluster.segments.push(segment) + } + + return clusters + } + + private createTraceFromCluster(params: { + globalConnNetId: string + axisGroup: AxisGroup + cluster: MergeCluster + groupIndex: number + clusterIndex: number + }): SolvedTracePath { + const { globalConnNetId, axisGroup, cluster, groupIndex, clusterIndex } = + params + const sourceTraces = cluster.segments.map((s) => s.sourceTrace) + const representative = sourceTraces[0]! + const mspConnectionPairIds = Array.from( + new Set( + sourceTraces.flatMap( + (trace) => trace.mspConnectionPairIds ?? [trace.mspPairId], + ), + ), + ) + const pinIds = Array.from( + new Set(sourceTraces.flatMap((trace) => trace.pinIds ?? [])), + ) + const pins = this.getRepresentativePins(sourceTraces) + const startPoint = this.getPoint(axisGroup, cluster.start) + const endPoint = this.getPoint(axisGroup, cluster.end) + + return { + ...representative, + mspPairId: + mspConnectionPairIds.length === 1 + ? mspConnectionPairIds[0]! + : `same-net-combined-${globalConnNetId}-${axisGroup.orientation}-${groupIndex}-${clusterIndex}`, + dcConnNetId: representative.dcConnNetId, + globalConnNetId, + userNetId: representative.userNetId, + pins, + tracePath: [startPoint, endPoint], + mspConnectionPairIds, + pinIds, + } + } + + private getRepresentativePins( + sourceTraces: SolvedTracePath[], + ): SolvedTracePath["pins"] { + const pinsById = new Map() + + for (const trace of sourceTraces) { + for (const pin of trace.pins) { + pinsById.set(pin.pinId, pin) + } + } + + const pins = Array.from(pinsById.values()) + if (pins.length >= 2) return [pins[0]!, pins[pins.length - 1]!] + if (pins.length === 1) return [pins[0]!, pins[0]!] + + const fallback = sourceTraces[0]!.pins + return [fallback[0]!, fallback[1]!] as [ + InputPin & { chipId: string }, + InputPin & { chipId: string }, + ] + } + + private getPoint(axisGroup: AxisGroup, value: number): Point { + if (axisGroup.orientation === "horizontal") { + return { x: value, y: axisGroup.axis } + } + + return { x: axisGroup.axis, y: value } + } + + override visualize(): GraphicsObject { + const graphics = visualizeInputProblem(this.inputProblem, { + chipAlpha: 0.1, + connectionAlpha: 0.1, + }) + + for (const trace of this.outputTraces) { + graphics.lines!.push({ + points: trace.tracePath, + strokeColor: getColorFromString(trace.globalConnNetId, 0.9), + }) + } + + return graphics + } +} diff --git a/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts b/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts index 59821f0c..b4359cd6 100644 --- a/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts +++ b/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts @@ -26,6 +26,7 @@ import { AvailableNetOrientationSolver } from "../AvailableNetOrientationSolver/ import { VccNetLabelCornerPlacementSolver } from "../VccNetLabelCornerPlacementSolver/VccNetLabelCornerPlacementSolver" import { TraceAnchoredNetLabelOverlapSolver } from "../TraceAnchoredNetLabelOverlapSolver/TraceAnchoredNetLabelOverlapSolver" import { NetLabelTraceCollisionSolver } from "../NetLabelTraceCollisionSolver/NetLabelTraceCollisionSolver" +import { SameNetTraceCombiningSolver } from "../SameNetTraceCombiningSolver/SameNetTraceCombiningSolver" type PipelineStep BaseSolver> = { solverName: string @@ -71,6 +72,7 @@ export class SchematicTracePipelineSolver extends BaseSolver { schematicTraceLinesSolver?: SchematicTraceLinesSolver longDistancePairSolver?: LongDistancePairSolver traceOverlapShiftSolver?: TraceOverlapShiftSolver + sameNetTraceCombiningSolver?: SameNetTraceCombiningSolver netLabelPlacementSolver?: NetLabelPlacementSolver labelMergingSolver?: MergedNetLabelObstacleSolver traceLabelOverlapAvoidanceSolver?: TraceLabelOverlapAvoidanceSolver @@ -154,19 +156,29 @@ export class SchematicTracePipelineSolver extends BaseSolver { onSolved: (_solver) => {}, }, ), + definePipelineStep( + "sameNetTraceCombiningSolver", + SameNetTraceCombiningSolver, + (instance) => [ + { + inputProblem: instance.inputProblem, + inputTraces: Object.values( + instance.traceOverlapShiftSolver!.correctedTraceMap, + ), + }, + ], + ), definePipelineStep( "netLabelPlacementSolver", NetLabelPlacementSolver, () => [ { inputProblem: this.inputProblem, - inputTraceMap: - this.traceOverlapShiftSolver?.correctedTraceMap ?? - Object.fromEntries( - this.longDistancePairSolver!.getOutput().allTracesMerged.map( - (p) => [p.mspPairId, p], - ), - ), + inputTraceMap: Object.fromEntries( + this.sameNetTraceCombiningSolver! + .getOutput() + .traces.map((p) => [p.mspPairId, p]), + ), }, ], { @@ -179,14 +191,7 @@ export class SchematicTracePipelineSolver extends BaseSolver { "traceLabelOverlapAvoidanceSolver", TraceLabelOverlapAvoidanceSolver, (instance) => { - const traceMap = - instance.traceOverlapShiftSolver?.correctedTraceMap ?? - Object.fromEntries( - instance - .longDistancePairSolver!.getOutput() - .allTracesMerged.map((p) => [p.mspPairId, p]), - ) - const traces = Object.values(traceMap) + const traces = instance.sameNetTraceCombiningSolver!.getOutput().traces const netLabelPlacements = instance.netLabelPlacementSolver!.netLabelPlacements diff --git a/tests/solvers/SameNetTraceCombiningSolver/SameNetTraceCombiningSolver.test.ts b/tests/solvers/SameNetTraceCombiningSolver/SameNetTraceCombiningSolver.test.ts new file mode 100644 index 00000000..f1093854 --- /dev/null +++ b/tests/solvers/SameNetTraceCombiningSolver/SameNetTraceCombiningSolver.test.ts @@ -0,0 +1,111 @@ +import { expect, test } from "bun:test" +import { SameNetTraceCombiningSolver } from "lib/solvers/SameNetTraceCombiningSolver/SameNetTraceCombiningSolver" +import type { SolvedTracePath } from "lib/solvers/SchematicTraceLinesSolver/SchematicTraceLinesSolver" +import type { InputProblem } from "lib/types/InputProblem" + +const inputProblem: InputProblem = { + chips: [], + directConnections: [], + netConnections: [], + availableNetLabelOrientations: {}, +} + +const makePin = (pinId: string, x: number, y: number) => ({ + pinId, + chipId: "chip", + x, + y, +}) + +const makeTrace = ( + mspPairId: string, + globalConnNetId: string, + tracePath: Array<{ x: number; y: number }>, +): SolvedTracePath => ({ + mspPairId, + dcConnNetId: globalConnNetId, + globalConnNetId, + userNetId: globalConnNetId, + pins: [ + makePin(`${mspPairId}.1`, tracePath[0]!.x, tracePath[0]!.y), + makePin( + `${mspPairId}.2`, + tracePath[tracePath.length - 1]!.x, + tracePath[tracePath.length - 1]!.y, + ), + ], + tracePath, + mspConnectionPairIds: [mspPairId], + pinIds: [`${mspPairId}.1`, `${mspPairId}.2`], +}) + +test("SameNetTraceCombiningSolver combines adjacent same-net horizontal segments", () => { + const traceA = makeTrace("trace-a", "net-1", [ + { x: -1, y: 0 }, + { x: -1, y: 1.5 }, + { x: 0, y: 1.5 }, + ]) + const traceB = makeTrace("trace-b", "net-1", [ + { x: 0, y: 1.5 }, + { x: 1, y: 1.5 }, + { x: 1, y: 0 }, + ]) + + const solver = new SameNetTraceCombiningSolver({ + inputProblem, + inputTraces: [traceA, traceB], + }) + solver.solve() + + const horizontalTraces = solver + .getOutput() + .traces.filter((trace) => trace.tracePath[0]!.y === trace.tracePath[1]!.y) + + expect(horizontalTraces).toHaveLength(1) + expect(horizontalTraces[0]!.tracePath).toEqual([ + { x: -1, y: 1.5 }, + { x: 1, y: 1.5 }, + ]) + expect(horizontalTraces[0]!.mspConnectionPairIds).toEqual([ + "trace-a", + "trace-b", + ]) +}) + +test("SameNetTraceCombiningSolver combines close vertical segments but not different nets", () => { + const traceA = makeTrace("trace-a", "net-1", [ + { x: 2, y: 0 }, + { x: 2, y: 1 }, + ]) + const traceB = makeTrace("trace-b", "net-1", [ + { x: 2.005, y: 1.1 }, + { x: 2.005, y: 2 }, + ]) + const traceC = makeTrace("trace-c", "net-2", [ + { x: 2, y: 0 }, + { x: 2, y: 1 }, + ]) + + const solver = new SameNetTraceCombiningSolver({ + inputProblem, + inputTraces: [traceA, traceB, traceC], + }) + solver.solve() + + const traces = solver.getOutput().traces + const net1Vertical = traces.find( + (trace) => + trace.globalConnNetId === "net-1" && + trace.mspConnectionPairIds.length === 2, + ) + const net2Vertical = traces.find((trace) => trace.globalConnNetId === "net-2") + + expect(net1Vertical?.tracePath).toEqual([ + { x: 2.0025, y: 0 }, + { x: 2.0025, y: 2 }, + ]) + expect(net2Vertical?.tracePath).toEqual([ + { x: 2, y: 0 }, + { x: 2, y: 1 }, + ]) +})